Skip to content

Commit 4858f3b

Browse files
authored
Add an option to launch cacheflow without ray (#51)
1 parent a96d63c commit 4858f3b

File tree

7 files changed

+102
-28
lines changed

7 files changed

+102
-28
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
*.egg-info/
44
*.eggs/
55
*.so
6+
*.log
7+
*.csv
68
build/
79

810
*.pkl
911
*.png
1012
**/log.txt
13+
.vscode/

benchmark/benchmark_latency.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
from cacheflow.master.simple_frontend import SimpleFrontend
1010
from cacheflow.master.server import (Server, add_server_arguments,
11-
initialize_ray_cluster)
11+
process_server_arguments,
12+
initialize_cluster)
1213
from cacheflow.sampling_params import SamplingParams
1314
from cacheflow.utils import get_gpu_memory, get_cpu_memory
1415

@@ -20,8 +21,8 @@ def main(args: argparse.Namespace):
2021

2122
(num_nodes, num_devices_per_node, distributed_init_method,
2223
all_stage_devices) = (
23-
initialize_ray_cluster(
24-
address='local',
24+
initialize_cluster(
25+
use_ray=args.use_ray,
2526
pipeline_parallel_size=args.pipeline_parallel_size,
2627
tensor_parallel_size=args.tensor_parallel_size))
2728

@@ -44,6 +45,7 @@ def main(args: argparse.Namespace):
4445
all_stage_devices=all_stage_devices,
4546
gpu_memory=get_gpu_memory(),
4647
cpu_memory=get_cpu_memory(),
48+
use_ray=args.use_ray,
4749
)
4850

4951
# Create a frontend.
@@ -91,14 +93,16 @@ def profile_step(profile=False):
9193

9294

9395
if __name__ == '__main__':
94-
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
96+
parser = argparse.ArgumentParser(
97+
description='Benchmark the latency of decoding a single sentence.')
9598
parser = add_server_arguments(parser)
9699
parser.add_argument('--input-len', type=int, default=32)
97100
parser.add_argument('--output-len', type=int, default=128)
98101
parser.add_argument('--batch-size', type=int, default=8)
99102
parser.add_argument('--n', type=int, default=1)
100103
parser.add_argument('--use-beam-search', action='store_true')
101104
args = parser.parse_args()
105+
args = process_server_arguments(args)
102106
args.max_num_batched_tokens = max(
103107
args.max_num_batched_tokens, args.batch_size * args.input_len)
104108
print(args)

benchmark/benchmark_text_completion.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from benchmark.trace import generate_text_completion_requests
1212
from cacheflow.master.simple_frontend import SimpleFrontend
1313
from cacheflow.master.server import (Server, add_server_arguments,
14-
initialize_ray_cluster)
14+
process_server_arguments,
15+
initialize_cluster)
1516
from cacheflow.sampling_params import SamplingParams
1617
from cacheflow.utils import get_gpu_memory, get_cpu_memory
1718

@@ -25,8 +26,8 @@ def main(args: argparse.Namespace):
2526

2627
(num_nodes, num_devices_per_node, distributed_init_method,
2728
all_stage_devices) = (
28-
initialize_ray_cluster(
29-
address='local',
29+
initialize_cluster(
30+
use_ray=args.use_ray,
3031
pipeline_parallel_size=args.pipeline_parallel_size,
3132
tensor_parallel_size=args.tensor_parallel_size))
3233

@@ -49,6 +50,7 @@ def main(args: argparse.Namespace):
4950
all_stage_devices=all_stage_devices,
5051
gpu_memory=get_gpu_memory(),
5152
cpu_memory=get_cpu_memory(),
53+
use_ray=args.use_ray,
5254
collect_stats=True,
5355
do_memory_analysis=args.do_memory_analysis,
5456
)
@@ -134,7 +136,7 @@ def main(args: argparse.Namespace):
134136
finished.append({
135137
'group_id': seq_group.group_id,
136138
'seq_id': seq.seq_id,
137-
'arrival_time': arrival_time,
139+
'arrival_time': arrival_time,
138140
'finish_time': finish_time,
139141
'prompt_len': seq.prompt_len,
140142
'output_len': output_len,
@@ -225,8 +227,9 @@ def get_sampling_dir_name(
225227

226228

227229
if __name__ == '__main__':
228-
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
229-
parser = add_server_arguments(parser)
230+
parser = argparse.ArgumentParser(
231+
description='Benchmark the performance on a series of requests.')
232+
parser = add_server_arguments(parser)
230233
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
231234

232235
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
@@ -246,6 +249,7 @@ def get_sampling_dir_name(
246249
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
247250
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
248251
args = parser.parse_args()
252+
args = process_server_arguments(args)
249253
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
250254
raise ValueError('The ratios of requests must sum to 1.')
251255

cacheflow/http_frontend/fastapi_frontend.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from cacheflow.sampling_params import SamplingParams
1414
from cacheflow.sequence import Sequence, SequenceGroup
1515
from cacheflow.master.server import (Server, add_server_arguments,
16-
initialize_ray_cluster)
16+
process_server_arguments,
17+
initialize_cluster)
1718
from cacheflow.worker.controller import DeviceID
1819
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
1920

@@ -33,17 +34,22 @@ def __init__(
3334
seed: int,
3435
swap_space: int,
3536
max_num_batched_tokens: int,
37+
max_num_sequences: int,
3638
num_nodes: int,
3739
num_devices_per_node: int,
3840
distributed_init_method: str,
3941
all_stage_devices: List[List[DeviceID]],
42+
server_use_ray: bool,
4043
):
4144
self.block_size = block_size
4245

4346
self.tokenizer = AutoTokenizer.from_pretrained(model)
4447
self.seq_group_counter = Counter()
4548
self.seq_counter = Counter()
46-
remote_server_class = ray.remote(num_cpus=0)(Server)
49+
if server_use_ray:
50+
remote_server_class = ray.remote(num_cpus=0)(Server)
51+
else:
52+
remote_server_class = ray.remote(num_gpus=1)(Server)
4753
self.server = remote_server_class.remote(
4854
model=model,
4955
model_path=model_path,
@@ -55,12 +61,14 @@ def __init__(
5561
seed=seed,
5662
swap_space=swap_space,
5763
max_num_batched_tokens=max_num_batched_tokens,
64+
max_num_sequences=max_num_sequences,
5865
num_nodes=num_nodes,
5966
num_devices_per_node=num_devices_per_node,
6067
distributed_init_method=distributed_init_method,
6168
all_stage_devices=all_stage_devices,
6269
gpu_memory=get_gpu_memory(),
6370
cpu_memory=get_cpu_memory(),
71+
use_ray=server_use_ray,
6472
)
6573

6674
self.running_seq_groups: Dict[int, SequenceGroup] = {}
@@ -149,14 +157,16 @@ async def generate_stream(request: Request):
149157
parser.add_argument("--port", type=int, default=10002)
150158
parser = add_server_arguments(parser)
151159
args = parser.parse_args()
160+
args = process_server_arguments(args)
152161

153162
# TODO(zhuohan): Support pipeline parallelism.
154163
assert args.pipeline_parallel_size == 1, (
155164
'Pipeline parallelism is not supported yet.')
156165

157166
(num_nodes, num_devices_per_node, distributed_init_method,
158167
all_stage_devices) = (
159-
initialize_ray_cluster(
168+
initialize_cluster(
169+
use_ray=True,
160170
pipeline_parallel_size=args.pipeline_parallel_size,
161171
tensor_parallel_size=args.tensor_parallel_size))
162172

@@ -170,10 +180,12 @@ async def generate_stream(request: Request):
170180
seed=args.seed,
171181
swap_space=args.swap_space,
172182
max_num_batched_tokens=args.max_num_batched_tokens,
183+
max_num_sequences=args.max_num_sequences,
173184
num_nodes=num_nodes,
174185
num_devices_per_node=num_devices_per_node,
175186
distributed_init_method=distributed_init_method,
176187
all_stage_devices=all_stage_devices,
188+
server_use_ray=args.use_ray,
177189
)
178190

179191
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

cacheflow/master/server.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import argparse
2-
from typing import List, Tuple
2+
from typing import List, Tuple, Optional
33
import random
44

5-
import ray
5+
import torch
6+
try:
7+
import ray
8+
except ImportError:
9+
ray = None
610

711
from cacheflow.master.scheduler import Scheduler
812
from cacheflow.models import get_memory_analyzer
@@ -31,13 +35,18 @@ def __init__(
3135
all_stage_devices: List[List[DeviceID]],
3236
gpu_memory: int,
3337
cpu_memory: int,
38+
use_ray: bool,
3439
collect_stats: bool = False,
3540
do_memory_analysis: bool = False,
3641
):
3742
self.num_nodes = num_nodes
3843
self.num_devices_per_node = num_devices_per_node
3944
self.world_size = pipeline_parallel_size * tensor_parallel_size
4045

46+
if not use_ray:
47+
assert self.world_size == 1, (
48+
"Only support single GPU without Ray.")
49+
4150
self.memory_analyzer = get_memory_analyzer(
4251
model_name=model,
4352
block_size=block_size,
@@ -72,6 +81,7 @@ def __init__(
7281
model_path=model_path,
7382
use_dummy_weights=use_dummy_weights,
7483
max_num_batched_tokens=max_num_batched_tokens,
84+
use_ray=use_ray,
7585
)
7686
self.controllers.append(controller)
7787

@@ -105,11 +115,30 @@ def has_unfinished_requests(self):
105115
self.scheduler.swapped)
106116

107117

108-
def initialize_ray_cluster(
109-
address: str = 'auto',
118+
def initialize_cluster(
119+
use_ray: bool = False,
120+
address: Optional[str] = None,
110121
pipeline_parallel_size: int = 1,
111122
tensor_parallel_size: int = 1,
112123
) -> Tuple[int, int, str, List[List[DeviceID]]]:
124+
# Initialize cluster locally.
125+
if not use_ray:
126+
assert pipeline_parallel_size * tensor_parallel_size == 1, (
127+
"Only support single GPU without Ray.")
128+
num_nodes = 1
129+
num_devices_per_node = torch.cuda.device_count()
130+
port = random.randint(10000, 20000)
131+
# We need to setup the distributed init method to make sure
132+
# the distributed megatron code (e.g., get world size) works correctly.
133+
distributed_init_method = f"tcp://localhost:{port}"
134+
all_stage_devices = [[(0, None, 0)]]
135+
return (num_nodes, num_devices_per_node, distributed_init_method,
136+
all_stage_devices)
137+
138+
assert ray is not None, (
139+
"Ray is not installed. Please install Ray to use distributed "
140+
"serving.")
141+
113142
# Connect to a ray cluster.
114143
ray.init(address=address)
115144

@@ -177,6 +206,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
177206
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
178207
help='model path to download and load the weights')
179208
# Parallel arguments
209+
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
180210
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
181211
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
182212
# KV cache arguments
@@ -190,3 +220,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
190220
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
191221
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
192222
return parser
223+
224+
def process_server_arguments(args: argparse.Namespace):
225+
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
226+
args.use_ray = True
227+
return args

cacheflow/worker/controller.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from typing import Dict, List, Union, Tuple
22

3-
import ray
3+
try:
4+
import ray
5+
except ImportError:
6+
ray = None
47

58
from cacheflow.master.scheduler import Scheduler
69
from cacheflow.sequence import SequenceGroupInputs
@@ -29,24 +32,29 @@ def __init__(
2932
model_path: str,
3033
use_dummy_weights: bool,
3134
max_num_batched_tokens: int,
35+
use_ray: bool,
3236
) -> None:
3337
self.stage_id = stage_id
3438
self.stage_devices = stage_devices
3539
self.model_name = model_name
3640
self.block_size = block_size
3741
self.num_gpu_blocks = num_gpu_blocks
3842
self.num_cpu_blocks = num_cpu_blocks
43+
self.use_ray = use_ray
3944

4045
# Which pipeline stage is this node assigned to?
4146
self.is_first_stage = stage_id == 0
4247
self.is_last_stage = False
4348

4449
self.workers: List[Worker] = []
4550
for rank, node_resource, device_id in stage_devices:
46-
worker_cls = ray.remote(num_cpus=0,
47-
num_gpus=1,
48-
resources={node_resource: 1e-5})(Worker)
49-
worker = worker_cls.remote(
51+
if self.use_ray:
52+
worker_cls = ray.remote(num_cpus=0,
53+
num_gpus=1,
54+
resources={node_resource: 1e-5})(Worker).remote
55+
else:
56+
worker_cls = Worker
57+
worker = worker_cls(
5058
model_name=model_name,
5159
block_size=block_size,
5260
num_gpu_blocks=num_gpu_blocks,
@@ -78,17 +86,21 @@ def execute_stage(
7886
blocks_to_swap_out: Dict[int, int],
7987
blocks_to_copy: Dict[int, List[int]],
8088
) -> None:
81-
futures = []
89+
all_outputs = []
8290
for worker in self.workers:
83-
future = worker.execute_stage.remote(
91+
executor = (worker.execute_stage.remote
92+
if self.use_ray else worker.execute_stage)
93+
output = executor(
8494
input_seq_groups,
8595
blocks_to_swap_in,
8696
blocks_to_swap_out,
8797
blocks_to_copy,
8898
)
89-
futures.append(future)
99+
all_outputs.append(output)
100+
101+
if self.use_ray:
102+
all_outputs = ray.get(all_outputs)
90103

91-
all_outputs = ray.get(futures)
92104
# Make sure all workers have the same results.
93105
output = all_outputs[0]
94106
for other_output in all_outputs[1:]:

simple_server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from cacheflow.master.simple_frontend import SimpleFrontend
55
from cacheflow.master.server import (Server, add_server_arguments,
6-
initialize_ray_cluster)
6+
process_server_arguments,
7+
initialize_cluster)
78
from cacheflow.sampling_params import SamplingParams
89
from cacheflow.utils import get_gpu_memory, get_cpu_memory
910

@@ -14,7 +15,8 @@ def main(args: argparse.Namespace):
1415

1516
(num_nodes, num_devices_per_node, distributed_init_method,
1617
all_stage_devices) = (
17-
initialize_ray_cluster(
18+
initialize_cluster(
19+
use_ray=args.use_ray,
1820
pipeline_parallel_size=args.pipeline_parallel_size,
1921
tensor_parallel_size=args.tensor_parallel_size))
2022

@@ -37,6 +39,7 @@ def main(args: argparse.Namespace):
3739
all_stage_devices=all_stage_devices,
3840
gpu_memory=get_gpu_memory(),
3941
cpu_memory=get_cpu_memory(),
42+
use_ray=args.use_ray,
4043
)
4144

4245
# Create a frontend.
@@ -70,4 +73,5 @@ def main(args: argparse.Namespace):
7073
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
7174
parser = add_server_arguments(parser)
7275
args = parser.parse_args()
76+
args = process_server_arguments(args)
7377
main(args)

0 commit comments

Comments
 (0)