Skip to content

Commit dceee17

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
Add utils unit tests for training cli (#97)
* Bug fix: Fixed create command job error * Add utils unit tests for training cli --------- Co-authored-by: Roja Reddy Sareddy <[email protected]>
1 parent 5a80519 commit dceee17

File tree

4 files changed

+213
-7
lines changed

4 files changed

+213
-7
lines changed

test/unit_tests/cli/test_cli.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,3 @@ def test_cli_command_help_format(self):
8282
self.assertIn("HyperPod PyTorch Job CLI", result.output)
8383

8484

85-
if __name__ == "__main__":
86-
unittest.main()

test/unit_tests/cli/test_constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,3 @@ def test_help_text_content(self):
1414
self.assertIn("Usage:", HELP_TEXT)
1515

1616

17-
if __name__ == "__main__":
18-
unittest.main()

test/unit_tests/cli/test_training.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,3 @@ def test_pytorch_describe_error(self, mock_hyperpod_pytorch_job):
221221
self.assertNotEqual(result.exit_code, 0)
222222
self.assertIn("Failed to describe job", result.output)
223223

224-
225-
if __name__ == "__main__":
226-
unittest.main()
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import pytest
2+
import json
3+
import click
4+
from click.testing import CliRunner
5+
from unittest.mock import Mock, patch
6+
7+
from sagemaker.hyperpod.cli.training_utils import load_schema_for_version, generate_click_command
8+
9+
10+
class TestLoadSchemaForVersion:
11+
@patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data')
12+
def test_success(self, mock_get_data):
13+
"""Test successful schema loading"""
14+
data = {"properties": {"x": {"type": "string"}}}
15+
mock_get_data.return_value = json.dumps(data).encode()
16+
17+
result = load_schema_for_version('1.2', 'test_package')
18+
19+
assert result == data
20+
mock_get_data.assert_called_once_with('test_package.v1_2', 'schema.json')
21+
22+
@patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data')
23+
def test_schema_not_found(self, mock_get_data):
24+
"""Test handling of missing schema file"""
25+
mock_get_data.return_value = None
26+
27+
with pytest.raises(click.ClickException) as exc:
28+
load_schema_for_version('1.0', 'test_package')
29+
30+
assert "Could not load schema.json for version 1.0" in str(exc.value)
31+
32+
@patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data')
33+
def test_invalid_json_schema(self, mock_get_data):
34+
"""Test handling of invalid JSON in schema file"""
35+
mock_get_data.return_value = b'invalid json'
36+
37+
with pytest.raises(json.JSONDecodeError):
38+
load_schema_for_version('1.0', 'test_package')
39+
40+
class TestGenerateClickCommand:
41+
def setup_method(self):
42+
self.runner = CliRunner()
43+
44+
def test_missing_registry(self):
45+
"""Test that registry is required"""
46+
with pytest.raises(ValueError) as exc:
47+
generate_click_command(schema_pkg="test_package")
48+
assert "You must pass a registry mapping" in str(exc.value)
49+
50+
@patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data')
51+
def test_pytorch_json_flags(self, mock_get_data):
52+
"""Test handling of JSON flags for PyTorch config"""
53+
schema = {
54+
'properties': {
55+
'environment': {'type': 'object'},
56+
'label_selector': {'type': 'object'}
57+
}
58+
}
59+
mock_get_data.return_value = json.dumps(schema).encode()
60+
61+
class DummyModel:
62+
def __init__(self, **kwargs):
63+
self.__dict__.update(kwargs)
64+
def to_domain(self):
65+
return self
66+
67+
registry = {'1.0': DummyModel}
68+
69+
@click.command()
70+
@generate_click_command(
71+
schema_pkg="hyperpod_pytorchjob_config_schemas",
72+
registry=registry
73+
)
74+
def cmd(version, debug, config):
75+
click.echo(json.dumps({
76+
'environment': config.environment,
77+
'label_selector': config.label_selector
78+
}))
79+
80+
# Test valid JSON input
81+
result = self.runner.invoke(cmd, [
82+
'--environment', '{"VAR1":"val1"}',
83+
'--label_selector', '{"key":"value"}'
84+
])
85+
assert result.exit_code == 0
86+
output = json.loads(result.output)
87+
assert output == {
88+
'environment': {'VAR1': 'val1'},
89+
'label_selector': {'key': 'value'}
90+
}
91+
92+
# Test invalid JSON input
93+
result = self.runner.invoke(cmd, ['--environment', 'invalid'])
94+
assert result.exit_code == 2
95+
assert 'must be valid JSON' in result.output
96+
97+
@patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data')
98+
def test_list_parameters(self, mock_get_data):
99+
"""Test handling of list parameters"""
100+
schema = {
101+
'properties': {
102+
'command': {'type': 'array'},
103+
'args': {'type': 'array'}
104+
}
105+
}
106+
mock_get_data.return_value = json.dumps(schema).encode()
107+
108+
class DummyModel:
109+
def __init__(self, **kwargs):
110+
self.__dict__.update(kwargs)
111+
def to_domain(self):
112+
return self
113+
114+
registry = {'1.0': DummyModel}
115+
116+
@click.command()
117+
@generate_click_command(
118+
schema_pkg="hyperpod_pytorchjob_config_schemas",
119+
registry=registry
120+
)
121+
def cmd(version, debug, config):
122+
click.echo(json.dumps({
123+
'command': config.command,
124+
'args': config.args
125+
}))
126+
127+
# Test list input
128+
result = self.runner.invoke(cmd, [
129+
'--command', '[python, train.py]',
130+
'--args', '[--epochs, 10]'
131+
])
132+
assert result.exit_code == 0
133+
output = json.loads(result.output)
134+
assert output == {
135+
'command': ['python', 'train.py'],
136+
'args': ['--epochs', '10']
137+
}
138+
139+
@patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data')
140+
def test_version_handling(self, mock_get_data):
141+
"""Test version handling in command generation"""
142+
schema = {'properties': {}}
143+
mock_get_data.return_value = json.dumps(schema).encode()
144+
145+
class DummyModel:
146+
def __init__(self, **kwargs): pass
147+
148+
def to_domain(self): return self
149+
150+
registry = {'2.0': DummyModel}
151+
152+
@click.command()
153+
@generate_click_command(
154+
version_key='2.0',
155+
schema_pkg="test_package",
156+
registry=registry
157+
)
158+
def cmd(version, debug, config):
159+
click.echo(version)
160+
161+
result = self.runner.invoke(cmd, [])
162+
assert result.exit_code == 0
163+
assert result.output.strip() == '2.0'
164+
165+
@patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data')
166+
def test_type_conversion(self, mock_get_data):
167+
"""Test type conversion for different parameter types"""
168+
# Mock the schema with different types
169+
schema = {
170+
'properties': {
171+
'node_count': {'type': 'integer'},
172+
'deep_health_check_passed_nodes_only': {'type': 'boolean'},
173+
'tasks_per_node': {'type': 'integer'},
174+
'job_name': {'type': 'string'}
175+
}
176+
}
177+
mock_get_data.return_value = json.dumps(schema).encode()
178+
179+
class DummyModel:
180+
def __init__(self, **kwargs):
181+
self.__dict__.update(kwargs)
182+
183+
def to_domain(self):
184+
return self
185+
186+
registry = {'1.0': DummyModel}
187+
188+
@click.command()
189+
@generate_click_command(registry=registry)
190+
def cmd(version, debug, config):
191+
click.echo(json.dumps({
192+
'node_count': config.node_count,
193+
'deep_health_check_passed_nodes_only': config.deep_health_check_passed_nodes_only,
194+
'tasks_per_node': config.tasks_per_node,
195+
'job_name': config.job_name
196+
}))
197+
198+
# Test integer conversion
199+
result = self.runner.invoke(cmd, ['--node-count', '5'])
200+
assert result.exit_code == 0
201+
202+
# Test boolean conversion
203+
result = self.runner.invoke(cmd, ['--deep-health-check-passed-nodes-only', 'true'])
204+
assert result.exit_code == 0
205+
206+
# Test string conversion
207+
result = self.runner.invoke(cmd, ['--job-name', 'test-job'])
208+
assert result.exit_code == 0
209+
210+
# Test invalid type (should fail gracefully)
211+
result = self.runner.invoke(cmd, ['--node-count', 'not-a-number'])
212+
assert result.exit_code == 2
213+
assert "Invalid value" in result.output

0 commit comments

Comments
 (0)