Skip to content

Commit 501b080

Browse files
committed
add unit test to ensure
1 parent 7a5b438 commit 501b080

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from pathlib import Path
9+
from unittest.mock import MagicMock, patch
10+
11+
12+
class TestLibTorchAoOpsLoader(unittest.TestCase):
13+
def test_find_and_load_success(self):
14+
mock_paths = [Path("/test/path1")]
15+
mock_lib = MagicMock()
16+
mock_lib.__str__.return_value = "/test/path1/libtorchao_ops_aten.so"
17+
18+
with patch("pathlib.Path.glob", return_value=[mock_lib]):
19+
with patch("torch.ops.load_library") as mock_load:
20+
from ..op_lib import find_and_load_libtorchao_ops
21+
22+
find_and_load_libtorchao_ops(mock_paths)
23+
24+
mock_load.assert_called_once_with("/test/path1/libtorchao_ops_aten.so")
25+
26+
def test_no_library_found(self):
27+
mock_paths = [Path("/test/path1"), Path("/test/path2")]
28+
29+
with patch("pathlib.Path.glob", return_value=[]):
30+
from ..op_lib import find_and_load_libtorchao_ops
31+
32+
with self.assertRaises(FileNotFoundError):
33+
find_and_load_libtorchao_ops(mock_paths)
34+
35+
def test_multiple_libraries_error(self):
36+
mock_paths = [Path("/test/path1")]
37+
mock_lib1 = MagicMock()
38+
mock_lib2 = MagicMock()
39+
mock_libs = [mock_lib1, mock_lib2]
40+
41+
with patch("pathlib.Path.glob", return_value=mock_libs):
42+
from ..op_lib import find_and_load_libtorchao_ops
43+
44+
try:
45+
find_and_load_libtorchao_ops(mock_paths)
46+
self.fail("Expected AssertionError was not raised")
47+
except AssertionError as e:
48+
expected_error_msg = f"Expected to find one libtorchao_ops_aten.* library at {mock_paths[0]}, but found 2"
49+
self.assertIn(expected_error_msg, str(e))
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

0 commit comments

Comments
 (0)