From 52f0168063d5721ff1b2d645c6b0ebae85328a34 Mon Sep 17 00:00:00 2001 From: Hai Zheng Date: Tue, 23 Sep 2025 01:26:22 -0700 Subject: [PATCH] update mtia info in torchrec Differential Revision: D83037454 --- torchrec/distributed/planner/constants.py | 6 +++--- torchrec/distributed/planner/stats.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/torchrec/distributed/planner/constants.py b/torchrec/distributed/planner/constants.py index 4b301c052..56c7dc26f 100644 --- a/torchrec/distributed/planner/constants.py +++ b/torchrec/distributed/planner/constants.py @@ -77,9 +77,9 @@ def kernel_bw_lookup( ("cpu", EmbeddingComputeKernel.QUANT.value): 1 * ddr_mem_bw, # TODO: Determine the correct value later. MTIA uses values same as CPU's. # MTIA - ("mtia", EmbeddingComputeKernel.DENSE.value): 0.5 * ddr_mem_bw, - ("mtia", EmbeddingComputeKernel.FUSED.value): 1 * ddr_mem_bw, - ("mtia", EmbeddingComputeKernel.QUANT.value): 1 * ddr_mem_bw, + ("mtia", EmbeddingComputeKernel.DENSE.value): 0.5 * hbm_mem_bw, + ("mtia", EmbeddingComputeKernel.FUSED.value): 1 * hbm_mem_bw, + ("mtia", EmbeddingComputeKernel.QUANT.value): 1 * hbm_mem_bw, # CUDA ("cuda", EmbeddingComputeKernel.DENSE.value): 0.5 * hbm_mem_bw, ("cuda", EmbeddingComputeKernel.FUSED.value): 1 * hbm_mem_bw, diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 431c336a8..d1aeb7761 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -705,7 +705,10 @@ def _log_rank_mem_usage_and_perf( used_hbm_gb = bytes_to_gb(used_hbm[rank]) used_hbm_ratio = ( used_hbm[rank] / ((1 - reserved_hbm_percent) * device.storage.hbm) - if topology.compute_device == "cuda" + if ( + topology.compute_device == "cuda" + or topology.compute_device == "mtia" + ) and ((1 - reserved_hbm_percent) * device.storage.hbm) != 0 else 0 )