Skip to content

Commit 7a8b57a

Browse files
committed
update
Signed-off-by: shen-shanshan <[email protected]>
1 parent 6dff4c8 commit 7a8b57a

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

vllm_ascend/platform.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# This file is a part of the vllm-ascend project.
1616
#
1717

18+
import gc
1819
import logging
1920
import os
2021
from typing import TYPE_CHECKING, Optional, Tuple
@@ -244,3 +245,9 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
244245
model configuration.
245246
"""
246247
return True
248+
249+
@classmethod
250+
def clear_npu_memory(cls):
251+
gc.collect()
252+
torch.npu.empty_cache()
253+
torch.npu.reset_peak_memory_stats()

vllm_ascend/worker/worker.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/worker.py
1818
#
1919

20-
import gc
2120
import os
2221
from typing import Dict, List, Optional, Set, Tuple, Type, Union
2322

2423
import msgpack # type: ignore
2524
import torch
26-
import torch.distributed
2725
import zmq
2826
from torch import nn
2927
from vllm import envs
@@ -209,9 +207,7 @@ def init_device(self) -> None:
209207
if self.device_config.device.type == "npu":
210208
self.device = torch.device(f"npu:{self.local_rank}")
211209
NPUPlatform.set_device(self.device)
212-
gc.collect()
213-
NPUPlatform.empty_cache()
214-
torch.npu.reset_peak_memory_stats()
210+
NPUPlatform.clear_npu_memory()
215211
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
216212
else:
217213
raise RuntimeError(
@@ -278,9 +274,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
278274
"""
279275
# Profile the memory usage of the model and get the maximum number of
280276
# cache blocks that can be allocated with the remaining free memory.
281-
gc.collect()
282-
NPUPlatform.empty_cache()
283-
torch.npu.reset_peak_memory_stats()
277+
NPUPlatform.clear_npu_memory()
284278

285279
# Execute a forward pass with dummy inputs to profile the memory usage
286280
# of the model.
@@ -306,10 +300,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
306300
cache_block_size)
307301
num_npu_blocks = max(num_npu_blocks, 0)
308302
num_cpu_blocks = max(num_cpu_blocks, 0)
309-
gc.collect()
310-
# TODO: don`t need impl this func after empty_cache in
311-
# Worker.determine_num_available_blocks() unified`
312-
NPUPlatform.empty_cache()
303+
304+
NPUPlatform.clear_npu_memory()
313305
return num_npu_blocks, num_cpu_blocks
314306

315307
def initialize_cache(self, num_gpu_blocks: int,

0 commit comments

Comments
 (0)