Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4e611a9
head_size = getattr(model.config, "head_dim", model.config.emb_dim /…
Oct 6, 2025
043926c
[dpp] store enforce_sizes in log name and added generic kwargs to get…
kcirred Sep 26, 2025
6d6541f
[utils] added doc string, refactor sample_key, added return_key to sq…
kcirred Sep 29, 2025
97aeea3
[dpp/validation] restore sample_key in logic after rebase of main
kcirred Sep 30, 2025
9efb6e7
[validation] Modified final file string to hash due to OSError name t…
kcirred Oct 3, 2025
ea1a8ac
[test_validation] remove unused line
kcirred Oct 3, 2025
b637ea5
[validation] removed enforce_sizes from find_validation_info_path
kcirred Oct 3, 2025
aef361e
[dpp] added handling of return_key for __custom_line_sampler
kcirred Oct 7, 2025
178b75b
updated llama model expectation tests using v1.0.0 aiu software stack…
JRosenkranz Oct 7, 2025
73f6551
[testing] changed get_default_validation_prefix to generic kwargs, fi…
kcirred Sep 30, 2025
86d60d6
fixed test_scripts program assertion
JRosenkranz Oct 8, 2025
f3dcee1
Add test case for caching
avery-blanchard Apr 8, 2025
f936eb5
Update cache test, add validation for cached run
avery-blanchard Jul 19, 2025
81870d5
Squashed commit of the following:
alex-jw-brooks Jul 30, 2025
ad3a584
don't skip save on aiu iter0
alex-jw-brooks Sep 11, 2025
54cc09b
fix fp8 dtype, always use persistent model fixture
alex-jw-brooks Sep 11, 2025
75eb5fb
remove model path from get_cpu_model args
alex-jw-brooks Sep 11, 2025
f1a810e
fix casing error
alex-jw-brooks Oct 6, 2025
15cf522
Rebase fixes, linting
alex-jw-brooks Oct 8, 2025
a743968
fix input prep
alex-jw-brooks Oct 8, 2025
b25e44f
use setdefault for torch sendnn cache dir
alex-jw-brooks Oct 8, 2025
4609ad5
Added head_dim override option to inference.py
Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions aiu_fms_testing_utils/testing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from collections.abc import Iterable


def format_kwargs_to_string(**kwargs):
"""
Turns kwargs into a str with variable names using `-`, variables separated by `_` and iterable separated by `,`
"""
formatted_pairs = []
for key, value in sorted(kwargs.items()):
formatted_value = None
if isinstance(value, str):
formatted_value = value
elif isinstance(value, Iterable):
formatted_value = ",".join(map(str, value))
elif value:
formatted_value = str(value)
# only append if formatted_value exists
if formatted_value:
# Keep previous convention of variable names with `-` instead of `_`
formatted_pairs.append(
f"{key.replace('_', '-')}-{formatted_value.replace('/', '--')}"
)

return "_".join(formatted_pairs)
31 changes: 21 additions & 10 deletions aiu_fms_testing_utils/testing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from aiu_fms_testing_utils.utils.aiu_setup import dprint
from aiu_fms_testing_utils._version import version_tuple
import os
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

import hashlib


class LogitsExtractorHook(
Expand Down Expand Up @@ -125,13 +128,7 @@ def __len__(self):


def get_default_validation_prefix(
model_id: str,
max_new_tokens: int,
batch_size: int,
seq_length: int,
dtype: str,
attn_type: str,
aftu_version: str,
**kwargs,
):
"""
Args:
Expand All @@ -144,9 +141,17 @@ def get_default_validation_prefix(
aftu_version (str): introduced in v0.3.0 to track changed in log

Returns:
str: A prefix that will be prepended to the file name
str: A hashed prefix that will be prepended to the file name
"""
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}_attn-type-{attn_type}.{aftu_version}"
aftu_version = kwargs.pop(
"aftu_version", ".".join([str(_) for _ in version_tuple[:3]])
)
kwargs_str = format_kwargs_to_string(**kwargs)

filename = f"{kwargs_str}"
hash_object = hashlib.sha256(filename.encode("utf-8"))
hex_digest = hash_object.hexdigest()
return f"{hex_digest}_{aftu_version}"


def load_validation_information(
Expand Down Expand Up @@ -416,11 +421,14 @@ def get_validation_info_path(
aftu_version: Optional[Tuple[int, int, int]] = None,
device_type: str = "cpu",
dtype: str = "fp16",
**kwargs,
):
if aftu_version is None:
aftu_version = version_tuple

validation_file_name = f"{get_default_validation_prefix(model_variant, max_new_tokens, batch_size, seq_length, dtype, attn_type, '.'.join([str(_) for _ in aftu_version[:3]]))}.{device_type}_validation_info.{seed}.out"
sample_key = kwargs.get("sample_key", None)

validation_file_name = f"{get_default_validation_prefix(aftu_version='.'.join([str(_) for _ in aftu_version[:3]]), model_id=model_variant, max_new_tokens=max_new_tokens, batch_size=batch_size, seq_length=seq_length, dtype=dtype, attn_type=attn_type, sample_key=sample_key)}.{device_type}_validation_info.{seed}.out"
full_path = os.path.join(validation_info_dir, validation_file_name)
return full_path

Expand Down Expand Up @@ -452,10 +460,12 @@ def find_validation_info_path(
version_allow_decrement: bool = False,
device_type: str = "cpu",
dtype: str = "fp16",
**kwargs,
):
"""
Find the validation info path if it exists, otherwise return None
"""
sample_key = kwargs.get("sample_key", None)

if aftu_version is None:
loc_version_tuple = version_tuple[:3]
Expand All @@ -476,6 +486,7 @@ def find_validation_info_path(
loc_version_tuple,
device_type,
dtype,
sample_key=sample_key,
)
# if the path is found, we are done searching and can return
if os.path.exists(full_path):
Expand Down
62 changes: 59 additions & 3 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, world_size
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

from fms.utils.generation import pad_input_ids
import torch
Expand Down Expand Up @@ -482,6 +483,7 @@ def sample_rag_factoid_requests(
enforce_sizes: List[int] = [],
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
if not os.path.exists(dataset_path):
print("error dataset does not exist")
Expand All @@ -492,7 +494,7 @@ def sample_rag_factoid_requests(
for line in f:
dataset.append(line)

return __sample_requests(
sample_request = __sample_requests(
dataset,
num_requests,
tokenizer,
Expand All @@ -506,6 +508,24 @@ def sample_rag_factoid_requests(
_cached_dataset_key=dataset_path,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="rag_factoid",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)

return sample_request, sample_key
else:
return sample_request


def sample_sharegpt_requests(
dataset_path: str,
Expand All @@ -518,6 +538,7 @@ def sample_sharegpt_requests(
enforce_sizes: List[int] | None = None,
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
if not os.path.exists(dataset_path):
print("downloading share-gpt dataset as it does not exist")
Expand All @@ -543,7 +564,7 @@ def sample_sharegpt_requests(
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset: List[str] = [data["conversations"][0]["value"] for data in dataset]

return __sample_requests(
sample_request = __sample_requests(
dataset,
num_requests,
tokenizer,
Expand All @@ -557,6 +578,23 @@ def sample_sharegpt_requests(
_cached_dataset_key=dataset_path,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="sharegpt",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)
return sample_request, sample_key
else:
return sample_request


def sample_squad_v2_qa_requests(
dataset_path: str,
Expand All @@ -569,6 +607,7 @@ def sample_squad_v2_qa_requests(
enforce_sizes: List[int] | None = None,
truncation: bool = False,
pad_multiple: int = 64,
return_key: bool = False,
) -> List[Tuple[str, int]]:
from datasets import load_dataset

Expand All @@ -582,7 +621,7 @@ def sample_squad_v2_qa_requests(

ds = [f"{data['context']}\n{data['question']}" for data in ds]

return __sample_requests(
sample_request = __sample_requests(
ds,
num_requests,
tokenizer,
Expand All @@ -595,6 +634,23 @@ def sample_squad_v2_qa_requests(
pad_multiple,
)

if return_key:
sample_key: str = format_kwargs_to_string(
dataset="squad_v2",
num_requests=num_requests,
tokenizer=tokenizer.name_or_path.replace("/", "--"),
prompt_length_min=prompt_length_min,
prompt_length_max=prompt_length_max,
seed=seed,
enforce_heterogeneous=enforce_heterogeneous,
enforce_sizes=enforce_sizes,
truncate=truncation,
pad_multiple=pad_multiple,
)
return sample_request, sample_key
else:
return sample_request


def prepare_inputs(
batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt"
Expand Down
4 changes: 3 additions & 1 deletion aiu_fms_testing_utils/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def generate(
raise ValueError("model must have a distributed_strategy")

kvheads = kvheads // tensor_parallel_size if kvheads > 1 else kvheads
head_size = model.config.emb_dim // nheads
head_size = getattr(
model.config, "head_dim", model.config.emb_dim // model.config.nheads
)
if "fp8" in kwargs["attn_name"]:
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor

Expand Down
26 changes: 20 additions & 6 deletions scripts/drive_paged_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_programs_prompts,
KVCACHE_NUM_BLOCKS_HINT,
)
from aiu_fms_testing_utils.testing.utils import format_kwargs_to_string

parser = argparse.ArgumentParser(
description="Script which will drive paged programs for debugging"
Expand Down Expand Up @@ -195,6 +196,10 @@
custom_shape = (len(result), max([_[1] for _ in result]))

def __custom_line_sampler(*args, **kwargs):
return_key = kwargs.get("return_key", False)
sample_key = format_kwargs_to_string(**kwargs)
if return_key:
return result, sample_key
return result

sampler = __custom_line_sampler
Expand Down Expand Up @@ -245,7 +250,7 @@ def __custom_line_sampler(*args, **kwargs):

def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0):
start = time.time()
prompts_and_sizes = sampler(
prompts_and_sizes, sample_key = sampler(
DATASET_PATH,
batch_size,
tokenizer,
Expand All @@ -254,6 +259,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0
seed,
enforce_sizes=enforce_sizes,
truncation=allow_truncation,
return_key=True,
)
end = time.time()
if local_rank == 0:
Expand All @@ -274,7 +280,7 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, enforce_sizes=[], seed=0

input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)
return input_ids, extra_kwargs
return input_ids, extra_kwargs, sample_key


def __maybe_prepare_fp8_weights(model_in, is_fp8):
Expand All @@ -296,7 +302,9 @@ def __load_validation_info(
tokenizer,
seed,
attn_type: str,
**kwargs,
):
sample_key = kwargs.get("sample_key", None)
full_path = find_validation_info_path(
args.validation_info_outputs_dir,
model_variant,
Expand All @@ -307,6 +315,7 @@ def __load_validation_info(
attn_type,
version_allow_decrement=True,
dtype=CPU_DTYPE,
sample_key=sample_key,
)
if full_path is not None:
dprint(f"cpu validation info found for seed={seed} -- loading it")
Expand Down Expand Up @@ -367,13 +376,14 @@ def __load_validation_info(

# warmup with any input so compiler produces criteria json
# TODO: Swap this with __prepare_inputs once fix for shape_id is available
# input_ids, extra_kwargs = __prepare_inputs(2, max_tkv, tokenizer)
# input_ids, extra_kwargs, sample_key = __prepare_inputs(2, max_tkv, tokenizer)
prompt_list = [torch.arange(0, 64, dtype=torch.int64)]
# matching vllm warmup to pad to 2 on fp8, and no pad for fp16
if is_fp8:
prompt_list = prompt_list * 2
input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=64)
extra_kwargs["mask"] = extra_kwargs["mask"].to(torch.float16)

extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
Expand Down Expand Up @@ -494,7 +504,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
for valid_prompt_shape in valid_prompt_shapes:
if valid_prompt_shape == custom_shape:
enforce_sizes = [valid_prompt_shape[1]]
input_ids, extra_kwargs = __prepare_inputs(
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
Expand All @@ -506,6 +516,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
custom_shape,
input_ids,
extra_kwargs,
sample_key,
)
]
break
Expand Down Expand Up @@ -566,7 +577,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
)
)
try:
input_ids, extra_kwargs = __prepare_inputs(
input_ids, extra_kwargs, sample_key = __prepare_inputs(
valid_prompt_shape[0],
valid_prompt_shape[1],
tokenizer,
Expand All @@ -578,6 +589,7 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]:
valid_prompt_shape,
input_ids,
extra_kwargs,
sample_key,
)
)
used_keys.add(program_seq_key[0])
Expand Down Expand Up @@ -609,7 +621,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):

failed_cases = []
# for each program and valid prompt (batch size, sequence length)
for program_id, valid_prompt, input_ids, extra_kwargs in valid_prompts:
for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts:
extra_kwargs["attn_name"] = ATTN_NAME
if (
"granite-3.3-8b-instruct" in model_variant
Expand All @@ -634,6 +646,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
tokenizer,
seed=0,
attn_type=ATTN_NAME,
sample_key=sample_key,
)
# if the cpu validation info is not yet computed, compute it
if cpu_validation_info is None:
Expand All @@ -657,6 +670,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
0,
ATTN_NAME,
dtype=CPU_DTYPE,
sample_key=sample_key,
)
)

Expand Down
Loading