File tree Expand file tree Collapse file tree 2 files changed +7
-4
lines changed
torchrec/distributed/planner Expand file tree Collapse file tree 2 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -77,9 +77,9 @@ def kernel_bw_lookup(
77
77
("cpu" , EmbeddingComputeKernel .QUANT .value ): 1 * ddr_mem_bw ,
78
78
# TODO: Determine the correct value later. MTIA uses values same as CPU's.
79
79
# MTIA
80
- ("mtia" , EmbeddingComputeKernel .DENSE .value ): 0.5 * ddr_mem_bw ,
81
- ("mtia" , EmbeddingComputeKernel .FUSED .value ): 1 * ddr_mem_bw ,
82
- ("mtia" , EmbeddingComputeKernel .QUANT .value ): 1 * ddr_mem_bw ,
80
+ ("mtia" , EmbeddingComputeKernel .DENSE .value ): 0.5 * hbm_mem_bw ,
81
+ ("mtia" , EmbeddingComputeKernel .FUSED .value ): 1 * hbm_mem_bw ,
82
+ ("mtia" , EmbeddingComputeKernel .QUANT .value ): 1 * hbm_mem_bw ,
83
83
# CUDA
84
84
("cuda" , EmbeddingComputeKernel .DENSE .value ): 0.5 * hbm_mem_bw ,
85
85
("cuda" , EmbeddingComputeKernel .FUSED .value ): 1 * hbm_mem_bw ,
Original file line number Diff line number Diff line change @@ -705,7 +705,10 @@ def _log_rank_mem_usage_and_perf(
705
705
used_hbm_gb = bytes_to_gb (used_hbm [rank ])
706
706
used_hbm_ratio = (
707
707
used_hbm [rank ] / ((1 - reserved_hbm_percent ) * device .storage .hbm )
708
- if topology .compute_device == "cuda"
708
+ if (
709
+ topology .compute_device == "cuda"
710
+ or topology .compute_device == "mtia"
711
+ )
709
712
and ((1 - reserved_hbm_percent ) * device .storage .hbm ) != 0
710
713
else 0
711
714
)
You can’t perform that action at this time.
0 commit comments