Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions config/settings.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
[default]
region = 'us-east-1'
data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
# "./data"
deploy_bucket_name = 'llm-finetune-us-east-1-{aws_account}'
# region = 'us-east-1'
region = 'us-west-2'
deploy_bucket_name = 'llm-finetune-{region}-{aws_account}'
data_dir = 's3://llm-finetune-{region}-{aws_account}/eval/tones/'
# data_dir = './data/'
human_eval_dir = 's3://llm-finetune-{region}-{aws_account}/human_eval/tones/'
deploy_bucket_prefix = 'models'
sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-us-east-1'
sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-{region}'
endpoint_type = 'bedrock'
model = 'anthropic.claude-3-haiku-20240307-v1:0'

Expand Down Expand Up @@ -37,11 +39,35 @@ model = 'Phi-3-5-mini-instruct'
hf_name = 'microsoft/Phi-3.5-mini-instruct'
endpoint_type = 'sagemaker'

[qwen-3-0-6B]
model = 'Qwen3-0-6B'
hf_name = 'Qwen/Qwen3-0.6B' # instruct is now this, and base is appended with 'base'
endpoint_type = 'sagemaker'
thinking = false

[qwen-2-5-1-5B]
model = 'Qwen2-5-1-5B-Instruct'
hf_name = 'Qwen/Qwen2.5-1.5B-Instruct'
endpoint_type = 'sagemaker'

[qwen-3-1-7B]
model = 'Qwen3-1-7B'
hf_name = 'Qwen/Qwen3-1.7B' # instruct is now this, and base is appended with 'base'
endpoint_type = 'sagemaker'
thinking = false

[qwen-3-1-7B-async]
model = 'Qwen3-1-7B-async'
hf_name = 'Qwen/Qwen3-1.7B' # instruct is now this, and base is appended with 'base'
endpoint_type = 'sagemaker'
thinking = false
asynchronous = true

[qwen-3-4B]
model = 'Qwen3-4B'
hf_name = 'Qwen/Qwen3-4B-Instruct-2507' # this finetune is non-thinking only
endpoint_type = 'sagemaker'

[phi-3-ollama]
model = 'phi3'
hf_name = 'microsoft/Phi-3.5-mini-instruct'
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ dependencies = [
"boto3",
"plotly~=5.24.1",
"transformers==4.51.0",
"datasets~=3.2.0",
"evaluate~=0.4.3",
"dynaconf~=3.2.7",
"torch",
"botocore",
Expand Down
13 changes: 9 additions & 4 deletions src/wraval/actions/action_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import boto3
import torch
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

PACKAGE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand Down Expand Up @@ -80,7 +81,7 @@ def write_model_to_s3(settings, model_name):
return s3_uri


def deploy_endpoint(s3_uri, role, endpoint_name):
def deploy_endpoint(s3_uri, role, endpoint_name, async_config=None):
env = {
"HF_TASK": "text-generation",
"HF_HUB_OFFLINE": "1",
Expand All @@ -100,13 +101,14 @@ def deploy_endpoint(s3_uri, role, endpoint_name):
initial_instance_count=1,
instance_type="ml.g5.2xlarge",
endpoint_name=endpoint_name,
async_inference_config=async_config,
)


def validate_deployment(predictor):
try:
sagemaker_runtime_client = boto3.client("sagemaker-runtime")
input_string = json.dumps({"inputs": "Hello, my dog is a little"})
input_string = json.dumps({"inputs": "<|im_start|>user\nHello, can you pass me the milk?<|im_end|>\n<|im_start|>assistant\n"})
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=predictor.endpoint_name,
Body=input_string.encode("utf-8"),
Expand Down Expand Up @@ -142,10 +144,13 @@ def cleanup_model_directory():
def deploy(settings):
validate_model_directory()
cleanup_model_directory()
sanitized_model_name = settings.hf_name.split("/")[1].replace(".", "-")
sanitized_model_name = settings.model.replace(".", "-")
load_artifacts(settings)
s3_uri = write_model_to_s3(settings, sanitized_model_name)
async_config = None
if settings.exists('asynchronous'):
async_config = AsyncInferenceConfig()
predictor = deploy_endpoint(
s3_uri, settings.sagemaker_execution_role_arn, sanitized_model_name
s3_uri, settings.sagemaker_execution_role_arn, sanitized_model_name, async_config
)
validate_deployment(predictor)
7 changes: 4 additions & 3 deletions src/wraval/actions/action_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,21 @@ def run_inference(

tone_prompt = get_prompt(Tone(tone))

queries = results[results["tone"] == tone]["synthetic_data"]
queries = results[results["tone"] == tone]["synthetic_data"].unique()

print(f"Processing {len(queries)} inputs for tone: {tone}")

outputs = route_completion(settings, queries, tone_prompt)

cleaned_output = [o.strip().strip('"') for o in outputs]

if no_rewrite:
mask = results["tone"] == tone
results.loc[mask, "rewrite"] = cleaned_output
results.loc[mask, "inference_model"] = model_name
else:
new_results = pd.DataFrame(
{"synthetic_data": results[results["tone"] == tone]["synthetic_data"]}
{"synthetic_data": results[results["tone"] == tone]["synthetic_data"].unique()}
)
new_results["tone"] = tone
new_results["rewrite"] = cleaned_output
Expand Down
49 changes: 0 additions & 49 deletions src/wraval/actions/bleu_conf.py

This file was deleted.

166 changes: 160 additions & 6 deletions src/wraval/actions/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,30 @@
import boto3
import re
import requests
import uuid


# Function to extract last assistant response from each entry
def extract_last_assistant_response(data):
matches = re.findall(r"<\|assistant\|>(.*?)<\|end\|>", data, re.DOTALL)
# matches = re.findall(r"<assistant>(.*?)</assistant>", data, re.DOTALL)
if matches:
return matches[-1].strip()
else:
return data

if r"<\|assistant\|>" in data: # phi
assistant_part = data.split(r"<\|assistant\|>")[-1]
response = response.replace(r"<\|end\|>", "").strip()
return response

if r"<|im_start|>assistant" in data: # qwen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we detect qwen and phi without searching the string for a specific pattern?

assistant_part = data.split(r"<|im_start|>assistant")[-1]

# Remove the thinking part if it exists
if r"<think>" in assistant_part:
# Extract everything after </think>
response = assistant_part.split(r"</think>")[-1]
else:
response = assistant_part
response = response.replace(r"<|im_end|>", "").strip()
return response

return data

def get_bedrock_completion(settings, prompt, system_prompt=None):
bedrock_client = boto3.client(
Expand Down Expand Up @@ -220,3 +233,144 @@ def invoke_ollama_endpoint(payload, endpoint_name, url="127.0.0.1:11434"):
lines.append(json.loads(r))

return "".join([l["response"] for l in lines])


def batch_invoke_sagemaker_endpoint(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we verify that this has been tested?

payloads,
endpoint_name,
region="us-east-1",
s3_bucket=None,
s3_input_prefix="/eval/async/input/",
poll_interval_seconds=10,
timeout_seconds=600,
):
"""
Invoke a SageMaker async endpoint for a batch of payloads.

- payloads: list of JSON-serializable objects (each is one request)
- endpoint_name: name of the async SageMaker endpoint
- region: AWS region
- s3_bucket: S3 bucket to upload inputs (required)
- s3_input_prefix: S3 prefix for input uploads
- poll_interval_seconds: interval between checks for output readiness
- timeout_seconds: max time to wait for each result

Returns list of raw results (strings) in the same order as payloads.
"""
if s3_bucket is None:
raise ValueError("s3_bucket is required for async invocations")
if not isinstance(s3_bucket, str) or not s3_bucket.strip():
raise ValueError(
"s3_bucket must be a non-empty string (e.g., 'my-bucket-name'), got: "
f"{type(s3_bucket).__name__}"
)

sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=region)
s3_client = boto3.client("s3", region_name=region)

# Normalize prefix
input_prefix = s3_input_prefix.lstrip("/")

input_locations = []
output_locations = []
inference_ids = []

# 1) Upload all payloads and invoke async endpoint
for idx, payload in enumerate(payloads):
print(f"Submitting {idx + 1}/{len(payloads)} to async endpoint '{endpoint_name}'...")
request_id = str(uuid.uuid4())[:8]
input_key = f"{input_prefix}batch-{request_id}-{idx}.json"

# Ensure payload is in expected format for the model container
if isinstance(payload, str):
payload_to_upload = {"inputs": payload}
elif isinstance(payload, list) and all(isinstance(p, str) for p in payload):
payload_to_upload = {"inputs": payload}
elif isinstance(payload, dict):
payload_to_upload = payload
else:
# Fallback: wrap unknown types under inputs
payload_to_upload = {"inputs": payload}

s3_client.put_object(
Bucket=s3_bucket,
Key=input_key,
Body=json.dumps(payload_to_upload),
ContentType="application/json",
)

input_location = f"s3://{s3_bucket}/{input_key}"
input_locations.append(input_location)

response = sagemaker_runtime.invoke_endpoint_async(
EndpointName=endpoint_name,
InputLocation=input_location,
ContentType="application/json",
InvocationTimeoutSeconds=3600,
)

output_locations.append(response["OutputLocation"]) # s3 uri
inference_ids.append(response.get("InferenceId"))
print(f"Submitted {idx + 1}/{len(payloads)}. Output will be written to {response['OutputLocation']}")

# 2) Poll for each output and download results
results = []
for i, output_location in enumerate(output_locations):
start_time = time.time()

# Parse s3 uri and derive expected result key: <prefix>/<InferenceId>.out
uri = output_location.replace("s3://", "")
bucket, key = uri.split("/", 1)
inference_id = inference_ids[i]
expected_key = f"{key.rstrip('/')}/{inference_id}.out" if isinstance(inference_id, str) and inference_id else key
if expected_key != key:
print(f"Polling for result object s3://{bucket}/{expected_key}")

while True:
try:
# First, check expected result key (InferenceId.out)
s3_client.head_object(Bucket=bucket, Key=expected_key)
break
except Exception:
if time.time() - start_time > timeout_seconds:
print(f"Timed out waiting for result {i + 1}/{len(output_locations)} after {timeout_seconds}s")
results.append(None)
break
elapsed = int(time.time() - start_time)
print(f"Waiting for result {i + 1}/{len(output_locations)}... {elapsed}s elapsed")

# Try to detect async failure artifact: async-endpoint-failures/.../<InferenceId>-error.out
if isinstance(inference_id, str) and inference_id:
try:
candidates = s3_client.list_objects_v2(
Bucket=bucket,
Prefix="async-endpoint-failures/",
MaxKeys=1000,
)
for obj in candidates.get("Contents", []):
k = obj.get("Key", "")
if k.endswith(f"{inference_id}-error.out"):
err_obj = s3_client.get_object(Bucket=bucket, Key=k)
err_text = err_obj["Body"].read().decode("utf-8", errors="replace")
print(f"Error for request {i + 1}/{len(output_locations)} (InferenceId={inference_id}):\n{err_text}")
results.append(None)
# Stop waiting for this one
elapsed = int(time.time() - start_time)
print(f"Marking request {i + 1} as failed after {elapsed}s due to async failure artifact: s3://{bucket}/{k}")
# Break out of the polling loop
raise StopIteration
except StopIteration:
break
except Exception:
# Ignore listing errors silently and keep polling
pass
time.sleep(poll_interval_seconds)

if len(results) == 0 or results[-1] is not None:
obj = s3_client.get_object(Bucket=bucket, Key=key)
result_body = obj["Body"].read().decode("utf-8")
results.append(result_body)
total = int(time.time() - start_time)
print(f"Result ready for {i + 1}/{len(output_locations)} after {total}s")

return results
7 changes: 5 additions & 2 deletions src/wraval/actions/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import xml.dom.minidom


def format_prompt(usr_prompt, prompt=None, tokenizer=None, type="bedrock"):
def format_prompt(usr_prompt, prompt=None, tokenizer=None, type="bedrock", thinking=None):
"""
Format prompts according to each model's prompt guidelines (e.g. xml tags for Haiku).

Expand All @@ -18,7 +18,10 @@ def format_prompt(usr_prompt, prompt=None, tokenizer=None, type="bedrock"):

if type == "hf":
if prompt:
sys_prompt = [{"role": "system", "content": prompt.sys_prompt}]
if thinking is None or thinking is True:
sys_prompt = [{"role": "system", "content": prompt.sys_prompt}]
else:
sys_prompt = [{"role": "system", "content": prompt.sys_prompt + '/no_think'}]
messages = []
if prompt.examples:
for k, v in prompt.examples[0].items():
Expand Down
Loading