diff --git a/server/src/handlers/http.rs b/server/src/handlers/http.rs
index ad1fe855e..804317821 100644
--- a/server/src/handlers/http.rs
+++ b/server/src/handlers/http.rs
@@ -34,6 +34,7 @@ use self::middleware::{DisAllowRootUser, RouteExt};
mod about;
mod health_check;
mod ingest;
+mod llm;
mod logstream;
mod middleware;
mod query;
@@ -229,6 +230,21 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
.wrap(DisAllowRootUser),
),
);
+
+ let llm_query_api = web::scope("/llm")
+ .service(
+ web::resource("").route(
+ web::post()
+ .to(llm::make_llm_request)
+ .authorize(Action::Query),
+ ),
+ )
+ .service(
+ // to check if the API key for an LLM has been set up as env var
+ web::resource("isactive")
+ .route(web::post().to(llm::is_llm_active).authorize(Action::Query)),
+ );
+
// Deny request if username is same as the env variable P_USERNAME.
cfg.service(
// Base path "{url}/api/v1"
@@ -266,7 +282,8 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
logstream_api,
),
)
- .service(user_api),
+ .service(user_api)
+ .service(llm_query_api),
)
// GET "/" ==> Serve the static frontend directory
.service(ResourceFiles::new("/", generated).resolve_not_found_to_root());
diff --git a/server/src/handlers/http/llm.rs b/server/src/handlers/http/llm.rs
new file mode 100644
index 000000000..ef8feccc0
--- /dev/null
+++ b/server/src/handlers/http/llm.rs
@@ -0,0 +1,176 @@
+/*
+ * Parseable Server (C) 2022 - 2023 Parseable, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ *
+ */
+
+use actix_web::{http::header::ContentType, web, HttpResponse, Result};
+use http::{header, StatusCode};
+use itertools::Itertools;
+use reqwest;
+use serde_json::{json, Value};
+
+use crate::{
+ metadata::{error::stream_info::MetadataError, STREAM_INFO},
+ option::CONFIG,
+};
+
+const OPEN_AI_URL: &str = "https://api.openai.com/v1/chat/completions";
+
+// Deserialize types for OpenAI Response
+#[derive(serde::Deserialize, Debug)]
+struct ResponseData {
+ choices: Vec,
+}
+
+#[derive(serde::Deserialize, Debug)]
+struct Choice {
+ message: Message,
+}
+
+#[derive(serde::Deserialize, Debug)]
+struct Message {
+ content: String,
+}
+
+// Request body
+#[derive(serde::Deserialize, Debug)]
+pub struct AiPrompt {
+ prompt: String,
+ stream: String,
+}
+
+// Temperory type
+#[derive(Debug, serde::Serialize)]
+struct Field {
+ name: String,
+ data_type: String,
+}
+
+impl From<&arrow_schema::Field> for Field {
+ fn from(field: &arrow_schema::Field) -> Self {
+ Self {
+ name: field.name().clone(),
+ data_type: field.data_type().to_string(),
+ }
+ }
+}
+
+fn build_prompt(stream: &str, prompt: &str, schema_json: &str) -> String {
+ format!(
+ r#"I have a table called {}.
+It has the columns:\n{}
+Based on this, generate valid SQL for the query: "{}"
+Generate only SQL as output. Also add comments in SQL syntax to explain your actions.
+Don't output anything else.
+If it is not possible to generate valid SQL, output an SQL comment saying so."#,
+ stream, schema_json, prompt
+ )
+}
+
+fn build_request_body(ai_prompt: String) -> impl serde::Serialize {
+ json!({
+ "model": "gpt-3.5-turbo",
+ "messages": [{ "role": "user", "content": ai_prompt}],
+ "temperature": 0.6,
+ })
+}
+
+pub async fn make_llm_request(body: web::Json) -> Result {
+ let api_key = match &CONFIG.parseable.open_ai_key {
+ Some(api_key) if api_key.len() > 3 => api_key,
+ _ => return Err(LLMError::InvalidAPIKey),
+ };
+
+ let stream_name = &body.stream;
+ let schema = STREAM_INFO.schema(stream_name)?;
+ let filtered_schema = schema
+ .all_fields()
+ .into_iter()
+ .map(Field::from)
+ .collect_vec();
+
+ let schema_json =
+ serde_json::to_string(&filtered_schema).expect("always converted to valid json");
+
+ let prompt = build_prompt(stream_name, &body.prompt, &schema_json);
+ let body = build_request_body(prompt);
+
+ let client = reqwest::Client::new();
+ let response = client
+ .post(OPEN_AI_URL)
+ .header(header::CONTENT_TYPE, "application/json")
+ .bearer_auth(api_key)
+ .json(&body)
+ .send()
+ .await?;
+
+ if response.status().is_success() {
+ let body: ResponseData = response
+ .json()
+ .await
+ .expect("OpenAI response is always the same");
+ Ok(HttpResponse::Ok()
+ .content_type("application/json")
+ .json(&body.choices[0].message.content))
+ } else {
+ let body: Value = response.json().await?;
+ let message = body
+ .as_object()
+ .and_then(|body| body.get("error"))
+ .and_then(|error| error.as_object())
+ .and_then(|error| error.get("message"))
+ .map(|message| message.to_string())
+ .unwrap_or_else(|| "Error from OpenAI".to_string());
+
+ Err(LLMError::APIError(message))
+ }
+}
+
+pub async fn is_llm_active(_body: web::Json) -> HttpResponse {
+ let is_active = matches!(&CONFIG.parseable.open_ai_key, Some(api_key) if api_key.len() > 3);
+ HttpResponse::Ok()
+ .content_type("application/json")
+ .json(json!({"is_active": is_active}))
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum LLMError {
+ #[error("Either OpenAI key was not provided or was invalid")]
+ InvalidAPIKey,
+ #[error("Failed to call OpenAI endpoint: {0}")]
+ FailedRequest(#[from] reqwest::Error),
+ #[error("{0}")]
+ APIError(String),
+ #[error("{0}")]
+ StreamDoesNotExist(#[from] MetadataError),
+}
+
+impl actix_web::ResponseError for LLMError {
+ fn status_code(&self) -> http::StatusCode {
+ match self {
+ Self::InvalidAPIKey => StatusCode::INTERNAL_SERVER_ERROR,
+ Self::FailedRequest(_) => StatusCode::INTERNAL_SERVER_ERROR,
+ Self::APIError(_) => StatusCode::INTERNAL_SERVER_ERROR,
+ Self::StreamDoesNotExist(_) => StatusCode::INTERNAL_SERVER_ERROR,
+ }
+ }
+
+ fn error_response(&self) -> actix_web::HttpResponse {
+ actix_web::HttpResponse::build(self.status_code())
+ .insert_header(ContentType::plaintext())
+ .body(self.to_string())
+ }
+}
diff --git a/server/src/option.rs b/server/src/option.rs
index 14e389e3a..3da036d0c 100644
--- a/server/src/option.rs
+++ b/server/src/option.rs
@@ -184,6 +184,9 @@ pub struct Server {
/// Server should send anonymous analytics or not
pub send_analytics: bool,
+ /// Open AI access key
+ pub open_ai_key: Option,
+
/// Rows in Parquet Rowgroup
pub row_group_size: usize,
@@ -232,6 +235,7 @@ impl FromArgMatches for Server {
.get_one::(Self::SEND_ANALYTICS)
.cloned()
.expect("default for send analytics");
+ self.open_ai_key = m.get_one::(Self::OPEN_AI_KEY).cloned();
// converts Gib to bytes before assigning
self.query_memory_pool_size = m
.get_one::(Self::QUERY_MEM_POOL_SIZE)
@@ -271,6 +275,7 @@ impl Server {
pub const PASSWORD: &str = "password";
pub const CHECK_UPDATE: &str = "check-update";
pub const SEND_ANALYTICS: &str = "send-analytics";
+ pub const OPEN_AI_KEY: &str = "open-ai-key";
pub const QUERY_MEM_POOL_SIZE: &str = "query-mempool-size";
pub const ROW_GROUP_SIZE: &str = "row-group-size";
pub const PARQUET_COMPRESSION_ALGO: &str = "compression-algo";
@@ -351,6 +356,24 @@ impl Server {
.required(true)
.help("Password for the basic authentication on the server"),
)
+ .arg(
+ Arg::new(Self::SEND_ANALYTICS)
+ .long(Self::SEND_ANALYTICS)
+ .env("P_SEND_ANONYMOUS_USAGE_DATA")
+ .value_name("BOOL")
+ .required(false)
+ .default_value("true")
+ .value_parser(value_parser!(bool))
+ .help("Disable/Enable sending anonymous user data"),
+ )
+ .arg(
+ Arg::new(Self::OPEN_AI_KEY)
+ .long(Self::OPEN_AI_KEY)
+ .env("OPENAI_API_KEY")
+ .value_name("STRING")
+ .required(false)
+ .help("Set OpenAI key to enable llm feature"),
+ )
.arg(
Arg::new(Self::CHECK_UPDATE)
.long(Self::CHECK_UPDATE)
@@ -380,16 +403,6 @@ impl Server {
.value_parser(value_parser!(usize))
.help("Number of rows in a row groups"),
)
- .arg(
- Arg::new(Self::SEND_ANALYTICS)
- .long(Self::SEND_ANALYTICS)
- .env("P_SEND_ANONYMOUS_USAGE_DATA")
- .value_name("BOOL")
- .required(false)
- .default_value("true")
- .value_parser(value_parser!(bool))
- .help("Disable/Enable sending anonymous user data"),
- )
.arg(
Arg::new(Self::PARQUET_COMPRESSION_ALGO)
.long(Self::PARQUET_COMPRESSION_ALGO)