Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions test/test_ao_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,53 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import unittest

import torch
from torch.testing._internal import common_utils

from torchao._models.llama.model import Transformer

_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):
"""Initialize and return a Transformer model with specified configuration."""
model = Transformer.from_name(name)
model.to(device=device, dtype=precision)
return model.eval()


@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("is_training", [True, False])
def test_ao_llama_model_inference_mode(device, batch_size, is_training):
random_model = init_model(device=device)
seq_len = 16
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)
input_pos = None if is_training else torch.arange(seq_len).to(device)
with torch.device(device):
random_model.setup_caches(
max_batch_size=batch_size, max_seq_length=seq_len, training=is_training
)
for i in range(3):
out = random_model(input_ids, input_pos)
assert out is not None, "model failed to run"
class TorchAOBasicTestCase(unittest.TestCase):
"""Test suite for basic Transformer inference functionality."""

@common_utils.parametrize(
"device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
)
@common_utils.parametrize("batch_size", [1, 4])
@common_utils.parametrize("is_training", [True, False])
def test_ao_inference_mode(self, device, batch_size, is_training):
# Initialize model with specified device
random_model = init_model(device=device)

# Set up test input parameters
seq_len = 16
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)

# input_pos is None for training mode, tensor for inference mode
input_pos = None if is_training else torch.arange(seq_len).to(device)

# Setup model caches within the device context
with torch.device(device):
random_model.setup_caches(
max_batch_size=batch_size, max_seq_length=seq_len, training=is_training
)

# Run multiple inference iterations to ensure consistency
for i in range(3):
out = random_model(input_ids, input_pos)
self.assertIsNotNone(out, f"Model failed to run on iteration {i}")


common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)

if __name__ == "__main__":
unittest.main()
Loading