| 
3 | 3 | #  | 
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the  | 
5 | 5 | # LICENSE file in the root directory of this source tree.  | 
6 |  | -import pytest  | 
 | 6 | +import unittest  | 
 | 7 | + | 
7 | 8 | import torch  | 
 | 9 | +from torch.testing._internal import common_utils  | 
8 | 10 | 
 
  | 
9 | 11 | from torchao._models.llama.model import Transformer  | 
10 | 12 | 
 
  | 
11 |  | -_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])  | 
12 |  | - | 
13 | 13 | 
 
  | 
14 | 14 | def init_model(name="stories15M", device="cpu", precision=torch.bfloat16):  | 
 | 15 | +    """Initialize and return a Transformer model with specified configuration."""  | 
15 | 16 |     model = Transformer.from_name(name)  | 
16 | 17 |     model.to(device=device, dtype=precision)  | 
17 | 18 |     return model.eval()  | 
18 | 19 | 
 
  | 
19 | 20 | 
 
  | 
20 |  | -@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)  | 
21 |  | -@pytest.mark.parametrize("batch_size", [1, 4])  | 
22 |  | -@pytest.mark.parametrize("is_training", [True, False])  | 
23 |  | -def test_ao_llama_model_inference_mode(device, batch_size, is_training):  | 
24 |  | -    random_model = init_model(device=device)  | 
25 |  | -    seq_len = 16  | 
26 |  | -    input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)  | 
27 |  | -    input_pos = None if is_training else torch.arange(seq_len).to(device)  | 
28 |  | -    with torch.device(device):  | 
29 |  | -        random_model.setup_caches(  | 
30 |  | -            max_batch_size=batch_size, max_seq_length=seq_len, training=is_training  | 
31 |  | -        )  | 
32 |  | -    for i in range(3):  | 
33 |  | -        out = random_model(input_ids, input_pos)  | 
34 |  | -        assert out is not None, "model failed to run"  | 
 | 21 | +class TorchAOBasicTestCase(unittest.TestCase):  | 
 | 22 | +    """Test suite for basic Transformer inference functionality."""  | 
 | 23 | + | 
 | 24 | +    @common_utils.parametrize(  | 
 | 25 | +        "device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]  | 
 | 26 | +    )  | 
 | 27 | +    @common_utils.parametrize("batch_size", [1, 4])  | 
 | 28 | +    @common_utils.parametrize("is_training", [True, False])  | 
 | 29 | +    def test_ao_inference_mode(self, device, batch_size, is_training):  | 
 | 30 | +        # Initialize model with specified device  | 
 | 31 | +        random_model = init_model(device=device)  | 
 | 32 | + | 
 | 33 | +        # Set up test input parameters  | 
 | 34 | +        seq_len = 16  | 
 | 35 | +        input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)  | 
 | 36 | + | 
 | 37 | +        # input_pos is None for training mode, tensor for inference mode  | 
 | 38 | +        input_pos = None if is_training else torch.arange(seq_len).to(device)  | 
 | 39 | + | 
 | 40 | +        # Setup model caches within the device context  | 
 | 41 | +        with torch.device(device):  | 
 | 42 | +            random_model.setup_caches(  | 
 | 43 | +                max_batch_size=batch_size, max_seq_length=seq_len, training=is_training  | 
 | 44 | +            )  | 
 | 45 | + | 
 | 46 | +        # Run multiple inference iterations to ensure consistency  | 
 | 47 | +        for i in range(3):  | 
 | 48 | +            out = random_model(input_ids, input_pos)  | 
 | 49 | +            self.assertIsNotNone(out, f"Model failed to run on iteration {i}")  | 
 | 50 | + | 
 | 51 | + | 
 | 52 | +common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)  | 
 | 53 | + | 
 | 54 | +if __name__ == "__main__":  | 
 | 55 | +    unittest.main()  | 
0 commit comments