generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 3
Thinking models #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
gabriben
wants to merge
18
commits into
main
Choose a base branch
from
thinking_models
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Thinking models #26
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
5d329ce
deploy fix and add qwen 1.7
gabriben 524e66c
remove bleu score as not in use for now
gabriben 2f4c637
move the comment in settings toml
gabriben b5f05d6
using dynaconf variable interpolation
gabriben 7519af6
using dynaconf variable interpolation for sagemaker role too
gabriben 837d55f
using dynaconf variable interpolation with only one nesting
gabriben ba7ff24
dropping dynaconf variable interpolation but keeping region flexible …
gabriben 0b1463f
no_think input and output
gabriben b58fd5c
try async
gabriben 8bdcd9c
empty async config
gabriben 9ac330f
different async qwen name
gabriben 7c34715
use model name and not hf name as sagemaker endpoint name
gabriben e132f59
model name has no slash unlike hf name
gabriben 019d561
add qwen 3 4B
gabriben ba7bdbc
async config fix
gabriben 815a24e
small fixes and first attempt at a batch endpoint
gabriben 6d046ce
add qwen 0.6B
gabriben d5385a2
typpo in qwen hf name
gabriben File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,17 +11,30 @@ | |
import boto3 | ||
import re | ||
import requests | ||
import uuid | ||
|
||
|
||
# Function to extract last assistant response from each entry | ||
def extract_last_assistant_response(data): | ||
matches = re.findall(r"<\|assistant\|>(.*?)<\|end\|>", data, re.DOTALL) | ||
# matches = re.findall(r"<assistant>(.*?)</assistant>", data, re.DOTALL) | ||
if matches: | ||
return matches[-1].strip() | ||
else: | ||
return data | ||
|
||
if r"<\|assistant\|>" in data: # phi | ||
assistant_part = data.split(r"<\|assistant\|>")[-1] | ||
response = response.replace(r"<\|end\|>", "").strip() | ||
return response | ||
|
||
if r"<|im_start|>assistant" in data: # qwen | ||
assistant_part = data.split(r"<|im_start|>assistant")[-1] | ||
|
||
# Remove the thinking part if it exists | ||
if r"<think>" in assistant_part: | ||
# Extract everything after </think> | ||
response = assistant_part.split(r"</think>")[-1] | ||
else: | ||
response = assistant_part | ||
response = response.replace(r"<|im_end|>", "").strip() | ||
return response | ||
|
||
return data | ||
|
||
def get_bedrock_completion(settings, prompt, system_prompt=None): | ||
bedrock_client = boto3.client( | ||
|
@@ -220,3 +233,144 @@ def invoke_ollama_endpoint(payload, endpoint_name, url="127.0.0.1:11434"): | |
lines.append(json.loads(r)) | ||
|
||
return "".join([l["response"] for l in lines]) | ||
|
||
|
||
def batch_invoke_sagemaker_endpoint( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we verify that this has been tested? |
||
payloads, | ||
endpoint_name, | ||
region="us-east-1", | ||
s3_bucket=None, | ||
s3_input_prefix="/eval/async/input/", | ||
poll_interval_seconds=10, | ||
timeout_seconds=600, | ||
): | ||
""" | ||
Invoke a SageMaker async endpoint for a batch of payloads. | ||
|
||
- payloads: list of JSON-serializable objects (each is one request) | ||
- endpoint_name: name of the async SageMaker endpoint | ||
- region: AWS region | ||
- s3_bucket: S3 bucket to upload inputs (required) | ||
- s3_input_prefix: S3 prefix for input uploads | ||
- poll_interval_seconds: interval between checks for output readiness | ||
- timeout_seconds: max time to wait for each result | ||
|
||
Returns list of raw results (strings) in the same order as payloads. | ||
""" | ||
if s3_bucket is None: | ||
raise ValueError("s3_bucket is required for async invocations") | ||
if not isinstance(s3_bucket, str) or not s3_bucket.strip(): | ||
raise ValueError( | ||
"s3_bucket must be a non-empty string (e.g., 'my-bucket-name'), got: " | ||
f"{type(s3_bucket).__name__}" | ||
) | ||
|
||
sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=region) | ||
s3_client = boto3.client("s3", region_name=region) | ||
|
||
# Normalize prefix | ||
input_prefix = s3_input_prefix.lstrip("/") | ||
|
||
input_locations = [] | ||
output_locations = [] | ||
inference_ids = [] | ||
|
||
# 1) Upload all payloads and invoke async endpoint | ||
for idx, payload in enumerate(payloads): | ||
print(f"Submitting {idx + 1}/{len(payloads)} to async endpoint '{endpoint_name}'...") | ||
request_id = str(uuid.uuid4())[:8] | ||
input_key = f"{input_prefix}batch-{request_id}-{idx}.json" | ||
|
||
# Ensure payload is in expected format for the model container | ||
if isinstance(payload, str): | ||
payload_to_upload = {"inputs": payload} | ||
elif isinstance(payload, list) and all(isinstance(p, str) for p in payload): | ||
payload_to_upload = {"inputs": payload} | ||
elif isinstance(payload, dict): | ||
payload_to_upload = payload | ||
else: | ||
# Fallback: wrap unknown types under inputs | ||
payload_to_upload = {"inputs": payload} | ||
|
||
s3_client.put_object( | ||
Bucket=s3_bucket, | ||
Key=input_key, | ||
Body=json.dumps(payload_to_upload), | ||
ContentType="application/json", | ||
) | ||
|
||
input_location = f"s3://{s3_bucket}/{input_key}" | ||
input_locations.append(input_location) | ||
|
||
response = sagemaker_runtime.invoke_endpoint_async( | ||
EndpointName=endpoint_name, | ||
InputLocation=input_location, | ||
ContentType="application/json", | ||
InvocationTimeoutSeconds=3600, | ||
) | ||
|
||
output_locations.append(response["OutputLocation"]) # s3 uri | ||
inference_ids.append(response.get("InferenceId")) | ||
print(f"Submitted {idx + 1}/{len(payloads)}. Output will be written to {response['OutputLocation']}") | ||
|
||
# 2) Poll for each output and download results | ||
results = [] | ||
for i, output_location in enumerate(output_locations): | ||
start_time = time.time() | ||
|
||
# Parse s3 uri and derive expected result key: <prefix>/<InferenceId>.out | ||
uri = output_location.replace("s3://", "") | ||
bucket, key = uri.split("/", 1) | ||
inference_id = inference_ids[i] | ||
expected_key = f"{key.rstrip('/')}/{inference_id}.out" if isinstance(inference_id, str) and inference_id else key | ||
if expected_key != key: | ||
print(f"Polling for result object s3://{bucket}/{expected_key}") | ||
|
||
while True: | ||
try: | ||
# First, check expected result key (InferenceId.out) | ||
s3_client.head_object(Bucket=bucket, Key=expected_key) | ||
break | ||
except Exception: | ||
if time.time() - start_time > timeout_seconds: | ||
print(f"Timed out waiting for result {i + 1}/{len(output_locations)} after {timeout_seconds}s") | ||
results.append(None) | ||
break | ||
elapsed = int(time.time() - start_time) | ||
print(f"Waiting for result {i + 1}/{len(output_locations)}... {elapsed}s elapsed") | ||
|
||
# Try to detect async failure artifact: async-endpoint-failures/.../<InferenceId>-error.out | ||
if isinstance(inference_id, str) and inference_id: | ||
try: | ||
candidates = s3_client.list_objects_v2( | ||
Bucket=bucket, | ||
Prefix="async-endpoint-failures/", | ||
MaxKeys=1000, | ||
) | ||
for obj in candidates.get("Contents", []): | ||
k = obj.get("Key", "") | ||
if k.endswith(f"{inference_id}-error.out"): | ||
err_obj = s3_client.get_object(Bucket=bucket, Key=k) | ||
err_text = err_obj["Body"].read().decode("utf-8", errors="replace") | ||
print(f"Error for request {i + 1}/{len(output_locations)} (InferenceId={inference_id}):\n{err_text}") | ||
results.append(None) | ||
# Stop waiting for this one | ||
elapsed = int(time.time() - start_time) | ||
print(f"Marking request {i + 1} as failed after {elapsed}s due to async failure artifact: s3://{bucket}/{k}") | ||
# Break out of the polling loop | ||
raise StopIteration | ||
except StopIteration: | ||
break | ||
except Exception: | ||
# Ignore listing errors silently and keep polling | ||
pass | ||
time.sleep(poll_interval_seconds) | ||
|
||
if len(results) == 0 or results[-1] is not None: | ||
obj = s3_client.get_object(Bucket=bucket, Key=key) | ||
result_body = obj["Body"].read().decode("utf-8") | ||
results.append(result_body) | ||
total = int(time.time() - start_time) | ||
print(f"Result ready for {i + 1}/{len(output_locations)} after {total}s") | ||
|
||
return results |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we detect qwen and phi without searching the string for a specific pattern?