Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions launch/dynamo-run/src/input/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub async fn run(
.port(flags.http_port)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.enable_embeddings_endpoints(true)
.with_request_template(template)
.build()?;
match engine_config {
Expand Down
50 changes: 46 additions & 4 deletions launch/dynamo-run/src/subprocess/sglang_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,42 @@ async def generate(self, request):
num_output_tokens_so_far = next_total_toks


class EmbeddingRequestHandler(RequestHandler):
"""
Request handler for the embedding endpoint
"""

def __init__(self, engine: sglang.Engine, model_name: str):
super().__init__(engine)
self._model_name = model_name

async def generate(self, request):
gen = await self.engine_client.async_encode(prompt=request["input"])
tokens = 0
embeddings = []
for idx, res in enumerate(gen):
embeddings.append(
{
"index": idx,
"object": "embedding",
"embedding": res["embedding"],
}
)
tokens += res["meta_info"]["prompt_tokens"]

out = {
"object": "list",
"model": self._model_name,
"data": embeddings,
"usage": {
"prompt_tokens": tokens,
"total_tokens": tokens,
},
}

yield out


@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, cmd_line_args())
Expand Down Expand Up @@ -129,13 +165,20 @@ async def init(runtime: DistributedRuntime, config: Config):
await component.create_service()

endpoint = component.endpoint(config.endpoint)
await register_llm(
ModelType.Backend, endpoint, config.model_path, config.model_name
model_type = (
ModelType.Backend if not engine_args.is_embedding else ModelType.Embedding
)
await register_llm(model_type, endpoint, config.model_path, config.model_name)

# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
await endpoint.serve_endpoint(
RequestHandler(engine_client).generate
if not engine_args.is_embedding
else EmbeddingRequestHandler(
engine_client, model_name=config.model_name or config.model_path
).generate
)


def cmd_line_args():
Expand Down Expand Up @@ -230,7 +273,6 @@ def cmd_line_args():
config.node_rank = args.node_rank
config.dist_init_addr = args.dist_init_addr
config.extra_engine_args = args.extra_engine_args

return config


Expand Down
4 changes: 1 addition & 3 deletions lib/llm/src/discovery/model_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ impl ModelManager {
clients.remove(model)
}

// TODO: Remove this allow once `embeddings` is implemented in lib/llm/src/http/service/openai.rs
#[allow(dead_code)]
fn get_embeddings_engine(
pub fn get_embeddings_engine(
&self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
Expand Down
57 changes: 53 additions & 4 deletions lib/llm/src/http/service/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use super::{
service_v2, RouteDoc,
};

use crate::protocols::openai::embeddings::NvCreateEmbeddingRequest;
use crate::protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse};
use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
};
Expand Down Expand Up @@ -204,10 +204,59 @@ async fn completions(

#[tracing::instrument(skip_all)]
async fn embeddings(
State(_state): State<Arc<service_v2::State>>,
Json(_request): Json<NvCreateEmbeddingRequest>,
State(state): State<Arc<service_v2::State>>,
Json(request): Json<NvCreateEmbeddingRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
unimplemented!("embeddings are not supported yet");
// return a 503 if the service is not ready
check_ready(&state)?;

// todo - extract distributed tracing id and context id from headers
let request_id = uuid::Uuid::new_v4().to_string();

// Embeddings are typically not streamed, so we default to non-streaming
let streaming = false;

// todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default
let model = &request.inner.model;

// todo - error handling should be more robust
let engine = state
.manager()
.get_embeddings_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?;

// this will increment the inflight gauge for the model
let mut inflight =
state
.metrics_clone()
.create_inflight_guard(model, Endpoint::Embeddings, streaming);

// setup context
// todo - inherit request_id from distributed trace details
let request = Context::with_id(request, request_id.clone());

// issue the generate call on the engine
let stream = engine
.generate(request)
.await
.map_err(|e| ErrorResponse::from_anyhow(e, "Failed to generate embeddings"))?;

// Embeddings are typically returned as a single response (non-streaming)
// so we fold the stream into a single response
let response = NvCreateEmbeddingResponse::from_annotated_stream(stream.into())
.await
.map_err(|e| {
tracing::error!(
"Failed to fold embeddings stream for {}: {:?}",
request_id,
e
);
ErrorResponse::internal_server_error("Failed to fold embeddings stream")
})?;

inflight.mark_ok();
Ok(Json(response).into_response())
}

/// OpenAI Chat Completions Request Handler
Expand Down
2 changes: 1 addition & 1 deletion lib/llm/src/http/service/service_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub struct HttpServiceConfig {
#[builder(default = "true")]
enable_cmpl_endpoints: bool,

#[builder(default = "false")]
#[builder(default = "true")]
enable_embeddings_endpoints: bool,

#[builder(default = "None")]
Expand Down
5 changes: 4 additions & 1 deletion lib/llm/src/protocols/openai/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
use serde::{Deserialize, Serialize};
use validator::Validate;

mod aggregator;
mod nvext;

pub use nvext::{NvExt, NvExtProvider};
// pub use delta::DeltaGenerator;
pub use aggregator::DeltaAggregator;

use dynamo_runtime::protocols::annotated::AnnotationsProvider;

Expand Down Expand Up @@ -59,7 +62,7 @@ impl NvCreateEmbeddingResponse {
}
}

/// Implements `NvExtProvider` for `NvCr eateEmbeddingRequest`,
/// Implements `NvExtProvider` for `NvCreateEmbeddingRequest`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateEmbeddingRequest {
/// Returns a reference to the optional `NvExt` extension, if available.
Expand Down
Loading
Loading