From c02aac47a3fb58b812aa01b5d4d98016a6c4fb6f Mon Sep 17 00:00:00 2001 From: Jeff Kim Date: Tue, 23 Sep 2025 11:54:40 -0700 Subject: [PATCH] Implement tensor padding for local shards wrapper (#3382) Summary: X-link: https://github.com/pytorch/pytorch/pull/163183 This diff implements the constant padding functionality (aten.constant_pad_nd.default) for `LocalShardsWrapper`. The method applies constant padding to the local shards based on the provided padding specification. Depending on the sharding type (RW, CW), the padding on [left, right, top, bottom] directions will be either applied to the first/last shard, or all local shards. New unit tests cover: - 1D (RW) top/bottom paddings - 2D (CW) left, right, top, bottom paddings - empty shards, number of dimensions > 2 Differential Revision: D82663766 --- torchrec/distributed/shards_wrapper.py | 222 ++++++++++- .../distributed/tests/test_shards_wrapper.py | 357 ++++++++++++++++++ 2 files changed, 573 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/shards_wrapper.py b/torchrec/distributed/shards_wrapper.py index 6475452f7..d4ea4a726 100644 --- a/torchrec/distributed/shards_wrapper.py +++ b/torchrec/distributed/shards_wrapper.py @@ -9,6 +9,7 @@ # COPY of the code from torch.distributed._tensor._shards_wrapper - for package compat +import logging from typing import Any, List, Tuple import torch @@ -24,6 +25,7 @@ WriteItemType, ) +logger: logging.Logger = logging.getLogger(__name__) aten = torch.ops.aten # pyre-ignore[5] @@ -73,7 +75,7 @@ def __new__( cat_tensor_shape[1] += shard.size()[1] # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension - if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + if len(local_shards) > 1 and local_shards[0].ndim == 1: # row-wise sharding for shard in local_shards[1:]: cat_tensor_shape[0] += shard.size()[0] @@ -119,6 +121,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): aten.copy_.default: cls.handle_copy_, aten.zeros_like.default: cls.handle_zeros_like, aten.empty_like.default: cls.handle_empty_like, + aten.constant_pad_nd.default: cls.handle_constant_pad_nd, } if func in dispatcher: @@ -162,12 +165,14 @@ def handle_copy_(args, kwargs): # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def handle_all_gather_into_tensor(args, kwargs): - dim = args[0].local_sizes()[0][1] - cat_tensor = torch.cat( - [t.view(-1) for t in args[0].local_shards()], dim=0 - ).view(-1, dim) + local_shards = args[0].local_shards() + if len(local_shards) == 1: + result_tensor = local_shards[0] + # 2D CW sharding: concat columns, 1D RW sharding: concat rows + result_tensor = torch.cat(local_shards, dim=-1) + logger.info(f"resulting tensor before all gather: {result_tensor}") return torch.ops._c10d_functional.all_gather_into_tensor.default( - cat_tensor, *args[1:], **kwargs + result_tensor, *args[1:], **kwargs ) @staticmethod @@ -279,6 +284,211 @@ def handle_new_empty(args, kwargs): self_ls.local_offsets(), ) + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_constant_pad_nd(args, kwargs): + """ + Apply constant padding to LocalShardsWrapper. + + The padding is based off of the following ideas: + - The resulting wrapper represents the padded version of the logical tensor. + - Each shard is padded based on the sharding type + dimension that is padded. + - For instance, CW shards padded on the left most col will have only padding on the first CW shard. + - Padding the top row will apply to all CW shards. + """ + self_lsw = args[0] + pad_spec = args[1] + pad_value = args[2] if len(args) > 2 else 0.0 + logger.info( + f"padding {self_lsw} with {pad_spec} and value: {pad_value}, current shards: {self_lsw.local_shards()} with offsets: {self_lsw.local_offsets()}. tensor storage metadata: {self_lsw.storage_metadata()}" + ) + + if len(self_lsw.local_shards()) == 0: + raise NotImplementedError( + "Padding empty LocalShardsWrapper is not supported." + ) + + local_shards = self_lsw.local_shards() + + if len(local_shards) == 1: + padded_shard = torch.nn.functional.pad( + local_shards[0], pad_spec, mode="constant", value=pad_value + ) + return LocalShardsWrapper([padded_shard], self_lsw.local_offsets()) + + padded_shards = list(local_shards) + + if local_shards[0].ndim == 2: + # 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom] + if len(pad_spec) == 2: + # Single dimension padding happens on the left most column + pad_spec = pad_spec + [0, 0] + + if len(pad_spec) != 4: + raise ValueError( + f"Padding spec must be of length 4 for 2D tensors, got {len(pad_spec)}" + ) + + pad_left, pad_right, pad_top, pad_bottom = ( + pad_spec[0], + pad_spec[1], + pad_spec[2], + pad_spec[3], + ) + + if pad_top > 0: + padded_shards = [ + torch.nn.functional.pad( + shard, [0, 0, pad_top, 0], mode="constant", value=pad_value + ) + for shard in padded_shards + ] + if pad_bottom > 0: + padded_shards = [ + torch.nn.functional.pad( + shard, [0, 0, 0, pad_bottom], mode="constant", value=pad_value + ) + for shard in padded_shards + ] + if pad_left > 0: + padded_shards[0] = torch.nn.functional.pad( + padded_shards[0], + [pad_left, 0, 0, 0], + mode="constant", + value=pad_value, + ) + if pad_right > 0: + padded_shards[-1] = torch.nn.functional.pad( + padded_shards[-1], + [0, pad_right, 0, 0], + mode="constant", + value=pad_value, + ) + elif local_shards[0].ndim == 1: + # 1D Row-wise sharding: [pad_top, pad_bottom] + if len(pad_spec) != 2: + raise ValueError( + f"Padding spec must be of length 2 for 1D tensors, got {len(pad_spec)}" + ) + pad_top, pad_bottom = pad_spec[0], pad_spec[1] + + if pad_top > 0: + padded_shards[0] = torch.nn.functional.pad( + padded_shards[0], [pad_top, 0], mode="constant", value=pad_value + ) + if pad_bottom > 0: + padded_shards[-1] = torch.nn.functional.pad( + padded_shards[-1], [0, pad_bottom], mode="constant", value=pad_value + ) + else: + raise NotImplementedError( + f"Padding for {local_shards[0].ndim}D tensors is not supported. " + f"Only 1D and 2D tensors are currently supported." + ) + + # Update offsets and storage metadata + original_storage = self_lsw.storage_metadata() + updated_offsets, updated_storage = LocalShardsWrapper._compute_updated_metadata( + original_storage, + self_lsw.local_offsets(), + pad_spec, + local_shards[0].ndim, + padded_shards, + ) + + result = LocalShardsWrapper(padded_shards, updated_offsets) + result._storage_meta = updated_storage + return result + + @staticmethod + def _compute_updated_metadata( + original_storage: TensorStorageMetadata, + original_offsets: list[torch.Size], + pad_spec: list[int], + ndim: int, + padded_shards: list[torch.Tensor], + ) -> tuple[list[tuple[int, ...]], TensorStorageMetadata]: + """ + Compute updated offsets and storage metadata after padding is applied. + + Args: + original_storage: Original storage metadata + original_offsets: Original shard offsets + pad_spec: Padding specification + ndim: Number of dimensions (1=RW or 2=CW) + padded_shards: Padded shard tensors + + Returns: + Tuple of (updated_offsets, updated_storage_metadata) + """ + if ndim == 1: # 1D RW + pad_top, pad_bottom = pad_spec[0], pad_spec[1] + + updated_offsets = [] + for i, offset in enumerate(original_offsets): + if i == 0: + # First shard: offset stays the same (absorbs top padding) + updated_offsets.append(tuple(offset)) + else: + # Subsequent shards: shift by top padding amount + new_offset = (offset[0] + pad_top,) + updated_offsets.append(new_offset) + + new_global_size = torch.Size( + [original_storage.size[0] + pad_top + pad_bottom] + ) + + elif ndim == 2: # 2D CW + pad_left, pad_right, pad_top, pad_bottom = ( + pad_spec[0], + pad_spec[1], + pad_spec[2], + pad_spec[3], + ) + + updated_offsets = [] + for i, offset in enumerate(original_offsets): + row_offset = offset[0] + col_offset = offset[1] + + # Top/bottom padding doesn't affect offsets + # Left padding affects column offsets + if i == 0: + # First shard: column offset stays the same (absorbs left padding) + new_2d_offset = (row_offset, col_offset) + else: + # Subsequent shards: shift column offset by left padding amount + new_2d_offset = (row_offset, col_offset + pad_left) + + updated_offsets.append(new_2d_offset) + + new_global_size = torch.Size( + [ + original_storage.size[0] + pad_top + pad_bottom, + original_storage.size[1] + pad_left + pad_right, + ] + ) + + else: + raise NotImplementedError(f"Metadata computation for {ndim}D not supported") + + updated_chunks = [ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=shard.size(), + ) + for offset, shard in zip(updated_offsets, padded_shards) + ] + + updated_storage = TensorStorageMetadata( + properties=original_storage.properties, + size=new_global_size, + chunks=updated_chunks, + ) + + return updated_offsets, updated_storage + @property def device(self) -> torch._C.device: # type: ignore[override] return ( diff --git a/torchrec/distributed/tests/test_shards_wrapper.py b/torchrec/distributed/tests/test_shards_wrapper.py index 7199552dd..a93026998 100644 --- a/torchrec/distributed/tests/test_shards_wrapper.py +++ b/torchrec/distributed/tests/test_shards_wrapper.py @@ -94,6 +94,363 @@ def all_gather_object( ) +class LocalShardsWrapperPaddingTest(unittest.TestCase): + """Test cases for constant padding functionality in LocalShardsWrapper.""" + + def test_empty_shards_padding(self) -> None: + """Test padding with empty shards list.""" + lsw = LocalShardsWrapper([], []) + pad_spec = [1, 2, 3, 4] + pad_value = 5.0 + + self.assertRaises( + Exception, + torch.ops.aten.constant_pad_nd.default, + lsw, + pad_spec, + pad_value, + ) + + def test_invalid_1d_rw_padding(self) -> None: + """Test invalid padding on 1D tensor throws ValueError.""" + shard1 = torch.tensor([1.0, 2.0]) + shard2 = torch.tensor([3.0, 4.0]) + lsw = LocalShardsWrapper([shard1, shard2], [(2, 0)]) + pad_spec = [1] # invalid padding spec + pad_value = 5.0 + + self.assertRaises( + ValueError, + torch.ops.aten.constant_pad_nd.default, + lsw, + pad_spec, + pad_value, + ) + + def test_invalid_2d_cw_padding(self) -> None: + """Test invalid padding on 2D tensor throws ValueError.""" + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) + pad_spec = [1, 2, 3] # invalid padding spec + pad_value = 5.0 + + self.assertRaises( + ValueError, + torch.ops.aten.constant_pad_nd.default, + lsw, + pad_spec, + pad_value, + ) + + pad_spec = [1] + + self.assertRaises( + ValueError, + torch.ops.aten.constant_pad_nd.default, + lsw, + pad_spec, + pad_value, + ) + + def test_single_shard_padding_2d(self) -> None: + """Test padding with single 2D shard.""" + tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + lsw = LocalShardsWrapper([tensor], [(0, 0)]) + pad_spec = [1, 2, 3, 4] # [left=1, right=2, top=3, bottom=4] + pad_value = 0.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + expected = torch.nn.functional.pad( + tensor, pad_spec, mode="constant", value=pad_value + ) + self.assertIsInstance(result, LocalShardsWrapper) + self.assertEqual(len(result.local_shards()), 1) + torch.testing.assert_close(result.local_shards()[0], expected) + + def test_single_shard_padding_1d(self) -> None: + """Test padding with single 1D shard.""" + tensor = torch.tensor([1.0, 2.0, 3.0]) + lsw = LocalShardsWrapper([tensor], [(0,)]) + pad_spec = [2, 1] # [top=2, bottom=1] + pad_value = -1.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + self.assertIsInstance(result, LocalShardsWrapper) + self.assertEqual(len(result.local_shards()), 1) + + expected = torch.nn.functional.pad( + tensor, pad_spec, mode="constant", value=pad_value + ) + torch.testing.assert_close(result.local_shards()[0], expected) + + def test_2d_cw_sharding_top_padding(self) -> None: + """Test column-wise sharding with top padding (affects all shards).""" + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) + pad_spec = [0, 0, 2, 0] # top=2 + pad_value = 0.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + self.assertEqual(len(result.local_shards()), 2) + # Both shards should have 2 rows added at top + expected_shape = (4, 2) + self.assertEqual(result.local_shards()[0].shape, expected_shape) + self.assertEqual(result.local_shards()[1].shape, expected_shape) + + torch.testing.assert_close(result.local_shards()[0][:2], torch.zeros(2, 2)) + torch.testing.assert_close(result.local_shards()[1][:2], torch.zeros(2, 2)) + torch.testing.assert_close(result.local_shards()[0][2:], shard1) + torch.testing.assert_close(result.local_shards()[1][2:], shard2) + + def test_2d_cw_sharding_bottom_padding(self) -> None: + """Test column-wise sharding with bottom padding (affects all shards).""" + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) + pad_spec = [0, 0, 0, 1] # bottom=1 + pad_value = -1.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + self.assertEqual(len(result.local_shards()), 2) + expected_shape = (3, 2) + self.assertEqual(result.local_shards()[0].shape, expected_shape) + self.assertEqual(result.local_shards()[1].shape, expected_shape) + + torch.testing.assert_close(result.local_shards()[0][:2], shard1) + torch.testing.assert_close(result.local_shards()[1][:2], shard2) + torch.testing.assert_close( + result.local_shards()[0][2:], torch.full((1, 2), -1.0) + ) + torch.testing.assert_close( + result.local_shards()[1][2:], torch.full((1, 2), -1.0) + ) + + def test_2d_cw_sharding_left_padding(self) -> None: + """Test column-wise sharding with left padding (affects first shard only).""" + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) + pad_spec = [3, 0, 0, 0] # left=3 + pad_value = 2.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + self.assertEqual(len(result.local_shards()), 2) + # First shard should have 3 columns added at left + self.assertEqual(result.local_shards()[0].shape, (2, 5)) + self.assertEqual(result.local_shards()[1].shape, (2, 2)) + + # Check content + torch.testing.assert_close( + result.local_shards()[0][:, :3], torch.full((2, 3), 2.0) + ) + torch.testing.assert_close(result.local_shards()[0][:, 3:], shard1) + torch.testing.assert_close(result.local_shards()[1], shard2) + + def test_2d_cw_sharding_right_padding(self) -> None: + """Test column-wise sharding with right padding (affects last shard only).""" + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) + pad_spec = [0, 2, 0, 0] # right=2 + pad_value = 3.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + # Second shard should have 2 columns added at right + expected_shard_1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) + expected_shard_2 = torch.tensor([[3.0, 4.0, 3.0, 3.0], [7.0, 8.0, 3.0, 3.0]]) + self.assertEqual(len(result.local_shards()), 2) + torch.testing.assert_close(result.local_shards()[0], expected_shard_1) + torch.testing.assert_close(result.local_shards()[1], expected_shard_2) + + # 1D padding on 2D pads the last dimension + pad_spec_2 = [0, 2] # right=2 + result_2 = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec_2, pad_value) + torch.testing.assert_close(result_2.local_shards()[0], expected_shard_1) + torch.testing.assert_close(result_2.local_shards()[1], expected_shard_2) + + def test_2d_cw_sharding_mixed_padding(self) -> None: + """Test column-wise sharding with mixed padding directions.""" + shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) + shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) + lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)]) + pad_spec = [1, 2, 1, 1] # [left=1, right=2, top=1, bottom=1] + pad_value = 0.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + expected_shard_1 = torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 1.0, 2.0], [0.0, 5.0, 6.0], [0.0, 0.0, 0.0]], + ) + + expected_shard_2 = torch.tensor( + [ + [0.0, 0.0, 0.0, 0.0], + [3.0, 4.0, 0.0, 0.0], + [7.0, 8.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + ) + + self.assertEqual(len(result.local_shards()), 2) + torch.testing.assert_close(result.local_shards()[0], expected_shard_1) + torch.testing.assert_close(result.local_shards()[1], expected_shard_2) + + def test_1d_rw_sharding_top_padding(self) -> None: + """Test row-wise sharding with top padding (affects first shard only).""" + shard1 = torch.tensor([1.0, 2.0, 3.0]) + shard2 = torch.tensor([4.0, 5.0, 6.0]) + lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)]) + pad_spec = [2, 0] # top=2 + pad_value = 0.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + expected_shard_1 = torch.tensor( + [0.0, 0.0, 1.0, 2.0, 3.0], + ) + expected_shard_2 = torch.tensor( + [4.0, 5.0, 6.0], + ) + + self.assertEqual(len(result.local_shards()), 2) + torch.testing.assert_close(result.local_shards()[0], expected_shard_1) + torch.testing.assert_close(result.local_shards()[1], expected_shard_2) + + def test_1d_rw_sharding_bottom_padding(self) -> None: + """Test row-wise sharding with bottom padding (affects last shard only).""" + shard1 = torch.tensor([1.0, 2.0, 3.0]) + shard2 = torch.tensor([4.0, 5.0, 6.0]) + lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)]) + pad_spec = [0, 1] # bottom=1 + pad_value = -1.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + expected_shard_1 = torch.tensor( + [1.0, 2.0, 3.0], + ) + expected_shard_2 = torch.tensor( + [4.0, 5.0, 6.0, -1.0], + ) + + self.assertEqual(len(result.local_shards()), 2) + torch.testing.assert_close(result.local_shards()[0], expected_shard_1) + torch.testing.assert_close(result.local_shards()[1], expected_shard_2) + + def test_1d_rw_sharding_mixed_padding(self) -> None: + """Test row-wise sharding with mixed top/bottom padding.""" + shard1 = torch.tensor([1.0, 2.0]) + shard2 = torch.tensor([3.0, 4.0]) + lsw = LocalShardsWrapper([shard1, shard2], [(0,), (2,)]) + pad_spec = [1, 2] # [top=1, bottom=2] + pad_value = 5.0 + + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + expected_shard_1 = torch.tensor( + [5.0, 1.0, 2.0], + ) + expected_shard_2 = torch.tensor( + [3.0, 4.0, 5.0, 5.0], + ) + + self.assertEqual(len(result.local_shards()), 2) + torch.testing.assert_close(result.local_shards()[0], expected_shard_1) + torch.testing.assert_close(result.local_shards()[1], expected_shard_2) + + def test_higher_dimensions_not_implemented(self) -> None: + """Test that higher dimensional tensors raise NotImplementedError.""" + tensor_3d = torch.rand(2, 3, 4) # 3D tensor + lsw = LocalShardsWrapper([tensor_3d, tensor_3d], [(0, 0, 0), (2, 0, 0)]) + pad_spec = [1, 1, 1, 1, 1, 1] # 3D padding spec + pad_value = 0.0 + + with self.assertRaises(NotImplementedError) as cm: + torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value) + + self.assertIn("3D tensors is not supported", str(cm.exception)) + self.assertIn( + "Only 1D and 2D tensors are currently supported", str(cm.exception) + ) + + def test_offsets_and_storage_metadata_after_padding_1d_rw(self) -> None: + # Test 1D RW sharding with top+bottom padding + shard1 = torch.tensor([1.0, 2.0]) + shard2 = torch.tensor([3.0, 4.0]) + original_offsets = [(0,), (2,)] + lsw = LocalShardsWrapper([shard1, shard2], original_offsets) + + # Check original storage metadata + original_storage = lsw.storage_metadata() + self.assertEqual(original_storage.size, torch.Size([4])) # [1,2,3,4] + self.assertEqual(len(original_storage.chunks), 2) + self.assertEqual(original_storage.chunks[0].offsets, torch.Size([0])) + self.assertEqual(original_storage.chunks[0].sizes, torch.Size([2])) + self.assertEqual(original_storage.chunks[1].offsets, torch.Size([2])) + self.assertEqual(original_storage.chunks[1].sizes, torch.Size([2])) + + pad_spec = [1, 1] # add 1 element at top and bottom + result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, 0.0) + + expected_offsets = [ + torch.Size([0]), + torch.Size([3]), + ] # Second shard's offset shifted by 1 + self.assertEqual(result.local_offsets(), expected_offsets) + + result_storage = result.storage_metadata() + + # Global tensor should be: [0, 1, 2, 3, 4, 0] shape=[6] + expected_global_size = torch.Size([6]) + self.assertEqual(result_storage.size, expected_global_size) + + self.assertEqual(len(result_storage.chunks), 2) + + # First chunk: [3] elements at offset [0] (size increased by top padding) + # Second chunk: [3] elements at offset [3] (size increased by bottom padding, offset shifted) + self.assertEqual(result_storage.chunks[0].offsets, torch.Size([0])) + self.assertEqual(result_storage.chunks[0].sizes, torch.Size([3])) + self.assertEqual(result_storage.chunks[1].offsets, torch.Size([3])) + self.assertEqual(result_storage.chunks[1].sizes, torch.Size([3])) + + def test_offsets_and_storage_metadata_after_padding_2d_cw(self) -> None: + # Test 2D CW sharding with left+right padding + shard1_2d = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) # [2, 2] columns 0-1 + shard2_2d = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) # [2, 2] columns 2-3 + original_offsets_2d = [(0, 0), (0, 2)] + lsw_2d = LocalShardsWrapper([shard1_2d, shard2_2d], original_offsets_2d) + + pad_spec_2d = [1, 1, 0, 0] # [left=1, right=1, top=0, bottom=0] + result_2d = torch.ops.aten.constant_pad_nd.default(lsw_2d, pad_spec_2d, 0.0) + + expected_offsets_2d = [ + torch.Size([0, 0]), + torch.Size([0, 3]), + ] + self.assertEqual(result_2d.local_offsets(), expected_offsets_2d) + + result_storage_2d = result_2d.storage_metadata() + + # Global tensor should go from [2,4] to [2,6] (add 1 left + 1 right) + expected_global_size_2d = torch.Size([2, 6]) # [2, 4+1+1] + self.assertEqual(result_storage_2d.size, expected_global_size_2d) + + # First chunk: [2,3] at offset [0,0] (size increased by left padding) + # Second chunk: [2,3] at offset [0,3] (size increased by right padding, offset shifted) + self.assertEqual(result_storage_2d.chunks[0].offsets, torch.Size([0, 0])) + self.assertEqual(result_storage_2d.chunks[0].sizes, torch.Size([2, 3])) + self.assertEqual(result_storage_2d.chunks[1].offsets, torch.Size([0, 3])) + self.assertEqual(result_storage_2d.chunks[1].sizes, torch.Size([2, 3])) + + @skip_if_asan_class class LocalShardsWrapperDistributedTest(MultiProcessTestBase): @seed_and_log