11import  argparse 
2- from  typing  import  List , Tuple 
2+ from  typing  import  List , Tuple ,  Optional 
33import  random 
44
5- import  ray 
5+ import  torch 
6+ try :
7+     import  ray 
8+ except  ImportError :
9+     ray  =  None 
610
711from  cacheflow .master .scheduler  import  Scheduler 
812from  cacheflow .models  import  get_memory_analyzer 
@@ -31,13 +35,18 @@ def __init__(
3135        all_stage_devices : List [List [DeviceID ]],
3236        gpu_memory : int ,
3337        cpu_memory : int ,
38+         use_ray : bool ,
3439        collect_stats : bool  =  False ,
3540        do_memory_analysis : bool  =  False ,
3641    ):
3742        self .num_nodes  =  num_nodes 
3843        self .num_devices_per_node  =  num_devices_per_node 
3944        self .world_size  =  pipeline_parallel_size  *  tensor_parallel_size 
4045
46+         if  not  use_ray :
47+             assert  self .world_size  ==  1 , (
48+                 "Only support single GPU without Ray." )
49+ 
4150        self .memory_analyzer  =  get_memory_analyzer (
4251            model_name = model ,
4352            block_size = block_size ,
@@ -72,6 +81,7 @@ def __init__(
7281                model_path = model_path ,
7382                use_dummy_weights = use_dummy_weights ,
7483                max_num_batched_tokens = max_num_batched_tokens ,
84+                 use_ray = use_ray ,
7585            )
7686            self .controllers .append (controller )
7787
@@ -105,11 +115,30 @@ def has_unfinished_requests(self):
105115                self .scheduler .swapped )
106116
107117
108- def  initialize_ray_cluster (
109-     address : str  =  'auto' ,
118+ def  initialize_cluster (
119+     use_ray : bool  =  False ,
120+     address : Optional [str ] =  None ,
110121    pipeline_parallel_size : int  =  1 ,
111122    tensor_parallel_size : int  =  1 ,
112123) ->  Tuple [int , int , str , List [List [DeviceID ]]]:
124+     # Initialize cluster locally. 
125+     if  not  use_ray :
126+         assert  pipeline_parallel_size  *  tensor_parallel_size  ==  1 , (
127+             "Only support single GPU without Ray." )
128+         num_nodes  =  1 
129+         num_devices_per_node  =  torch .cuda .device_count ()
130+         port  =  random .randint (10000 , 20000 )
131+         # We need to setup the distributed init method to make sure 
132+         # the distributed megatron code (e.g., get world size) works correctly. 
133+         distributed_init_method  =  f"tcp://localhost:{ port }  
134+         all_stage_devices  =  [[(0 , None , 0 )]]
135+         return  (num_nodes , num_devices_per_node , distributed_init_method ,
136+                 all_stage_devices )
137+ 
138+     assert  ray  is  not None , (
139+         "Ray is not installed. Please install Ray to use distributed " 
140+         "serving." )
141+ 
113142    # Connect to a ray cluster. 
114143    ray .init (address = address )
115144
@@ -177,6 +206,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
177206    parser .add_argument ('--model-path' , type = str , default = '~/.cacheflow/model_weights' ,
178207                        help = 'model path to download and load the weights' )
179208    # Parallel arguments 
209+     parser .add_argument ('--use-ray' , action = 'store_true' , help = 'use Ray for distributed serving, will be automatically set when using more than 1 GPU' )
180210    parser .add_argument ('--pipeline-parallel-size' , '-pp' , type = int , default = 1 , help = 'number of pipeline stages' )
181211    parser .add_argument ('--tensor-parallel-size' , '-tp' , type = int , default = 1 , help = 'number of tensor parallel replicas' )
182212    # KV cache arguments 
@@ -190,3 +220,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
190220    parser .add_argument ('--max-num-sequences' , type = int , default = 256 , help = 'maximum number of sequences per iteration' )
191221    parser .add_argument ('--use-dummy-weights' , action = 'store_true' , help = 'use dummy values for model weights' )
192222    return  parser 
223+ 
224+ def  process_server_arguments (args : argparse .Namespace ):
225+     if  args .pipeline_parallel_size  *  args .tensor_parallel_size  >  1 :
226+         args .use_ray  =  True 
227+     return  args 
0 commit comments