Skip to content

Commit 8abfd90

Browse files
optimiseafacebook-github-bot
authored andcommitted
update mtia info in torchrec (#3391)
Summary: Pull Request resolved: #3391 Reviewed By: egienvalue Differential Revision: D83037454 fbshipit-source-id: 824b101749fe6e9022b54323e072828f4675cdb2
1 parent 97a1534 commit 8abfd90

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

torchrec/distributed/planner/constants.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def kernel_bw_lookup(
7777
("cpu", EmbeddingComputeKernel.QUANT.value): 1 * ddr_mem_bw,
7878
# TODO: Determine the correct value later. MTIA uses values same as CPU's.
7979
# MTIA
80-
("mtia", EmbeddingComputeKernel.DENSE.value): 0.5 * ddr_mem_bw,
81-
("mtia", EmbeddingComputeKernel.FUSED.value): 1 * ddr_mem_bw,
82-
("mtia", EmbeddingComputeKernel.QUANT.value): 1 * ddr_mem_bw,
80+
("mtia", EmbeddingComputeKernel.DENSE.value): 0.5 * hbm_mem_bw,
81+
("mtia", EmbeddingComputeKernel.FUSED.value): 1 * hbm_mem_bw,
82+
("mtia", EmbeddingComputeKernel.QUANT.value): 1 * hbm_mem_bw,
8383
# CUDA
8484
("cuda", EmbeddingComputeKernel.DENSE.value): 0.5 * hbm_mem_bw,
8585
("cuda", EmbeddingComputeKernel.FUSED.value): 1 * hbm_mem_bw,

torchrec/distributed/planner/stats.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,10 @@ def _log_rank_mem_usage_and_perf(
705705
used_hbm_gb = bytes_to_gb(used_hbm[rank])
706706
used_hbm_ratio = (
707707
used_hbm[rank] / ((1 - reserved_hbm_percent) * device.storage.hbm)
708-
if topology.compute_device == "cuda"
708+
if (
709+
topology.compute_device == "cuda"
710+
or topology.compute_device == "mtia"
711+
)
709712
and ((1 - reserved_hbm_percent) * device.storage.hbm) != 0
710713
else 0
711714
)

0 commit comments

Comments
 (0)