diff --git a/smdebug/rules/action/action.py b/smdebug/rules/action/action.py index d7a4cd52a..989883c2d 100644 --- a/smdebug/rules/action/action.py +++ b/smdebug/rules/action/action.py @@ -12,12 +12,12 @@ class Actions: - def __init__(self, actions_str="", rule_name=""): + def __init__(self, actions_str, rule_name): self._actions = [] self._logger = get_logger() actions_str = actions_str.strip() if actions_str is not None else "" if actions_str == "": - self._logger.info(f"No action specified. Action str is {actions_str}") + self._logger.info(f"No action specified for rule {rule_name}.") return self._register_actions(actions_str, rule_name) @@ -77,7 +77,7 @@ def _register_actions(self, actions_str="", rule_name=""): f"Action :{action_dict['name']} not supported. Allowed action names are: {ALLOWED_ACTIONS}" ) - def invoke(self): + def invoke(self, message=""): self._logger.info("Invoking actions") for action in self._actions: - action.invoke() + action.invoke(message) diff --git a/smdebug/rules/action/message_action.py b/smdebug/rules/action/message_action.py index cdf9c39e2..36d1005b0 100644 --- a/smdebug/rules/action/message_action.py +++ b/smdebug/rules/action/message_action.py @@ -141,5 +141,5 @@ def _send_message(self, message): self._last_send_mesg_response = response return response - def invoke(self, message=None): + def invoke(self, message): self._send_message(message) diff --git a/smdebug/rules/action/stop_training_action.py b/smdebug/rules/action/stop_training_action.py index c60ac60fb..2b22c3f47 100644 --- a/smdebug/rules/action/stop_training_action.py +++ b/smdebug/rules/action/stop_training_action.py @@ -81,15 +81,19 @@ def _get_sm_tj_jobs_with_prefix(self): return list(found_job_dict.keys()) - def _stop_training_job(self): + def _stop_training_job(self, message): if len(self._found_jobs) != 1: return - self._logger.info(f"Invoking StopTrainingJob action on SM jobname:{self._found_jobs}") + if message != "": + message = f"with message {message}" + self._logger.info( + f"Invoking StopTrainingJob action on SM jobname {self._found_jobs} {message}" + ) try: res = self._sm_client.stop_training_job(TrainingJobName=self._found_jobs[0]) self._logger.info(f"Stop Training job response:{res}") except Exception as e: self._logger.info(f"Got exception while stopping training job{self._found_jobs[0]}:{e}") - def invoke(self, message=None): - self._stop_training_job() + def invoke(self, message): + self._stop_training_job(message) diff --git a/smdebug/rules/rule.py b/smdebug/rules/rule.py index 5e20272fd..6cbf27d9a 100644 --- a/smdebug/rules/rule.py +++ b/smdebug/rules/rule.py @@ -13,7 +13,7 @@ # This is Rule interface class Rule(ABC): - def __init__(self, base_trial, other_trials=None, action_str=""): + def __init__(self, base_trial, action_str, other_trials=None): self.base_trial = base_trial self.other_trials = other_trials @@ -25,7 +25,7 @@ def __init__(self, base_trial, other_trials=None, action_str=""): self.logger = get_logger() self.rule_name = self.__class__.__name__ - self._actions = Actions(actions_str=action_str, rule_name=self.rule_name) + self._actions = Actions(action_str, rule_name=self.rule_name) self.report = { "RuleTriggered": 0, "Violations": 0, diff --git a/tests/analysis/rules/test_rule_no_refresh.py b/tests/analysis/rules/test_rule_no_refresh.py index 61a16d8b5..1a5924c20 100644 --- a/tests/analysis/rules/test_rule_no_refresh.py +++ b/tests/analysis/rules/test_rule_no_refresh.py @@ -14,7 +14,7 @@ def test_no_refresh_invocation(): class TestRule(Rule): def __init__(self, base_trial): - super().__init__(base_trial=base_trial) + super().__init__(base_trial=base_trial, action_str="") def set_required_tensors(self, step): for t in self.base_trial.tensor_names(): diff --git a/tests/rules/action/test_message_action.py b/tests/rules/action/test_message_action.py index b1b0d44a1..f5c72eda4 100644 --- a/tests/rules/action/test_message_action.py +++ b/tests/rules/action/test_message_action.py @@ -6,13 +6,13 @@ def test_action_stop_training_job(): action_str = '{"name": "stoptraining" , "training_job_prefix":"training_prefix"}' - action = Actions(actions_str=action_str) + action = Actions(actions_str=action_str, rule_name="test_rule") action.invoke() def test_action_stop_training_job_invalid_params(): action_str = '{"name": "stoptraining" , "invalid_job_prefix":"training_prefix"}' - action = Actions(actions_str=action_str) + action = Actions(actions_str=action_str, rule_name="test_rule") action.invoke()