Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions smdebug/rules/action/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


class Actions:
def __init__(self, actions_str="", rule_name=""):
def __init__(self, actions_str, rule_name):
self._actions = []
self._logger = get_logger()
actions_str = actions_str.strip() if actions_str is not None else ""
if actions_str == "":
self._logger.info(f"No action specified. Action str is {actions_str}")
self._logger.info(f"No action specified for rule {rule_name}.")
return
self._register_actions(actions_str, rule_name)

Expand Down Expand Up @@ -77,7 +77,7 @@ def _register_actions(self, actions_str="", rule_name=""):
f"Action :{action_dict['name']} not supported. Allowed action names are: {ALLOWED_ACTIONS}"
)

def invoke(self):
def invoke(self, message=""):
self._logger.info("Invoking actions")
for action in self._actions:
action.invoke()
action.invoke(message)
2 changes: 1 addition & 1 deletion smdebug/rules/action/message_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,5 @@ def _send_message(self, message):
self._last_send_mesg_response = response
return response

def invoke(self, message=None):
def invoke(self, message):
self._send_message(message)
12 changes: 8 additions & 4 deletions smdebug/rules/action/stop_training_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,19 @@ def _get_sm_tj_jobs_with_prefix(self):

return list(found_job_dict.keys())

def _stop_training_job(self):
def _stop_training_job(self, message):
if len(self._found_jobs) != 1:
return
self._logger.info(f"Invoking StopTrainingJob action on SM jobname:{self._found_jobs}")
if message != "":
message = f"with message {message}"
self._logger.info(
f"Invoking StopTrainingJob action on SM jobname {self._found_jobs} {message}"
)
try:
res = self._sm_client.stop_training_job(TrainingJobName=self._found_jobs[0])
self._logger.info(f"Stop Training job response:{res}")
except Exception as e:
self._logger.info(f"Got exception while stopping training job{self._found_jobs[0]}:{e}")

def invoke(self, message=None):
self._stop_training_job()
def invoke(self, message):
self._stop_training_job(message)
4 changes: 2 additions & 2 deletions smdebug/rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# This is Rule interface
class Rule(ABC):
def __init__(self, base_trial, other_trials=None, action_str=""):
def __init__(self, base_trial, action_str, other_trials=None):
self.base_trial = base_trial
self.other_trials = other_trials

Expand All @@ -25,7 +25,7 @@ def __init__(self, base_trial, other_trials=None, action_str=""):

self.logger = get_logger()
self.rule_name = self.__class__.__name__
self._actions = Actions(actions_str=action_str, rule_name=self.rule_name)
self._actions = Actions(action_str, rule_name=self.rule_name)
self.report = {
"RuleTriggered": 0,
"Violations": 0,
Expand Down
2 changes: 1 addition & 1 deletion tests/analysis/rules/test_rule_no_refresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def test_no_refresh_invocation():
class TestRule(Rule):
def __init__(self, base_trial):
super().__init__(base_trial=base_trial)
super().__init__(base_trial=base_trial, action_str="")

def set_required_tensors(self, step):
for t in self.base_trial.tensor_names():
Expand Down
4 changes: 2 additions & 2 deletions tests/rules/action/test_message_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

def test_action_stop_training_job():
action_str = '{"name": "stoptraining" , "training_job_prefix":"training_prefix"}'
action = Actions(actions_str=action_str)
action = Actions(actions_str=action_str, rule_name="test_rule")
action.invoke()


def test_action_stop_training_job_invalid_params():
action_str = '{"name": "stoptraining" , "invalid_job_prefix":"training_prefix"}'
action = Actions(actions_str=action_str)
action = Actions(actions_str=action_str, rule_name="test_rule")
action.invoke()


Expand Down