Skip to content

Commit 676eca9

Browse files
kausvfacebook-github-bot
authored andcommitted
Add support for Write Dist (#3347)
Summary: X-link: #3347 https://docs.google.com/document/d/1N4Q8tdFRVB_qj2vbfadEVmJ-tHolBqL5QiCxa5I8zNk Create Embedding Write Dist. An ability to update individual embedding of a specific feature ID. For now, this is implemented only for KVZCH compute kernel with RW sharding. If any other sharding or kernel tries to call write, an exception will be thrown. Differential Revision: D78749760
1 parent e083ca6 commit 676eca9

11 files changed

+709
-15
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,10 @@ def __init__(
19441944
assert (
19451945
config.is_using_virtual_table
19461946
), "Try to create ZeroCollisionKeyValueEmbedding for non virtual tables"
1947+
assert embedding_cache_mode == config.enable_embedding_update, (
1948+
f"Embedding_cache kernel is {embedding_cache_mode} "
1949+
f"but embedding config has enable_embedding_update {config.enable_embedding_update}"
1950+
)
19471951
for table in config.embedding_tables:
19481952
assert table.local_cols % 4 == 0, (
19491953
f"table {table.name} has local_cols={table.local_cols} "

torchrec/distributed/dist_data.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,12 @@ def __init__(
368368
# https://github.com/pytorch/pytorch/issues/122788
369369
with record_function(f"## all2all_data:kjt {label} ##"):
370370
if self._pg._get_backend_name() == "custom":
371+
if input_tensor.dim() == 2:
372+
output_size = [sum(output_split), input_tensor.size(1)]
373+
else:
374+
output_size = [sum(output_split)]
371375
output_tensor = torch.empty(
372-
sum(output_split),
376+
output_size,
373377
device=self._device,
374378
dtype=input_tensor.dtype,
375379
)
@@ -391,8 +395,12 @@ def __init__(
391395
)
392396
self._output_tensors.append(output_tensor)
393397
else:
398+
if input_tensor.dim() == 2:
399+
output_size = [sum(output_split), input_tensor.size(1)]
400+
else:
401+
output_size = [sum(output_split)]
394402
output_tensor = torch.empty(
395-
sum(output_split), device=self._device, dtype=input_tensor.dtype
403+
output_size, device=self._device, dtype=input_tensor.dtype
396404
)
397405
with record_function(f"## all2all_data:kjt {label} ##"):
398406
awaitable = dist.all_to_all_single(
@@ -542,6 +550,113 @@ def _wait_impl(self) -> KJTAllToAllTensorsAwaitable:
542550
)
543551

544552

553+
class KJEAllToAll(nn.Module):
554+
"""
555+
Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits.
556+
557+
Implementation utilizes AlltoAll collective as part of torch.distributed.
558+
559+
The input provides the necessary tensors, embedding weights and input splits to distribute.
560+
The first collective call in `KJTAllToAllSplitsAwaitable` will transmit output
561+
splits (to allocate correct space for tensors) and batch size per rank. The
562+
following collective calls in `KJTAllToAllTensorsAwaitable` will transmit the actual
563+
tensors asynchronously.
564+
This module is used for embedding updates wherein input KJT weights are updated into the embedding tables.
565+
566+
Args:
567+
pg (dist.ProcessGroup): ProcessGroup for AlltoAll communication.
568+
splits (List[int]): List of len(pg.size()) which indicates how many features to
569+
send to each pg.rank(). It is assumed the `KeyedJaggedTensor` is ordered by
570+
destination rank. Same for all ranks.
571+
stagger (int): stagger value to apply to recat tensor, see `_get_recat` function
572+
for more detail.
573+
574+
Example::
575+
576+
keys=['A','B','C']
577+
splits=[2,1]
578+
kjeA2A = KJEAllToAll(pg, splits)
579+
awaitable = kjeA2A(rank0_input)
580+
581+
# where:
582+
# rank0_input is KeyedJaggedTensor holding
583+
584+
# 0 1 2
585+
# 'A' [A.V0] None [A.V1, A.V2]
586+
# 'B' None [B.V0] [B.V1]
587+
# 'C' [C.V0] [C.V1] None
588+
589+
# rank1_input is KeyedJaggedTensor holding
590+
591+
# 0 1 2
592+
# 'A' [A.V3] [A.V4] None
593+
# 'B' None [B.V2] [B.V3, B.V4]
594+
# 'C' [C.V2] [C.V3] None
595+
596+
Output is None since this is write operation but still awaitable for synchronization
597+
awaitable.wait()
598+
599+
# where input after the distribution is :
600+
# rank0
601+
602+
# 0 1 2 3 4 5
603+
# 'A' [A.V0] None [A.V1, A.V2] [A.V3] [A.V4] None
604+
# 'B' None [B.V0] [B.V1] None [B.V2] [B.V3, B.V4]
605+
606+
# rank1
607+
# 0 1 2 3 4 5
608+
# 'C' [C.V0] [C.V1] None [C.V2] [C.V3] None
609+
"""
610+
611+
def __init__(
612+
self,
613+
pg: dist.ProcessGroup,
614+
splits: List[int],
615+
stagger: int = 1,
616+
) -> None:
617+
super().__init__()
618+
torch._check(len(splits) == pg.size())
619+
self._pg: dist.ProcessGroup = pg
620+
self._splits = splits
621+
self._splits_cumsum: List[int] = [0] + list(itertools.accumulate(splits))
622+
self._stagger = stagger
623+
624+
def forward(
625+
self, input: KeyedJaggedTensor
626+
) -> Awaitable[KJTAllToAllTensorsAwaitable]:
627+
"""
628+
Sends input to relevant `ProcessGroup` ranks.
629+
630+
The first wait will get the output splits for the provided tensors and issue
631+
tensors AlltoAll. The second wait will wait for the update.
632+
633+
Args:
634+
input (KeyedJaggedTensor): `KeyedJaggedTensor` of values and weights to distribute.
635+
636+
Returns:
637+
Awaitable[KJTAllToAllTensorsAwaitable]: awaitable of a `KJTAllToAllTensorsAwaitable`.
638+
"""
639+
640+
with torch.no_grad():
641+
assert len(input.keys()) == sum(self._splits)
642+
rank = dist.get_rank(self._pg)
643+
local_keys = input.keys()[
644+
self._splits_cumsum[rank] : self._splits_cumsum[rank + 1]
645+
]
646+
647+
return KJTAllToAllSplitsAwaitable(
648+
pg=self._pg,
649+
input=input,
650+
splits=self._splits,
651+
labels=input.dist_labels(),
652+
tensor_splits=input.dist_splits(self._splits),
653+
input_tensors=input.dist_tensors(),
654+
keys=local_keys,
655+
device=input.device(),
656+
stagger=self._stagger,
657+
)
658+
659+
545660
class KJTAllToAll(nn.Module):
546661
"""
547662
Redistributes `KeyedJaggedTensor` to a `ProcessGroup` according to splits.

torchrec/distributed/embedding.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,19 @@ def __init__(
468468
for sharding_type, embedding_confings in sharding_type_to_sharding_infos.items()
469469
}
470470

471+
self.enable_embedding_update: bool = any(
472+
config.enable_embedding_update for config in self._embedding_configs
473+
)
471474
self._device = device
472475
self._input_dists: List[nn.Module] = []
476+
self._write_dists: List[nn.Module] = []
473477
self._lookups: List[nn.Module] = []
478+
self._updates: List[nn.Module] = []
474479
self._create_lookups()
475480
self._output_dists: List[nn.Module] = []
476481
self._create_output_dist()
477482

483+
self._write_splits: List[int] = []
478484
self._feature_splits: List[int] = []
479485
self._features_order: List[int] = []
480486

@@ -631,6 +637,7 @@ def create_grouped_sharding_infos(
631637
total_num_buckets=config.total_num_buckets,
632638
use_virtual_table=config.use_virtual_table,
633639
virtual_table_eviction_policy=config.virtual_table_eviction_policy,
640+
enable_embedding_update=config.enable_embedding_update,
634641
),
635642
param_sharding=parameter_sharding,
636643
param=param,
@@ -1308,7 +1315,10 @@ def _create_input_dist(
13081315

13091316
def _create_lookups(self) -> None:
13101317
for sharding in self._sharding_type_to_sharding.values():
1311-
self._lookups.append(sharding.create_lookup())
1318+
lookup = sharding.create_lookup()
1319+
if self.enable_embedding_update and sharding.enable_embedding_update:
1320+
self._updates.append(sharding.create_update(lookup))
1321+
self._lookups.append(lookup)
13121322

13131323
def _create_output_dist(
13141324
self,
@@ -1627,6 +1637,40 @@ def fused_optimizer(self) -> KeyedOptimizer:
16271637
def create_context(self) -> EmbeddingCollectionContext:
16281638
return EmbeddingCollectionContext(sharding_contexts=[])
16291639

1640+
def _create_write_dist(self) -> None:
1641+
for sharding in self._sharding_type_to_sharding.values():
1642+
if sharding.enable_embedding_update:
1643+
self._write_dists.append(sharding.create_write_dist())
1644+
self._write_splits.append(sharding._get_num_writable_features())
1645+
1646+
# pyre-ignore [14]
1647+
def write_dist(
1648+
self, ctx: EmbeddingCollectionContext, embeddings: KeyedJaggedTensor
1649+
) -> Awaitable[Awaitable[KJTList]]:
1650+
if not self.enable_embedding_update:
1651+
raise ValueError("enable_embedding_update is False for this collection")
1652+
if not self._write_dists:
1653+
self._create_write_dist()
1654+
with torch.no_grad():
1655+
embeddings_by_shards = embeddings.split(self._write_splits)
1656+
awaitables = []
1657+
for write_dist, embeddings in zip(self._write_dists, embeddings_by_shards):
1658+
awaitables.append(write_dist(embeddings))
1659+
1660+
return KJTListSplitsAwaitable(
1661+
awaitables,
1662+
ctx,
1663+
self._module_fqn,
1664+
list(self._sharding_type_to_sharding.keys()),
1665+
)
1666+
1667+
def update(self, ctx: EmbeddingCollectionContext, dist_input: KJTList) -> None:
1668+
for update, embeddings in zip(
1669+
self._updates,
1670+
dist_input,
1671+
):
1672+
update(embeddings)
1673+
16301674

16311675
class EmbeddingCollectionSharder(BaseEmbeddingSharder[EmbeddingCollection]):
16321676
def __init__(

torchrec/distributed/embedding_lookup.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
from abc import ABC
1212
from collections import OrderedDict
13-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
13+
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
1414

1515
import torch
1616
import torch.distributed as dist
@@ -39,6 +39,7 @@
3939
BatchedFusedEmbeddingBag,
4040
KeyValueEmbedding,
4141
KeyValueEmbeddingBag,
42+
ZeroCollisionEmbeddingCache,
4243
ZeroCollisionKeyValueEmbedding,
4344
ZeroCollisionKeyValueEmbeddingBag,
4445
)
@@ -49,6 +50,7 @@
4950
from torchrec.distributed.embedding_kernel import BaseEmbedding
5051
from torchrec.distributed.embedding_types import (
5152
BaseEmbeddingLookup,
53+
BaseEmbeddingUpdate,
5254
BaseGroupedFeatureProcessor,
5355
EmbeddingComputeKernel,
5456
GroupedEmbeddingConfig,
@@ -249,12 +251,20 @@ def _create_embedding_kernel(
249251
)
250252
elif config.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE:
251253
# for dram kv
252-
return ZeroCollisionKeyValueEmbedding(
253-
config=config,
254-
pg=pg,
255-
device=device,
256-
backend_type=BackendType.DRAM,
257-
)
254+
if config.enable_embedding_update:
255+
return ZeroCollisionEmbeddingCache(
256+
config=config,
257+
pg=pg,
258+
device=device,
259+
backend_type=BackendType.DRAM,
260+
)
261+
else:
262+
return ZeroCollisionKeyValueEmbedding(
263+
config=config,
264+
pg=pg,
265+
device=device,
266+
backend_type=BackendType.DRAM,
267+
)
258268
else:
259269
raise ValueError(f"Compute kernel not supported {config.compute_kernel}")
260270

@@ -411,6 +421,33 @@ def purge(self) -> None:
411421
emb_module.purge()
412422

413423

424+
class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]):
425+
"""
426+
Update modules for Sequence embeddings (i.e Embeddings)
427+
"""
428+
429+
def __init__(
430+
self,
431+
grouped_embeddings_lookup: GroupedEmbeddingsLookup,
432+
) -> None:
433+
super().__init__()
434+
self._emb_modules: List[BaseEmbedding] = []
435+
self._feature_splits: List[int] = []
436+
for emb_module in grouped_embeddings_lookup._emb_modules:
437+
emb_module = cast(BaseBatchedEmbedding[torch.Tensor], emb_module)
438+
if emb_module.config.enable_embedding_update:
439+
self._feature_splits.append(emb_module.config.num_features())
440+
self._emb_modules.append(emb_module)
441+
442+
def forward(self, embeddings: KeyedJaggedTensor) -> None:
443+
features_by_group = embeddings.split(
444+
self._feature_splits,
445+
)
446+
for emb_module, features in zip(self._emb_modules, features_by_group):
447+
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
448+
emb_module.update(features)
449+
450+
414451
class CommOpGradientScaling(torch.autograd.Function):
415452
@staticmethod
416453
# pyre-ignore

0 commit comments

Comments
 (0)