diff --git a/config/settings.toml b/config/settings.toml
index fbf1d08..594fd24 100644
--- a/config/settings.toml
+++ b/config/settings.toml
@@ -1,10 +1,12 @@
[default]
-region = 'us-east-1'
-data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
-# "./data"
-deploy_bucket_name = 'llm-finetune-us-east-1-{aws_account}'
+# region = 'us-east-1'
+region = 'us-west-2'
+deploy_bucket_name = 'llm-finetune-{region}-{aws_account}'
+data_dir = 's3://llm-finetune-{region}-{aws_account}/eval/tones/'
+# data_dir = './data/'
+human_eval_dir = 's3://llm-finetune-{region}-{aws_account}/human_eval/tones/'
deploy_bucket_prefix = 'models'
-sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-us-east-1'
+sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-{region}'
endpoint_type = 'bedrock'
model = 'anthropic.claude-3-haiku-20240307-v1:0'
@@ -37,11 +39,35 @@ model = 'Phi-3-5-mini-instruct'
hf_name = 'microsoft/Phi-3.5-mini-instruct'
endpoint_type = 'sagemaker'
+[qwen-3-0-6B]
+model = 'Qwen3-0-6B'
+hf_name = 'Qwen/Qwen3-0.6B' # instruct is now this, and base is appended with 'base'
+endpoint_type = 'sagemaker'
+thinking = false
+
[qwen-2-5-1-5B]
model = 'Qwen2-5-1-5B-Instruct'
hf_name = 'Qwen/Qwen2.5-1.5B-Instruct'
endpoint_type = 'sagemaker'
+[qwen-3-1-7B]
+model = 'Qwen3-1-7B'
+hf_name = 'Qwen/Qwen3-1.7B' # instruct is now this, and base is appended with 'base'
+endpoint_type = 'sagemaker'
+thinking = false
+
+[qwen-3-1-7B-async]
+model = 'Qwen3-1-7B-async'
+hf_name = 'Qwen/Qwen3-1.7B' # instruct is now this, and base is appended with 'base'
+endpoint_type = 'sagemaker'
+thinking = false
+asynchronous = true
+
+[qwen-3-4B]
+model = 'Qwen3-4B'
+hf_name = 'Qwen/Qwen3-4B-Instruct-2507' # this finetune is non-thinking only
+endpoint_type = 'sagemaker'
+
[phi-3-ollama]
model = 'phi3'
hf_name = 'microsoft/Phi-3.5-mini-instruct'
diff --git a/pyproject.toml b/pyproject.toml
index f12bce3..2c9593f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,8 +19,6 @@ dependencies = [
"boto3",
"plotly~=5.24.1",
"transformers==4.51.0",
- "datasets~=3.2.0",
- "evaluate~=0.4.3",
"dynaconf~=3.2.7",
"torch",
"botocore",
diff --git a/src/wraval/actions/action_deploy.py b/src/wraval/actions/action_deploy.py
index 8cd2aa0..1ff18ff 100644
--- a/src/wraval/actions/action_deploy.py
+++ b/src/wraval/actions/action_deploy.py
@@ -6,6 +6,7 @@
import boto3
import torch
from sagemaker.huggingface import HuggingFaceModel
+from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
PACKAGE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -80,7 +81,7 @@ def write_model_to_s3(settings, model_name):
return s3_uri
-def deploy_endpoint(s3_uri, role, endpoint_name):
+def deploy_endpoint(s3_uri, role, endpoint_name, async_config=None):
env = {
"HF_TASK": "text-generation",
"HF_HUB_OFFLINE": "1",
@@ -100,13 +101,14 @@ def deploy_endpoint(s3_uri, role, endpoint_name):
initial_instance_count=1,
instance_type="ml.g5.2xlarge",
endpoint_name=endpoint_name,
+ async_inference_config=async_config,
)
def validate_deployment(predictor):
try:
sagemaker_runtime_client = boto3.client("sagemaker-runtime")
- input_string = json.dumps({"inputs": "Hello, my dog is a little"})
+ input_string = json.dumps({"inputs": "<|im_start|>user\nHello, can you pass me the milk?<|im_end|>\n<|im_start|>assistant\n"})
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=predictor.endpoint_name,
Body=input_string.encode("utf-8"),
@@ -142,10 +144,13 @@ def cleanup_model_directory():
def deploy(settings):
validate_model_directory()
cleanup_model_directory()
- sanitized_model_name = settings.hf_name.split("/")[1].replace(".", "-")
+ sanitized_model_name = settings.model.replace(".", "-")
load_artifacts(settings)
s3_uri = write_model_to_s3(settings, sanitized_model_name)
+ async_config = None
+ if settings.exists('asynchronous'):
+ async_config = AsyncInferenceConfig()
predictor = deploy_endpoint(
- s3_uri, settings.sagemaker_execution_role_arn, sanitized_model_name
+ s3_uri, settings.sagemaker_execution_role_arn, sanitized_model_name, async_config
)
validate_deployment(predictor)
diff --git a/src/wraval/actions/action_inference.py b/src/wraval/actions/action_inference.py
index b9ee95d..b8f0c19 100644
--- a/src/wraval/actions/action_inference.py
+++ b/src/wraval/actions/action_inference.py
@@ -40,20 +40,21 @@ def run_inference(
tone_prompt = get_prompt(Tone(tone))
- queries = results[results["tone"] == tone]["synthetic_data"]
+ queries = results[results["tone"] == tone]["synthetic_data"].unique()
print(f"Processing {len(queries)} inputs for tone: {tone}")
outputs = route_completion(settings, queries, tone_prompt)
-
+
cleaned_output = [o.strip().strip('"') for o in outputs]
+
if no_rewrite:
mask = results["tone"] == tone
results.loc[mask, "rewrite"] = cleaned_output
results.loc[mask, "inference_model"] = model_name
else:
new_results = pd.DataFrame(
- {"synthetic_data": results[results["tone"] == tone]["synthetic_data"]}
+ {"synthetic_data": results[results["tone"] == tone]["synthetic_data"].unique()}
)
new_results["tone"] = tone
new_results["rewrite"] = cleaned_output
diff --git a/src/wraval/actions/bleu_conf.py b/src/wraval/actions/bleu_conf.py
deleted file mode 100644
index 5738d6e..0000000
--- a/src/wraval/actions/bleu_conf.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
-# // SPDX-License-Identifier: Apache-2.0
-#
-import random
-import numpy as np
-from evaluate import load
-
-# Load BLEU metric
-bleu = load("bleu")
-
-
-def compute_bleu_with_ci(
- predictions, references, num_bootstrap_samples=1000, confidence_level=0.95
-):
- # Compute the original BLEU score
- original_bleu = bleu.compute(predictions=predictions, references=references)["bleu"]
-
- # Bootstrap sampling
- bootstrap_scores = []
- n = len(predictions)
-
- for _ in range(num_bootstrap_samples):
- # Sample indices with replacement
- indices = [random.randint(0, n - 1) for _ in range(n)]
- sampled_predictions = [predictions[i] for i in indices]
- sampled_references = [references[i] for i in indices]
-
- # Compute BLEU for the bootstrap sample
- score = bleu.compute(
- predictions=sampled_predictions, references=sampled_references
- )["bleu"]
- bootstrap_scores.append(score)
-
- # Calculate confidence intervals
- lower_bound = np.percentile(bootstrap_scores, (1 - confidence_level) / 2 * 100)
- upper_bound = np.percentile(bootstrap_scores, (1 + confidence_level) / 2 * 100)
-
- return {"bleu": original_bleu, "confidence_interval": (lower_bound, upper_bound)}
-
-
-# Example usage
-predictions = ["This is a test", "Another sentence"]
-references = [["This is a test"], ["Another sentence"]]
-
-results = compute_bleu_with_ci(predictions, references)
-
-print(f"BLEU Score: {results['bleu']}")
-print(f"95% Confidence Interval: {results['confidence_interval']}")
diff --git a/src/wraval/actions/completion.py b/src/wraval/actions/completion.py
index 41951bf..e2699b0 100644
--- a/src/wraval/actions/completion.py
+++ b/src/wraval/actions/completion.py
@@ -11,17 +11,30 @@
import boto3
import re
import requests
+import uuid
# Function to extract last assistant response from each entry
def extract_last_assistant_response(data):
- matches = re.findall(r"<\|assistant\|>(.*?)<\|end\|>", data, re.DOTALL)
- # matches = re.findall(r"(.*?)", data, re.DOTALL)
- if matches:
- return matches[-1].strip()
- else:
- return data
+ if r"<\|assistant\|>" in data: # phi
+ assistant_part = data.split(r"<\|assistant\|>")[-1]
+ response = response.replace(r"<\|end\|>", "").strip()
+ return response
+
+ if r"<|im_start|>assistant" in data: # qwen
+ assistant_part = data.split(r"<|im_start|>assistant")[-1]
+
+ # Remove the thinking part if it exists
+ if r"" in assistant_part:
+ # Extract everything after
+ response = assistant_part.split(r"")[-1]
+ else:
+ response = assistant_part
+ response = response.replace(r"<|im_end|>", "").strip()
+ return response
+
+ return data
def get_bedrock_completion(settings, prompt, system_prompt=None):
bedrock_client = boto3.client(
@@ -220,3 +233,144 @@ def invoke_ollama_endpoint(payload, endpoint_name, url="127.0.0.1:11434"):
lines.append(json.loads(r))
return "".join([l["response"] for l in lines])
+
+
+def batch_invoke_sagemaker_endpoint(
+ payloads,
+ endpoint_name,
+ region="us-east-1",
+ s3_bucket=None,
+ s3_input_prefix="/eval/async/input/",
+ poll_interval_seconds=10,
+ timeout_seconds=600,
+):
+ """
+ Invoke a SageMaker async endpoint for a batch of payloads.
+
+ - payloads: list of JSON-serializable objects (each is one request)
+ - endpoint_name: name of the async SageMaker endpoint
+ - region: AWS region
+ - s3_bucket: S3 bucket to upload inputs (required)
+ - s3_input_prefix: S3 prefix for input uploads
+ - poll_interval_seconds: interval between checks for output readiness
+ - timeout_seconds: max time to wait for each result
+
+ Returns list of raw results (strings) in the same order as payloads.
+ """
+ if s3_bucket is None:
+ raise ValueError("s3_bucket is required for async invocations")
+ if not isinstance(s3_bucket, str) or not s3_bucket.strip():
+ raise ValueError(
+ "s3_bucket must be a non-empty string (e.g., 'my-bucket-name'), got: "
+ f"{type(s3_bucket).__name__}"
+ )
+
+ sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=region)
+ s3_client = boto3.client("s3", region_name=region)
+
+ # Normalize prefix
+ input_prefix = s3_input_prefix.lstrip("/")
+
+ input_locations = []
+ output_locations = []
+ inference_ids = []
+
+ # 1) Upload all payloads and invoke async endpoint
+ for idx, payload in enumerate(payloads):
+ print(f"Submitting {idx + 1}/{len(payloads)} to async endpoint '{endpoint_name}'...")
+ request_id = str(uuid.uuid4())[:8]
+ input_key = f"{input_prefix}batch-{request_id}-{idx}.json"
+
+ # Ensure payload is in expected format for the model container
+ if isinstance(payload, str):
+ payload_to_upload = {"inputs": payload}
+ elif isinstance(payload, list) and all(isinstance(p, str) for p in payload):
+ payload_to_upload = {"inputs": payload}
+ elif isinstance(payload, dict):
+ payload_to_upload = payload
+ else:
+ # Fallback: wrap unknown types under inputs
+ payload_to_upload = {"inputs": payload}
+
+ s3_client.put_object(
+ Bucket=s3_bucket,
+ Key=input_key,
+ Body=json.dumps(payload_to_upload),
+ ContentType="application/json",
+ )
+
+ input_location = f"s3://{s3_bucket}/{input_key}"
+ input_locations.append(input_location)
+
+ response = sagemaker_runtime.invoke_endpoint_async(
+ EndpointName=endpoint_name,
+ InputLocation=input_location,
+ ContentType="application/json",
+ InvocationTimeoutSeconds=3600,
+ )
+
+ output_locations.append(response["OutputLocation"]) # s3 uri
+ inference_ids.append(response.get("InferenceId"))
+ print(f"Submitted {idx + 1}/{len(payloads)}. Output will be written to {response['OutputLocation']}")
+
+ # 2) Poll for each output and download results
+ results = []
+ for i, output_location in enumerate(output_locations):
+ start_time = time.time()
+
+ # Parse s3 uri and derive expected result key: /.out
+ uri = output_location.replace("s3://", "")
+ bucket, key = uri.split("/", 1)
+ inference_id = inference_ids[i]
+ expected_key = f"{key.rstrip('/')}/{inference_id}.out" if isinstance(inference_id, str) and inference_id else key
+ if expected_key != key:
+ print(f"Polling for result object s3://{bucket}/{expected_key}")
+
+ while True:
+ try:
+ # First, check expected result key (InferenceId.out)
+ s3_client.head_object(Bucket=bucket, Key=expected_key)
+ break
+ except Exception:
+ if time.time() - start_time > timeout_seconds:
+ print(f"Timed out waiting for result {i + 1}/{len(output_locations)} after {timeout_seconds}s")
+ results.append(None)
+ break
+ elapsed = int(time.time() - start_time)
+ print(f"Waiting for result {i + 1}/{len(output_locations)}... {elapsed}s elapsed")
+
+ # Try to detect async failure artifact: async-endpoint-failures/.../-error.out
+ if isinstance(inference_id, str) and inference_id:
+ try:
+ candidates = s3_client.list_objects_v2(
+ Bucket=bucket,
+ Prefix="async-endpoint-failures/",
+ MaxKeys=1000,
+ )
+ for obj in candidates.get("Contents", []):
+ k = obj.get("Key", "")
+ if k.endswith(f"{inference_id}-error.out"):
+ err_obj = s3_client.get_object(Bucket=bucket, Key=k)
+ err_text = err_obj["Body"].read().decode("utf-8", errors="replace")
+ print(f"Error for request {i + 1}/{len(output_locations)} (InferenceId={inference_id}):\n{err_text}")
+ results.append(None)
+ # Stop waiting for this one
+ elapsed = int(time.time() - start_time)
+ print(f"Marking request {i + 1} as failed after {elapsed}s due to async failure artifact: s3://{bucket}/{k}")
+ # Break out of the polling loop
+ raise StopIteration
+ except StopIteration:
+ break
+ except Exception:
+ # Ignore listing errors silently and keep polling
+ pass
+ time.sleep(poll_interval_seconds)
+
+ if len(results) == 0 or results[-1] is not None:
+ obj = s3_client.get_object(Bucket=bucket, Key=key)
+ result_body = obj["Body"].read().decode("utf-8")
+ results.append(result_body)
+ total = int(time.time() - start_time)
+ print(f"Result ready for {i + 1}/{len(output_locations)} after {total}s")
+
+ return results
diff --git a/src/wraval/actions/format.py b/src/wraval/actions/format.py
index d789682..f9bc374 100644
--- a/src/wraval/actions/format.py
+++ b/src/wraval/actions/format.py
@@ -6,7 +6,7 @@
import xml.dom.minidom
-def format_prompt(usr_prompt, prompt=None, tokenizer=None, type="bedrock"):
+def format_prompt(usr_prompt, prompt=None, tokenizer=None, type="bedrock", thinking=None):
"""
Format prompts according to each model's prompt guidelines (e.g. xml tags for Haiku).
@@ -18,7 +18,10 @@ def format_prompt(usr_prompt, prompt=None, tokenizer=None, type="bedrock"):
if type == "hf":
if prompt:
- sys_prompt = [{"role": "system", "content": prompt.sys_prompt}]
+ if thinking is None or thinking is True:
+ sys_prompt = [{"role": "system", "content": prompt.sys_prompt}]
+ else:
+ sys_prompt = [{"role": "system", "content": prompt.sys_prompt + '/no_think'}]
messages = []
if prompt.examples:
for k, v in prompt.examples[0].items():
diff --git a/src/wraval/actions/model_router.py b/src/wraval/actions/model_router.py
index 40a8a77..b05c1f0 100644
--- a/src/wraval/actions/model_router.py
+++ b/src/wraval/actions/model_router.py
@@ -2,6 +2,7 @@
batch_get_bedrock_completions,
invoke_sagemaker_endpoint,
invoke_ollama_endpoint,
+ batch_invoke_sagemaker_endpoint,
)
from .format import format_prompt
from transformers import AutoTokenizer
@@ -45,14 +46,30 @@ class SageMakerRouter(HuggingFaceModelRouter):
def __init__(self, master_sys_prompt, settings):
super().__init__(master_sys_prompt, settings)
self.model_name = settings.model
+ self.region = settings.region
+ self.thinking = None
+ if settings.exists('thinking'):
+ self.thinking = settings.thinking
+ self.async_config = False
+ if settings.exists('asynchronous'):
+ self.async_config = settings.asynchronous
+ self.deploy_bucket_name = settings.deploy_bucket_name
def get_completion(self, queries: List[str]) -> List[str]:
prompts = [
- format_prompt(text, self.master_sys_prompt, self.tokenizer, type="hf")
+ format_prompt(text, self.master_sys_prompt, self.tokenizer, "hf", self.thinking)
for text in queries
]
+ if self.async_config:
+ return batch_invoke_sagemaker_endpoint(prompts,
+ self.model_name,
+ self.region,
+ self.deploy_bucket_name)
return [
- invoke_sagemaker_endpoint({"inputs": prompt}, self.model_name) for prompt in tqdm(prompts)
+ invoke_sagemaker_endpoint({"inputs": prompt},
+ self.model_name,
+ self.region)
+ for prompt in tqdm(prompts)
]
diff --git a/src/wraval/main.py b/src/wraval/main.py
index b0e269e..3a81d49 100644
--- a/src/wraval/main.py
+++ b/src/wraval/main.py
@@ -62,6 +62,11 @@ def get_settings(
if local_tokenizer_path:
settings.local_tokenizer_path = local_tokenizer_path
+ settings.deploy_bucket_name = settings.deploy_bucket_name.format(region=settings.region, aws_account=settings.aws_account)
+ settings.data_dir = settings.data_dir.format(region=settings.region, aws_account=settings.aws_account)
+ settings.human_eval_dir = settings.human_eval_dir.format(region=settings.region, aws_account=settings.aws_account)
+ settings.sagemaker_execution_role_arn = settings.sagemaker_execution_role_arn.format(region=settings.region, aws_account=settings.aws_account)
+
# Format settings with AWS account
settings.model = settings.model.format(aws_account=settings.aws_account)
settings.data_dir = settings.data_dir.format(aws_account=settings.aws_account)
@@ -228,7 +233,13 @@ def human_judge_upload(
@app.command()
-def deploy(
+def human_judge_parsing():
+ """Parse human judgments, merge it to the original results table and create a plot."""
+ settings = get_settings()
+ parse_human_judgements(settings)
+
+@app.command()
+def deploy_model(
model: str = typer.Option("haiku-3", "--model", "-m", help="Model to deploy"),
cleanup_endpoints: bool = typer.Option(
False,