diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index a449fa614..cbcd0d78b 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -1944,6 +1944,10 @@ def __init__( assert ( config.is_using_virtual_table ), "Try to create ZeroCollisionKeyValueEmbedding for non virtual tables" + assert embedding_cache_mode == config.enable_embedding_update, ( + f"Embedding_cache kernel is {embedding_cache_mode} " + f"but embedding config has enable_embedding_update {config.enable_embedding_update}" + ) for table in config.embedding_tables: assert table.local_cols % 4 == 0, ( f"table {table.name} has local_cols={table.local_cols} " diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index 87352faca..b58ca1b12 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -368,8 +368,12 @@ def __init__( # https://github.com/pytorch/pytorch/issues/122788 with record_function(f"## all2all_data:kjt {label} ##"): if self._pg._get_backend_name() == "custom": + if input_tensor.dim() == 2: + output_size = [sum(output_split), input_tensor.size(1)] + else: + output_size = [sum(output_split)] output_tensor = torch.empty( - sum(output_split), + output_size, device=self._device, dtype=input_tensor.dtype, ) @@ -391,8 +395,12 @@ def __init__( ) self._output_tensors.append(output_tensor) else: + if input_tensor.dim() == 2: + output_size = [sum(output_split), input_tensor.size(1)] + else: + output_size = [sum(output_split)] output_tensor = torch.empty( - sum(output_split), device=self._device, dtype=input_tensor.dtype + output_size, device=self._device, dtype=input_tensor.dtype ) with record_function(f"## all2all_data:kjt {label} ##"): awaitable = dist.all_to_all_single( @@ -542,6 +550,113 @@ def _wait_impl(self) -> KJTAllToAllTensorsAwaitable: ) +class KJEAllToAll(nn.Module): + """ + Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits. + + Implementation utilizes AlltoAll collective as part of torch.distributed. + + The input provides the necessary tensors, embedding weights and input splits to distribute. + The first collective call in `KJTAllToAllSplitsAwaitable` will transmit output + splits (to allocate correct space for tensors) and batch size per rank. The + following collective calls in `KJTAllToAllTensorsAwaitable` will transmit the actual + tensors asynchronously. + This module is used for embedding updates wherein input KJT weights are updated into the embedding tables. + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + splits (List[int]): List of len(pg.size()) which indicates how many features to + send to each pg.rank(). It is assumed the `KeyedJaggedTensor` is ordered by + destination rank. Same for all ranks. + stagger (int): stagger value to apply to recat tensor, see `_get_recat` function + for more detail. + + Example:: + + keys=['A','B','C'] + splits=[2,1] + kjeA2A = KJEAllToAll(pg, splits) + awaitable = kjeA2A(rank0_input) + + # where: + # rank0_input is KeyedJaggedTensor holding + + # 0 1 2 + # 'A' [A.V0] None [A.V1, A.V2] + # 'B' None [B.V0] [B.V1] + # 'C' [C.V0] [C.V1] None + + # rank1_input is KeyedJaggedTensor holding + + # 0 1 2 + # 'A' [A.V3] [A.V4] None + # 'B' None [B.V2] [B.V3, B.V4] + # 'C' [C.V2] [C.V3] None + + Output is None since this is write operation but still awaitable for synchronization + awaitable.wait() + + # where input after the distribution is : + # rank0 + + # 0 1 2 3 4 5 + # 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None + # 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4] + + # rank1 + # 0 1 2 3 4 5 + # 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None + """ + + def __init__( + self, + pg: dist.ProcessGroup, + splits: List[int], + stagger: int = 1, + ) -> None: + super().__init__() + torch._check(len(splits) == pg.size()) + self._pg: dist.ProcessGroup = pg + self._splits = splits + self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits)) + self._stagger = stagger + + def forward( + self, input: KeyedJaggedTensor + ) -> Awaitable[KJTAllToAllTensorsAwaitable]: + """ + Sends input to relevant `ProcessGroup` ranks. + + The first wait will get the output splits for the provided tensors and issue + tensors AlltoAll. The second wait will wait for the update. + + Args: + input (KeyedJaggedTensor): `KeyedJaggedTensor` of values and weights to distribute. + + Returns: + Awaitable[KJTAllToAllTensorsAwaitable]: awaitable of a `KJTAllToAllTensorsAwaitable`. + """ + + with torch.no_grad(): + assert len(input.keys()) == sum(self._splits) + rank = dist.get_rank(self._pg) + local_keys = input.keys()[ + self._splits_cumsum[rank] : self._splits_cumsum[rank + 1] + ] + + return KJTAllToAllSplitsAwaitable( + pg=self._pg, + input=input, + splits=self._splits, + labels=input.dist_labels(), + tensor_splits=input.dist_splits(self._splits), + input_tensors=input.dist_tensors(), + keys=local_keys, + device=input.device(), + stagger=self._stagger, + ) + + class KJTAllToAll(nn.Module): """ Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits. diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 24b0a4885..673fbccae 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -468,13 +468,19 @@ def __init__( for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items() } + self.enable_embedding_update: bool = any( + config.enable_embedding_update for config in self._embedding_configs + ) self._device = device self._input_dists: List[nn.Module] = [] + self._write_dists: List[nn.Module] = [] self._lookups: List[nn.Module] = [] + self._updates: List[nn.Module] = [] self._create_lookups() self._output_dists: List[nn.Module] = [] self._create_output_dist() + self._write_splits: List[int] = [] self._feature_splits: List[int] = [] self._features_order: List[int] = [] @@ -631,6 +637,7 @@ def create_grouped_sharding_infos( total_num_buckets=config.total_num_buckets, use_virtual_table=config.use_virtual_table, virtual_table_eviction_policy=config.virtual_table_eviction_policy, + enable_embedding_update=config.enable_embedding_update, ), param_sharding=parameter_sharding, param=param, @@ -1308,7 +1315,10 @@ def _create_input_dist( def _create_lookups(self) -> None: for sharding in self._sharding_type_to_sharding.values(): - self._lookups.append(sharding.create_lookup()) + lookup = sharding.create_lookup() + if self.enable_embedding_update and sharding.enable_embedding_update: + self._updates.append(sharding.create_update(lookup)) + self._lookups.append(lookup) def _create_output_dist( self, @@ -1627,6 +1637,40 @@ def fused_optimizer(self) -> KeyedOptimizer: def create_context(self) -> EmbeddingCollectionContext: return EmbeddingCollectionContext(sharding_contexts=[]) + def _create_write_dist(self) -> None: + for sharding in self._sharding_type_to_sharding.values(): + if sharding.enable_embedding_update: + self._write_dists.append(sharding.create_write_dist()) + self._write_splits.append(sharding._get_num_writable_features()) + + # pyre-ignore [14] + def write_dist( + self, ctx: EmbeddingCollectionContext, embeddings: KeyedJaggedTensor + ) -> Awaitable[Awaitable[KJTList]]: + if not self.enable_embedding_update: + raise ValueError("enable_embedding_update is False for this collection") + if not self._write_dists: + self._create_write_dist() + with torch.no_grad(): + embeddings_by_shards = embeddings.split(self._write_splits) + awaitables = [] + for write_dist, embeddings in zip(self._write_dists, embeddings_by_shards): + awaitables.append(write_dist(embeddings)) + + return KJTListSplitsAwaitable( + awaitables, + ctx, + self._module_fqn, + list(self._sharding_type_to_sharding.keys()), + ) + + def update(self, ctx: EmbeddingCollectionContext, dist_input: KJTList) -> None: + for update, embeddings in zip( + self._updates, + dist_input, + ): + update(embeddings) + class EmbeddingCollectionSharder(BaseEmbeddingSharder[EmbeddingCollection]): def __init__( diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index b1ade4834..891dd0b02 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -10,7 +10,7 @@ import logging from abc import ABC from collections import OrderedDict -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -39,6 +39,7 @@ BatchedFusedEmbeddingBag, KeyValueEmbedding, KeyValueEmbeddingBag, + ZeroCollisionEmbeddingCache, ZeroCollisionKeyValueEmbedding, ZeroCollisionKeyValueEmbeddingBag, ) @@ -49,6 +50,7 @@ from torchrec.distributed.embedding_kernel import BaseEmbedding from torchrec.distributed.embedding_types import ( BaseEmbeddingLookup, + BaseEmbeddingUpdate, BaseGroupedFeatureProcessor, EmbeddingComputeKernel, GroupedEmbeddingConfig, @@ -249,12 +251,20 @@ def _create_embedding_kernel( ) elif config.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE: # for dram kv - return ZeroCollisionKeyValueEmbedding( - config=config, - pg=pg, - device=device, - backend_type=BackendType.DRAM, - ) + if config.enable_embedding_update: + return ZeroCollisionEmbeddingCache( + config=config, + pg=pg, + device=device, + backend_type=BackendType.DRAM, + ) + else: + return ZeroCollisionKeyValueEmbedding( + config=config, + pg=pg, + device=device, + backend_type=BackendType.DRAM, + ) else: raise ValueError(f"Compute kernel not supported {config.compute_kernel}") @@ -411,6 +421,33 @@ def purge(self) -> None: emb_module.purge() +class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]): + """ + Update modules for Sequence embeddings (i.e Embeddings) + """ + + def __init__( + self, + grouped_embeddings_lookup: GroupedEmbeddingsLookup, + ) -> None: + super().__init__() + self._emb_modules: List[BaseEmbedding] = [] + self._feature_splits: List[int] = [] + for emb_module in grouped_embeddings_lookup._emb_modules: + emb_module = cast(BaseBatchedEmbedding[torch.Tensor], emb_module) + if emb_module.config.enable_embedding_update: + self._feature_splits.append(emb_module.config.num_features()) + self._emb_modules.append(emb_module) + + def forward(self, embeddings: KeyedJaggedTensor) -> None: + features_by_group = embeddings.split( + self._feature_splits, + ) + for emb_module, features in zip(self._emb_modules, features_by_group): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. + emb_module.update(features) + + class CommOpGradientScaling(torch.autograd.Function): @staticmethod # pyre-ignore diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index ba0d522d1..80fdcb05a 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -26,6 +26,7 @@ ) from torchrec.distributed.embedding_types import ( BaseEmbeddingLookup, + BaseEmbeddingUpdate, BaseGroupedFeatureProcessor, EmbeddingComputeKernel, FeatureShardingMixIn, @@ -42,7 +43,7 @@ QuantizedCommCodecs, ShardMetadata, ) -from torchrec.distributed.utils import maybe_annotate_embedding_event +from torchrec.distributed.utils import maybe_annotate_embedding_event, none_throws from torchrec.fx.utils import assert_fx_safe from torchrec.modules.embedding_configs import EmbeddingTableConfig from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -708,6 +709,77 @@ def _split(flat_list: List[T], splits: List[int]) -> List[List[T]]: ] +def bucketize_embeddings_before_all2all_write( + kjt: KeyedJaggedTensor, + num_buckets: int, + block_sizes: torch.Tensor, + total_num_blocks: Optional[torch.Tensor] = None, + output_permute: bool = False, + bucketize_pos: bool = False, + block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, + keep_original_indices: bool = False, +) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: + """ + Bucketize embeddings before writing to HBM/DRAM. + + Args: + """ + num_features = len(kjt.keys()) + assert_fx_safe( + block_sizes.numel() == num_features, + f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.", + ) + + ( + bucketized_lengths, + bucketized_indices, + bucketized_weights, + pos, + unbucketize_permute, + ) = torch.ops.fbgemm.block_bucketize_sparse_features_2d_weights( + lengths=kjt.lengths().view(-1), + indices=kjt.values(), + bucketize_pos=bucketize_pos, + sequence=output_permute, + block_sizes=_fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()), + total_num_blocks=( + _fx_wrap_tensor_to_device_dtype(total_num_blocks, kjt.values()) + if total_num_blocks is not None + else None + ), + my_size=num_buckets, + weights=kjt.weights_or_none(), + weights_dim=(none_throws(kjt.weights_or_none()).size(1)), + batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt), + max_B=_fx_wrap_max_B(kjt), + block_bucketize_pos=( + [ + _fx_wrap_tensor_to_device_dtype(pos, kjt.values()) + for pos in block_bucketize_row_pos + ] + if block_bucketize_row_pos is not None + else None + ), + keep_orig_idx=keep_original_indices, + ) + return ( + KeyedJaggedTensor( + # duplicate keys will be resolved by AllToAll + keys=_fx_wrap_gen_list_n_times(kjt.keys(), num_buckets), + values=bucketized_indices, + weights=pos if bucketize_pos else bucketized_weights, + lengths=bucketized_lengths.view(-1), + offsets=None, + stride=_fx_wrap_stride(kjt), + stride_per_key_per_rank=_fx_wrap_stride_per_key_per_rank(kjt, num_buckets), + length_per_key=None, + offset_per_key=None, + index_per_key=None, + ), + unbucketize_permute, + ) + + class KJTListSplitsAwaitable(Awaitable[Awaitable[KJTList]], Generic[C]): """ Awaitable of Awaitable of KJTList. @@ -958,6 +1030,24 @@ def forward( pass +class BaseSparseFeaturesWriteDist(abc.ABC, nn.Module, Generic[F]): + """ + Converts input from data-parallel to model-parallel. + """ + + @abc.abstractmethod + def forward( + self, + embeddings: F, + ) -> Union[Awaitable[Awaitable[F]], F]: + """ + Writes the input embeddings to the embedding table. + Args: + embeddings (F): KJT containing ID values and weights as embeddings to write into the embedding table. + """ + pass + + class BaseEmbeddingDist(abc.ABC, nn.Module, Generic[C, T, W]): """ Converts output of EmbeddingLookup from model-parallel to data-parallel. @@ -983,6 +1073,7 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: self._qcomm_codecs_registry = qcomm_codecs_registry + self.enable_embedding_update: bool = False @property def qcomm_codecs_registry(self) -> Optional[Dict[str, QuantizedCommCodecs]]: @@ -1011,6 +1102,20 @@ def create_lookup( ) -> BaseEmbeddingLookup[F, T]: pass + def create_write_dist( + self, device: Optional[torch.device] = None + ) -> BaseSparseFeaturesWriteDist[F]: + raise NotImplementedError() + + def create_update( + self, + grouped_embeddings_lookup: BaseEmbeddingLookup[F, T], + ) -> BaseEmbeddingUpdate[F]: + raise NotImplementedError() + + def _get_num_writable_features(self) -> int: + raise NotImplementedError() + @abc.abstractmethod def embedding_dims(self) -> List[int]: pass diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index e9f28b31f..5434a3203 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -325,6 +325,19 @@ def forward( pass +class BaseEmbeddingUpdate(abc.ABC, nn.Module, Generic[F]): + """ + Interface implemented by different embedding implementations for updating the weights + """ + + @abc.abstractmethod + def forward( + self, + embeddings: F, + ) -> None: + pass + + class FeatureShardingMixIn: """ Feature Sharding Interface to provide sharding-aware feature metadata. diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index f66465603..ebc76b976 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -17,12 +17,15 @@ ) from torchrec.distributed.embedding_lookup import ( GroupedEmbeddingsLookup, + GroupedEmbeddingsUpdate, InferGroupedEmbeddingsLookup, ) from torchrec.distributed.embedding_sharding import ( BaseEmbeddingDist, BaseEmbeddingLookup, + BaseEmbeddingUpdate, BaseSparseFeaturesDist, + BaseSparseFeaturesWriteDist, ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, @@ -33,6 +36,7 @@ get_embedding_shard_metadata, InferRwSparseFeaturesDist, RwSparseFeaturesDist, + RwSparseFeaturesWriteDist, ) from torchrec.distributed.sharding.sequence_sharding import ( InferSequenceShardingContext, @@ -163,6 +167,30 @@ def create_output_dist( qcomm_codecs_registry=self.qcomm_codecs_registry, ) + def create_write_dist( + self, device: Optional[torch.device] = None + ) -> BaseSparseFeaturesWriteDist[KeyedJaggedTensor]: + num_features = self._get_num_writable_features() + feature_hash_sizes = self._get_writable_feature_hash_sizes() + return RwSparseFeaturesWriteDist( + # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got + # `Optional[ProcessGroup]`. + pg=self._pg, + num_features=num_features, + feature_hash_sizes=feature_hash_sizes, + device=device if device is not None else self._device, + is_sequence=True, + ) + + # pyre-ignore [14] + def create_update( + self, + grouped_embeddings_lookup: GroupedEmbeddingsLookup, + ) -> BaseEmbeddingUpdate[KeyedJaggedTensor]: + return GroupedEmbeddingsUpdate( + grouped_embeddings_lookup=grouped_embeddings_lookup, + ) + class InferRwSequenceEmbeddingDist( BaseEmbeddingDist[ diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index cb16822c1..d310127c0 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -17,6 +17,7 @@ from torch.distributed._tensor.placement_types import Replicate, Shard from torchrec.distributed.dist_data import ( EmbeddingsAllToOneReduce, + KJEAllToAll, KJTAllToAll, KJTOneToAll, PooledEmbeddingsReduceScatter, @@ -30,6 +31,8 @@ BaseEmbeddingDist, BaseEmbeddingLookup, BaseSparseFeaturesDist, + BaseSparseFeaturesWriteDist, + bucketize_embeddings_before_all2all_write, bucketize_kjt_before_all2all, bucketize_kjt_inference, EmbeddingSharding, @@ -143,6 +146,10 @@ def __init__( self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( self._grouped_embedding_configs_per_rank[self._rank] ) + self.enable_embedding_update: bool = any( + grouped_config.enable_embedding_update + for grouped_config in self._grouped_embedding_configs + ) self._has_feature_processor: bool = False for group_config in self._grouped_embedding_configs: @@ -590,6 +597,113 @@ def create_output_dist( ) +class RwSparseFeaturesWriteDist(BaseSparseFeaturesWriteDist[KeyedJaggedTensor]): + """ + Accepts sparse feature embedding weights in RW fashion and then redistributes with an AlltoAll + collective operation. + + Args: + pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication. + num_features (int): total number of features. + feature_hash_sizes (List[int]): hash sizes of features. + feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size. + device (Optional[torch.device]): device on which buffers will be allocated. + is_sequence (bool): if this is for a sequence embedding. + has_feature_processor (bool): existence of feature processor (ie. position + weighted features). + + """ + + def __init__( + self, + pg: dist.ProcessGroup, + num_features: int, + feature_hash_sizes: List[int], + feature_total_num_buckets: Optional[List[int]] = None, + device: Optional[torch.device] = None, + is_sequence: bool = False, + keep_original_indices: bool = False, + ) -> None: + super().__init__() + self._world_size: int = pg.size() + self._num_features = num_features + + feature_block_sizes: List[int] = [] + + for i, hash_size in enumerate(feature_hash_sizes): + block_divisor = self._world_size + if feature_total_num_buckets is not None: + assert feature_total_num_buckets[i] % self._world_size == 0 + block_divisor = feature_total_num_buckets[i] + feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor) + + self.register_buffer( + "_feature_block_sizes_tensor", + torch.tensor( + feature_block_sizes, + device=device, + dtype=torch.int64, + ), + persistent=False, + ) + self._has_multiple_blocks_per_shard: bool = ( + feature_total_num_buckets is not None + ) + if self._has_multiple_blocks_per_shard: + self.register_buffer( + "_feature_total_num_blocks_tensor", + torch.tensor( + [feature_total_num_buckets], + device=device, + dtype=torch.int64, + ), + persistent=False, + ) + + self._dist = KJEAllToAll( + pg=pg, + splits=[self._num_features] * self._world_size, + ) + self._is_sequence = is_sequence + self.unbucketize_permute_tensor: Optional[torch.Tensor] = None + self._keep_original_indices = keep_original_indices + + def forward( + self, + embeddings: KeyedJaggedTensor, + ) -> Awaitable[Awaitable[KeyedJaggedTensor]]: + """ + Bucketizes sparse feature values into world size number of buckets and then + performs AlltoAll operation. + + Args: + sparse_features (KeyedJaggedTensor): sparse features to bucketize and + redistribute. + + Returns: + Awaitable[Awaitable[KeyedJaggedTensor]]: awaitable of awaitable of KeyedJaggedTensor. + """ + + ( + bucketized_features, + self.unbucketize_permute_tensor, + ) = bucketize_embeddings_before_all2all_write( + embeddings, + num_buckets=self._world_size, + block_sizes=self._feature_block_sizes_tensor, + total_num_blocks=( + self._feature_total_num_blocks_tensor + if self._has_multiple_blocks_per_shard + else None + ), + output_permute=self._is_sequence, + bucketize_pos=False, + keep_original_indices=self._keep_original_indices, + ) + + return self._dist(bucketized_features) + + @overload def convert_tensor(t: torch.Tensor, feature: KeyedJaggedTensor) -> torch.Tensor: ... @overload diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index bc4ffc6af..81e4fad8e 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -512,7 +512,8 @@ def _parameter_sharding_generator( def row_wise( - sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None + sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None, + compute_kernel: Optional[str] = None, ) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan. @@ -538,6 +539,10 @@ def row_wise( ), "sizes_placement and device per placement (in case of sharding " "across HBM and CPU host) must have the same length" + compute_kernel = ( + EmbeddingComputeKernel.QUANT.value if sizes_placement else compute_kernel + ) + def _parameter_sharding_generator( param: nn.Parameter, local_size: int, @@ -598,9 +603,7 @@ def _parameter_sharding_generator( device_type, sharder, placements=placements if sizes_placement else None, - compute_kernel=( - EmbeddingComputeKernel.QUANT.value if sizes_placement else None - ), + compute_kernel=compute_kernel, ) return _parameter_sharding_generator diff --git a/torchrec/distributed/tests/test_embedding_update.py b/torchrec/distributed/tests/test_embedding_update.py new file mode 100644 index 000000000..d2b049926 --- /dev/null +++ b/torchrec/distributed/tests/test_embedding_update.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Dict, List, Optional + +import torch + +import torch.nn as nn +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.global_settings import set_propogate_device +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + data_parallel, + EmbeddingCollectionSharder, + row_wise, +) + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.types import ShardingEnv, ShardingPlan +from torchrec.modules.embedding_configs import EmbeddingConfig, NoEvictionPolicy +from torchrec.modules.embedding_modules import EmbeddingCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +class TestECModel(nn.Module): + def __init__(self, tables: List[EmbeddingConfig], device: torch.device) -> None: + super().__init__() + self.ec = EmbeddingCollection(tables=tables, device=device) + + def forward(self, features: KeyedJaggedTensor) -> Dict[str, torch.Tensor]: + return self.ec(features) + + +class TestEmbeddingUpdate(MultiProcessTestBase): + + def test_sharded_embedding_update_disabled_in_oss_compatibility( + self, + # sharding_type: str, + # kernel_type: str, + ) -> None: + if torch.cuda.device_count() <= 1: + self.skipTest("Not enough GPUs, this test requires at least two GPUs") + WORLD_SIZE = 2 + tables = [ + EmbeddingConfig( + num_embeddings=8000, + embedding_dim=64, + name="table_0", + feature_names=["feature_0", "feature_1"], + total_num_buckets=20, + use_virtual_table=True, + enable_embedding_update=True, + virtual_table_eviction_policy=NoEvictionPolicy(), + ), + EmbeddingConfig( + num_embeddings=8000, + embedding_dim=64, + name="table_1", + feature_names=["feature_2"], + total_num_buckets=40, + use_virtual_table=True, + enable_embedding_update=True, + virtual_table_eviction_policy=NoEvictionPolicy(), + ), + EmbeddingConfig( + num_embeddings=8000, + embedding_dim=64, + name="table_2", + feature_names=["feature_3"], + ), + ] + backend = "nccl" + inputs_per_rank = [ # noqa + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2", "feature_3"], + values=torch.randint(0, 8000, (13,)), + lengths=torch.LongTensor([2, 1, 1, 1, 1, 1, 2, 0, 1, 1, 2, 0]), + ), + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2", "feature_3"], + values=torch.randint(0, 8000, (12,)), + lengths=torch.LongTensor([1, 1, 1, 1, 1, 1, 0, 0, 3, 1, 0, 2]), + ), + ] + embeddings_per_rank = [ + KeyedJaggedTensor.from_lengths_sync( + keys=["feature_0", "feature_1", "feature_2"], + values=torch.cat( + ( + input["feature_0"].values(), + input["feature_1"].values(), + input["feature_2"].values(), + ) + ), + lengths=input.lengths()[: -input["feature_3"].lengths().size(0)], + weights=torch.rand( + int( + torch.sum( + input.lengths()[: -input["feature_3"].lengths().size(0)] + ).item() + ), + 64, + dtype=torch.float32, + ), + ) + for input in inputs_per_rank + ] + self._run_multi_process_test( + callable=sharded_embedding_update, + world_size=WORLD_SIZE, + tables=tables, + backend=backend, + inputs_per_rank=inputs_per_rank, + embeddings_per_rank=embeddings_per_rank, + ) + + +def sharded_embedding_update( + rank: int, + world_size: int, + tables: List[EmbeddingConfig], + backend: str, + embeddings_per_rank: List[KeyedJaggedTensor], + inputs_per_rank: List[KeyedJaggedTensor], + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + assert ctx.pg is not None + model = TestECModel( + tables=tables, + device=ctx.device, + ) + + sharder = EmbeddingCollectionSharder() + per_param_sharding = { + "table_0": row_wise( + compute_kernel=EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value + ), + "table_1": row_wise( + compute_kernel=EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value + ), + "table_2": data_parallel(), + } + sharding_plan = construct_module_sharding_plan( + model.ec, + per_param_sharding=per_param_sharding, + local_size=local_size, + world_size=world_size, + device_type=ctx.device.type, + sharder=sharder, # pyre-ignore + ) + + set_propogate_device(True) + sharded_model = DistributedModelParallel( + model, + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore + plan=ShardingPlan({"ec": sharding_plan}), + sharders=[sharder], # pyre-ignore[6] + device=ctx.device, + ) + + kjts = inputs_per_rank[rank] + sharded_model(kjts.to(ctx.device)) + torch.cuda.synchronize() + # pyre-ignore [16] + sharded_model._dmp_wrapped_module.ec.write( + embeddings_per_rank[rank].to(ctx.device) + ) + torch.cuda.synchronize() + expected_embeddings = { + key: embeddings_per_rank[rank][key].weights() + for key in embeddings_per_rank[rank].keys() + } + embeddings = None + embeddings = sharded_model(kjts.to(ctx.device)) + for key, values in expected_embeddings.items(): + torch.testing.assert_close( + torch.cat(embeddings[key].to_dense()), + values.to_dense().to(ctx.device), + rtol=1e-3, + atol=1e-3, + ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index bd13515b4..46521ca6c 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -1078,6 +1078,41 @@ def forward(self, *input, **kwargs) -> LazyAwaitable[Out]: dist_input = self.input_dist(ctx, *input, **kwargs).wait().wait() return self.compute_and_output_dist(ctx, dist_input) + def update(self, ctx: ShrdCtx, dist_input: CompIn) -> None: + """ + Updates the sharded module with the given input. + + Args: + ctx (ShrdCtx): sharding context. + dist_input (CompIn): distributed input. + """ + raise NotImplementedError( + "The update method is not implemented for this collection. Please make sure you are using the correct compute kernel and sharding type." + ) + + def write_dist( + self, ctx: ShrdCtx, *input, **kwargs # pyre-ignore[2] + ) -> Awaitable[Awaitable[CompIn]]: + raise NotImplementedError( + "The write_dist method is not implemented for this collection. Please make sure you are using the correct compute kernel and sharding type." + ) + + # pyre-ignore[2] + def write(self, *input, **kwargs) -> None: + """ + Executes the write dist and update steps. + + Args: + *input: input. + **kwargs: keyword arguments. + + Returns: + LazyAwaitable[Out]: awaitable of output from output dist. + """ + ctx = self.create_context() + dist_input = self.write_dist(ctx, *input, **kwargs).wait().wait() + self.update(ctx, dist_input) + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for key, _ in self.named_parameters(prefix): yield key