Skip to content

Commit 3dd15a3

Browse files
Fix bug for actions and improve design (#448)
1 parent 64fb95f commit 3dd15a3

File tree

6 files changed

+18
-14
lines changed

6 files changed

+18
-14
lines changed

smdebug/rules/action/action.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212

1313

1414
class Actions:
15-
def __init__(self, actions_str="", rule_name=""):
15+
def __init__(self, actions_str, rule_name):
1616
self._actions = []
1717
self._logger = get_logger()
1818
actions_str = actions_str.strip() if actions_str is not None else ""
1919
if actions_str == "":
20-
self._logger.info(f"No action specified. Action str is {actions_str}")
20+
self._logger.info(f"No action specified for rule {rule_name}.")
2121
return
2222
self._register_actions(actions_str, rule_name)
2323

@@ -77,7 +77,7 @@ def _register_actions(self, actions_str="", rule_name=""):
7777
f"Action :{action_dict['name']} not supported. Allowed action names are: {ALLOWED_ACTIONS}"
7878
)
7979

80-
def invoke(self):
80+
def invoke(self, message=""):
8181
self._logger.info("Invoking actions")
8282
for action in self._actions:
83-
action.invoke()
83+
action.invoke(message)

smdebug/rules/action/message_action.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,5 +141,5 @@ def _send_message(self, message):
141141
self._last_send_mesg_response = response
142142
return response
143143

144-
def invoke(self, message=None):
144+
def invoke(self, message):
145145
self._send_message(message)

smdebug/rules/action/stop_training_action.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,19 @@ def _get_sm_tj_jobs_with_prefix(self):
8181

8282
return list(found_job_dict.keys())
8383

84-
def _stop_training_job(self):
84+
def _stop_training_job(self, message):
8585
if len(self._found_jobs) != 1:
8686
return
87-
self._logger.info(f"Invoking StopTrainingJob action on SM jobname:{self._found_jobs}")
87+
if message != "":
88+
message = f"with message {message}"
89+
self._logger.info(
90+
f"Invoking StopTrainingJob action on SM jobname {self._found_jobs} {message}"
91+
)
8892
try:
8993
res = self._sm_client.stop_training_job(TrainingJobName=self._found_jobs[0])
9094
self._logger.info(f"Stop Training job response:{res}")
9195
except Exception as e:
9296
self._logger.info(f"Got exception while stopping training job{self._found_jobs[0]}:{e}")
9397

94-
def invoke(self, message=None):
95-
self._stop_training_job()
98+
def invoke(self, message):
99+
self._stop_training_job(message)

smdebug/rules/rule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
# This is Rule interface
1515
class Rule(ABC):
16-
def __init__(self, base_trial, other_trials=None, action_str=""):
16+
def __init__(self, base_trial, action_str, other_trials=None):
1717
self.base_trial = base_trial
1818
self.other_trials = other_trials
1919

@@ -25,7 +25,7 @@ def __init__(self, base_trial, other_trials=None, action_str=""):
2525

2626
self.logger = get_logger()
2727
self.rule_name = self.__class__.__name__
28-
self._actions = Actions(actions_str=action_str, rule_name=self.rule_name)
28+
self._actions = Actions(action_str, rule_name=self.rule_name)
2929
self.report = {
3030
"RuleTriggered": 0,
3131
"Violations": 0,

tests/analysis/rules/test_rule_no_refresh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
def test_no_refresh_invocation():
1515
class TestRule(Rule):
1616
def __init__(self, base_trial):
17-
super().__init__(base_trial=base_trial)
17+
super().__init__(base_trial=base_trial, action_str="")
1818

1919
def set_required_tensors(self, step):
2020
for t in self.base_trial.tensor_names():

tests/rules/action/test_message_action.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
def test_action_stop_training_job():
88
action_str = '{"name": "stoptraining" , "training_job_prefix":"training_prefix"}'
9-
action = Actions(actions_str=action_str)
9+
action = Actions(actions_str=action_str, rule_name="test_rule")
1010
action.invoke()
1111

1212

1313
def test_action_stop_training_job_invalid_params():
1414
action_str = '{"name": "stoptraining" , "invalid_job_prefix":"training_prefix"}'
15-
action = Actions(actions_str=action_str)
15+
action = Actions(actions_str=action_str, rule_name="test_rule")
1616
action.invoke()
1717

1818

0 commit comments

Comments
 (0)