88import uuid
99from abc import ABC
1010from dataclasses import dataclass
11- from typing import Any , Dict , List , Optional
11+ from typing import Any , Dict , List , Optional , Union
1212
1313from build .utils import device_sync
1414
@@ -86,6 +86,9 @@ class StreamOptions:
8686
8787 include_usage : bool = False
8888
89+ @dataclass
90+ class ResponseFormat :
91+ type : Optional [str ] = None
8992
9093@dataclass
9194class CompletionRequest :
@@ -94,25 +97,27 @@ class CompletionRequest:
9497 See the "Create Chat Completion >>> Request body" section of the OpenAI API docs for more details.
9598 """
9699
100+ messages : List [_AbstractMessage ]
97101 model : str
98- prompt : str
99- messages : Optional [List [_AbstractMessage ]]
100- frequency_penalty : float = 0.0
101- temperature : float = 0.0
102- stop : Optional [List [str ]] = None
103- stream : bool = False
104- stream_options : Optional [StreamOptions ] = None
105- echo : bool = False
106- frequency_penalty : float = 0.0
107- guided_decode_json_schema : str = None
108- guided_decode_json_schema_path : str = None
102+ frequency_penalty : float = 0.0 # unimplemented
103+ logit_bias : Optional [Dict [str , float ]] = None # unimplemented
104+ logprobs : Optional [bool ] = None # unimplemented
105+ top_logprobs : Optional [int ] = None # unimplemented
106+ max_tokens : Optional [int ] = None # unimplemented
109107 n : int = 1
110- presence_penalty : float = 0
111- logit_bias : Optional [Dict [str , float ]] = None
112- logprobs : Optional [bool ] = None
113- top_logprobs : Optional [int ] = None
114- max_tokens : Optional [int ] = None
115-
108+ presence_penalty : float = 0 # unimplemented
109+ response_format : Optional [ResponseFormat ] = None # unimplemented
110+ seed : Optional [int ] = None # unimplemented
111+ service_tier : Optional [str ] = None # unimplemented
112+ stop : Optional [List [str ]] = None # unimplemented
113+ stream : bool = False
114+ stream_options : Optional [StreamOptions ] = None # unimplemented
115+ temperature : Optional [float ] = 1.0 # unimplemented
116+ top_p : Optional [float ] = 1.0 # unimplemented
117+ tools : Optional [List [Any ]] = None # unimplemented
118+ tool_choice : Optional [Union [str , Any ]] = None # unimplemented
119+ parallel_tool_calls : Optional [bool ] = None # unimplemented
120+ user : Optional [str ] = None # unimplemented
116121
117122@dataclass
118123class CompletionChoice :
@@ -121,10 +126,10 @@ class CompletionChoice:
121126 See the "The chat completion object >>> choices" section of the OpenAI API docs for more details.
122127 """
123128
124- finish_reason : str
125129 index : int
126130 message : AssistantMessage
127- logprobs : Optional [List [Any ]]
131+ finish_reason : str = None
132+ logprobs : Optional [List [Any ]] = None
128133
129134
130135@dataclass
@@ -150,10 +155,10 @@ class CompletionResponse:
150155 choices : List [CompletionChoice ]
151156 created : int
152157 model : str
153- system_fingerprint : str
154- usage : UsageStats
155- object : str = "chat.completion"
158+ system_fingerprint : str
156159 service_tier : Optional [str ] = None
160+ usage : Optional [UsageStats ] = None
161+ object : str = "chat.completion"
157162
158163
159164@dataclass
@@ -193,8 +198,8 @@ class CompletionResponseChunk:
193198 created : int
194199 model : str
195200 system_fingerprint : str
196- object : str = "chat.completion.chunk"
197201 service_tier : Optional [str ] = None
202+ object : str = "chat.completion.chunk"
198203 usage : Optional [UsageStats ] = None
199204
200205
@@ -220,8 +225,13 @@ def __init__(self, *args, **kwargs):
220225 if self .draft_model is not None
221226 else self .model .config .max_seq_length
222227 )
228+ # The System fingerprint is a unique identifier for the model and its configuration.
229+ # Currently, this is not implemented in a
230+ self .system_fingerprint = (
231+ self .builder_args .device + type (self .builder_args .precision ).__name__
232+ )
223233
224- def completion (self , completion_request : CompletionRequest ):
234+ def chunked_completion (self , completion_request : CompletionRequest ):
225235 """Handle a chat completion request and yield a chunked response.
226236
227237 ** Warning ** : Not all arguments of the CompletionRequest are consumed as the server isn't completely implemented.
@@ -230,7 +240,8 @@ def completion(self, completion_request: CompletionRequest):
230240 - messages: The server consumes the final element of the array as the prompt.
231241 - model: This has no impact on the server state, i.e. changing the model in the request
232242 will not change which model is responding. Instead, use the --model flag to seelect the model when starting the server.
233- - temperature: This is used to control the randomness of the response. The server will use the temperature
243+ - temperature: This is used to control the randomness of the response.
244+ - system_fingerprint: A unique identifier for the model and its configuration. Currently unimplemented - subject to change.
234245
235246 See https://github.com/pytorch/torchchat/issues/973 for more details.
236247
@@ -246,13 +257,16 @@ def completion(self, completion_request: CompletionRequest):
246257
247258 # Initialize counters for chunk responses and encode the prompt.
248259 id = str (uuid .uuid4 ())
260+
249261 idx = 0
250262 buffer = []
251263 encoded = self .encode_tokens (
252- completion_request .prompt , bos = True , device = self .builder_args .device
264+ completion_request .messages [- 1 ].get ("content" ),
265+ bos = True ,
266+ device = self .builder_args .device ,
253267 )
254268 generator_args = GeneratorArgs (
255- completion_request .prompt ,
269+ completion_request .messages [ - 1 ]. get ( "content" ) ,
256270 encoded_prompt = encoded ,
257271 chat_mode = False ,
258272 )
@@ -302,21 +316,45 @@ def callback(x, *, done_generating=False):
302316 choices = [choice_chunk ],
303317 created = int (time .time ()),
304318 model = completion_request .model ,
305- system_fingerprint = uuid . UUID ( int = uuid . getnode ()) ,
319+ system_fingerprint = self . system_fingerprint ,
306320 )
307321 yield chunk_response
308322 self .start_pos += y .size (0 )
309323 idx += 1
310324
311325 # Yield an ending chunk indicating the generation has completed.
312- end_chunk = CompletionChoiceChunk (ChunkDelta (None , None , None ), idx , "eos" )
326+ end_chunk = CompletionChoiceChunk (
327+ ChunkDelta (None , None , None ), idx , finish_reason = "stop"
328+ )
313329
314330 yield CompletionResponseChunk (
315331 id = str (id ),
316332 choices = [end_chunk ],
317333 created = int (time .time ()),
318334 model = completion_request .model ,
319- system_fingerprint = uuid .UUID (int = uuid .getnode ()),
335+ system_fingerprint = self .system_fingerprint ,
336+ )
337+
338+ def sync_completion (self , request : CompletionRequest ):
339+ """Handle a chat completion request and yield a single, non-chunked response"""
340+ output = ""
341+ for chunk in self .chunked_completion (request ):
342+ if not chunk .choices [0 ].finish_reason :
343+ output += chunk .choices [0 ].delta .content
344+
345+ message = AssistantMessage (content = output )
346+ return CompletionResponse (
347+ id = str (uuid .uuid4 ()),
348+ choices = [
349+ CompletionChoice (
350+ finish_reason = "stop" ,
351+ index = 0 ,
352+ message = message ,
353+ )
354+ ],
355+ created = int (time .time ()),
356+ model = request .model ,
357+ system_fingerprint = self .system_fingerprint ,
320358 )
321359
322360 def _callback (self , x , * , buffer , done_generating ):
0 commit comments