diff --git a/core/basic_models/actions/basic_actions.py b/core/basic_models/actions/basic_actions.py index 6d7b2f3a..23f91efd 100644 --- a/core/basic_models/actions/basic_actions.py +++ b/core/basic_models/actions/basic_actions.py @@ -1,4 +1,5 @@ # coding: utf-8 +import asyncio import random from typing import Union, Dict, List, Any, Optional @@ -25,8 +26,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.id = id self.version = items.get("version", -1) - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: raise NotImplementedError def on_run_error(self, text_preprocessing_result, user): @@ -51,9 +52,9 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.request_type = items.get("request_type") or self.DEFAULT_REQUEST_TYPE self.request_data = items.get("request_data") - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: - super(CommandAction, self).run(user, text_preprocessing_result, params) + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + await super(CommandAction, self).run(user, text_preprocessing_result, params) return None @@ -66,8 +67,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(DoingNothingAction, self).__init__(items, id) self.nodes = items.get("nodes") or {} - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: commands = [Command(self.command, self.nodes, self.id, request_type=self.request_type, request_data=self.request_data)] return commands @@ -98,11 +99,56 @@ def build_requirement(self): def build_internal_item(self): return self._item - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: result = None - if self.requirement.check(text_preprocessing_result, user, params): - result = self.internal_item.run(user, text_preprocessing_result, params) + if await self.requirement.check(text_preprocessing_result, user, params): + result = await self.internal_item.run(user, text_preprocessing_result, params) + return result + + +class GatherChoiceAction(Action): + version: Optional[int] + requirement_actions: RequirementAction + else_action: Action + + FIELD_REQUIREMENT_KEY = "requirement_actions" + FIELD_ELSE_KEY = "else_action" + + def __init__(self, items: Dict[str, Any], id: Optional[str] = None): + super(GatherChoiceAction, self).__init__(items, id) + self._requirement_items = items[self.FIELD_REQUIREMENT_KEY] + self._else_item = items.get(self.FIELD_ELSE_KEY) + + self.items = self.build_items() + + if self._else_item: + self.else_item = self.build_else_item() + else: + self.else_item = None + + @list_factory(RequirementAction) + def build_items(self): + return self._requirement_items + + @factory(Action) + def build_else_item(self): + return self._else_item + + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + result = None + choice_is_made = False + check_results = await asyncio.gather( + item.requirement.check(text_preprocessing_result, user, params) for item in self.items) + for i, checked in enumerate(check_results): + if checked: + item = self.items[i] + result = await item.internal_item.run(user, text_preprocessing_result, params) + choice_is_made = True + break + if not choice_is_made and self._else_item: + result = await self.else_item.run(user, text_preprocessing_result, params) return result @@ -134,18 +180,17 @@ def build_items(self): def build_else_item(self): return self._else_item - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: result = None choice_is_made = False for item in self.items: - checked = item.requirement.check(text_preprocessing_result, user, params) - if checked: - result = item.internal_item.run(user, text_preprocessing_result, params) + if await item.requirement.check(text_preprocessing_result, user, params): + result = await item.internal_item.run(user, text_preprocessing_result, params) choice_is_made = True break if not choice_is_made and self._else_item: - result = self.else_item.run(user, text_preprocessing_result, params) + result = await self.else_item.run(user, text_preprocessing_result, params) return result @@ -182,13 +227,13 @@ def build_item(self): def build_else_item(self): return self._else_item - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Optional[Dict[str, Union[str, float, int]]]] = None) -> Optional[List[Command]]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Optional[Dict[str, Union[str, float, int]]]] = None) -> Optional[List[Command]]: result = None - if self.requirement.check(text_preprocessing_result, user, params): - result = self.item.run(user, text_preprocessing_result, params) + if await self.requirement.check(text_preprocessing_result, user, params): + result = await self.item.run(user, text_preprocessing_result, params) elif self._else_item: - result = self.else_item.run(user, text_preprocessing_result, params) + result = await self.else_item.run(user, text_preprocessing_result, params) return result @@ -205,11 +250,11 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): def build_actions(self): return self._actions - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: commands = [] for action in self.actions: - action_result = action.run(user, text_preprocessing_result, params) + action_result = await action.run(user, text_preprocessing_result, params) if action_result: commands += action_result return commands @@ -225,8 +270,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self._actions_count = len(items["actions"]) self._last_action_ids_storage = items["last_action_ids_storage"] - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: last_ids = user.last_action_ids[self._last_action_ids_storage] all_indexes = list(range(self._actions_count)) max_last_ids_count = self._actions_count - 1 @@ -236,7 +281,7 @@ def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingRe action_index = random.choice(available_indexes) action = self.actions[action_index] last_ids.add(action_index) - result = action.run(user, text_preprocessing_result, params) + result = await action.run(user, text_preprocessing_result, params) return result @@ -251,8 +296,8 @@ def __init__(self, items, id=None): def build_actions(self): return self._raw_actions - def run(self, user, text_preprocessing_result, params=None): + async def run(self, user, text_preprocessing_result, params=None): pos = random.randint(0, len(self._raw_actions) - 1) action = self.actions[pos] - command_list = action.run(user, text_preprocessing_result, params=params) + command_list = await action.run(user, text_preprocessing_result, params=params) return command_list diff --git a/core/basic_models/actions/client_profile_actions.py b/core/basic_models/actions/client_profile_actions.py index 0aa53e9e..2e87b29d 100644 --- a/core/basic_models/actions/client_profile_actions.py +++ b/core/basic_models/actions/client_profile_actions.py @@ -54,8 +54,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): {"memoryPartition": key, "tags": val} for key, val in self._nodes["memory"].items() ] - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: self._nodes["consumer"] = {"projectId": user.settings["template_settings"]["project_id"]} settings_kafka_key = user.settings["template_settings"].get("client_profile_kafka_key") @@ -70,7 +70,7 @@ def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult user.behaviors.add(user.message.generate_new_callback_id(), self.behavior, scenario_id, text_preprocessing_result.raw, action_params=pickle_deepcopy(params)) - commands = super().run(user, text_preprocessing_result, params) + commands = await super().run(user, text_preprocessing_result, params) return commands @@ -144,8 +144,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.kafka_key = items.get("kafka_key") self._nodes["root_nodes"] = {"protocolVersion": items.get("protocolVersion") or 3} - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: self._nodes["consumer"] = {"projectId": user.settings["template_settings"]["project_id"]} settings_kafka_key = user.settings["template_settings"].get("client_profile_kafka_key") @@ -155,5 +155,5 @@ def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult "kafka_replyTopic": user.settings["template_settings"]["consumer_topic"] } - commands = super().run(user, text_preprocessing_result, params) + commands = await super().run(user, text_preprocessing_result, params) return commands diff --git a/core/basic_models/actions/counter_actions.py b/core/basic_models/actions/counter_actions.py index 5aeb8b25..c5d4260b 100644 --- a/core/basic_models/actions/counter_actions.py +++ b/core/basic_models/actions/counter_actions.py @@ -18,20 +18,20 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.value = items.get("value", 1) self.lifetime = items.get("lifetime") - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.counters[self.key].inc(self.value, self.lifetime) class CounterDecrementAction(CounterIncrementAction): - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.counters[self.key].dec(-self.value, self.lifetime) class CounterClearAction(CounterIncrementAction): - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.counters.clear(self.key) @@ -48,8 +48,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.reset_time = items.get("reset_time", False) self.time_shift = items.get("time_shift", 0) - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.counters[self.key].set(self.value, self.reset_time, self.time_shift) @@ -61,7 +61,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.reset_time = items.get("reset_time", False) self.time_shift = items.get("time_shift", 0) - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: value = user.counters[self.src].value user.counters[self.dst].set(value, self.reset_time, self.time_shift) diff --git a/core/basic_models/actions/external_actions.py b/core/basic_models/actions/external_actions.py index b4cddca0..e96495c0 100644 --- a/core/basic_models/actions/external_actions.py +++ b/core/basic_models/actions/external_actions.py @@ -24,8 +24,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(ExternalAction, self).__init__(items, id) self._action_key = items["action"] - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: action = user.descriptions["external_actions"][self._action_key] - commands = action.run(user, text_preprocessing_result, params) + commands = await action.run(user, text_preprocessing_result, params) return commands diff --git a/core/basic_models/actions/push_action.py b/core/basic_models/actions/push_action.py index d9174716..288fc51c 100644 --- a/core/basic_models/actions/push_action.py +++ b/core/basic_models/actions/push_action.py @@ -64,8 +64,8 @@ def _render_request_data(self, action_params): } return request_data - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: params = params or {} command_params = { "surface": self.surface, diff --git a/core/basic_models/actions/string_actions.py b/core/basic_models/actions/string_actions.py index 1918579c..1432252b 100644 --- a/core/basic_models/actions/string_actions.py +++ b/core/basic_models/actions/string_actions.py @@ -1,4 +1,5 @@ # coding: utf-8 +import asyncio import random from copy import copy from typing import Union, Dict, List, Any, Optional @@ -80,8 +81,8 @@ def _get_rendered_tree_recursive(self, value, params, no_empty=False): result = value return result - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: raise NotImplementedError @@ -101,12 +102,14 @@ class StringAction(NodeAction): } } """ + def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(StringAction, self).__init__(items, id) def _generate_command_context(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Dict: + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Dict: command_params = dict() + params = params or {} collected = user.parametrizer.collect(text_preprocessing_result, filter_params={"command": self.command}) params.update(collected) @@ -116,8 +119,8 @@ def _generate_command_context(self, user: BaseUser, text_preprocessing_result: B command_params[key] = rendered return command_params - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: # Example: Command("ANSWER_TO_USER", {"answer": {"key1": "string1", "keyN": "stringN"}}) params = params or {} command_params = self._generate_command_context(user, text_preprocessing_result, params) @@ -145,8 +148,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(AfinaAnswerAction, self).__init__(items, id) self.command: str = ANSWER_TO_USER - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: params = user.parametrizer.collect(text_preprocessing_result, filter_params={"command": self.command}) answer_params = dict() result = [] @@ -216,13 +219,15 @@ class SDKAnswer(NodeAction): карточки на андроиде требуют sdk_version не ниже "20.03.0.0" """ INDEX_WILDCARD = "*index*" - RANDOM_PATH = [['items', INDEX_WILDCARD, 'bubble', 'text'], ['pronounceText'], ['suggestions', 'buttons', INDEX_WILDCARD, 'title']] + RANDOM_PATH = [['items', INDEX_WILDCARD, 'bubble', 'text'], ['pronounceText'], + ['suggestions', 'buttons', INDEX_WILDCARD, 'title']] def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(SDKAnswer, self).__init__(items, id) self.command: str = ANSWER_TO_USER if self._nodes == {}: - self._nodes = {i: items.get(i) for i in items if i not in ['random_paths', 'same_ans', 'type', 'support_templates', 'no_empty_nodes']} + self._nodes = {i: items.get(i) for i in items if + i not in ['random_paths', 'same_ans', 'type', 'support_templates', 'no_empty_nodes']} # функция идет по RANDOM_PATH, числа в нем считает индексами массива, # INDEX_WILDCARD - произвольным индексом массива, прочее - ключами словаря @@ -243,8 +248,8 @@ def random_by_path(self, input_dict, nested_key): return last_dict[k] = random.choice(last_dict[k]) - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: result = [] params = user.parametrizer.collect(text_preprocessing_result, filter_params={"command": self.command}) rendered = self._get_rendered_tree(self.nodes, params, self.no_empty_nodes) @@ -263,6 +268,191 @@ def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingRe return result +class GatherSDKAnswerToUser(NodeAction): + """ + Example: + { + "type": "sdk_answer_to_user", + "root": + [ + { + "type": "pronounce_text", + "text": "ans" + } + ], + "static": { + "ios_card": { + "type": "list_card", + "cells": [ + { + "ios_params": "ios" + } + ] + }, + "android_card": { + "type": "list_card", + "cells": [ + { + "android_params": "android" + } + ] + }, + "tittle1": "static tittle1", + "tittle2": "static tittle2", + "sg_dl": "www.www.www", + "sg_text": "static suggest text" + }, + "random_choice": [ + { + "ans": "random text1" + }, + { + "ans": "random text2" + } + ], + "items": [ + { + "type": "item_card", + "value": "ios_card", + "requirement": { + "type": "external", + "requirement": "OCTOBER_iOS" + } + }, + { + "type": "item_card", + "value": "android_card", + "requirement": { + "type": "external", + "requirement": "OCTOBER_android" + } + }, + { + "type": "bubble_text", + "text": "ans" + } + ], + "suggestions": [ + { + "type": "suggest_text", + "text": "sg_text", + "title": "tittle1" + }, + { + "type": "suggest_deeplink", + "text": "sg_text", + "deep_link": "tittle1" + } + ] + } + + Output: + { + "items": + [{ + "card": {"type": "list_card", "cells": [{"ios_params": "ios"}]}}, + {"bubble": {"text": "random texti", "markdown": True} + }], + "suggestions": + {"buttons": + [ + {"title": "static tittle1", "action": {"text": "static suggest text", "type": "text"}}, + {"title": "static tittle2", "action": {"deep_link": "www.www.www", "type": "deep_link"}} + ] + }, + "pronounceText": "random texti" + } + ответ c карточками с случайным выбором текстов из random_choice + карточки на андроиде требуют sdk_version не ниже "20.03.0.0" + """ + + ITEMS = "items" + SUGGESTIONS = "suggestions" + SUGGESTIONS_TEMPLATE = "suggestions_template" + BUTTONS = "buttons" + STATIC = "static" + RANDOM_CHOICE = "random_choice" + COMMAND = "command" + ROOT = "root" + + def __init__(self, items: Dict[str, Any], id: Optional[str] = None): + super(GatherSDKAnswerToUser, self).__init__(items, id) + self.command: str = ANSWER_TO_USER + self._nodes[self.STATIC] = items.get(self.STATIC, {}) + self._nodes[self.RANDOM_CHOICE] = items.get(self.RANDOM_CHOICE, {}) + self._nodes[self.SUGGESTIONS] = items.get(self.SUGGESTIONS, {}) + self._nodes[self.SUGGESTIONS_TEMPLATE] = items.get(self.SUGGESTIONS_TEMPLATE, {}) + self._items = items.get(self.ITEMS, {}) + self._suggests = items.get(self.SUGGESTIONS, {}) + self._suggests_template = items.get(self.SUGGESTIONS_TEMPLATE) + self._root = items.get(self.ROOT, {}) + + self.items = self.build_items() + self.suggests = self.build_suggests() + self.root = self.build_root() + + @list_factory(SdkAnswerItem) + def build_items(self): + return self._items + + @list_factory(SdkAnswerItem) + def build_suggests(self): + return self._suggests + + @list_factory(SdkAnswerItem) + def build_root(self): + return self._root + + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + + result = [] + params = user.parametrizer.collect(text_preprocessing_result, filter_params={self.COMMAND: self.command}) + rendered = self._get_rendered_tree(self.nodes[self.STATIC], params, self.no_empty_nodes) + if self._nodes[self.RANDOM_CHOICE]: + random_node = random.choice(self.nodes[self.RANDOM_CHOICE]) + rendered_random = self._get_rendered_tree(random_node, params, self.no_empty_nodes) + rendered.update(rendered_random) + out = {} + check_results = await asyncio.gather( + item.requirement.check(text_preprocessing_result, user, params) for item in self.items) + for i, check in enumerate(check_results): + item = self.items[i] + if check: + out.setdefault(self.ITEMS, []).append(item.render(rendered)) + + if self._suggests_template is not None: + out[self.SUGGESTIONS] = self._get_rendered_tree(self.nodes[self.SUGGESTIONS_TEMPLATE], params, + self.no_empty_nodes) + else: + check_results = await asyncio.gather( + suggest.requirement.check(text_preprocessing_result, user, params) for suggest in self.suggests) + for i, check in enumerate(check_results): + suggest = self.suggests[i] + if check: + data_dict = out.setdefault(self.SUGGESTIONS, {self.BUTTONS: []}) + buttons = data_dict[self.BUTTONS] + rendered_text = suggest.render(rendered) + buttons.append(rendered_text) + check_results = await asyncio.gather( + part.requirement.check(text_preprocessing_result, user) for part in self.root) + for i, check in enumerate(check_results): + part = self.root[i] + if check: + out.update(part.render(rendered)) + if rendered or not self.no_empty_nodes: + result = [ + Command( + self.command, + out, + self.id, + request_type=self.request_type, + request_data=self.request_data, + ) + ] + return result + + class SDKAnswerToUser(NodeAction): """ Example: @@ -398,8 +588,8 @@ def build_suggests(self): def build_root(self): return self._root - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> List[Command]: result = [] params = user.parametrizer.collect(text_preprocessing_result, filter_params={self.COMMAND: self.command}) @@ -410,7 +600,7 @@ def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingRe rendered.update(rendered_random) out = {} for item in self.items: - if item.requirement.check(text_preprocessing_result, user, params): + if await item.requirement.check(text_preprocessing_result, user, params): out.setdefault(self.ITEMS, []).append(item.render(rendered)) if self._suggests_template is not None: @@ -418,13 +608,13 @@ def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingRe self.no_empty_nodes) else: for suggest in self.suggests: - if suggest.requirement.check(text_preprocessing_result, user, params): + if await suggest.requirement.check(text_preprocessing_result, user, params): data_dict = out.setdefault(self.SUGGESTIONS, {self.BUTTONS: []}) buttons = data_dict[self.BUTTONS] rendered_text = suggest.render(rendered) buttons.append(rendered_text) for part in self.root: - if part.requirement.check(text_preprocessing_result, user): + if await part.requirement.check(text_preprocessing_result, user): out.update(part.render(rendered)) if rendered or not self.no_empty_nodes: result = [ diff --git a/core/basic_models/actions/variable_actions.py b/core/basic_models/actions/variable_actions.py index 580d4c56..4bf79e18 100644 --- a/core/basic_models/actions/variable_actions.py +++ b/core/basic_models/actions/variable_actions.py @@ -32,8 +32,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): value: str = items["value"] self.template: UnifiedTemplate = UnifiedTemplate(value) - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: params = user.parametrizer.collect(text_preprocessing_result) try: # if path is wrong, it may fail with UndefinedError @@ -61,8 +61,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(DeleteVariableAction, self).__init__(items, id) self.key: str = items["key"] - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.variables.delete(self.key) @@ -72,6 +72,6 @@ class ClearVariablesAction(Action): def __init__(self, items: Dict[str, Any] = None, id: Optional[str] = None): super(ClearVariablesAction, self).__init__(items, id) - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.variables.clear() diff --git a/core/basic_models/requirement/basic_requirements.py b/core/basic_models/requirement/basic_requirements.py index 5c3b2f61..1781fe05 100644 --- a/core/basic_models/requirement/basic_requirements.py +++ b/core/basic_models/requirement/basic_requirements.py @@ -1,3 +1,4 @@ +import asyncio import hashlib from datetime import datetime, timezone from random import random @@ -38,8 +39,8 @@ def _log_params(self): "requirement": self.__class__.__name__ } - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return True def on_check_error(self, text_preprocessing_result, user): @@ -63,20 +64,44 @@ def build_requirements(self): return self._requirements +class GatherAndRequirement(CompositeRequirement): + + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: + check_results = await asyncio.gather( + requirement.check(text_preprocessing_result=text_preprocessing_result, user=user, params=params) + for requirement in self.requirements) + return all(check_results) + + class AndRequirement(CompositeRequirement): - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: - return all(requirement.check(text_preprocessing_result=text_preprocessing_result, user=user, params=params) - for requirement in self.requirements) + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: + return all( + [await requirement.check(text_preprocessing_result=text_preprocessing_result, user=user, params=params) + for requirement in self.requirements] + ) + + +class GatherOrRequirement(CompositeRequirement): + + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: + check_results = await asyncio.gather( + requirement.check(text_preprocessing_result=text_preprocessing_result, user=user, params=params) + for requirement in self.requirements) + return any(check_results) class OrRequirement(CompositeRequirement): - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: - return any(requirement.check(text_preprocessing_result=text_preprocessing_result, user=user, params=params) - for requirement in self.requirements) + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: + return any( + [await requirement.check(text_preprocessing_result=text_preprocessing_result, user=user, params=params) + for requirement in self.requirements] + ) class NotRequirement(Requirement): @@ -91,9 +116,10 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: def build_requirement(self): return self._requirement - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: - return not self.requirement.check(text_preprocessing_result=text_preprocessing_result, user=user, params=params) + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: + return not await self.requirement.check(text_preprocessing_result=text_preprocessing_result, user=user, + params=params) class ComparisonRequirement(Requirement): @@ -116,8 +142,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super(RandomRequirement, self).__init__(items, id) self.percent = items["percent"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: result = random() * 100 return result < self.percent @@ -129,8 +155,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super(TopicRequirement, self).__init__(items, id) self.topics = items["topics"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return user.message.topic_key in self.topics @@ -139,8 +165,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super(TemplateRequirement, self).__init__(items, id) self._template = UnifiedTemplate(items["template"]) - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: params = params or {} collected = user.parametrizer.collect(text_preprocessing_result) params.update(collected) @@ -149,7 +175,7 @@ def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: Ba return True if render_result == "False": return False - raise TypeError(f'Template result should be "True" or "False", got: ', + raise TypeError(f'Template result should be "True" or "False", got: ' f'{render_result} for template {self.items["template"]}') @@ -160,8 +186,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super(RollingRequirement, self).__init__(items, id) self.percent = items["percent"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: id = user.id s = id.encode('utf-8') hash = int(hashlib.sha256(s).hexdigest(), 16) @@ -173,7 +199,7 @@ class TimeRequirement(ComparisonRequirement): def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super().__init__(items, id) - def check( + async def check( self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, @@ -199,7 +225,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super().__init__(items, id) self.match_cron = items['match_cron'] - def check( + async def check( self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, @@ -226,14 +252,14 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: id, ) - def check( + async def check( self, text_preprocessing_result: TextPreprocessingResult, user: User, params: Dict[str, Any] = None ) -> bool: result = bool( - self.filler.extract(text_preprocessing_result, user, params), + await self.filler.extract(text_preprocessing_result, user, params), ) return result @@ -252,8 +278,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: def classifier(self) -> Classifier: return ExternalClassifier(self._classifier) - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: check_res = True classifier = self.classifier with StatsTimer() as timer: @@ -279,8 +305,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: self.field_name = items["field_name"] self.value = items["value"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> bool: return user.forms[self.form_name].fields[self.field_name].value == self.value @@ -300,8 +326,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: # Если среда исполнения задана, то проверям, что среда в списке возможных значений для сценария, иначе - False self.check_result = self.environment in self.values if self.environment else False - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return self.check_result @@ -314,8 +340,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super(CharacterIdRequirement, self).__init__(items=items, id=id) self.values = items["values"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> bool: return user.message.payload["character"]["id"] in self.values @@ -328,6 +354,6 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super(FeatureToggleRequirement, self).__init__(items=items, id=id) self.toggle_name = items["toggle_name"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> bool: return user.settings["template_settings"].get(self.toggle_name, False) diff --git a/core/basic_models/requirement/counter_requirements.py b/core/basic_models/requirement/counter_requirements.py index 1619a8b8..dce4a532 100644 --- a/core/basic_models/requirement/counter_requirements.py +++ b/core/basic_models/requirement/counter_requirements.py @@ -18,8 +18,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: items = items or {} self.key = items["key"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: counter = user.counters[self.key] return self.operator.compare(counter) @@ -34,7 +34,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: self.key = items["key"] self.fallback_value = items.get("fallback_value") or False - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: _time = user.counters[self.key].update_time return self.operator.compare(time() - _time) if _time else self.fallback_value diff --git a/core/basic_models/requirement/device_requirements.py b/core/basic_models/requirement/device_requirements.py index 8323b89d..2ba2f71d 100644 --- a/core/basic_models/requirement/device_requirements.py +++ b/core/basic_models/requirement/device_requirements.py @@ -23,8 +23,8 @@ def descr_to_check_in(self): def get_field(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser): return NotImplementedError - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return self.get_field(text_preprocessing_result, user) in self.descr_to_check_in @@ -51,8 +51,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: items = items or {} self.platfrom_type = items["platfrom_type"] - def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return user.message.device.platform_type == self.platfrom_type @@ -67,8 +67,8 @@ def build_operator(self): class PlatformVersionRequirement(BasicVersionRequirement): - def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: platform_version = convert_version_to_list_of_int(user.message.device.platform_version) return self.operator.compare(platform_version) if platform_version is not None else False @@ -80,15 +80,15 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: items = items or {} self.surface = items["surface"] - def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return user.message.device.surface == self.surface class SurfaceVersionRequirement(BasicVersionRequirement): - def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: surface_version = convert_version_to_list_of_int(user.message.device.surface_version) return self.operator.compare(surface_version) if surface_version is not None else False @@ -100,8 +100,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: items = items or {} self.app_type = items["app_type"] - def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return self.app_type in user.message.device.features.get("appTypes", []) @@ -112,6 +112,6 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: items = items or {} self.property_type = items["property_type"] - def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: TextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return user.message.device.capabilities.get(self.property_type, {}).get("available", False) diff --git a/core/basic_models/requirement/external_requirements.py b/core/basic_models/requirement/external_requirements.py index e5aace78..a0660ce0 100644 --- a/core/basic_models/requirement/external_requirements.py +++ b/core/basic_models/requirement/external_requirements.py @@ -20,7 +20,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super(ExternalRequirement, self).__init__(items, id) self.requirement = items["requirement"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: requirement = user.descriptions["external_requirements"][self.requirement] - return requirement.check(text_preprocessing_result, user, params) + return await requirement.check(text_preprocessing_result, user, params) diff --git a/core/basic_models/requirement/project_requirements.py b/core/basic_models/requirement/project_requirements.py index 3114ce1d..5f1203b5 100644 --- a/core/basic_models/requirement/project_requirements.py +++ b/core/basic_models/requirement/project_requirements.py @@ -14,6 +14,6 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: self._key = items["key"] self._value = items["value"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: return user.settings[self._config][self._key] == self._value diff --git a/core/basic_models/requirement/user_text_requirements.py b/core/basic_models/requirement/user_text_requirements.py index 8cbbf415..dec25d58 100644 --- a/core/basic_models/requirement/user_text_requirements.py +++ b/core/basic_models/requirement/user_text_requirements.py @@ -18,8 +18,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super(AnySubstringInLoweredTextRequirement, self).__init__(items, id) self.substrings = self.items["substrings"] - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: lowered_text = text_preprocessing_result.lower() if isinstance(text_preprocessing_result, str) \ else text_preprocessing_result.raw["original_text"].lower() return any(s.lower() in lowered_text for s in self.substrings) @@ -49,8 +49,8 @@ class IntersectionWithTokensSetRequirement(NormalizedInputWordsRequirement): Слова из input_words также проходят нормализацию перед сравнением. """ - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: words_normalized_set = set([ token["lemma"] for token in text_preprocessing_result.raw["tokenized_elements_list_pymorphy"] if not token.get("token_type") == "SENTENCE_ENDPOINT_TOKEN" @@ -71,8 +71,8 @@ class NormalizedTextInSetRequirement(NormalizedInputWordsRequirement): нормализованных строк из input_words, иначе - False. """ - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: normalized_text = text_preprocessing_result.raw["normalized_text"].replace(".", "").strip() result = normalized_text in self.normalized_input_words if result: @@ -94,8 +94,8 @@ class PhoneNumberNumberRequirement(ComparisonRequirement): def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: super().__init__(items, id) - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: BaseUser, + params: Dict[str, Any] = None) -> bool: len_phone_number_token = len(text_preprocessing_result.get_token_values_by_type("PHONE_NUMBER_TOKEN")) result = self.operator.compare(len_phone_number_token) if result: @@ -114,7 +114,7 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: self.min_num = float(items["min_num"]) self.max_num = float(items["max_num"]) - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> bool: num = float(text_preprocessing_result.num_token_values) return self.min_num <= num <= self.max_num if num else False diff --git a/core/basic_models/scenarios/base_scenario.py b/core/basic_models/scenarios/base_scenario.py index 5a92d9c3..33bfa5c3 100644 --- a/core/basic_models/scenarios/base_scenario.py +++ b/core/basic_models/scenarios/base_scenario.py @@ -46,23 +46,23 @@ def build_actions(self): def build_available_requirement(self): return self._available_requirement - def check_available(self, text_preprocessing_result, user): + async def check_available(self, text_preprocessing_result, user): if not self.switched_off: - return self.available_requirement.check(text_preprocessing_result, user) + return await self.available_requirement.check(text_preprocessing_result, user) return False def _log_params(self): return {log_const.KEY_NAME: log_const.SCENARIO_VALUE} - def text_fits(self, text_preprocessing_result, user): + async def text_fits(self, text_preprocessing_result, user): return False - def get_no_commands_action(self, user, text_preprocessing_result, params: Dict[str, Any] = None): + async def get_no_commands_action(self, user, text_preprocessing_result, params: Dict[str, Any] = None): log_params = {log_const.KEY_NAME: scenarios_log_const.CHOSEN_ACTION_VALUE, scenarios_log_const.CHOSEN_ACTION_VALUE: self._empty_answer} log(scenarios_log_const.CHOSEN_ACTION_MESSAGE, user, log_params) try: - empty_answer = self.empty_answer.run(user, text_preprocessing_result, params) or [] + empty_answer = await self.empty_answer.run(user, text_preprocessing_result, params) or [] except KeyError: log_params = {log_const.KEY_NAME: scenarios_log_const.CHOSEN_ACTION_VALUE} log("Scenario has empty answer, but empty_answer action isn't defined", @@ -70,11 +70,11 @@ def get_no_commands_action(self, user, text_preprocessing_result, params: Dict[s empty_answer = [] return empty_answer - def get_action_results(self, user, text_preprocessing_result, + async def get_action_results(self, user, text_preprocessing_result, actions: List[Action], params: Dict[str, Any] = None) -> List[Command]: results = [] for action in actions: - result = action.run(user, text_preprocessing_result, params) + result = await action.run(user, text_preprocessing_result, params) log_params = self._log_params() log_params["class"] = action.__class__.__name__ log("called action: %(class)s", user, log_params) @@ -96,5 +96,5 @@ def get_action_results(self, user, text_preprocessing_result, def history(self): return {"scenario_path": [{"scenario": self.id, "node": None}]} - def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): - return self.get_action_results(user, text_preprocessing_result, self.actions, params) + async def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): + return await self.get_action_results(user, text_preprocessing_result, self.actions, params) diff --git a/core/db_adapter/aioredis_adapter.py b/core/db_adapter/aioredis_adapter.py index e80e88a9..b85d6084 100644 --- a/core/db_adapter/aioredis_adapter.py +++ b/core/db_adapter/aioredis_adapter.py @@ -3,15 +3,14 @@ import aioredis import typing -from core.db_adapter.db_adapter import DBAdapter +from core.db_adapter.db_adapter import AsyncDBAdapter from core.db_adapter import error -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring from core.logging.logger_utils import log -class AIORedisAdapter(DBAdapter): - IS_ASYNC = True +class AIORedisAdapter(AsyncDBAdapter): def __init__(self, config=None): super().__init__(config) @@ -22,20 +21,20 @@ def __init__(self, config=None): except KeyError: pass - @monitoring.got_histogram_decorate("save_time") + @monitoring.monitoring.got_histogram_decorate("save_time") async def save(self, id, data): - return await self._run(self._save, id, data) + return await self._async_run(self._save, id, data) - @monitoring.got_histogram_decorate("save_time") + @monitoring.monitoring.got_histogram_decorate("save_time") async def replace_if_equals(self, id, sample, data): - return await self._run(self._replace_if_equals, id, sample, data) + return await self._async_run(self._replace_if_equals, id, sample, data) - @monitoring.got_histogram_decorate("get_time") + @monitoring.monitoring.got_histogram_decorate("get_time") async def get(self, id): - return await self._run(self._get, id) + return await self._async_run(self._get, id) async def path_exists(self, path): - return await self._run(self._path_exists, path) + return await self._async_run(self._path_exists, path) async def connect(self): print("Here is the content of REDIS_CONFIG:", self.config) @@ -61,14 +60,14 @@ async def _get(self, id): data = await self._redis.get(id) return data - def _list_dir(self, path): + async def _list_dir(self, path): raise error.NotSupportedOperation - def _glob(self, path, pattern): + async def _glob(self, path, pattern): raise error.NotSupportedOperation async def _path_exists(self, path): return await self._redis.exists(path) - def _on_prepare(self): + async def _on_prepare(self): pass diff --git a/core/db_adapter/aioredis_sentinel_adapter.py b/core/db_adapter/aioredis_sentinel_adapter.py index 070d6572..5d01a59f 100644 --- a/core/db_adapter/aioredis_sentinel_adapter.py +++ b/core/db_adapter/aioredis_sentinel_adapter.py @@ -1,17 +1,15 @@ import copy -import aioredis import typing from aioredis.sentinel import Sentinel -from core.db_adapter.db_adapter import DBAdapter +from core.db_adapter.db_adapter import AsyncDBAdapter from core.db_adapter import error -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring from core.logging.logger_utils import log -class AIORedisSentinelAdapter(DBAdapter): - IS_ASYNC = True +class AIORedisSentinelAdapter(AsyncDBAdapter): def __init__(self, config=None): super().__init__(config) @@ -24,21 +22,21 @@ def __init__(self, config=None): except KeyError: pass - @monitoring.got_histogram_decorate("save_time") + @monitoring.monitoring.got_histogram_decorate("save_time") async def save(self, id, data): - return await self._run(self._save, id, data) + return await self._async_run(self._save, id, data) - @monitoring.got_histogram_decorate("save_time") + @monitoring.monitoring.got_histogram_decorate("save_time") async def replace_if_equals(self, id, sample, data): - return await self._run(self._replace_if_equals, id, sample, data) + return await self._async_run(self._replace_if_equals, id, sample, data) - @monitoring.got_histogram_decorate("get_time") + @monitoring.monitoring.got_histogram_decorate("get_time") async def get(self, id): - return await self._run(self._get, id) + return await self._async_run(self._get, id) async def path_exists(self, path): - return await self._run(self._path_exists, path) + return await self._async_run(self._path_exists, path) async def connect(self): @@ -57,7 +55,7 @@ async def connect(self): sentinels_tuples.append(tuple(sent)) self._sentinel = Sentinel(sentinels_tuples, **config) - def _open(self, filename, *args, **kwargs): + async def _open(self, filename, *args, **kwargs): pass async def _save(self, id, data): @@ -73,15 +71,15 @@ async def _get(self, id): data = await redis.get(id) return data - def _list_dir(self, path): + async def _list_dir(self, path): raise error.NotSupportedOperation - def _glob(self, path, pattern): + async def _glob(self, path, pattern): raise error.NotSupportedOperation async def _path_exists(self, path): redis = await self._sentinel.master_for(self.service_name, socket_timeout=self.socket_timeout) return await redis.exists(path) - def _on_prepare(self): + async def _on_prepare(self): pass diff --git a/core/db_adapter/ceph/ceph_adapter.py b/core/db_adapter/ceph/ceph_adapter.py index 97d91bb2..ce75c8b9 100644 --- a/core/db_adapter/ceph/ceph_adapter.py +++ b/core/db_adapter/ceph/ceph_adapter.py @@ -9,7 +9,7 @@ from core.db_adapter.ceph.ceph_io import CephIO from core.db_adapter.db_adapter import DBAdapter from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring ssl._create_default_https_context = ssl._create_unverified_context @@ -35,7 +35,7 @@ def connect(self): params={log_const.KEY_NAME: log_const.HANDLED_EXCEPTION_VALUE}, level="ERROR", exc_info=True) - monitoring.got_counter("ceph_connection_exception") + monitoring.monitoring.got_counter("ceph_connection_exception") raise @property diff --git a/core/db_adapter/db_adapter.py b/core/db_adapter/db_adapter.py index 9ac6979a..9582dd85 100644 --- a/core/db_adapter/db_adapter.py +++ b/core/db_adapter/db_adapter.py @@ -1,7 +1,11 @@ # coding: utf-8 +import asyncio + +import core.logging.logger_constants as log_const +from core.logging.logger_utils import log from core.model.factory import build_factory from core.model.registered import Registered -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring from core.utils.rerunable import Rerunable db_adapters = Registered() @@ -61,15 +65,15 @@ def path_exists(self, path): def mtime(self, path): return self._run(self._mtime, path) - @monitoring.got_histogram_decorate("save_time") + @monitoring.monitoring.got_histogram_decorate("save_time") def save(self, id, data): return self._run(self._save, id, data) - @monitoring.got_histogram_decorate("save_time") + @monitoring.monitoring.got_histogram_decorate("save_time") def replace_if_equals(self, id, sample, data): return self._run(self._replace_if_equals, id, sample, data) - @monitoring.got_histogram_decorate("get_time") + @monitoring.monitoring.got_histogram_decorate("get_time") def get(self, id): return self._run(self._get, id) @@ -82,3 +86,63 @@ def _handled_exception(self): def _on_all_tries_fail(self): raise + + +class AsyncDBAdapter(DBAdapter): + IS_ASYNC = True + + async def _on_all_tries_fail(self): + raise + + async def _save(self, id, data): + raise NotImplementedError + + async def _replace_if_equals(self, id, sample, data): + raise NotImplementedError + + async def _get(self, id): + raise NotImplementedError + + async def _path_exists(self, path): + raise NotImplementedError + + async def path_exists(self, path): + return await self._async_run(self._path_exists, path) + + @monitoring.monitoring.got_histogram("save_time") + async def save(self, id, data): + return await self._async_run(self._save, id, data) + + @monitoring.monitoring.got_histogram("save_time") + async def replace_if_equals(self, id, sample, data): + return await self._async_run(self._replace_if_equals, id, sample, data) + + @monitoring.monitoring.got_histogram("get_time") + async def get(self, id): + return await self._async_run(self._get, id) + + async def _async_run(self, action, *args, _try_count=None, **kwargs): + if _try_count is None: + _try_count = self.try_count + if _try_count <= 0: + await self._on_all_tries_fail() + _try_count = _try_count - 1 + try: + result = await action(*args, **kwargs) if asyncio.iscoroutinefunction(action) \ + else action(*args, **kwargs) + except self._handled_exception as e: + params = { + "class_name": str(self.__class__), + "exception": str(e), + "try_count": _try_count, + log_const.KEY_NAME: log_const.HANDLED_EXCEPTION_VALUE + } + log("%(class_name)s run failed with %(exception)s.\n Got %(try_count)s tries left.", + params=params, + level="ERROR") + self._on_prepare() + result = await self._async_run(action, *args, _try_count=_try_count, **kwargs) + counter_name = self._get_counter_name() + if counter_name: + monitoring.monitoring.got_counter(f"{counter_name}_exception") + return result diff --git a/core/db_adapter/ignite_adapter.py b/core/db_adapter/ignite_adapter.py index 3e87db43..53d6795c 100644 --- a/core/db_adapter/ignite_adapter.py +++ b/core/db_adapter/ignite_adapter.py @@ -1,17 +1,22 @@ # coding: utf-8 import random +from concurrent.futures._base import CancelledError import pyignite +from pyignite import AioClient +from pyignite.aio_cache import AioCache from pyignite.exceptions import ReconnectError, SocketError import core.logging.logger_constants as log_const from core.db_adapter import error -from core.db_adapter.db_adapter import DBAdapter +from core.db_adapter.db_adapter import AsyncDBAdapter from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring -class IgniteAdapter(DBAdapter): +class IgniteAdapter(AsyncDBAdapter): + _client: AioClient + _cache = AioCache def __init__(self, config): self._init_params = config.get("init_params", {}) @@ -22,23 +27,23 @@ def __init__(self, config): self._cache = None super(IgniteAdapter, self).__init__(config) - def _open(self, filename, *args, **kwargs): + async def _open(self, filename, *args, **kwargs): pass - def _list_dir(self, path): + async def _list_dir(self, path): raise error.NotSupportedOperation - def _glob(self, path, pattern): + async def _glob(self, path, pattern): raise error.NotSupportedOperation - def _path_exists(self, path): + async def _path_exists(self, path): raise error.NotSupportedOperation - def connect(self): + async def connect(self): try: - self._client = pyignite.Client(**self._init_params) - self._client.connect(self._url) - self._cache = self._client.get_or_create_cache(self._cache_name) + self._client = pyignite.aio_client.AioClient(**self._init_params) + await self._client.connect(self._url) + self._cache = await self._client.get_or_create_cache(self._cache_name) logger_args = { log_const.KEY_NAME: log_const.IGNITE_VALUE, "pyignite_args": str(self._init_params), @@ -47,37 +52,39 @@ def connect(self): log("IgniteAdapter to servers %(pyignite_addresses)s created", params=logger_args, level="WARNING") except Exception: log("IgniteAdapter connect error", - params={log_const.KEY_NAME: log_const.HANDLED_EXCEPTION_VALUE}, - level="ERROR", - exc_info=True) - monitoring.got_counter("ignite_connection_exception") + params={log_const.KEY_NAME: log_const.HANDLED_EXCEPTION_VALUE}, + level="ERROR", + exc_info=True) + monitoring.monitoring.got_counter("ignite_connection_exception") raise - def _save(self, id, data): - return self.cache.put(id, data) + async def _save(self, id, data): + cache = await self.get_cache() + return await cache.put(id, data) - def _replace_if_equals(self, id, sample, data): - return self._cache.replace_if_equals(id, sample, data) + async def _replace_if_equals(self, id, sample, data): + cache = await self.get_cache() + return await cache.replace_if_equals(id, sample, data) - def _get(self, id): - data = self.cache.get(id) + async def _get(self, id): + cache = await self.get_cache() + data = await cache.get(id) return data - @property - def cache(self): - if self._cache is None: + async def get_cache(self): + if self._client is None: log('Attempt to recreate ignite instance', level="WARNING") - self.connect() - monitoring.got_counter("ignite_reconnection") + await self.connect() + monitoring.monitoring.got_counter("ignite_reconnection") return self._cache @property def _handled_exception(self): # TypeError is raised during reconnection if all nodes are exhausted - return OSError, SocketError, ReconnectError + return OSError, SocketError, ReconnectError, CancelledError - def _on_prepare(self): - self._cache = None + async def _on_prepare(self): + self._client = None - def _get_counter_name(self): - return "ignite_adapter" + async def _get_counter_name(self): + return "ignite_async_adapter" diff --git a/core/db_adapter/ignite_thread_adapter.py b/core/db_adapter/ignite_thread_adapter.py deleted file mode 100644 index addc8b81..00000000 --- a/core/db_adapter/ignite_thread_adapter.py +++ /dev/null @@ -1,98 +0,0 @@ -# coding: utf-8 -import random -import threading - -import pyignite -from pyignite.exceptions import ReconnectError, SocketError - -import core.logging.logger_constants as log_const -from core.db_adapter import error -from core.db_adapter.db_adapter import DBAdapter -from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring - - -class IgniteThreadAdapter(DBAdapter): - - def __init__(self, config): - self._init_params = config.get("init_params", {}) - self._url = config["url"] - if config.get("randomize_url"): - random.shuffle(self._url) - self._cache_name = config["cache_name"] - self._clients = {} - self._caches = {} - super(IgniteThreadAdapter, self).__init__(config) - - def _open(self, filename, *args, **kwargs): - pass - - def _list_dir(self, path): - raise error.NotSupportedOperation - - def _glob(self, path, pattern): - raise error.NotSupportedOperation - - def _path_exists(self, path): - raise error.NotSupportedOperation - - def connect(self): - self._get_cache() - - def _connect_thread(self, thread_id): - try: - client = pyignite.Client(**self._init_params) - client.connect(self._url) - cache = client.get_or_create_cache(self._cache_name) - self._clients[thread_id] = client - self._caches[thread_id] = cache - logger_args = { - log_const.KEY_NAME: log_const.IGNITE_VALUE, - "pyignite_args": str(self._init_params), - "pyignite_addresses": str(self._url) - } - log("IgniteAdapter to servers %(pyignite_addresses)s created", params=logger_args, level="WARNING") - except Exception: - log( - "IgniteAdapter connect error", - params={log_const.KEY_NAME: log_const.HANDLED_EXCEPTION_VALUE}, - level="ERROR", - exc_info=True - ) - monitoring.got_counter("ignite_connection_exception") - raise - - def _get_cache(self): - thread_id = threading.get_ident() - if thread_id not in self._caches: - self._connect_thread(thread_id) - return self._caches[thread_id] - - def _save(self, id, data): - return self._get_cache().put(id, data) - - def _replace_if_equals(self, id, sample, data): - return self._get_cache().replace_if_equals(id, sample, data) - - def _get(self, id): - data = self._get_cache().get(id) - return data - - @property - def cache(self): - if self._get_cache() is None: - log('Attempt to recreate ignite instance', level="WARNING") - self.connect() - monitoring.got_counter("ignite_reconnection") - return self._get_cache() - - @property - def _handled_exception(self): - # TypeError is raised during reconnection if all nodes are exhausted - return OSError, SocketError, ReconnectError - - def _on_prepare(self): - self._cache = None - - def _get_counter_name(self): - return "ignite_adapter" diff --git a/core/db_adapter/memory_adapter.py b/core/db_adapter/memory_adapter.py index 5491b7cc..6f0f3360 100644 --- a/core/db_adapter/memory_adapter.py +++ b/core/db_adapter/memory_adapter.py @@ -1,41 +1,41 @@ from core.db_adapter import error -from core.db_adapter.db_adapter import DBAdapter +from core.db_adapter.db_adapter import AsyncDBAdapter -class MemoryAdapter(DBAdapter): +class MemoryAdapter(AsyncDBAdapter): def __init__(self, config=None): - super(DBAdapter, self).__init__(config) + super(AsyncDBAdapter, self).__init__(config) self.memory_storage = {} - def _glob(self, path, pattern): + async def _glob(self, path, pattern): raise error.NotSupportedOperation - def _path_exists(self, path): + async def _path_exists(self, path): raise error.NotSupportedOperation - def _on_prepare(self): + async def _on_prepare(self): pass - def connect(self): + async def connect(self): pass - def _open(self, filename, *args, **kwargs): + async def _open(self, filename, *args, **kwargs): pass - def _save(self, id, data): + async def _save(self, id, data): self.memory_storage[id] = data - def _replace_if_equals(self, id, sample, data): + async def _replace_if_equals(self, id, sample, data): stored_data = self.memory_storage.get(id) if stored_data == sample: self.memory_storage[id] = data return True return False - def _get(self, id): + async def _get(self, id): data = self.memory_storage.get(id) return data - def _list_dir(self, path): + async def _list_dir(self, path): pass diff --git a/core/db_adapter/redis_adapter.py b/core/db_adapter/redis_adapter.py deleted file mode 100644 index 5f5afd75..00000000 --- a/core/db_adapter/redis_adapter.py +++ /dev/null @@ -1,43 +0,0 @@ -import redis -import typing -from core.db_adapter.db_adapter import DBAdapter -from core.db_adapter import error - - -class RedisAdapter(DBAdapter): - def __init__(self, config=None): - super().__init__(config) - self._redis: typing.Optional[redis.Redis] = None - - try: - del self.config["type"] - except KeyError: - pass - - def connect(self): - self._redis = redis.Redis(**self.config) - - def _open(self, filename, *args, **kwargs): - pass - - def _save(self, id, data): - return self._redis.set(id, data) - - def _replace_if_equals(self, id, sample, data): - return self._redis.set(id, data) - - def _get(self, id): - data = self._redis.get(id) - return data.decode() if data else None - - def _list_dir(self, path): - raise error.NotSupportedOperation - - def _glob(self, path, pattern): - raise error.NotSupportedOperation - - def _path_exists(self, path): - self._redis.exists(path) - - def _on_prepare(self): - pass diff --git a/core/message/from_message.py b/core/message/from_message.py index a0a80607..df034773 100644 --- a/core/message/from_message.py +++ b/core/message/from_message.py @@ -226,6 +226,10 @@ def callback_id(self): def callback_id(self, value): self._callback_id = value + @property + def has_callback_id(self): + return self._callback_id is not None or self.headers.get(self._callback_id_header_name) is not None + # noinspection PyMethodMayBeStatic def generate_new_callback_id(self): return str(uuid.uuid4()) diff --git a/core/mq/kafka/async_kafka_publisher.py b/core/mq/kafka/async_kafka_publisher.py new file mode 100644 index 00000000..5bf271ae --- /dev/null +++ b/core/mq/kafka/async_kafka_publisher.py @@ -0,0 +1,66 @@ +# coding: utf-8 +from threading import Thread + +import core.logging.logger_constants as log_const +from core.logging.logger_utils import log +from core.monitoring import monitoring +from core.mq.kafka.kafka_publisher import KafkaPublisher + + +class AsyncKafkaPublisher(KafkaPublisher): + def __init__(self, config): + super().__init__(config) + self._cancelled = False + self._poll_thread = Thread(target=self._poll_for_callbacks) + self._poll_thread.start() + + def send(self, value, key=None, topic_key=None, headers=None): + try: + topic = self._config["topic"] + if topic_key is not None: + topic = topic[topic_key] + producer_params = dict() + if key is not None: + producer_params["key"] = key + self._producer.produce(topic=topic, value=value, headers=headers or [], **producer_params) + except BufferError as e: + params = { + "queue_amount": len(self._producer), + log_const.KEY_NAME: log_const.EXCEPTION_VALUE + } + log("KafkaProducer: Local producer queue is full (%(queue_amount)s messages awaiting delivery):" + " try again\n", params=params, level="ERROR") + monitoring.monitoring.got_counter("kafka_producer_exception") + + def send_to_topic(self, value, key=None, topic=None, headers=None): + try: + if topic is None: + params = { + "message": str(value), + log_const.KEY_NAME: log_const.EXCEPTION_VALUE + } + log("KafkaProducer: Failed sending message %{message}s. Topic is not defined", params=params, + level="ERROR") + producer_params = dict() + if key is not None: + producer_params["key"] = key + self._producer.produce(topic=topic, value=value, headers=headers or [], **producer_params) + except BufferError as e: + params = { + "queue_amount": len(self._producer), + log_const.KEY_NAME: log_const.EXCEPTION_VALUE + } + log("KafkaProducer: Local producer queue is full (%(queue_amount)s messages awaiting delivery):" + " try again\n", params=params, level="ERROR") + monitoring.monitoring.got_counter("kafka_producer_exception") + + def _poll_for_callbacks(self): + poll_timeout = self._config.get("poll_timeout", 1) + while not self._cancelled: + self._producer.poll(poll_timeout) + + def close(self): + self._producer.flush(self._config["flush_timeout"]) + self._cancelled = True + self._poll_thread.join() + log(f"KafkaProducer.close: producer to {self._config['topic']} flushed, poll_thread joined.") diff --git a/core/mq/kafka/kafka_consumer.py b/core/mq/kafka/kafka_consumer.py index dfdc3c6b..6e8d22d1 100644 --- a/core/mq/kafka/kafka_consumer.py +++ b/core/mq/kafka/kafka_consumer.py @@ -10,7 +10,7 @@ import core.logging.logger_constants as log_const from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring from core.mq.kafka.base_kafka_consumer import BaseKafkaConsumer @@ -59,10 +59,11 @@ def on_assign_log(consumer, partitions): log("KafkaConsumer.subscribe: assign %(partitions)s %(log_level)s", params=params, level=log_level) def subscribe(self, topics=None): - topics = topics or list(self._config["topics"].values()) + topics = list(set(topics or list(self._config["topics"].values()))) self._consumer.subscribe(topics, - on_assign=self.get_on_assign_callback() if self.assign_offset_end else KafkaConsumer.on_assign_log) + on_assign=self.get_on_assign_callback() if self.assign_offset_end else + KafkaConsumer.on_assign_log) def get_on_assign_callback(self): if "cooperative" in self._config["conf"].get("partition.assignment.strategy", ""): @@ -101,7 +102,7 @@ def _error_callback(self, err): log_const.KEY_NAME: log_const.EXCEPTION_VALUE } log("KafkaConsumer: Error: %(error)s", params=params, level="WARNING") - monitoring.got_counter("kafka_consumer_exception") + monitoring.monitoring.got_counter("kafka_consumer_exception") # noinspection PyMethodMayBeStatic def _process_message(self, msg: KafkaMessage): @@ -110,7 +111,7 @@ def _process_message(self, msg: KafkaMessage): if err.code() == KafkaError._PARTITION_EOF: return None else: - monitoring.got_counter("kafka_consumer_exception") + monitoring.monitoring.got_counter("kafka_consumer_exception") params = { "code": err.code(), "pid": os.getpid(), diff --git a/core/mq/kafka/kafka_publisher.py b/core/mq/kafka/kafka_publisher.py index 51ba44d7..f87e6dca 100644 --- a/core/mq/kafka/kafka_publisher.py +++ b/core/mq/kafka/kafka_publisher.py @@ -7,7 +7,7 @@ import core.logging.logger_constants as log_const from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring from core.mq.kafka.base_kafka_publisher import BaseKafkaPublisher @@ -41,7 +41,7 @@ def send(self, value, key=None, topic_key=None, headers=None): } log("KafkaProducer: Local producer queue is full (%(queue_amount)s messages awaiting delivery):" " try again\n", params=params, level="ERROR") - monitoring.got_counter("kafka_producer_exception") + monitoring.monitoring.got_counter("kafka_producer_exception") self._poll() def send_to_topic(self, value, key=None, topic=None, headers=None): @@ -64,7 +64,7 @@ def send_to_topic(self, value, key=None, topic=None, headers=None): } log("KafkaProducer: Local producer queue is full (%(queue_amount)s messages awaiting delivery):" " try again\n", params=params, level="ERROR") - monitoring.got_counter("kafka_producer_exception") + monitoring.monitoring.got_counter("kafka_producer_exception") self._poll() def _poll(self): @@ -81,7 +81,7 @@ def _error_callback(self, err): log_const.KEY_NAME: log_const.EXCEPTION_VALUE } log("KafkaProducer: Error: %(error)s", params=params, level="ERROR") - monitoring.got_counter("kafka_producer_exception") + monitoring.monitoring.got_counter("kafka_producer_exception") def _delivery_callback(self, err, msg): if err: @@ -101,7 +101,7 @@ def _delivery_callback(self, err, msg): log_const.KEY_NAME: log_const.EXCEPTION_VALUE}, level="ERROR", exc_info=True) - monitoring.got_counter("kafka_producer_exception") + monitoring.monitoring.got_counter("kafka_producer_exception") def close(self): self._producer.flush(self._config["flush_timeout"]) diff --git a/core/request/rest_request.py b/core/request/rest_request.py index b92b0018..f787614e 100644 --- a/core/request/rest_request.py +++ b/core/request/rest_request.py @@ -2,7 +2,7 @@ from timeout_decorator import timeout_decorator from core.request.base_request import BaseRequest from core.utils.exception_handlers import exc_handler -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring class RestNoMethodSpecifiedException(Exception): @@ -37,7 +37,7 @@ def run(self, data, params=None): return method(data) def on_timeout_error(self, *args, **kwarg): - monitoring.got_counter("core_rest_run_timeout") + monitoring.monitoring.got_counter("core_rest_run_timeout") def _requests_get(self, params): return requests.get(self.url, params=params, **self.rest_args).text diff --git a/core/unified_template/unified_template.py b/core/unified_template/unified_template.py index 1832ba03..3cd2191d 100644 --- a/core/unified_template/unified_template.py +++ b/core/unified_template/unified_template.py @@ -5,7 +5,7 @@ import core.logging.logger_constants as log_const from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring UNIFIED_TEMPLATE_TYPE_NAME = "unified_template" @@ -49,7 +49,7 @@ def render(self, *args, **kwargs): "params_dict_str": str(params_dict)}, level="ERROR", exc_info=True) - monitoring.got_counter("core_jinja_template_error") + monitoring.monitoring.got_counter("core_jinja_template_error") raise return result diff --git a/core/utils/exception_handlers.py b/core/utils/exception_handlers.py index f272d4c1..4a0c58e5 100644 --- a/core/utils/exception_handlers.py +++ b/core/utils/exception_handlers.py @@ -1,3 +1,4 @@ +import asyncio import sys from functools import wraps @@ -6,20 +7,39 @@ def exc_handler(on_error_obj_method_name=None, handled_exceptions=None): handled_exceptions = tuple(handled_exceptions) if handled_exceptions else (Exception,) def exc_handler_decorator(funct): - @wraps(funct) - def _wrapper(obj, *args, **kwarg): - result = None - try: - result = funct(obj, *args, **kwarg) - except handled_exceptions: + if asyncio.iscoroutinefunction(funct): + @wraps(funct) + async def _wrapper(obj, *args, **kwarg): + result = None try: - on_error = getattr(obj, on_error_obj_method_name) if \ - on_error_obj_method_name else (lambda *x: None) - result = on_error(*args, **kwarg) - except: - print(sys.exc_info()) - return result + result = await funct(obj, *args, **kwarg) + except handled_exceptions: + try: + on_error = getattr(obj, on_error_obj_method_name) if \ + on_error_obj_method_name else (lambda *x: None) + result = on_error(*args, **kwarg) if not asyncio.iscoroutinefunction(on_error) else \ + await on_error(*args, **kwarg) + except: + print(sys.exc_info()) + return result - return _wrapper + return _wrapper + + else: + @wraps(funct) + def _wrapper(obj, *args, **kwarg): + result = None + try: + result = funct(obj, *args, **kwarg) + except handled_exceptions: + try: + on_error = getattr(obj, on_error_obj_method_name) if \ + on_error_obj_method_name else (lambda *x: None) + result = on_error(*args, **kwarg) + except: + print(sys.exc_info()) + return result + + return _wrapper return exc_handler_decorator diff --git a/core/utils/memstats.py b/core/utils/memstats.py index 501c2c8c..040c99a5 100644 --- a/core/utils/memstats.py +++ b/core/utils/memstats.py @@ -1,4 +1,6 @@ import os +import tracemalloc + import objgraph import psutil import time @@ -22,6 +24,30 @@ def get_leaking_objects(file=None, limit=5): objgraph.show_refs(roots[:limit], refcounts=True, shortnames=False, output=file) +def get_top_malloc(trace_limit=3): + snapshot = tracemalloc.take_snapshot() + snapshot = snapshot.filter_traces(( + tracemalloc.Filter(False, ""), + tracemalloc.Filter(False, ""), + )) + top_stats = snapshot.statistics('traceback') + msg = "" + + if trace_limit > 0: + msg += f"Top malloc {trace_limit} lines\n" + for index, stat in enumerate(top_stats[:trace_limit], 1): + msg += f"#{index}: {stat.size // 1024} KB, {stat.count} times\n" + for line in stat.traceback.format(limit=16): + msg += f"{line}\n" + other = top_stats[trace_limit:] + if other: + size = sum(stat.size for stat in other) + msg += f"{len(other)} other: {size // 1024} KB\n" + total = sum(stat.size for stat in top_stats) + msg += f"Total allocated size: {total // 1024 // 1024} MB" + return msg + + if __name__ == "__main__": while 1 : print(show_most_common_types()) diff --git a/core/utils/rerunable.py b/core/utils/rerunable.py index 07703910..431cf7b3 100644 --- a/core/utils/rerunable.py +++ b/core/utils/rerunable.py @@ -1,6 +1,8 @@ +import asyncio + import core.logging.logger_constants as log_const from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring class Rerunable(): @@ -36,13 +38,13 @@ def _run(self, action, *args, _try_count=None, **kwargs): log_const.KEY_NAME: log_const.HANDLED_EXCEPTION_VALUE } log("%(class_name)s run failed with %(exception)s.\n Got %(try_count)s tries left.", - params=params, - level="ERROR") + params=params, + level="ERROR") self._on_prepare() result = self._run(action, *args, _try_count=_try_count, **kwargs) counter_name = self._get_counter_name() if counter_name: - monitoring.got_counter(f"{counter_name}_exception") + monitoring.monitoring.got_counter(f"{counter_name}_exception") return result def _get_counter_name(self): diff --git a/core/utils/utils.py b/core/utils/utils.py index 273191be..26670d50 100644 --- a/core/utils/utils.py +++ b/core/utils/utils.py @@ -1,14 +1,18 @@ # coding=utf-8 import datetime +import gc import json import os import re +import weakref from collections import OrderedDict from math import isnan, isinf from typing import Optional from time import time +from scenarios.user.user_model import User + def convert_version_to_list_of_int(version): if not version or re.search("[0-9]", version) is None: diff --git a/scenarios/actions/action.py b/scenarios/actions/action.py index 6c5309e5..88e2ac45 100644 --- a/scenarios/actions/action.py +++ b/scenarios/actions/action.py @@ -1,3 +1,4 @@ +import asyncio import collections import copy import json @@ -22,7 +23,7 @@ import scenarios.logging.logger_constants as log_const from scenarios.actions.action_params_names import TO_MESSAGE_NAME, TO_MESSAGE_PARAMS, SAVED_MESSAGES, \ - REQUEST_FIELD, LOCAL_VARS + REQUEST_FIELD from scenarios.user.parametrizer import Parametrizer from scenarios.user.user_model import User from scenarios.scenario_models.history import Event @@ -38,8 +39,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(ClearFormAction, self).__init__(items, id) self.form = items["form"] - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.forms.remove_item(self.form) @@ -53,8 +54,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(ClearInnerFormAction, self).__init__(items, id) self.inner_form = items["inner_form"] - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: form = user.forms[self.form] if form: form.forms.remove_item(self.inner_form) @@ -71,8 +72,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.form = items["form"] self.field = items["field"] - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: form = user.forms[self.form] form.fields.remove_item(self.field) @@ -88,8 +89,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(RemoveCompositeFormFieldAction, self).__init__(items, id) self.inner_form = items["inner_form"] - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: form = user.forms[self.form] inner_form = form.forms[self.inner_form] inner_form.fields.remove_item(self.field) @@ -102,8 +103,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(BreakScenarioAction, self).__init__(items, id) self.scenario_id = items.get("scenario_id") - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: scenario_id = self.scenario_id if self.scenario_id is not None else user.last_scenarios.last_scenario_name user.scenario_models[scenario_id].set_break() @@ -119,8 +120,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): self.behavior = items["behavior"] self.check_scenario = items.get("check_scenario", True) - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: scenario_id = None if self.check_scenario: scenario_id = user.last_scenarios.last_scenario_name @@ -151,15 +152,15 @@ def command_action(self) -> StringAction: def _check(self, user): return not user.behaviors.check_got_saved_id(self.behavior_action.behavior) - def _run(self, user, text_preprocessing_result, params=None): - self.behavior_action.run(user, text_preprocessing_result, params) - command_action_result = self.command_action.run(user, text_preprocessing_result, params) or [] + async def _run(self, user, text_preprocessing_result, params=None): + await self.behavior_action.run(user, text_preprocessing_result, params) + command_action_result = await self.command_action.run(user, text_preprocessing_result, params) or [] return command_action_result - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: if self._check(user): - return self._run(user, text_preprocessing_result, params) + return await self._run(user, text_preprocessing_result, params) class BaseSetVariableAction(Action): @@ -178,8 +179,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): def _set(self, user, value): raise NotImplemented - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: params = user.parametrizer.collect(text_preprocessing_result) try: # if path is wrong, it may fail with UndefinedError @@ -224,8 +225,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(DeleteVariableAction, self).__init__(items, id) self.key: str = items["key"] - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.variables.delete(self.key) @@ -235,8 +236,8 @@ class ClearVariablesAction(Action): def __init__(self, items: Dict[str, Any] = None, id: Optional[str] = None): super(ClearVariablesAction, self).__init__(items, id) - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.variables.clear() @@ -262,8 +263,8 @@ def _fill(self, user, data): def _get_data(self, params): return self.template.render(params) - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: params = user.parametrizer.collect(text_preprocessing_result) data = self._get_data(params) self._fill(user, data) @@ -292,8 +293,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(RunScenarioAction, self).__init__(items, id) self.scenario: UnifiedTemplate = UnifiedTemplate(items["scenario"]) - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: if params is None: params = user.parametrizer.collect(text_preprocessing_result) else: @@ -301,16 +302,16 @@ def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult scenario_id = self.scenario.render(params) scenario = user.descriptions["scenarios"].get(scenario_id) if scenario: - return scenario.run(text_preprocessing_result, user, params) + return await scenario.run(text_preprocessing_result, user, params) class RunLastScenarioAction(Action): - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: last_scenario_id = user.last_scenarios.last_scenario_name scenario = user.descriptions["scenarios"].get(last_scenario_id) if scenario: - return scenario.run(text_preprocessing_result, user, params) + return await scenario.run(text_preprocessing_result, user, params) class ChoiceScenarioAction(Action): @@ -339,20 +340,65 @@ def build_requirement_items(self): def build_else_item(self): return self._else_item - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: result = None choice_is_made = False for scenario, requirement in zip(self._scenarios, self.requirement_items): - check_res = requirement.check(text_preprocessing_result, user, params) + check_res = await requirement.check(text_preprocessing_result, user, params) if check_res: - result = RunScenarioAction(items=scenario).run(user, text_preprocessing_result, params) + result = await RunScenarioAction(items=scenario).run(user, text_preprocessing_result, params) choice_is_made = True break if not choice_is_made and self._else_item: - result = self.else_item.run(user, text_preprocessing_result, params) + result = await self.else_item.run(user, text_preprocessing_result, params) + + return result + + +class GatherChoiceScenarioAction(Action): + FIELD_SCENARIOS_KEY = "scenarios" + FIELD_ELSE_KEY = "else_action" + FIELD_REQUIREMENT_KEY = "requirement" + + def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: + super(GatherChoiceScenarioAction, self).__init__(items, id) + self._else_item = items.get(self.FIELD_ELSE_KEY) + self._scenarios = items[self.FIELD_SCENARIOS_KEY] + self._requirements = [scenario.pop(self.FIELD_REQUIREMENT_KEY) for scenario in self._scenarios] + + self.requirement_items = self.build_requirement_items() + + if self._else_item: + self.else_item = self.build_else_item() + else: + self.else_item = None + + @list_factory(Requirement) + def build_requirement_items(self): + return self._requirements + + @factory(Action) + def build_else_item(self): + return self._else_item + + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: + result = None + choice_is_made = False + + check_results = await asyncio.gather(requirement.check(text_preprocessing_result, user, params) + for requirement in self.requirement_items) + for scenario, check_res in zip(self._scenarios, check_results): + if check_res: + result = await RunScenarioAction(items=scenario).run(user, text_preprocessing_result, params) + choice_is_made = True + break + + if not choice_is_made and self._else_item: + result = await self.else_item.run(user, text_preprocessing_result, params) return result @@ -364,8 +410,8 @@ def _clear_scenario(self, user, scenario_id): user.last_scenarios.delete(scenario_id) user.forms.remove_item(scenario.form_type) - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: last_scenario_id = user.last_scenarios.last_scenario_name if last_scenario_id: self._clear_scenario(user, last_scenario_id) @@ -373,7 +419,7 @@ def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult class ClearAllScenariosAction(Action): - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: user.last_scenarios.clear_all() @@ -387,8 +433,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): super(ClearScenarioByIdAction, self).__init__(items, id) self.scenario_id = items.get("scenario_id") - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: if self.scenario_id: self._clear_scenario(user, self.scenario_id) @@ -397,7 +443,7 @@ class ClearCurrentScenarioFormAction(Action): def __init__(self, items, id=None): super().__init__(items, id) - def run(self, user, text_preprocessing_result, params=None): + async def run(self, user, text_preprocessing_result, params=None): last_scenario_id = user.last_scenarios.last_scenario_name if last_scenario_id: user.forms.clear_form(last_scenario_id) @@ -408,7 +454,7 @@ def __init__(self, items, id=None): super().__init__(items, id) self.node_id = items.get('node_id', None) - def run(self, user, text_preprocessing_result, params=None): + async def run(self, user, text_preprocessing_result, params=None): last_scenario_id = user.last_scenarios.last_scenario_name if last_scenario_id: user.scenario_models[last_scenario_id].current_node = self.node_id @@ -428,8 +474,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None): for k, v in self.event_content.items(): self.event_content[k] = UnifiedTemplate(v) - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: last_scenario_id = user.last_scenarios.last_scenario_name scenario = user.descriptions["scenarios"].get(last_scenario_id) if scenario: @@ -453,8 +499,8 @@ def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult class EmptyAction(Action): - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: log("%(class_name)s.run: action do nothing.", params={log_const.KEY_NAME: "empty_action", "class_name": self.__class__.__name__}, user=user) return None @@ -462,12 +508,12 @@ def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult class RunScenarioByProjectNameAction(Action): - def run(self, user: User, text_preprocessing_result: TextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: + async def run(self, user: User, text_preprocessing_result: TextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Union[None, str, List[Command]]: scenario_id = user.message.project_name scenario = user.descriptions["scenarios"].get(scenario_id) if scenario: - return scenario.run(text_preprocessing_result, user, params) + return await scenario.run(text_preprocessing_result, user, params) else: log("%(class_name)s warning: %(scenario_id)s isn't exist", params={log_const.KEY_NAME: "warning_in_RunScenarioByProjectNameAction", @@ -476,8 +522,8 @@ def run(self, user: User, text_preprocessing_result: TextPreprocessingResult, class ProcessBehaviorAction(Action): - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: callback_id = user.message.callback_id log("%(class_name)s.run: got callback_id %(callback_id)s.", @@ -493,9 +539,9 @@ def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult return None if user.message.payload: - return user.behaviors.success(callback_id) + return await user.behaviors.success(callback_id) - return user.behaviors.fail(callback_id) + return await user.behaviors.fail(callback_id) class SelfServiceActionWithState(BasicSelfServiceActionWithState): @@ -510,7 +556,7 @@ def __init__(self, items, id=None): self.rewrite_saved_messages = items.get("rewrite_saved_messages", False) self._check_scenario: bool = items.get("check_scenario", True) - def _run(self, user, text_preprocessing_result, params=None): + async def _run(self, user, text_preprocessing_result, params=None): action_params = copy.copy(params or {}) diff --git a/scenarios/behaviors/behavior_description.py b/scenarios/behaviors/behavior_description.py index 87c52680..d4459cab 100644 --- a/scenarios/behaviors/behavior_description.py +++ b/scenarios/behaviors/behavior_description.py @@ -18,9 +18,6 @@ def __init__(self, items, id=None): self.version = items.get("version", -1) self.loop_def = items.get("loop_def", True) - def get_expire_time_from_now(self, user): - return time.time() + self.timeout(user) - def timeout(self, user): setting_timeout = user.settings["template_settings"].get("services_timeout", {}).get(self.id) return setting_timeout or self._timeout diff --git a/scenarios/behaviors/behaviors.py b/scenarios/behaviors/behaviors.py index ce397d05..972ef982 100644 --- a/scenarios/behaviors/behaviors.py +++ b/scenarios/behaviors/behaviors.py @@ -43,8 +43,8 @@ def initialize(self): for key, value in callback_action_params.get(LOCAL_VARS, {}).items(): self._user.local_vars.set(key, value) - def _add_behavior_timeout(self, expire_time_us, callback_id): - self._behavior_timeouts.append((expire_time_us, callback_id)) + def _add_behavior_timeout(self, time_left, callback_id): + self._behavior_timeouts.append((time_left, callback_id)) def get_behavior_timeouts(self): return self._behavior_timeouts @@ -60,11 +60,7 @@ def add(self, callback_id: str, behavior_id, scenario_id=None, text_preprocessin host = socket.gethostname() text_preprocessing_result_raw = text_preprocessing_result_raw or {} # behavior will be removed after timeout + EXPIRATION_DELAY - expiration_time = ( - int(time()) + - self.descriptions[behavior_id].timeout(self._user) + - self.EXPIRATION_DELAY - ) + expiration_time = int(time()) + self.descriptions[behavior_id].timeout(self._user) + self.EXPIRATION_DELAY action_params = action_params or dict() action_params[LOCAL_VARS] = pickle_deepcopy(self._user.local_vars.values) @@ -77,8 +73,7 @@ def add(self, callback_id: str, behavior_id, scenario_id=None, text_preprocessin hostname=host ) self._callbacks[callback_id] = callback - log( - f"behavior.add: adding behavior %({log_const.BEHAVIOR_ID_VALUE})s with scenario_id" + log(f"behavior.add: adding behavior %({log_const.BEHAVIOR_ID_VALUE})s with scenario_id" f" %({log_const.CHOSEN_SCENARIO_VALUE})s for callback %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s" f" expiration_time: %(expiration_time)s.", user=self._user, @@ -89,8 +84,7 @@ def add(self, callback_id: str, behavior_id, scenario_id=None, text_preprocessin "expiration_time": expiration_time}) behavior_description = self.descriptions[behavior_id] - expire_time_us = behavior_description.get_expire_time_from_now(self._user) - self._add_behavior_timeout(expire_time_us, callback_id) + self._add_behavior_timeout(behavior_description.timeout(self._user) + self.EXPIRATION_DELAY, callback_id) def _delete(self, callback_id): if callback_id in self._callbacks: @@ -121,7 +115,7 @@ def _log_callback(self, callback_id: str, log_name: str, metric, behavior_result user=self._user, params=log_params) - def success(self, callback_id: str): + async def success(self, callback_id: str): log(f"behavior.success started: got callback %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s.", self._user, params={log_const.KEY_NAME: log_const.BEHAVIOR_SUCCESS_VALUE, @@ -141,11 +135,11 @@ def success(self, callback_id: str): callback_action_params, ) text_preprocessing_result = TextPreprocessingResult(callback.text_preprocessing_result) - result = behavior.success_action.run(self._user, text_preprocessing_result, callback_action_params) + result = await behavior.success_action.run(self._user, text_preprocessing_result, callback_action_params) self._delete(callback_id) return result - def fail(self, callback_id: str): + async def fail(self, callback_id: str): log(f"behavior.fail started: got callback %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s.", self._user, params={log_const.KEY_NAME: log_const.BEHAVIOR_FAIL_VALUE, @@ -161,11 +155,11 @@ def fail(self, callback_id: str): smart_kit_metrics.counter_behavior_fail, "fail", callback_action_params) text_preprocessing_result = TextPreprocessingResult(callback.text_preprocessing_result) - result = behavior.fail_action.run(self._user, text_preprocessing_result, callback_action_params) + result = await behavior.fail_action.run(self._user, text_preprocessing_result, callback_action_params) self._delete(callback_id) return result - def timeout(self, callback_id: str): + async def timeout(self, callback_id: str): log(f"behavior.timeout started: got callback %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s.", self._user, params={log_const.KEY_NAME: log_const.BEHAVIOR_TIMEOUT_VALUE, @@ -180,11 +174,11 @@ def timeout(self, callback_id: str): smart_kit_metrics.counter_behavior_timeout, "timeout", callback_action_params) text_preprocessing_result = TextPreprocessingResult(callback.text_preprocessing_result) - result = behavior.timeout_action.run(self._user, text_preprocessing_result, callback_action_params) + result = await behavior.timeout_action.run(self._user, text_preprocessing_result, callback_action_params) self._delete(callback_id) return result - def misstate(self, callback_id: str): + async def misstate(self, callback_id: str): log(f"behavior.misstate started: got callback %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s.", self._user, params={log_const.KEY_NAME: log_const.BEHAVIOR_MISSTATE_VALUE, @@ -200,7 +194,7 @@ def misstate(self, callback_id: str): smart_kit_metrics.counter_behavior_misstate, "misstate", callback_action_params) text_preprocessing_result = TextPreprocessingResult(callback.text_preprocessing_result) - result = behavior.misstate_action.run(self._user, text_preprocessing_result, callback_action_params) + result = await behavior.misstate_action.run(self._user, text_preprocessing_result, callback_action_params) self._delete(callback_id) return result @@ -240,8 +234,7 @@ def check_misstate(self, callback_id: str): def expire(self): callback_id_for_delete = [] - for callback_id, ( - behavior_id, expiration_time, *_) in self._callbacks.items(): + for callback_id, (behavior_id, expiration_time, *_) in self._callbacks.items(): if expiration_time <= time(): callback_id_for_delete.append(callback_id) for callback_id in callback_id_for_delete: @@ -254,8 +247,8 @@ def expire(self): log_const.BEHAVIOR_DATA_VALUE: str(self._callbacks[callback_id]), "to_message_name": to_message_name} log_params.update(app_info) - log( - f"behavior.expire: if you see this - something went wrong(should be timeout in normal case) callback %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s, with to_message_name %(to_message_name)s", + log(f"behavior.expire: if you see this - something went wrong(should be timeout in normal case) callback " + f"%({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s, with to_message_name %(to_message_name)s", params=log_params, level="WARNING", user=self._user) self._delete(callback_id) @@ -263,8 +256,8 @@ def check_got_saved_id(self, behavior_id): if self.descriptions[behavior_id].loop_def: for callback_id, (_behavior_id, *_) in self._callbacks.items(): if _behavior_id == behavior_id: - log( - f"behavior.check_got_saved_id == True: already got saved behavior %({log_const.BEHAVIOR_ID_VALUE})s for callback_id %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s", + log(f"behavior.check_got_saved_id == True: already got saved behavior " + f"%({log_const.BEHAVIOR_ID_VALUE})s for callback_id %({log_const.BEHAVIOR_CALLBACK_ID_VALUE})s", user=self._user, params={log_const.KEY_NAME: "behavior_got_saved", log_const.BEHAVIOR_CALLBACK_ID_VALUE: callback_id, diff --git a/scenarios/requirements/requirements.py b/scenarios/requirements/requirements.py index 4e2be379..3a877945 100644 --- a/scenarios/requirements/requirements.py +++ b/scenarios/requirements/requirements.py @@ -10,8 +10,8 @@ class AskAgainExistRequirement(Requirement): - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> bool: last_scenario_id = user.last_scenarios.last_scenario_name scenario = user.descriptions["scenarios"].get(last_scenario_id) return scenario.check_ask_again_requests(text_preprocessing_result, user, params) @@ -23,8 +23,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: self._template = UnifiedTemplate(items["template"]) self._items = set(items["items"]) - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> bool: params = params or {} collected = user.parametrizer.collect(text_preprocessing_result) params.update(collected) @@ -38,8 +38,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: self._template = UnifiedTemplate(items["template"]) self._items = set(items["items"]) - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> bool: params = params or {} collected = user.parametrizer.collect(text_preprocessing_result) params.update(collected) @@ -56,8 +56,8 @@ def __init__(self, items: Dict[str, Any], id: Optional[str] = None) -> None: self._template = UnifiedTemplate(items["template"]) self._regexp = re.compile(items["regexp"], re.S | re.M) - def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> bool: params = params or {} collected = user.parametrizer.collect(text_preprocessing_result) params.update(collected) diff --git a/scenarios/scenario_descriptions/form_filling_scenario.py b/scenarios/scenario_descriptions/form_filling_scenario.py index fece0018..9fe34b3f 100644 --- a/scenarios/scenario_descriptions/form_filling_scenario.py +++ b/scenarios/scenario_descriptions/form_filling_scenario.py @@ -2,7 +2,7 @@ from typing import Dict, Any from core.basic_models.scenarios.base_scenario import BaseScenario -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring from core.logging.logger_utils import log import scenarios.logging.logger_constants as log_const @@ -26,17 +26,17 @@ def _get_form(self, user): form.refresh() return form - def text_fits(self, text_preprocessing_result, user): - return self._check_field(text_preprocessing_result, user, None) + async def text_fits(self, text_preprocessing_result, user): + return await self._check_field(text_preprocessing_result, user, None) - def check_ask_again_requests(self, text_preprocessing_result, user, params): + async def check_ask_again_requests(self, text_preprocessing_result, user, params): form = user.forms[self.form_type] - question_field = self._field(form, text_preprocessing_result, user, params) + question_field = await self._field(form, text_preprocessing_result, user, params) return question_field.ask_again_counter < len(question_field.description.ask_again_requests) - def ask_again(self, text_preprocessing_result, user, params): + async def ask_again(self, text_preprocessing_result, user, params): form = user.forms[self.form_type] - question_field = self._field(form, text_preprocessing_result, user, params) + question_field = await self._field(form, text_preprocessing_result, user, params) question = question_field.description.ask_again_requests[question_field.ask_again_counter] question_field.ask_again_counter += 1 @@ -46,23 +46,21 @@ def ask_again(self, text_preprocessing_result, user, params): content={HistoryConstants.content_fields.FIELD: question_field.description.id}, results=HistoryConstants.event_results.ASK_QUESTION)) - return question.run(user, text_preprocessing_result, params) + return await question.run(user, text_preprocessing_result, params) - def _check_field(self, text_preprocessing_result, user, params): + async def _check_field(self, text_preprocessing_result, user, params): form = user.forms[self.form_type] - field = self._field(form, text_preprocessing_result, user, params) - return field.check_can_be_filled(text_preprocessing_result, user) if field else False + field = await self._field(form, text_preprocessing_result, user, params) + return await field.check_can_be_filled(text_preprocessing_result, user) if field else False - def _field(self, form, text_preprocessing_result, user, params): - return self._find_field(form, text_preprocessing_result, user, params) + async def _field(self, form, text_preprocessing_result, user, params): + return await self._find_field(form, text_preprocessing_result, user, params) - def _find_field(self, form, text_preprocessing_result, user, params): + async def _find_field(self, form, text_preprocessing_result, user, params): for field_name in form.fields.descriptions: field = form.fields[field_name] if not field.valid and field.description.has_requests and \ - field.description.requirement.check( - text_preprocessing_result, user, params - ): + await field.description.requirement.check(text_preprocessing_result, user, params): return field def get_fields_data(self, form, form_key): @@ -75,9 +73,9 @@ def get_fields_data(self, form, form_key): def _clean_key(self, key: str): return key.replace(" ", "") - def _extract_by_field_filler(self, field_key, field_descr, text_normalization_result, user, params): + async def _extract_by_field_filler(self, field_key, field_descr, text_normalization_result, user, params): result = {} - check = field_descr.requirement.check(text_normalization_result, user, params) + check = await field_descr.requirement.check(text_normalization_result, user, params) log_params = self._log_params() log_params["requirement"] = field_descr.requirement.__class__.__name__, log_params["check"] = check @@ -85,7 +83,7 @@ def _extract_by_field_filler(self, field_key, field_descr, text_normalization_re message = "FormFillingScenario.extract: field %(field_key)s requirement %(requirement)s return value: %(check)s" log(message, user, log_params) if check: - result[field_key] = field_descr.filler.run(user, text_normalization_result, params) + result[field_key] = await (field_descr.filler.run(user, text_normalization_result, params)) event = Event(type=HistoryConstants.types.FIELD_EVENT, scenario=self.root_id, content={HistoryConstants.content_fields.FIELD: field_key}, @@ -93,7 +91,7 @@ def _extract_by_field_filler(self, field_key, field_descr, text_normalization_re user.history.add_event(event) return result - def _extract_data(self, form, text_normalization_result, user, params): + async def _extract_data(self, form, text_normalization_result, user, params): result = {} callback_id = user.message.callback_id @@ -103,22 +101,24 @@ def _extract_data(self, form, text_normalization_result, user, params): field = form.fields[request_field["id"]] field_descr = form.description.fields[request_field["id"]] if field.available: - result.update(self._extract_by_field_filler(request_field["id"], field_descr, text_normalization_result, - user, params)) + result.update(await self._extract_by_field_filler(request_field["id"], field_descr, + text_normalization_result, + user, params)) else: for field_key, field_descr in form.description.fields.items(): field = form.fields[field_key] if field.available and isinstance(field, QuestionField): - result.update(self._extract_by_field_filler(field_key, field_descr, - text_normalization_result, user, params)) + result.update(await self._extract_by_field_filler(field_key, field_descr, + text_normalization_result, user, params)) return result - def _validate_extracted_data(self, user, text_preprocessing_result, form, data_extracted, params): + async def _validate_extracted_data(self, user, text_preprocessing_result, form, data_extracted, params): error_msgs = [] for field_key, field in form.description.fields.items(): value = data_extracted.get(field_key) # is not None is necessary, because 0 and False should be checked, None - shouldn't fill - if value is not None and not field.field_validator.requirement.check(value, params): + if value is not None and \ + not await field.field_validator.requirement.check(value, params): log_params = { log_const.KEY_NAME: log_const.SCENARIO_RESULT_VALUE, "field_key": field_key @@ -126,11 +126,11 @@ def _validate_extracted_data(self, user, text_preprocessing_result, form, data_e message = "Field is not valid: %(field_key)s" log(message, user, log_params) actions = field.field_validator.actions - error_msgs = self.get_action_results(user, text_preprocessing_result, actions) + error_msgs = await self.get_action_results(user, text_preprocessing_result, actions) break return error_msgs - def _fill_form(self, user, text_preprocessing_result, form, data_extracted): + async def _fill_form(self, user, text_preprocessing_result, form, data_extracted): on_filled_actions = [] fields = form.fields scenario_model = user.scenario_models[self.id] @@ -140,15 +140,15 @@ def _fill_form(self, user, text_preprocessing_result, form, data_extracted): value = data_extracted.get(key) field = fields[key] if field.fill(value): - _action = self.get_action_results(user=user, text_preprocessing_result=text_preprocessing_result, - actions=field.description.on_filled_actions) + _action = await self.get_action_results(user=user, text_preprocessing_result=text_preprocessing_result, + actions=field.description.on_filled_actions) on_filled_actions.extend(_action) if scenario_model.break_scenario: is_break = True return _action, is_break return on_filled_actions, is_break - def get_reply(self, user, text_preprocessing_result, reply_actions, field, form): + async def get_reply(self, user, text_preprocessing_result, reply_actions, field, form): action_params = {} if field: field.set_available() @@ -160,7 +160,7 @@ def get_reply(self, user, text_preprocessing_result, reply_actions, field, form) message = "Ask question on field: %(field)s" log(message, user, params) action_params[REQUEST_FIELD] = {"type": field.description.type, "id": field.description.id} - action_messages = self.get_action_results(user, text_preprocessing_result, actions, action_params) + action_messages = await self.get_action_results(user, text_preprocessing_result, actions, action_params) else: actions = reply_actions params = { @@ -170,35 +170,36 @@ def get_reply(self, user, text_preprocessing_result, reply_actions, field, form) message = "Finished scenario: %(id)s" log(message, user, params) user.preprocessing_messages_for_scenarios.clear() - action_messages = self.get_action_results(user, text_preprocessing_result, actions, action_params) + action_messages = await self.get_action_results(user, text_preprocessing_result, actions, action_params) user.last_scenarios.delete(self.id) return action_messages - @monitoring.got_histogram_decorate("scenario_time") - def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): + @monitoring.monitoring.got_histogram_decorate("scenario_time") + async def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): form = self._get_form(user) user.last_scenarios.add(self.id, text_preprocessing_result) user.preprocessing_messages_for_scenarios.add(text_preprocessing_result) - data_extracted = self._extract_data(form, text_preprocessing_result, user, params) + data_extracted = await self._extract_data(form, text_preprocessing_result, user, params) logging_params = {"data_extracted_str": str(data_extracted)} logging_params.update(self._log_params()) log("Extracted data=%(data_extracted_str)s", user, logging_params) - validation_error_msg = self._validate_extracted_data(user, text_preprocessing_result, form, data_extracted, params) + validation_error_msg = await self._validate_extracted_data(user, text_preprocessing_result, + form, data_extracted, params) if validation_error_msg: reply_messages = validation_error_msg else: - reply_messages, is_break = self._fill_form(user, text_preprocessing_result, form, data_extracted) + reply_messages, is_break = await self._fill_form(user, text_preprocessing_result, form, data_extracted) if not is_break: - field = self._field(form, text_preprocessing_result, user, params) + field = await self._field(form, text_preprocessing_result, user, params) if field: user.history.add_event( Event(type=HistoryConstants.types.FIELD_EVENT, scenario=self.root_id, content={HistoryConstants.content_fields.FIELD: field.description.id}, results=HistoryConstants.event_results.ASK_QUESTION)) - reply = self.get_reply(user, text_preprocessing_result, self.actions, field, form) + reply = await self.get_reply(user, text_preprocessing_result, self.actions, field, form) reply_messages.extend(reply) if not reply_messages: diff --git a/scenarios/scenario_descriptions/tree_scenario/tree_scenario.py b/scenarios/scenario_descriptions/tree_scenario/tree_scenario.py index 8274b9d5..f69c797b 100644 --- a/scenarios/scenario_descriptions/tree_scenario/tree_scenario.py +++ b/scenarios/scenario_descriptions/tree_scenario/tree_scenario.py @@ -1,11 +1,10 @@ # coding: utf-8 -from lazy import lazy from typing import Dict, Any from scenarios.scenario_descriptions.form_filling_scenario import FormFillingScenario from scenarios.scenario_descriptions.tree_scenario.tree_scenario_node import TreeScenarioNode from core.model.factory import dict_factory -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring from core.logging.logger_utils import log import scenarios.logging.logger_constants as log_const from scenarios.scenario_models.history import Event, HistoryConstants @@ -24,10 +23,10 @@ def __init__(self, items, id): def build_scenario_nodes(self): return self._scenario_nodes - def _field(self, form, text_preprocessing_result, user, params): + async def _field(self, form, text_preprocessing_result, user, params): current_node = self.get_current_node(user) internal_form = self._get_internal_form(form.forms, current_node.form_key) - return self._find_field(internal_form, text_preprocessing_result, user, params) + return await self._find_field(internal_form, text_preprocessing_result, user, params) def _set_current_node_id(self, user, node_id): user.scenario_models[self.id].current_node = node_id @@ -43,14 +42,14 @@ def get_current_node(self, user): current_node = self.scenario_nodes[current_node_id] return current_node - def get_next_node(self, user, node, text_preprocessing_result, params): + async def get_next_node(self, user, node, text_preprocessing_result, params): available_node_keys = node.available_nodes for key in available_node_keys: node = self.scenario_nodes[key] log_params = {log_const.KEY_NAME: log_const.CHECKING_NODE_ID_VALUE, log_const.CHECKING_NODE_ID_VALUE: node.id} log(log_const.CHECKING_NODE_ID_MESSAGE, user, log_params) - requirement_result = node.requirement.check(text_preprocessing_result, user, params) + requirement_result = await node.requirement.check(text_preprocessing_result, user, params) if requirement_result: log_params = {log_const.KEY_NAME: log_const.CHOSEN_NODE_ID_VALUE, log_const.CHOSEN_NODE_ID_VALUE: node.id} @@ -84,8 +83,8 @@ def get_fields_data(self, main_form, form_type): all_forms_fields.update(form_field_data) return all_forms_fields - @monitoring.got_histogram_decorate("scenario_time") - def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): + @monitoring.monitoring.got_histogram_decorate("scenario_time") + async def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): main_form = self._get_form(user) user.last_scenarios.add(self.id, text_preprocessing_result) user.preprocessing_messages_for_scenarios.add(text_preprocessing_result) @@ -107,7 +106,7 @@ def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): for field_key, field_descr in internal_form.description.fields.items(): field = internal_form.fields[field_key] if field.available: - extracted = field_descr.filler.run(user, text_preprocessing_result, params) + extracted = await field_descr.filler.run(user, text_preprocessing_result, params) if extracted is not None: event = Event(type=HistoryConstants.types.FIELD_EVENT, scenario=self.root_id, @@ -124,15 +123,15 @@ def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): if extracted is not None and fill_other: fill_other = fill_other and field_descr.fill_other field_data = {field_key: extracted} - _validation_error_msg = self._validate_extracted_data(user, text_preprocessing_result, - internal_form, field_data, params) + _validation_error_msg = await self._validate_extracted_data(user, text_preprocessing_result, + internal_form, field_data, params) if _validation_error_msg: # return only first validation message in form validation_error_msg = validation_error_msg or _validation_error_msg else: data_extracted.update(field_data) - on_filled_node_actions, is_break = self._fill_form(user, text_preprocessing_result, - internal_form, data_extracted) + on_filled_node_actions, is_break = await self._fill_form(user, text_preprocessing_result, + internal_form, data_extracted) if is_break: return on_filled_node_actions on_filled_actions.extend(on_filled_node_actions) @@ -157,10 +156,10 @@ def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): elif not form: form = internal_form self._set_current_node_id(user, current_node.id) - new_node = self.get_next_node(user, current_node, text_preprocessing_result, params) + new_node = await self.get_next_node(user, current_node, text_preprocessing_result, params) - field = self._find_field(form, text_preprocessing_result, - user, params) if form else None + field = await self._find_field(form, text_preprocessing_result, + user, params) if form else None reply_commands = on_filled_actions if field: @@ -170,10 +169,10 @@ def run(self, text_preprocessing_result, user, params: Dict[str, Any] = None): content={HistoryConstants.content_fields.FIELD: field.description.id}, results=HistoryConstants.event_results.ASK_QUESTION) user.history.add_event(event) - _command = self.get_reply(user, text_preprocessing_result, current_node.actions, field, main_form) + _command = await self.get_reply(user, text_preprocessing_result, current_node.actions, field, main_form) reply_commands.extend(_command) if not reply_commands: - reply_commands = self.get_no_commands_action(user, text_preprocessing_result) + reply_commands = await self.get_no_commands_action(user, text_preprocessing_result) return reply_commands diff --git a/scenarios/scenario_models/field/composite_fillers.py b/scenarios/scenario_models/field/composite_fillers.py index 76210d58..80244142 100644 --- a/scenarios/scenario_models/field/composite_fillers.py +++ b/scenarios/scenario_models/field/composite_fillers.py @@ -28,9 +28,9 @@ def on_extract_error(self, text_preprocessing_result, user, params=None): return None @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, - user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: - return self.run(user, text_preprocessing_result, params) + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, + user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: + return await self.run(user, text_preprocessing_result, params) class ChoiceFiller(ChoiceAction): @@ -60,9 +60,9 @@ def on_extract_error(self, text_preprocessing_result, user, params=None): return None @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, - user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: - return self.run(user, text_preprocessing_result, params) + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, + user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: + return await self.run(user, text_preprocessing_result, params) class ElseFiller(ElseAction): @@ -92,6 +92,6 @@ def on_extract_error(self, text_preprocessing_result, user, params=None): return None @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, - user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: - return self.run(user, text_preprocessing_result, params) + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, + user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: + return await self.run(user, text_preprocessing_result, params) diff --git a/scenarios/scenario_models/field/field.py b/scenarios/scenario_models/field/field.py index efb2b7e6..d09422b2 100644 --- a/scenarios/scenario_models/field/field.py +++ b/scenarios/scenario_models/field/field.py @@ -1,4 +1,6 @@ # coding: utf-8 +import asyncio + from core.logging.logger_utils import log from core.model.registered import Registered from core.utils.masking_message import masking @@ -39,11 +41,11 @@ def available(self): def can_be_updated(self): return self.value is not None - def check_can_be_filled(self, text_preprocessing_result, user): - return ( - self.description.requirement.check(text_preprocessing_result, user) and - self.description.filler.run(user, text_preprocessing_result) is not None - ) + async def check_can_be_filled(self, text_preprocessing_result, user): + check, run = await asyncio.gather( + self.description.requirement.check(text_preprocessing_result, user), + self.description.filler.run(user, text_preprocessing_result)) + return check and run is not None @property def valid(self): diff --git a/scenarios/scenario_models/field/field_filler_description.py b/scenarios/scenario_models/field/field_filler_description.py index 9a8d44d3..4bb56ef3 100644 --- a/scenarios/scenario_models/field/field_filler_description.py +++ b/scenarios/scenario_models/field/field_filler_description.py @@ -44,8 +44,8 @@ def _log_params(self): "filler": self.__class__.__name__ } - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> None: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> None: return None def on_extract_error(self, text_preprocessing_result, user, params=None): @@ -54,9 +54,9 @@ def on_extract_error(self, text_preprocessing_result, user, params=None): level="ERROR", exc_info=True) return None - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Any]] = None) -> None: - return self.extract(text_preprocessing_result, user, params) + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Any]] = None) -> None: + return await self.extract(text_preprocessing_result, user, params) def _postprocessing(self, user: User, item: str) -> None: last_scenario_name = user.last_scenarios.last_scenario_name @@ -71,10 +71,10 @@ def __init__(self, items: Optional[Dict[str, Any]], id: Optional[str] = None) -> self.filler = items.get("filler") @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, - user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, + user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: filler = user.descriptions["external_field_fillers"][self.filler] - return filler.run(user, text_preprocessing_result, params) + return await filler.run(user, text_preprocessing_result, params) class CompositeFiller(FieldFillerDescription): @@ -90,11 +90,11 @@ def build_fillers(self): return self._fillers @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, - user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, + user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: extracted = None for filler in self.fillers: - extracted = filler.extract(text_preprocessing_result, user, params) + extracted = await filler.extract(text_preprocessing_result, user, params) if extracted is not None: break return extracted @@ -112,8 +112,8 @@ def __init__(self, items: Optional[Dict[str, Any]], id: Optional[str] = None) -> self.template: UnifiedTemplate = UnifiedTemplate(value) @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, - user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, + user: User, params: Dict[str, Any] = None) -> Optional[Union[int, float, str, bool, List, Dict]]: params = params or {} collected = user.parametrizer.collect(text_preprocessing_result) params.update(collected) @@ -137,8 +137,8 @@ def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, class FirstNumberFiller(FieldFillerDescription): @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[int]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[int]: numbers = text_preprocessing_result.num_token_values if numbers: log_params = self._log_params() @@ -151,8 +151,8 @@ def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: class FirstCurrencyFiller(FieldFillerDescription): @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: currencies = text_preprocessing_result.ccy_token_values if currencies: log_params = self._log_params() @@ -165,8 +165,8 @@ def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: class FirstOrgFiller(FieldFillerDescription): @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: orgs = text_preprocessing_result.org_token_values if orgs: log_params = self._log_params() @@ -179,8 +179,8 @@ def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: class FirstGeoFiller(FieldFillerDescription): @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: geos = text_preprocessing_result.geo_token_values if geos: log_params = self._log_params() @@ -200,8 +200,8 @@ def __init__(self, items: Optional[Dict[str, Any]], id: Optional[str] = None) -> self.delimiter = items.get("delimiter", ",") @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: original_text = text_preprocessing_result.original_text match = re.findall(self.regexp, original_text) if match: @@ -233,15 +233,16 @@ def _operation(self, original_text, typeOp, amount): return func(original_text, amount) if amount else func(original_text) @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: original_text = text_preprocessing_result.original_text if self.operations: for op in self.operations: original_text = self._operation(original_text, op["type"], op.get("amount")) text_preprocessing_result_copy = pickle_deepcopy(text_preprocessing_result) text_preprocessing_result_copy.original_text = original_text - return super(RegexpAndStringOperationsFieldFiller, self).extract(text_preprocessing_result_copy, user, params) + return await super(RegexpAndStringOperationsFieldFiller, self).extract(text_preprocessing_result_copy, + user, params) class AllRegexpsFieldFiller(FieldFillerDescription): @@ -257,8 +258,8 @@ def __init__(self, items: Optional[Dict[str, Any]], id: Optional[str] = None) -> self.original_text_lower = items.get("original_text_lower") or False @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: original_text = text_preprocessing_result.original_text if self.original_text_lower: original_text = original_text.lower() @@ -277,8 +278,8 @@ def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: class FirstPersonFiller(FieldFillerDescription): @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[Dict[str, str]]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[Dict[str, str]]: persons = text_preprocessing_result.person_token_values if persons: log_params = self._log_params() @@ -303,19 +304,19 @@ def build_filler(self): return self._filler @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: - result = self.filler.extract(text_preprocessing_result, user, params) + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: + result = await self.filler.extract(text_preprocessing_result, user, params) if result is None: - result = self._try_extract_last_messages(user, params) + result = await self._try_extract_last_messages(user, params) return result - def _try_extract_last_messages(self, user, params): + async def _try_extract_last_messages(self, user, params): processed_items = user.preprocessing_messages_for_scenarios.processed_items count = self.count - 1 if self.count else len(processed_items) for preprocessing_result_raw in islice(processed_items, 0, count): preprocessing_result = TextPreprocessingResult(preprocessing_result_raw) - result = self.filler.extract(preprocessing_result, user, params) + result = await self.filler.extract(preprocessing_result, user, params) if result is not None: return result @@ -323,8 +324,8 @@ def _try_extract_last_messages(self, user, params): class UserIdFiller(FieldFillerDescription): @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: result = user.message.uuid.get('userId') return result @@ -356,8 +357,8 @@ def __init__(self, items: Optional[Dict[str, Any]], id: Optional[str] = None) -> self.normalized_cases.append((key, tokens_list)) @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: tpr_tokenized_set = {norm.get("lemma") for norm in text_preprocessing_result.tokenized_elements_list_pymorphy if norm.get("token_type") != "SENTENCE_ENDPOINT_TOKEN"} for key, tokens_list in self.normalized_cases: @@ -396,16 +397,16 @@ def __init__(self, items: Optional[Dict[str, Any]], id: Optional[str] = None) -> self.max_days_in_period = items.get('max_days_in_period', None) self.future_days_allowed = items.get('future_days_allowed', False) - def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Dict[str, str]: - if text_preprocessing_result\ - .words_tokenized_set\ - .intersection( - [ - 'TIME_DATE_TOKEN', - 'TIME_DATE_INTERVAL_TOKEN', - 'PERIOD_TOKEN' - ]): + async def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Dict[str, str]: + if text_preprocessing_result \ + .words_tokenized_set \ + .intersection( + [ + 'TIME_DATE_TOKEN', + 'TIME_DATE_INTERVAL_TOKEN', + 'PERIOD_TOKEN' + ]): words_from_intent: List[Optional[str]] = text_preprocessing_result.human_normalized_text.lower().split() else: words_from_intent: List[Optional[str]] = text_preprocessing_result.original_text.lower().split() @@ -416,7 +417,7 @@ def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User is_determined: bool = False is_error: bool = False if not (begin_str == '' or begin_str == 'error' - or end_str == '' or end_str == 'error'): + or end_str == '' or end_str == 'error'): is_determined = True if begin_str == 'error' or end_str == 'error': @@ -448,8 +449,8 @@ def _check_exceptions(self, key, tpr_original_set): return False @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[str]: + async def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[str]: tpr_original_set = {*text_preprocessing_result.original_text.split()} for key, tokens_list in self.original_cases: for tokens in tokens_list: @@ -486,8 +487,8 @@ def __init__(self, items: Optional[Dict[str, Any]], id: Optional[str] = None) -> } @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Optional[bool]: + async def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Optional[bool]: if text_preprocessing_result.tokenized_string in self.yes_words_normalized: params = self._log_params() params["tokenized_string"] = text_preprocessing_result.tokenized_string @@ -509,9 +510,8 @@ def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User class ApproveRawTextFiller(ApproveFiller): @exc_handler(on_error_obj_method_name="on_extract_error") - def extract( - self, text_preprocessing_result: TextPreprocessingResult, user: User, params: Dict[str, Any] = None - ) -> Optional[bool]: + async def extract(self, text_preprocessing_result: TextPreprocessingResult, + user: User, params: Dict[str, Any] = None) -> Optional[bool]: original_text = ' '.join(text_preprocessing_result.original_text.split()).lower().rstrip('!.)') if original_text in self.set_yes_words: params = self._log_params() @@ -551,8 +551,8 @@ def _get_result(self, answers: List[Dict[str, Union[str, float, bool]]]) -> str: return answers[0][self._cls_const_answer_key] @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, - params: Dict[str, Any] = None) -> Union[str, None, List[Dict[str, Union[str, float, bool]]]]: + async def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: User, + params: Dict[str, Any] = None) -> Union[str, None, List[Dict[str, Union[str, float, bool]]]]: result = None classifier = self.classifier with StatsTimer() as timer: @@ -573,6 +573,6 @@ def extract(self, text_preprocessing_result: BaseTextPreprocessingResult, user: class ClassifierFillerMeta(ClassifierFiller): - def _get_result(self, answers: List[Dict[str, Union[str, float, bool]]]) -> List[ - Dict[str, Union[str, float, bool]]]: + def _get_result(self, answers: List[Dict[str, Union[str, float, bool]]]) \ + -> List[Dict[str, Union[str, float, bool]]]: return answers diff --git a/scenarios/scenario_models/field_requirements/field_requirements.py b/scenarios/scenario_models/field_requirements/field_requirements.py index 7759491c..e4b5aaeb 100644 --- a/scenarios/scenario_models/field_requirements/field_requirements.py +++ b/scenarios/scenario_models/field_requirements/field_requirements.py @@ -1,4 +1,5 @@ # coding: utf-8 +import asyncio from typing import Dict, List, Optional, Any, Set from core.basic_models.operators.operators import Operator @@ -14,7 +15,7 @@ class FieldRequirement: def __init__(self, items: Optional[Dict[str, Any]]) -> None: pass - def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: return True @@ -32,13 +33,27 @@ def build_requirements(self): class AndFieldRequirement(CompositeFieldRequirement): - def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: - return all(requirement.check(field_value=field_value, params=params) for requirement in self.requirements) + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + return all(await requirement.check(field_value=field_value, params=params) for requirement in self.requirements) + + +class GatherAndFieldRequirement(CompositeFieldRequirement): + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + check_results = await asyncio.gather(requirement.check(field_value=field_value, params=params) + for requirement in self.requirements) + return all(check_results) class OrFieldRequirement(CompositeFieldRequirement): - def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: - return any(requirement.check(field_value=field_value, params=params) for requirement in self.requirements) + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + return any(await requirement.check(field_value=field_value, params=params) for requirement in self.requirements) + + +class GatherOrFieldRequirement(CompositeFieldRequirement): + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + check_results = await asyncio.gather(requirement.check(field_value=field_value, params=params) + for requirement in self.requirements) + return any(check_results) class NotFieldRequirement(FieldRequirement): @@ -53,8 +68,8 @@ def __init__(self, items: Optional[Dict[str, Any]]) -> None: def build_requirement(self): return self._requirement - def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: - return not self.requirement.check(field_value=field_value, params=params) + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + return not await self.requirement.check(field_value=field_value, params=params) class ComparisonFieldRequirement(FieldRequirement): @@ -69,7 +84,7 @@ def __init__(self, items: Optional[Dict[str, Any]]) -> None: def build_operator(self): return self._operator - def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: return self.operator.compare(field_value) @@ -77,7 +92,7 @@ class IsIntFieldRequirement(FieldRequirement): def __init__(self, items: Optional[Dict[str, Any]]) -> None: super(IsIntFieldRequirement, self).__init__(items) - def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: try: int(field_value) return True @@ -92,7 +107,7 @@ def __init__(self, items: Optional[Dict[str, Any]]) -> None: super(ValueInSetRequirement, self).__init__(items) self.symbols: Set = set(items["symbols"]) - def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: return field_value in self.symbols @@ -102,7 +117,7 @@ def __init__(self, items: Optional[Dict[str, Any]]) -> None: self.part = items['part'] self.values = items['values'] - def check(self, field_value: dict, params: Dict[str, Any] = None) -> bool: + async def check(self, field_value: dict, params: Dict[str, Any] = None) -> bool: return field_value[self.part] in self.values @@ -115,5 +130,5 @@ def __init__(self, items: Optional[Dict[str, Any]]) -> None: self.min_field_length = items["min_field_length"] self.max_field_length = items["max_field_length"] - def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: + async def check(self, field_value: str, params: Dict[str, Any] = None) -> bool: return self.min_field_length <= len(field_value) <= self.max_field_length diff --git a/scenarios/user/last_scenarios/last_scenarios_description.py b/scenarios/user/last_scenarios/last_scenarios_description.py index 12638a95..2c48095d 100644 --- a/scenarios/user/last_scenarios/last_scenarios_description.py +++ b/scenarios/user/last_scenarios/last_scenarios_description.py @@ -17,6 +17,7 @@ def __init__(self, items, id): def requirement(self): return self._requirement - def check(self, text_preprocessing_result, user): - return user.message.channel in self._channels and self.requirement.check(text_preprocessing_result, user) if \ - self._channels else self.requirement.check(text_preprocessing_result, user) + async def check(self, text_preprocessing_result, user): + return user.message.channel in self._channels and \ + await self.requirement.check(text_preprocessing_result, user) if self._channels else \ + await self.requirement.check(text_preprocessing_result, user) diff --git a/scenarios/user/user_model.py b/scenarios/user/user_model.py index 75af4818..96495350 100644 --- a/scenarios/user/user_model.py +++ b/scenarios/user/user_model.py @@ -18,6 +18,7 @@ from smart_kit.utils.monitoring import smart_kit_metrics import scenarios.logging.logger_constants as log_const + class User(BaseUser): forms: Forms diff --git a/setup.py b/setup.py index 6bddbdb3..8c64431c 100644 --- a/setup.py +++ b/setup.py @@ -56,8 +56,7 @@ "freezegun==1.1.0", ], classifiers=[ - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9" ] ) diff --git a/smart_kit/action/base_http.py b/smart_kit/action/base_http.py deleted file mode 100644 index 1b3ece45..00000000 --- a/smart_kit/action/base_http.py +++ /dev/null @@ -1,113 +0,0 @@ -import json -from typing import Optional, Dict, Union, List, Any - -import requests - -import core.logging.logger_constants as log_const -from core.basic_models.actions.command import Command -from core.basic_models.actions.string_actions import NodeAction -from core.logging.logger_utils import log -from core.model.base_user import BaseUser -from core.text_preprocessing.base import BaseTextPreprocessingResult - - -class BaseHttpRequestAction(NodeAction): - """ - Example: - { - // обязательные параметры - "method": "POST", - "url": "http://some_url.com/...", - - // необязательные параметры - "json": { - "data": "value", - ... - }, - "timeout": 120, - "headers": { - "Content-Type":"application/json" - } - } - """ - POST = "POST" - GET = "GET" - DEFAULT_METHOD = POST - - TIMEOUT = "TIMEOUT" - CONNECTION = "CONNECTION" - - def __init__(self, items, id=None): - super().__init__(items, id) - self.method_params = items - self.error = None - - @staticmethod - def _check_headers_validity(headers: Dict[str, Any], user) -> Dict[str, str]: - for header_name, header_value in list(headers.items()): - if not isinstance(header_value, (str, bytes)): - if isinstance(header_value, (int, float, bool)): - headers[header_name] = str(header_value) - else: - log(f"{__class__.__name__}._check_headers_validity remove header {header_name} because " - f"({type(header_value)}) is not in [int, float, bool, str, bytes]", user=user, params={ - log_const.KEY_NAME: "sent_http_remove_header", - }) - del headers[header_name] - return headers - - def _make_response(self, request_parameters, user): - try: - with requests.request(**request_parameters) as response: - response.raise_for_status() - try: - data = response.json() - except json.decoder.JSONDecodeError: - data = None - self._log_response(user, response, data) - return data - except requests.exceptions.Timeout: - self.error = self.TIMEOUT - except (requests.exceptions.ConnectionError, requests.exceptions.HTTPError): - self.error = self.CONNECTION - - def _get_request_params(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None): - collected = user.parametrizer.collect(text_preprocessing_result) - params.update(collected) - - request_parameters = self._get_rendered_tree_recursive(self._get_template_tree(self.method_params), params) - - req_headers = request_parameters.get("headers") - if req_headers: - # Заголовки в запросах должны иметь тип str или bytes. Поэтому добавлена проверка и приведение к типу str, - # на тот случай если в сценарии заголовок указали как int, float и тд - request_parameters["headers"] = self._check_headers_validity(req_headers, user) - return request_parameters - - def _log_request(self, user, request_parameters, additional_params=None): - additional_params = additional_params or {} - log(f"{self.__class__.__name__}.run sent https request ", user=user, params={ - **request_parameters, - log_const.KEY_NAME: "sent_http_request", - **additional_params, - }) - - def _log_response(self, user, response, data, additional_params=None): - additional_params = additional_params or {} - log(f"{self.__class__.__name__}.run get https response ", user=user, params={ - 'headers': dict(response.headers), - 'time': response.elapsed.microseconds, - 'cookie': {i.name: i.value for i in response.cookies}, - 'status': response.status_code, - 'data': data, - log_const.KEY_NAME: "got_http_response", - **additional_params, - }) - - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: - params = params or {} - request_parameters = self._get_request_params(user, text_preprocessing_result, params) - self._log_request(user, request_parameters) - return self._make_response(request_parameters, user) diff --git a/smart_kit/action/http.py b/smart_kit/action/http.py index 7437270c..ce09bb36 100644 --- a/smart_kit/action/http.py +++ b/smart_kit/action/http.py @@ -1,13 +1,19 @@ -from typing import Optional, Dict, Union, List +import asyncio +from typing import Optional, Dict, Union, List, Any -from core.basic_models.actions.basic_actions import Action +import aiohttp +import aiohttp.client_exceptions +from aiohttp import ClientTimeout + +import core.logging.logger_constants as log_const from core.basic_models.actions.command import Command +from core.basic_models.actions.string_actions import NodeAction +from core.logging.logger_utils import log from core.model.base_user import BaseUser from core.text_preprocessing.base import BaseTextPreprocessingResult -from smart_kit.action.base_http import BaseHttpRequestAction -class HTTPRequestAction(Action): +class HTTPRequestAction(NodeAction): """ Example: { @@ -21,31 +27,107 @@ class HTTPRequestAction(Action): } """ - HTTP_ACTION = BaseHttpRequestAction + POST = "POST" + GET = "GET" + DEFAULT_METHOD = POST + + TIMEOUT = "TIMEOUT" + CONNECTION = "CONNECTION" def __init__(self, items, id=None): - self.http_action = self.HTTP_ACTION(items["params"], id) - self.store = items["store"] - self.behavior = items["behavior"] super().__init__(items, id) + self.method_params = items['params'] + self.method_params.setdefault("method", self.DEFAULT_METHOD) + self.error = None + self.init_save_params(items) + + def init_save_params(self, items): + self.store = items["store"] + self.behavior = items.get("behavior") def preprocess(self, user, text_processing, params): behavior_description = user.descriptions["behaviors"][self.behavior] - self.http_action.method_params.setdefault("timeout", behavior_description.timeout(user)) + self.method_params.setdefault("timeout", behavior_description.timeout(user)) + self.method_params["timeout"] = ClientTimeout(self.method_params["timeout"]) - def process_result(self, result, user, text_preprocessing_result, params): - behavior_description = user.descriptions["behaviors"][self.behavior] - if self.http_action.error is None: - user.variables.set(self.store, result) - action = behavior_description.success_action - elif self.http_action.error == self.http_action.TIMEOUT: - action = behavior_description.timeout_action - else: - action = behavior_description.fail_action - return action.run(user, text_preprocessing_result, None) - - def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + @staticmethod + def _check_headers_validity(headers: Dict[str, Any], user) -> Dict[str, str]: + for header_name, header_value in list(headers.items()): + if not isinstance(header_value, (str, bytes)): + if isinstance(header_value, (int, float, bool)): + headers[header_name] = str(header_value) + else: + log(f"{__class__.__name__}._check_headers_validity remove header {header_name} because " + f"({type(header_value)}) is not in [int, float, bool, str, bytes]", user=user, params={ + log_const.KEY_NAME: "sent_http_remove_header", + }) + del headers[header_name] + return headers + + async def _make_response(self, request_parameters, user): + try: + async with aiohttp.request(**request_parameters) as response: + response.raise_for_status() + self._log_response(user, response) + return response + except (aiohttp.ServerTimeoutError, asyncio.TimeoutError): + self.error = self.TIMEOUT + except aiohttp.ClientError: + self.error = self.CONNECTION + + def _get_request_params(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None): + collected = user.parametrizer.collect(text_preprocessing_result) + params.update(collected) + request_parameters = self._get_rendered_tree_recursive(self._get_template_tree(self.method_params), params) + req_headers = request_parameters.get("headers") + if req_headers: + # Заголовки в запросах должны иметь тип str или bytes. Поэтому добавлена проверка и приведение к типу str, + # на тот случай если в сценарии заголовок указали как int, float и тд + request_parameters["headers"] = self._check_headers_validity(req_headers, user) + return request_parameters + + def _log_request(self, user, request_parameters, additional_params=None): + additional_params = additional_params or {} + log(f"{self.__class__.__name__}.run sent https request ", user=user, params={ + **request_parameters, + log_const.KEY_NAME: "sent_http_request", + **additional_params, + }) + + def _log_response(self, user, response, additional_params=None): + additional_params = additional_params or {} + log(f"{self.__class__.__name__}.run get https response ", user=user, params={ + 'headers': dict(response.headers), + 'cookie': {k: v.value for k, v in response.cookies.items()}, + 'status': response.status, + log_const.KEY_NAME: "got_http_response", + **additional_params, + }) + + async def process_result(self, response, user, text_preprocessing_result, params): + behavior_description = user.descriptions["behaviors"][self.behavior] if self.behavior else None + action = None + if self.error is None: + try: + data = await response.json() + except aiohttp.client_exceptions.ContentTypeError: + data = None + user.variables.set(self.store, data) + action = behavior_description.success_action if behavior_description else None + elif behavior_description is not None: + if self.error == self.TIMEOUT: + action = behavior_description.timeout_action + else: + action = behavior_description.fail_action + if action: + return await action.run(user, text_preprocessing_result, None) + + async def run(self, user: BaseUser, text_preprocessing_result: BaseTextPreprocessingResult, params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: self.preprocess(user, text_preprocessing_result, params) - result = self.http_action.run(user, text_preprocessing_result, params) - return self.process_result(result, user, text_preprocessing_result, params) + params = params or {} + request_parameters = self._get_request_params(user, text_preprocessing_result, params) + self._log_request(user, request_parameters) + response = await self._make_response(request_parameters, user) + return await self.process_result(response, user, text_preprocessing_result, params) diff --git a/smart_kit/compatibility/commands.py b/smart_kit/compatibility/commands.py index 69949cce..80e8577f 100644 --- a/smart_kit/compatibility/commands.py +++ b/smart_kit/compatibility/commands.py @@ -9,8 +9,6 @@ def combine_answer_to_user(commands: typing.List[Command]) -> Command: - from smart_kit.configs import get_app_config - config = get_app_config() answer = Command(name=ANSWER_TO_USER, request_data=commands[0].request_data, request_type=commands[0].request_type) summary_pronounce_text = [] diff --git a/smart_kit/configs/settings.py b/smart_kit/configs/settings.py index eface70b..22a89f44 100644 --- a/smart_kit/configs/settings.py +++ b/smart_kit/configs/settings.py @@ -1,5 +1,6 @@ import yaml import os +import asyncio from core.configs.base_config import BaseConfig from core.db_adapter.ceph.ceph_adapter import CephAdapter @@ -19,6 +20,7 @@ def __init__(self, *args, **kwargs): self.secret_path = kwargs.get("secret_path") self.app_name = kwargs.get("app_name") self.adapters = {Settings.CephAdapterKey: CephAdapter, self.OSAdapterKey: OSAdapter} + self.loop = asyncio.get_event_loop() self.repositories = [ UpdatableFileRepository( self.subfolder_path("template_config.yml"), loader=yaml.safe_load, key="template_settings" @@ -61,6 +63,9 @@ def get_source(self): adapter_settings = self.registered_repositories[ adapter_key].data if adapter_key != Settings.OSAdapterKey else None adapter = self.adapters[adapter_key](adapter_settings) - adapter.connect() + if asyncio.iscoroutinefunction(adapter.connect): + self.loop.run_until_complete(adapter.connect()) + else: + adapter.connect() source = adapter.source return source diff --git a/smart_kit/handlers/handle_close_app.py b/smart_kit/handlers/handle_close_app.py index eed21e61..fe803c8c 100644 --- a/smart_kit/handlers/handle_close_app.py +++ b/smart_kit/handlers/handle_close_app.py @@ -11,12 +11,12 @@ def __init__(self, app_name): super(HandlerCloseApp, self).__init__(app_name) self._clear_current_scenario = ClearCurrentScenarioAction(None) - def run(self, payload, user): - super().run(payload, user) + async def run(self, payload, user): + await super().run(payload, user) text_preprocessing_result = TextPreprocessingResult(payload.get("message", {})) params = { log_const.KEY_NAME: "HandlerCloseApp", "tpr_str": str(text_preprocessing_result.raw) } - self._clear_current_scenario.run(user, text_preprocessing_result) + await self._clear_current_scenario.run(user, text_preprocessing_result) log("HandlerCloseApp with text preprocessing result", user, params) diff --git a/smart_kit/handlers/handle_respond.py b/smart_kit/handlers/handle_respond.py index d4d2cf9c..070b8284 100644 --- a/smart_kit/handlers/handle_respond.py +++ b/smart_kit/handlers/handle_respond.py @@ -24,7 +24,7 @@ def get_action_params(self, payload, user): callback_id = user.message.callback_id return user.behaviors.get_callback_action_params(callback_id) - def run(self, payload, user): + async def run(self, payload, user): callback_id = user.message.callback_id action_params = self.get_action_params(payload, user) action_name = self.get_action_name(payload, user) @@ -53,7 +53,7 @@ def run(self, payload, user): log("text preprocessing result: '%(normalized_text)s'", user, params, level="DEBUG") action = user.descriptions["external_actions"][action_name] - return action.run(user, text_preprocessing_result, action_params) + return await action.run(user, text_preprocessing_result, action_params) @staticmethod def get_processing_time(user): diff --git a/smart_kit/handlers/handle_server_action.py b/smart_kit/handlers/handle_server_action.py index 3eeb10c2..8be109dd 100644 --- a/smart_kit/handlers/handle_server_action.py +++ b/smart_kit/handlers/handle_server_action.py @@ -22,7 +22,7 @@ def get_action_name(self, payload, user): def get_action_params(self, payload): return payload[SERVER_ACTION].get("parameters", {}) - def run(self, payload, user): + async def run(self, payload, user): action_params = pickle_deepcopy(self.get_action_params(payload)) params = {log_const.KEY_NAME: "handling_server_action", "server_action_params": str(action_params), @@ -35,4 +35,4 @@ def run(self, payload, user): action_id = self.get_action_name(payload, user) action = user.descriptions["external_actions"][action_id] - return action.run(user, TextPreprocessingResult({}), action_params) + return await action.run(user, TextPreprocessingResult({}), action_params) diff --git a/smart_kit/handlers/handler_base.py b/smart_kit/handlers/handler_base.py index 697ef96b..98291441 100644 --- a/smart_kit/handlers/handler_base.py +++ b/smart_kit/handlers/handler_base.py @@ -9,7 +9,7 @@ class HandlerBase: def __init__(self, app_name): self.app_name = app_name - def run(self, payload, user): + async def run(self, payload, user): # отправка события о входящем сообщении в систему мониторинга smart_kit_metrics.counter_incoming(self.app_name, user.message.message_name, self.__class__.__name__, user, app_info=user.message.app_info) diff --git a/smart_kit/handlers/handler_text.py b/smart_kit/handlers/handler_text.py index f7f8f60a..90dab1b2 100644 --- a/smart_kit/handlers/handler_text.py +++ b/smart_kit/handlers/handler_text.py @@ -19,8 +19,8 @@ def __init__(self, app_name, dialogue_manager): f"{self.__class__.__name__}.__init__ finished.", params={log_const.KEY_NAME: log_const.STARTUP_VALUE} ) - def run(self, payload, user): - super().run(payload, user) + async def run(self, payload, user): + await super().run(payload, user) text_preprocessing_result = TextPreprocessingResult(payload.get("message", {})) params = { @@ -29,9 +29,9 @@ def run(self, payload, user): } log("text preprocessing result: '%(normalized_text)s'", user, params) - answer = self._handle_base(text_preprocessing_result, user) + answer = await self._handle_base(text_preprocessing_result, user) return answer - def _handle_base(self, text_preprocessing_result, user): - answer, is_answer_found = self.dialogue_manager.run(text_preprocessing_result, user) + async def _handle_base(self, text_preprocessing_result, user): + answer, is_answer_found = await self.dialogue_manager.run(text_preprocessing_result, user) return answer or [] diff --git a/smart_kit/handlers/handler_timeout.py b/smart_kit/handlers/handler_timeout.py index 570bcf5e..224bdad9 100644 --- a/smart_kit/handlers/handler_timeout.py +++ b/smart_kit/handlers/handler_timeout.py @@ -10,8 +10,8 @@ class HandlerTimeout(HandlerBase): - def run(self, payload, user): - super().run(payload, user) + async def run(self, payload, user): + await super().run(payload, user) callback_id = user.message.callback_id if user.behaviors.has_callback(callback_id): params = {log_const.KEY_NAME: "handling_timeout"} @@ -28,5 +28,5 @@ def run(self, payload, user): user, app_info=app_info) callback_id = user.message.callback_id - result = user.behaviors.timeout(callback_id) + result = await user.behaviors.timeout(callback_id) return result diff --git a/smart_kit/models/dialogue_manager.py b/smart_kit/models/dialogue_manager.py index afeaa235..6c0f4324 100644 --- a/smart_kit/models/dialogue_manager.py +++ b/smart_kit/models/dialogue_manager.py @@ -27,27 +27,27 @@ def __init__(self, scenario_descriptions, app_name, **kwargs): def _nothing_found_action(self): return self.actions.get(self.NOTHING_FOUND_ACTION) or NothingFoundAction() - def run(self, text_preprocessing_result, user): + async def run(self, text_preprocessing_result, user): before_action = user.descriptions["external_actions"].get("before_action") if before_action: params = user.parametrizer.collect(text_preprocessing_result) - before_action.run(user, text_preprocessing_result, params) + await before_action.run(user, text_preprocessing_result, params) scenarios_names = user.last_scenarios.scenarios_names scenario_key = user.message.payload[field.INTENT] if scenario_key in scenarios_names: scenario = self.scenarios[scenario_key] is_form_filling = isinstance(scenario, FormFillingScenario) if is_form_filling: - if not scenario.text_fits(text_preprocessing_result, user): + if not await scenario.text_fits(text_preprocessing_result, user): params = user.parametrizer.collect(text_preprocessing_result) - if scenario.check_ask_again_requests(text_preprocessing_result, user, params): - reply = scenario.ask_again(text_preprocessing_result, user, params) + if await scenario.check_ask_again_requests(text_preprocessing_result, user, params): + reply = await scenario.ask_again(text_preprocessing_result, user, params) return reply, True smart_kit_metrics.counter_nothing_found(self.app_name, scenario_key, user) - return self._nothing_found_action.run(user, text_preprocessing_result), False - return self.run_scenario(scenario_key, text_preprocessing_result, user), True + return await self._nothing_found_action.run(user, text_preprocessing_result), False + return await self.run_scenario(scenario_key, text_preprocessing_result, user), True - def run_scenario(self, scen_id, text_preprocessing_result, user): + async def run_scenario(self, scen_id, text_preprocessing_result, user): initial_last_scenario = user.last_scenarios.last_scenario_name scenario = self.scenarios[scen_id] params = {log_const.KEY_NAME: log_const.CHOSEN_SCENARIO_VALUE, @@ -55,7 +55,7 @@ def run_scenario(self, scen_id, text_preprocessing_result, user): log_const.SCENARIO_DESCRIPTION_VALUE: scenario.scenario_description } log(log_const.LAST_SCENARIO_MESSAGE, user, params) - run_scenario_result = scenario.run(text_preprocessing_result, user) + run_scenario_result = await scenario.run(text_preprocessing_result, user) actual_last_scenario = user.last_scenarios.last_scenario_name if actual_last_scenario and actual_last_scenario != initial_last_scenario: diff --git a/smart_kit/models/smartapp_model.py b/smart_kit/models/smartapp_model.py index bf5a4855..ecc74c1f 100644 --- a/smart_kit/models/smartapp_model.py +++ b/smart_kit/models/smartapp_model.py @@ -54,19 +54,19 @@ def get_handler(self, message_type): return self._handlers[message_type] @exc_handler(on_error_obj_method_name="on_answer_error") - def answer(self, message, user): + async def answer(self, message, user): user.expire() handler = self.get_handler(message.type) if not user.load_error: - commands = handler.run(message.payload, user) + commands = await handler.run(message.payload, user) else: log("Error in loading user data", user, level="ERROR", exc_info=True) raise Exception("Error in loading user data") return commands - def on_answer_error(self, message, user): + async def on_answer_error(self, message, user): user.do_not_save = True smart_kit_metrics.counter_exception(self.app_name) params = {log_const.KEY_NAME: log_const.DIALOG_ERROR_VALUE, @@ -82,6 +82,6 @@ def on_answer_error(self, message, user): if user.settings["template_settings"].get("debug_info"): set_debug_info(self.app_name, callback_action_params, error) exception_action = user.descriptions["external_actions"]["exception_action"] - commands = exception_action.run(user=user, text_preprocessing_result=None, - params=callback_action_params) + commands = await exception_action.run(user=user, text_preprocessing_result=None, + params=callback_action_params) return commands diff --git a/smart_kit/resources/__init__.py b/smart_kit/resources/__init__.py index b85379d5..e9495137 100644 --- a/smart_kit/resources/__init__.py +++ b/smart_kit/resources/__init__.py @@ -40,9 +40,7 @@ from core.db_adapter.aioredis_adapter import AIORedisAdapter from core.db_adapter.db_adapter import db_adapters from core.db_adapter.ignite_adapter import IgniteAdapter -from core.db_adapter.ignite_thread_adapter import IgniteThreadAdapter from core.db_adapter.memory_adapter import MemoryAdapter -from core.db_adapter.redis_adapter import RedisAdapter from core.descriptions.descriptions import registered_description_factories from core.model.queued_objects.limited_queued_hashable_objects_description import \ LimitedQueuedHashableObjectsDescriptionsItems @@ -393,9 +391,7 @@ def init_requests(self): def init_db_adapters(self): db_adapters[None] = MemoryAdapter db_adapters["ignite"] = IgniteAdapter - db_adapters["ignite_thread"] = IgniteThreadAdapter db_adapters["memory"] = MemoryAdapter - db_adapters["redis"] = RedisAdapter db_adapters["aioredis"] = AIORedisAdapter db_adapters["aioredis_sentinel"] = AIORedisSentinelAdapter diff --git a/smart_kit/start_points/base_main_loop.py b/smart_kit/start_points/base_main_loop.py index 532c8dea..9b301612 100644 --- a/smart_kit/start_points/base_main_loop.py +++ b/smart_kit/start_points/base_main_loop.py @@ -1,13 +1,14 @@ # coding=utf-8 from typing import Type, Iterable +import asyncio import signal import scenarios.logging.logger_constants as log_const from core.db_adapter.db_adapter import DBAdapterException from core.db_adapter.db_adapter import db_adapter_factory from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring from core.monitoring.healthcheck_handler import RootResource from core.monitoring.twisted_server import TwistedServer from core.model.base_user import BaseUser @@ -37,6 +38,7 @@ def __init__( try: signal.signal(signal.SIGINT, self.stop) signal.signal(signal.SIGTERM, self.stop) + self.loop = asyncio.get_event_loop() self.settings = settings self.app_name = self.settings.app_name self.model: SmartAppModel = model @@ -59,20 +61,20 @@ def __init__( self._init_monitoring_config(template_settings) log("%(class_name)s.__init__ completed.", params={log_const.KEY_NAME: log_const.STARTUP_VALUE, - "class_name": self.__class__.__name__}) + "class_name": self.__class__.__name__}) except: log("%(class_name)s.__init__ exception.", params={log_const.KEY_NAME: log_const.STARTUP_VALUE, - "class_name": self.__class__.__name__}, - level="ERROR", exc_info=True) + "class_name": self.__class__.__name__}, + level="ERROR", exc_info=True) raise def get_db(self): db_adapter = db_adapter_factory(self.settings["template_settings"].get("db_adapter", {})) - if db_adapter.IS_ASYNC: + if not db_adapter.IS_ASYNC: raise Exception( - f"Async adapter {db_adapter.__class__.__name__} doesnt compare with {self.__class__.__name__}" + f"Blocking adapter {db_adapter.__class__.__name__} is not good for {self.__class__.__name__}" ) - db_adapter.connect() + self.loop.run_until_complete(db_adapter.connect()) return db_adapter def _generate_answers(self, user, commands, message, **kwargs): @@ -93,20 +95,21 @@ def _create_health_check_server(self, settings): def _init_monitoring_config(self, template_settings): monitoring_config = template_settings["monitoring"] - monitoring.apply_config(monitoring_config) - smart_kit_metrics.apply_config(monitoring_config) + monitoring.monitoring.apply_config(monitoring_config) smart_kit_metrics.init_metrics(app_name=self.app_name) - def load_user(self, db_uid, message): + async def load_user(self, db_uid, message): db_data = None load_error = False try: - db_data = self.db_adapter.get(db_uid) + db_data = await self.db_adapter.get(db_uid) except (DBAdapterException, ValueError): log("Failed to get user data", params={log_const.KEY_NAME: log_const.FAILED_DB_INTERACTION, log_const.REQUEST_VALUE: str(message.value)}, level="ERROR") load_error = True smart_kit_metrics.counter_load_error(self.app_name) + # to skip message when load failed + raise return self.user_cls( message.uid, message=message, @@ -117,22 +120,21 @@ def load_user(self, db_uid, message): load_error=load_error ) - def save_user(self, db_uid, user, message): + async def save_user(self, db_uid, user, message): no_collisions = True if user.do_not_save: log("User %(uid)s will not saved", user=user, params={"uid": user.id, log_const.KEY_NAME: "user_will_not_saved"}) else: - no_collisions = True try: str_data = user.raw_str if user.initial_db_data and self.user_save_check_for_collisions: - no_collisions = self.db_adapter.replace_if_equals(db_uid, - sample=user.initial_db_data, - data=str_data) + no_collisions = await self.db_adapter.replace_if_equals(db_uid, + sample=user.initial_db_data, + data=str_data) else: - self.db_adapter.save(db_uid, str_data) + await self.db_adapter.save(db_uid, str_data) except (DBAdapterException, ValueError): log("Failed to set user data", params={log_const.KEY_NAME: log_const.FAILED_DB_INTERACTION, log_const.REQUEST_VALUE: str(message.value)}, level="ERROR") diff --git a/smart_kit/start_points/main_loop_async_http.py b/smart_kit/start_points/main_loop_async_http.py index 00e8a06a..1ba27c20 100644 --- a/smart_kit/start_points/main_loop_async_http.py +++ b/smart_kit/start_points/main_loop_async_http.py @@ -2,7 +2,6 @@ import os import asyncio -import concurrent.futures import aiohttp import aiohttp.web @@ -18,14 +17,12 @@ class AIOHttpMainLoop(BaseHttpMainLoop): def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.app = aiohttp.web.Application() + self.app.add_routes([aiohttp.web.route('*', '/health', self.get_health_check)]) self.app.add_routes([aiohttp.web.route('*', '/{tail:.*}', self.iterate)]) - super().__init__(*args, **kwargs) - max_workers = self.settings["template_settings"].get("max_workers", (os.cpu_count() or 1) * 5) - self.pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) - async def async_init(self): - await self.db_adapter.connect() + async def async_init(self):await self.db_adapter.connect() def get_db(self): db_adapter = db_adapter_factory(self.settings["template_settings"].get("db_adapter", {})) @@ -82,13 +79,13 @@ async def save_user(self, db_uid, user, message): await self.db_adapter.save(db_uid, str_data) else: if user.initial_db_data and self.user_save_check_for_collisions: - no_collisions = self.db_adapter.replace_if_equals( + no_collisions = await self.db_adapter.replace_if_equals( db_uid, sample=user.initial_db_data, data=str_data ) else: - self.db_adapter.save(db_uid, str_data) + await self.db_adapter.save(db_uid, str_data) except (DBAdapterException, ValueError): log("Failed to set user data", params={log_const.KEY_NAME: log_const.FAILED_DB_INTERACTION, log_const.REQUEST_VALUE: str(message.value)}, level="ERROR") @@ -109,19 +106,33 @@ def stop(self, signum, frame): async def handle_message(self, message: SmartAppFromMessage) -> typing.Tuple[int, str, SmartAppToMessage]: if not message.validate(): - return 400, "BAD REQUEST", SmartAppToMessage(self.BAD_REQUEST_COMMAND, message=message, request=None) + answer = SmartAppToMessage(self.BAD_REQUEST_COMMAND, message=message, request=None) + code = 400 + log(f"OUTGOING DATA: {answer.value} with code: {code}", + params={log_const.KEY_NAME: "outgoing_policy_message", "msg_id": message.incremental_id}) + return code, "BAD REQUEST", answer - answer, stats = await self.process_message(message) + answer, stats, user = await self.process_message(message) if not answer: - return 204, "NO CONTENT", SmartAppToMessage(self.NO_ANSWER_COMMAND, message=message, request=None) + answer = SmartAppToMessage(self.NO_ANSWER_COMMAND, message=message, request=None) + code = 204 + log(f"OUTGOING DATA: {answer.value} with code: {code}", + params={log_const.KEY_NAME: "outgoing_policy_message"}, user=user) + return code, "NO CONTENT", answer - answer_message = SmartAppToMessage( - answer, message, request=None, - validators=self.to_msg_validators) + answer_message = SmartAppToMessage(answer, message, request=None, validators=self.to_msg_validators) if answer_message.validate(): - return 200, "OK", answer_message + code = 200 + log_answer = str(answer_message.value).replace("%", "%%") + log(f"OUTGOING DATA: {log_answer} with code: {code}", + params={log_const.KEY_NAME: "outgoing_policy_message"}, user=user) + return code, "OK", answer_message else: - return 500, "BAD ANSWER", SmartAppToMessage(self.BAD_ANSWER_COMMAND, message=message, request=None) + code = 500 + answer = SmartAppToMessage(self.BAD_ANSWER_COMMAND, message=message, request=None) + log(f"OUTGOING DATA: {answer.value} with code: {code}", + params={log_const.KEY_NAME: "outgoing_policy_message"}, user=user) + return code, "BAD ANSWER", answer async def process_message(self, message: SmartAppFromMessage, *args, **kwargs): stats = "" @@ -133,7 +144,7 @@ async def process_message(self, message: SmartAppFromMessage, *args, **kwargs): user = await self.load_user(db_uid, message) stats += "Loading time: {} msecs\n".format(load_timer.msecs) with StatsTimer() as script_timer: - commands = await self.app.loop.run_in_executor(self.pool, self.model.answer, message, user) + commands = await self.model.answer(message, user) if commands: answer = self._generate_answers(user, commands, message) else: @@ -144,8 +155,14 @@ async def process_message(self, message: SmartAppFromMessage, *args, **kwargs): await self.save_user(db_uid, user, message) stats += "Saving time: {} msecs\n".format(save_timer.msecs) log(stats, params={log_const.KEY_NAME: "timings"}) - self.postprocessor.postprocess(user, message) - return answer, stats + await self.postprocessor.postprocess(user, message) + return answer, stats, user + + async def get_health_check(self, request: aiohttp.web.Request): + status, reason, answer = 200, "OK", "ok" + return aiohttp.web.json_response( + status=status, reason=reason, data=answer, + ) async def iterate(self, request: aiohttp.web.Request): headers = self._get_headers(request.headers) diff --git a/smart_kit/start_points/main_loop_http.py b/smart_kit/start_points/main_loop_http.py index 25cd7ce4..21cf0b8a 100644 --- a/smart_kit/start_points/main_loop_http.py +++ b/smart_kit/start_points/main_loop_http.py @@ -1,3 +1,4 @@ +import asyncio import json import typing from collections import defaultdict @@ -69,10 +70,10 @@ def process_message(self, message: SmartAppFromMessage, *args, **kwargs): db_uid = message.db_uid with StatsTimer() as load_timer: - user = self.load_user(db_uid, message) + user = self.loop.run_until_complete(self.load_user(db_uid, message)) stats += "Loading time: {} msecs\n".format(load_timer.msecs) with StatsTimer() as script_timer: - commands = self.model.answer(message, user) + commands = asyncio.get_event_loop().run_until_complete(self.model.answer(message, user)) if commands: answer = self._generate_answers(user, commands, message) else: @@ -80,10 +81,10 @@ def process_message(self, message: SmartAppFromMessage, *args, **kwargs): stats += "Script time: {} msecs\n".format(script_timer.msecs) with StatsTimer() as save_timer: - self.save_user(db_uid, user, message) + self.loop.run_until_complete(self.save_user(db_uid, user, message)) stats += "Saving time: {} msecs\n".format(save_timer.msecs) log(stats, user=user, params={log_const.KEY_NAME: "timings"}) - self.postprocessor.postprocess(user, message) + self.loop.run_until_complete(self.postprocessor.postprocess(user, message)) return answer, stats def _get_headers(self, environ): diff --git a/smart_kit/start_points/main_loop_kafka.py b/smart_kit/start_points/main_loop_kafka.py index f0bb323d..3fd05983 100644 --- a/smart_kit/start_points/main_loop_kafka.py +++ b/smart_kit/start_points/main_loop_kafka.py @@ -1,25 +1,31 @@ # coding=utf-8 +import asyncio +import cProfile +import gc +import hashlib import json -import time -from collections import namedtuple +import pstats +import random +import signal +import concurrent.futures +import tracemalloc from functools import lru_cache from confluent_kafka.cimpl import KafkaException from lazy import lazy import scenarios.logging.logger_constants as log_const -from core.basic_models.actions.push_action import PUSH_NOTIFY from core.logging.logger_utils import log, UID_STR, MESSAGE_ID_STR from core.message.from_message import SmartAppFromMessage -from core.model.heapq.heapq_storage import HeapqKV +from core.mq.kafka.async_kafka_publisher import AsyncKafkaPublisher from core.mq.kafka.kafka_consumer import KafkaConsumer -from core.mq.kafka.kafka_publisher import KafkaPublisher +from core.utils.memstats import get_top_malloc from core.utils.stats_timer import StatsTimer from core.basic_models.actions.command import Command +from core.utils.utils import current_time_ms from smart_kit.compatibility.commands import combine_commands from smart_kit.message.get_to_message import get_to_message -from smart_kit.message.smart_app_push_message import SmartAppPushToMessage from smart_kit.message.smartapp_to_message import SmartAppToMessage from smart_kit.names import message_names from smart_kit.request.kafka_request import SmartKitKafkaRequest @@ -36,13 +42,25 @@ def _enrich_config_from_secret(kafka_config, secret_config): class MainLoop(BaseMainLoop): + # in milliseconds. log event if elapsed time more than value MAX_LOG_TIME = 20 BAD_ANSWER_COMMAND = Command(message_names.ERROR, {"code": -1, "description": "Invalid Answer Message"}) def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) log("%(class_name)s.__init__ started.", params={log_const.KEY_NAME: log_const.STARTUP_VALUE, "class_name": self.__class__.__name__}) + self.loop = asyncio.get_event_loop() + # We have many async loops for messages processing in main thread + # And 1 thread for independent consecutive Kafka reading + self.health_check_server_future = None + super().__init__(*args, **kwargs) + # We have many async loops for messages processing in main thread + # And 1 thread for independent consecutive Kafka reading + self.kafka_executor_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + self._timers = dict() # stores aio timers for callbacks + self.template_settings = self.settings["template_settings"] + self.worker_tasks = [] + try: kafka_config = _enrich_config_from_secret( self.settings["kafka"]["template-engine"], self.settings.get("secret_kafka", {}) @@ -58,7 +76,7 @@ def __init__(self, *args, **kwargs): if config.get("consumer"): consumers.update({key: KafkaConsumer(kafka_config[key])}) if config.get("publisher"): - publishers.update({key: KafkaPublisher(kafka_config[key])}) + publishers.update({key: AsyncKafkaPublisher(kafka_config[key])}) log( "%(class_name)s FINISHED CONSUMERS/PUBLISHERS CREATE", params={"class_name": self.__class__.__name__}, level="WARNING" @@ -69,44 +87,204 @@ def __init__(self, *args, **kwargs): for key in self.consumers: self.consumers[key].subscribe() self.publishers = publishers - self.behaviors_timeouts_value_cls = namedtuple('behaviors_timeouts_value', - 'db_uid, callback_id, mq_message, kafka_key') - self.behaviors_timeouts = HeapqKV(value_to_key_func=lambda val: val.callback_id) + self.concurrent_messages = 0 + log("%(class_name)s.__init__ completed.", params={log_const.KEY_NAME: log_const.STARTUP_VALUE, "class_name": self.__class__.__name__}) - except: + except Exception: log("%(class_name)s.__init__ exception.", params={log_const.KEY_NAME: log_const.STARTUP_VALUE, "class_name": self.__class__.__name__}, level="ERROR", exc_info=True) raise - def pre_handle(self): - self.iterate_behavior_timeouts() - def run(self): + signal.signal(signal.SIGINT, self.stop) + signal.signal(signal.SIGTERM, self.stop) log("%(class_name)s.run started", params={log_const.KEY_NAME: log_const.STARTUP_VALUE, "class_name": self.__class__.__name__}) - while self.is_work: - self.pre_handle() - for kafka_key in self.consumers: - self.iterate(kafka_key) + # try: + loop = asyncio.get_event_loop() + loop.run_until_complete(self.general_coro()) - if self.health_check_server: - with StatsTimer() as health_check_server_timer: - self.health_check_server.iterate() + log("MainLoop stopping kafka", level="WARNING") - if health_check_server_timer.msecs >= self.MAX_LOG_TIME: - log("Health check iterate time: {} msecs\n".format(health_check_server_timer.msecs), - params={log_const.KEY_NAME: "slow_health_check", - "time_msecs": health_check_server_timer.msecs}, level="WARNING") - - log("Stopping Kafka handler", level="WARNING") for kafka_key in self.consumers: self.consumers[kafka_key].close() - log("Kafka consumer connection is closed", level="WARNING") + for kafka_key in self.publishers: self.publishers[kafka_key].close() - log("Kafka publisher connection is closed", level="WARNING") - log("Kafka handler is stopped", level="WARNING") + # except (SystemExit,) as e: + log("MainLoop EXIT.", level="WARNING") + # raise e + + async def general_coro(self): + tasks = [self.process_consumer(kafka_key) for kafka_key in self.consumers] + if self.health_check_server is not None: + tasks.append(self.healthcheck_coro()) + await asyncio.gather(*tasks) + + async def healthcheck_coro(self): + while self.is_work: + if not self.health_check_server_future or self.health_check_server_future.done() or \ + self.health_check_server_future.cancelled(): + self.health_check_server_future = self.loop.run_in_executor(None, self.health_check_server.iterate) + await asyncio.sleep(0.5) + log("healthcheck_coro stopped") + + async def process_consumer(self, kafka_key): + consumer = self.consumers[kafka_key] + loop = asyncio.get_event_loop() + max_concurrent_messages = self.template_settings.get("max_concurrent_messages", 100) + total_messages = 0 + + profiling_settings = self.template_settings.get("profiling", {}) + profile_cpu = profiling_settings.get("cpu", False) + profile_cpu_path = profiling_settings.get("cpu_path", "/tmp/dp.cpu.prof") + profile_memory = profiling_settings.get("memory", False) + profile_memory_log_delta = profiling_settings.get("memory_log_delta", 30) + profile_memory_depth = profiling_settings.get("memory_depth", 4) + + async def worker(iteration, queue): + nonlocal total_messages + message_value = None + user = None + validation_failed = False + last_poll_begin_time = self.loop.time() + last_mem_log = self.loop.time() + log(f"-- Starting {iteration} iter") + + while self.is_work: + if profile_memory and iteration == 0 and self.loop.time() - last_mem_log > profile_memory_log_delta: + top = get_top_malloc(trace_limit=0) + async_counts = len(self.loop._ready), len(self.loop._scheduled), len(self.loop._asyncgens) + async_values = " + ".join(map(str, async_counts)) + log( + f"Total memory: {top}; " + f"Async: {async_values} = {sum(async_counts)}; " + f"Trash: {gc.get_count()} ", + level="DEBUG" + ) + last_mem_log = self.loop.time() + + from_last_poll_begin_ms = int((self.loop.time() - last_poll_begin_time) * 1000) + stats = f"From last message coro {iteration} time: {from_last_poll_begin_ms} msecs\n" + log_params = { + log_const.KEY_NAME: "timings", + "from_last_poll_begin_ms": from_last_poll_begin_ms, + "iteration": iteration + } + last_poll_begin_time = self.loop.time() + + try: + mq_message = await queue.get() + self.concurrent_messages += 1 + if mq_message: + log(f"\n-- Processing {self.concurrent_messages} msgs at {iteration} iter\n") + total_messages += 1 + headers = mq_message.headers() + if headers is None: + raise Exception("No incoming message headers found.") + message_value = json.loads(mq_message.value()) + await self.process_message(mq_message, consumer, kafka_key, stats, log_params) + + except KafkaException as kafka_exp: + self.concurrent_messages -= 1 + log("kafka error: %(kafka_exp)s.", + params={log_const.KEY_NAME: log_const.STARTUP_VALUE, + "kafka_exp": str(kafka_exp), + log_const.REQUEST_VALUE: str(message_value)}, + level="ERROR", exc_info=True) + queue.task_done() + + except Exception: + self.concurrent_messages -= 1 + log("%(class_name)s iterate error. Kafka key %(kafka_key)s", + params={log_const.KEY_NAME: "worker_exception", + "kafka_key": kafka_key, + log_const.REQUEST_VALUE: str(message_value)}, + level="ERROR", exc_info=True) + try: + consumer.commit_offset(mq_message) + except Exception: + log("Error handling worker fail exception.", level="ERROR", exc_info=True) + raise + queue.task_done() + else: + self.concurrent_messages -= 1 + queue.task_done() + + # END while self.is_work + log(f"-- Stop {iteration} iter") + + start_time = self.loop.time() + if profile_cpu: + cpu_pr = cProfile.Profile() + cpu_pr.enable() + else: + cpu_pr = None + if profile_memory: + tracemalloc.start(profile_memory_depth) + + log(f"Starting %(class_name)s in {max_concurrent_messages} coro", + params={"class_name": self.__class__.__name__}) + + # TODO: think about queue maxsize + queues = [asyncio.Queue() for _ in range(max_concurrent_messages)] + + for i, queue in enumerate(queues): + task = asyncio.create_task(worker(f'worker-{i}', queue)) + self.worker_tasks.append(task) + + await self.poll_kafka(consumer, queues) # blocks while self.is_works + + log("waiting for process unfinished tasks in queues") + await asyncio.gather(*(queue.join() for queue in queues)) + + time_delta = self.loop.time() - start_time + log(f"Process Consumer exit: {total_messages} msg in {int(time_delta)} sec", level="DEBUG") + + t = self.loop.time() + delay = self.template_settings.get("behavior_timers_tear_down_delay", 15) + log(f"wait timers to do their jobs for {delay} secs...") + while self._timers and (self.loop.time() - t) < delay: + await asyncio.sleep(1) + + for task in self.worker_tasks: + cancell_status = task.cancel() + log(f"{task} cancell status: {cancell_status} ") + + log(f"Stop consuming messages. All workers closed, erasing {len(self._timers)} timers.") + + if profile_memory: + log(f"{get_top_malloc(trace_limit=16)}") + tracemalloc.stop() + if cpu_pr is not None: + cpu_pr.disable() + stats = pstats.Stats(cpu_pr) + stats.sort_stats(pstats.SortKey.TIME) + stats.print_stats(10) + stats.dump_stats(filename=profile_cpu_path) + + async def poll_kafka(self, consumer, queues): + while self.is_work: + with StatsTimer() as poll_timer: + # Max delay between polls configured in consumer.poll_timeout param + mq_message = consumer.poll() + if poll_timer.msecs > self.MAX_LOG_TIME: + log_params = {"kafka_polling": poll_timer.msecs} + log(f"Long poll time: %(kafka_polling)s msecs\n", params=log_params, level="WARNING") + + if mq_message: + key = mq_message.key() + if key: + queue_index = int(hashlib.sha1(key).hexdigest(), 16) % len(queues) + else: + queue_index = random.randrange(len(queues)) + + # this will block if queue is full! + await queues[queue_index].put(mq_message) + else: + await asyncio.sleep(self.template_settings.get("no_kafka_messages_poll_time", 0.01)) + log(f"Stop poll_kafka consumer.") def _generate_answers(self, user, commands, message, **kwargs): topic_key = kwargs["topic_key"] @@ -121,8 +299,8 @@ def _generate_answers(self, user, commands, message, **kwargs): request.update_empty_items({"topic_key": topic_key, "kafka_key": kafka_key}) to_message = get_to_message(command.name) answer = to_message(command=command, message=message, request=request, - masking_fields=self.masking_fields, - validators=self.to_msg_validators) + masking_fields=self.masking_fields, + validators=self.to_msg_validators) if answer.validate(): answers.append(answer) else: @@ -140,103 +318,57 @@ def _get_timeout_from_message(self, orig_message_raw, callback_id, headers): timeout_from_message.callback_id = callback_id return timeout_from_message - def iterate_behavior_timeouts(self): - now = time.time() - while now > (self.behaviors_timeouts.get_head_key() or float("inf")): - _, behavior_timeout_value = self.behaviors_timeouts.pop() - db_uid, callback_id, mq_message, kafka_key = behavior_timeout_value - try: - save_tries = 0 - user_save_no_collisions = False - user = None - while save_tries < self.user_save_collisions_tries and not user_save_no_collisions: - save_tries += 1 - - orig_message_raw = json.loads(mq_message.value()) - orig_message_raw[SmartAppFromMessage.MESSAGE_NAME] = message_names.LOCAL_TIMEOUT - - timeout_from_message = self._get_timeout_from_message(orig_message_raw, callback_id, - headers=mq_message.headers()) - - user = self.load_user(db_uid, timeout_from_message) - commands = self.model.answer(timeout_from_message, user) - topic_key = self._get_topic_key(mq_message, kafka_key) - answers = self._generate_answers(user=user, commands=commands, message=timeout_from_message, - topic_key=topic_key, - kafka_key=kafka_key) - - user_save_no_collisions = self.save_user(db_uid, user, mq_message) - - if user and not user_save_no_collisions: - log( - "MainLoop.iterate_behavior_timeouts: save user got collision on uid %(uid)s db_version %(db_version)s.", - user=user, - params={log_const.KEY_NAME: "ignite_collision", - "db_uid": db_uid, - "message_key": mq_message.key(), - "kafka_key": kafka_key, - "uid": user.id, - "db_version": str(user.variables.get(user.USER_DB_VERSION))}, - level="WARNING") - - continue - - if not user_save_no_collisions: - log( - "MainLoop.iterate_behavior_timeouts: db_save collision all tries left on uid %(uid)s db_version %(db_version)s.", - user=user, - params={log_const.KEY_NAME: "ignite_collision", - "db_uid": db_uid, - "message_key": mq_message.key(), - "message_partition": mq_message.partition(), - "kafka_key": kafka_key, - "uid": user.id, - "db_version": str(user.variables.get(user.USER_DB_VERSION))}, - level="WARNING") - - smart_kit_metrics.counter_save_collision_tries_left(self.app_name) - self.save_behavior_timeouts(user, mq_message, kafka_key) - for answer in answers: - self._send_request(user, answer, mq_message) - except: - log("%(class_name)s error.", params={log_const.KEY_NAME: "error_handling_timeout", - "class_name": self.__class__.__name__, - log_const.REQUEST_VALUE: str(mq_message.value())}, - level="ERROR", exc_info=True) - def _get_topic_key(self, mq_message, kafka_key): topic_names_2_key = self._topic_names_2_key(kafka_key) return self.default_topic_key(kafka_key) or topic_names_2_key[mq_message.topic()] - def process_message(self, mq_message, consumer, kafka_key, stats): + async def process_message(self, mq_message, consumer, kafka_key, stats, log_params): + user = None topic_key = self._get_topic_key(mq_message, kafka_key) - save_tries = 0 - user_save_no_collisions = False - user = None + user_save_ok = False + skip_timeout = False db_uid = None + validation_failed = False + message_handled_ok = False message = None - while save_tries < self.user_save_collisions_tries and not user_save_no_collisions: + while save_tries < self.user_save_collisions_tries and not user_save_ok: save_tries += 1 message_value = mq_message.value() - message = SmartAppFromMessage(message_value, - headers=mq_message.headers(), + message = SmartAppFromMessage(message_value, headers=mq_message.headers(), masking_fields=self.masking_fields, creation_time=consumer.get_msg_create_time(mq_message)) - # TODO вернуть проверку ключа!!! if message.validate(): + log( + "Incoming RAW message: %(message)s", params={"message": message.masked_value}, + level="DEBUG") waiting_message_time = 0 if message.creation_time: - waiting_message_time = time.time() * 1000 - message.creation_time - stats += "Waiting message: {} msecs\n".format(waiting_message_time) + waiting_message_time = current_time_ms() - message.creation_time + stats += f"Waiting message: {waiting_message_time} msecs\n" + log_params["waiting_message"] = waiting_message_time + + stats += f"Mid: {message.incremental_id}\n" + log_params[MESSAGE_ID_STR] = message.incremental_id - stats += "Mid: {}\n".format(message.incremental_id) smart_kit_metrics.sampling_mq_waiting_time(self.app_name, waiting_message_time / 1000) + if self._is_message_timeout_to_skip(message, waiting_message_time): + skip_timeout = True + break + + db_uid = message.db_uid + with StatsTimer() as load_timer: + user = await self.load_user(db_uid, message) self.check_message_key(message, mq_message.key(), user) + stats += f"Loading user time from DB time: {load_timer.msecs} msecs\n" + log_params["user_loading"] = load_timer.msecs + smart_kit_metrics.sampling_load_time(self.app_name, load_timer.secs) + log( - "INCOMING FROM TOPIC: %(topic)s partition %(message_partition)s HEADERS: %(headers)s DATA: %(incoming_data)s", + "INCOMING FROM TOPIC: %(topic)s partition %(message_partition)s HEADERS: %(headers)s DATA: %(" + "incoming_data)s", params={log_const.KEY_NAME: "incoming_message", "topic": mq_message.topic(), "message_partition": mq_message.partition(), @@ -250,7 +382,7 @@ def process_message(self, mq_message, consumer, kafka_key, stats): "surface": message.device.surface, MESSAGE_ID_STR: message.incremental_id}, user=user - ) + ) db_uid = message.db_uid with StatsTimer() as load_timer: @@ -258,26 +390,27 @@ def process_message(self, mq_message, consumer, kafka_key, stats): smart_kit_metrics.sampling_load_time(self.app_name, load_timer.secs) stats += "Loading time: {} msecs\n".format(load_timer.msecs) with StatsTimer() as script_timer: - commands = self.model.answer(message, user) + commands = await self.model.answer(message, user) - answers = self._generate_answers(user=user, commands=commands, message=message, - topic_key=topic_key, + answers = self._generate_answers(user=user, commands=commands, message=message, topic_key=topic_key, kafka_key=kafka_key) + + stats += f"Script time: {script_timer.msecs} msecs\n" + log_params["script_time"] = script_timer.msecs smart_kit_metrics.sampling_script_time(self.app_name, script_timer.secs) - stats += "Script time: {} msecs\n".format(script_timer.msecs) with StatsTimer() as save_timer: - user_save_no_collisions = self.save_user(db_uid, user, message) + user_save_ok = await self.save_user(db_uid, user, message) + stats += "Saving user to DB time: {} msecs\n".format(save_timer.msecs) + log_params["user_saving"] = save_timer.msecs smart_kit_metrics.sampling_save_time(self.app_name, save_timer.secs) - stats += "Saving time: {} msecs\n".format(save_timer.msecs) - if not user_save_no_collisions: - log( - "MainLoop.iterate: save user got collision on uid %(uid)s db_version %(db_version)s.", + if not user_save_ok: + log("MainLoop.iterate: save user got collision on uid %(uid)s db_version %(db_version)s.", user=user, params={log_const.KEY_NAME: "ignite_collision", "db_uid": db_uid, - "message_key": mq_message.key(), + "message_key": (mq_message.key() or b"").decode('utf-8', 'backslashreplace'), "message_partition": mq_message.partition(), "kafka_key": kafka_key, "uid": user.id, @@ -285,103 +418,117 @@ def process_message(self, mq_message, consumer, kafka_key, stats): level="WARNING") continue - self.save_behavior_timeouts(user, mq_message, kafka_key) - - if mq_message.headers() is None: - mq_message.set_headers([]) + message_handled_ok = True + if answers: + self.save_behavior_timeouts(user, mq_message, kafka_key) if answers: for answer in answers: with StatsTimer() as publish_timer: self._send_request(user, answer, mq_message) - stats += "Publishing time: {} msecs".format(publish_timer.msecs) - log(stats, user=user) + smart_kit_metrics.counter_outgoing(self.app_name, answer.command.name, answer, user) + stats += f"Publishing to Kafka time: {publish_timer.msecs} msecs\n" + log_params["kafka_publishing"] = publish_timer.msecs else: + validation_failed = True + data = None + mid = None try: data = message.masked_value + mid = message.incremental_id except: - data = "" + pass log(f"Message validation failed, skip message handling.", params={log_const.KEY_NAME: "invalid_message", - "data": data}, level="ERROR") + "data": data, + MESSAGE_ID_STR: mid}, level="ERROR") smart_kit_metrics.counter_invalid_message(self.app_name) - if user and not user_save_no_collisions: - log( - "MainLoop.iterate: db_save collision all tries left on uid %(uid)s db_version %(db_version)s.", + break + if stats: + log(stats, user=user, params=log_params) + + if user and not user_save_ok and not validation_failed and not skip_timeout: + log("MainLoop.iterate: db_save collision all tries left on uid %(uid)s db_version %(db_version)s.", user=user, params={log_const.KEY_NAME: "ignite_collision", "db_uid": db_uid, - "message_key": mq_message.key(), + "message_key": (mq_message.key() or b"").decode('utf-8', 'backslashreplace'), "message_partition": mq_message.partition(), "kafka_key": kafka_key, "uid": user.id, "db_version": str(user.variables.get(user.USER_DB_VERSION))}, level="WARNING") - self.postprocessor.postprocess(user, message) + await self.postprocessor.postprocess(user, message) smart_kit_metrics.counter_save_collision_tries_left(self.app_name) - consumer.commit_offset(mq_message) - - def iterate(self, kafka_key): - consumer = self.consumers[kafka_key] - mq_message = None - message_value = None - try: - mq_message = None - message_value = None - with StatsTimer() as poll_timer: - mq_message = consumer.poll() - if mq_message: - stats = "Polling time: {} msecs\n".format(poll_timer.msecs) - message_value = mq_message.value() # DRY! - self.process_message(mq_message, consumer, kafka_key, stats) - - except KafkaException as kafka_exp: - log("kafka error: %(kafka_exp)s. MESSAGE: {}.".format(message_value), - params={log_const.KEY_NAME: log_const.STARTUP_VALUE, - "kafka_exp": str(kafka_exp), - log_const.REQUEST_VALUE: str(message_value)}, - level="ERROR", exc_info=True) - except Exception: - try: - log("%(class_name)s iterate error. Kafka key %(kafka_key)s MESSAGE: {}.".format(message_value), - params={log_const.KEY_NAME: log_const.STARTUP_VALUE, - "kafka_key": kafka_key}, - level="ERROR", exc_info=True) - consumer.commit_offset(mq_message) - except Exception: - log("Error handling worker fail exception.", - level="ERROR", exc_info=True) + consumer.commit_offset(mq_message) + if message_handled_ok: + self.remove_timer(message) + + def remove_timer(self, kafka_message): + if kafka_message.has_callback_id: + timer = self._timers.pop(kafka_message.callback_id, None) + if timer is not None: + log(f"Removing aio timer for callback {kafka_message.callback_id}. Have {len(self._timers)} running " + f"timers.", level="DEBUG") + timer.cancel() + + def _is_message_timeout_to_skip(self, message, waiting_message_time): + # Returns True if timeout is found + waiting_message_timeout = self.settings["template_settings"].get("waiting_message_timeout", {}) + warning_delay = waiting_message_timeout.get('warning', 200) + skip_delay = waiting_message_timeout.get('skip', 8000) + log_level = None + make_break = False + + if waiting_message_time >= skip_delay: + # Too old message + log_level = "ERROR" + make_break = True + + elif waiting_message_time >= warning_delay: + # Warn, but continue message processing + log_level = "WARNING" + smart_kit_metrics.counter_mq_long_waiting(self.app_name) + + if log_level is not None: + log( + f"Out of time message %(waiting_message_time)s msecs, " + f"mid: %(mid)s {message.as_dict}", + params={ + log_const.KEY_NAME: "waiting_message_timeout", + "waiting_message_time": waiting_message_time, + "mid": message.incremental_id + }, + level=log_level) + return make_break def check_message_key(self, from_message, message_key, user): sub = from_message.sub channel = from_message.channel uid = from_message.uid - message_key = message_key or b"" + valid_key = "_".join([i for i in [channel, sub, uid] if i]) + try: - params = [channel, sub, uid] - valid_key = "" - for value in params: - if value: - valid_key = "{}{}{}".format(valid_key, "_", value) if valid_key else "{}".format(value) - key_str = message_key.decode() - - message_key_is_valid = key_str == valid_key - if not message_key_is_valid: - log(f"Failed to check Kafka message key {message_key} != {valid_key}", - params={ - log_const.KEY_NAME: "check_kafka_key_validation", - MESSAGE_ID_STR: from_message.incremental_id, - UID_STR: uid - }, user=user, - level="WARNING") - except: - log(f"Exception to check Kafka message key {message_key}", + message_key = message_key or b"" + if isinstance(message_key, bytes): + message_key = message_key.decode() + except UnicodeDecodeError: + log(f"Decode error to check Kafka message key {message_key}", params={log_const.KEY_NAME: "check_kafka_key_error", MESSAGE_ID_STR: from_message.incremental_id, UID_STR: uid }, user=user, level="ERROR") + if message_key != valid_key: + log(f"Failed to check Kafka message key {message_key} != {valid_key}", + params={ + log_const.KEY_NAME: "check_kafka_key_validation", + MESSAGE_ID_STR: from_message.incremental_id, + UID_STR: uid + }, user=user, + level="WARNING") + def _send_request(self, user, answer, mq_message): kafka_broker_settings = self.settings["template_settings"].get( "route_kafka_broker" @@ -425,24 +572,105 @@ def masking_fields(self): return self.settings["template_settings"].get("masking_fields") def save_behavior_timeouts(self, user, mq_message, kafka_key): - for i, (expire_time_us, callback_id) in enumerate(user.behaviors.get_behavior_timeouts()): - # two behaviors can be created in one query, so we need add some salt to make theirs key unique - unique_key = expire_time_us + i * 1e-5 - log( - "%(class_name)s: adding local_timeout on callback %(callback_id)s with timeout on %(unique_key)s", + for (behavior_delay, callback_id) in user.behaviors.get_behavior_timeouts(): + log("%(class_name)s: adding local_timeout on callback %(callback_id)s with delay in %(delay)s seconds.", params={log_const.KEY_NAME: "adding_local_timeout", "class_name": self.__class__.__name__, "callback_id": callback_id, - "unique_key": unique_key}) - self.behaviors_timeouts.push(unique_key, self.behaviors_timeouts_value_cls._make( - (user.message.db_uid, callback_id, mq_message, kafka_key))) + "delay": behavior_delay}) - for callback_id in user.behaviors.get_returned_callbacks(): - log("%(class_name)s: removing local_timeout on callback %(callback_id)s", - params={log_const.KEY_NAME: "removing_local_timeout", - "class_name": self.__class__.__name__, - "callback_id": callback_id}) - self.behaviors_timeouts.remove(callback_id) + self._timers[callback_id] = self.loop.call_later( + behavior_delay, self.loop.create_task, + self.do_behavior_timeout(user.message.db_uid, callback_id, mq_message, kafka_key) + ) def stop(self, signum, frame): + log("Stop signal handler!") self.is_work = False + + async def do_behavior_timeout(self, db_uid, callback_id, mq_message, kafka_key): + try: + save_tries = 0 + user_save_ok = False + answers = [] + user = None + while save_tries < self.user_save_collisions_tries and not user_save_ok: + callback_found = False + log(f"MainLoop.do_behavior_timeout: handling callback {callback_id}. for db_uid {db_uid}. try " + f"{save_tries}.") + + save_tries += 1 + + orig_message_raw = json.loads(mq_message.value()) + orig_message_raw[SmartAppFromMessage.MESSAGE_NAME] = message_names.LOCAL_TIMEOUT + + timeout_from_message = self._get_timeout_from_message(orig_message_raw, callback_id, + headers=mq_message.headers()) + + user = await self.load_user(db_uid, timeout_from_message) + # TODO: not to load user to check behaviors.has_callback ? + + self.remove_timer(timeout_from_message) + + if user.behaviors.has_callback(callback_id): + callback_found = True + commands = await self.model.answer(timeout_from_message, user) + topic_key = self._get_topic_key(mq_message, kafka_key) + answers = self._generate_answers(user=user, commands=commands, message=timeout_from_message, + topic_key=topic_key, + kafka_key=kafka_key) + + user_save_ok = await self.save_user(db_uid, user, mq_message) + + if not user_save_ok: + log("MainLoop.do_behavior_timeout: save user got collision on uid %(uid)s db_version %(" + "db_version)s.", + user=user, + params={log_const.KEY_NAME: "ignite_collision", + "db_uid": db_uid, + "message_key": mq_message.key(), + "kafka_key": kafka_key, + "uid": user.id, + "db_version": str(user.variables.get(user.USER_DB_VERSION))}, + level="WARNING") + + if not user_save_ok and callback_found: + log("MainLoop.do_behavior_timeout: db_save collision all tries left on uid %(uid)s db_version " + "%(db_version)s.", + user=user, + params={log_const.KEY_NAME: "ignite_collision", + "db_uid": db_uid, + "message_key": mq_message.key(), + "message_partition": mq_message.partition(), + "kafka_key": kafka_key, + "uid": user.id, + "db_version": str(user.variables.get(user.USER_DB_VERSION))}, + level="WARNING") + + smart_kit_metrics.counter_save_collision_tries_left(self.app_name) + if user_save_ok: + self.save_behavior_timeouts(user, mq_message, kafka_key) + for answer in answers: + self._send_request(user, answer, mq_message) + except: + log("%(class_name)s error.", params={log_const.KEY_NAME: "error_handling_timeout", + "class_name": self.__class__.__name__, + log_const.REQUEST_VALUE: str(mq_message.value())}, + level="ERROR", exc_info=True) + + def _incoming_message_log(self, user, mq_message, message, kafka_key, waiting_message_time): + log( + "INCOMING FROM TOPIC: %(topic)s partition %(message_partition)s HEADERS: %(headers)s DATA: %(" + "incoming_data)s", + params={log_const.KEY_NAME: "incoming_message", + "topic": mq_message.topic(), + "message_partition": mq_message.partition(), + "message_key": mq_message.key(), + "kafka_key": kafka_key, + "incoming_data": str(message.masked_value), + "headers": str(mq_message.headers()), + "waiting_message": waiting_message_time, + "surface": message.device.surface, + MESSAGE_ID_STR: message.incremental_id}, + user=user + ) diff --git a/smart_kit/start_points/postprocess.py b/smart_kit/start_points/postprocess.py index 08752619..377d97cf 100644 --- a/smart_kit/start_points/postprocess.py +++ b/smart_kit/start_points/postprocess.py @@ -3,19 +3,19 @@ class PostprocessMainLoop: - def postprocess(self, user, message, *args, **kwargs): + async def postprocess(self, user, message, *args, **kwargs): pass class PostprocessCompose(PostprocessMainLoop): postprocessors: List[PostprocessMainLoop] = [] - def postprocess(self, user, message, *args, **kwargs): + async def postprocess(self, user, message, *args, **kwargs): for processor in self.postprocessors: - processor.postprocess(user, message, *args, **kwargs) + await processor.postprocess(user, message, *args, **kwargs) -def postprocessor_compose(*args: List[Type[PostprocessMainLoop]]): +def postprocessor_compose(*args: Type[PostprocessMainLoop]): class Compose(PostprocessCompose): postprocessors = [processor_cls() for processor_cls in args] return Compose diff --git a/smart_kit/system_answers/nothing_found_action.py b/smart_kit/system_answers/nothing_found_action.py index e1986337..9c0947ac 100644 --- a/smart_kit/system_answers/nothing_found_action.py +++ b/smart_kit/system_answers/nothing_found_action.py @@ -18,6 +18,6 @@ def __init__(self, items: Dict[str, Any] = None, id: Optional[str] = None): super(NothingFoundAction, self).__init__(items, id) self._action = StringAction({"command": NOTHING_FOUND}) - def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: - return self._action.run(user, text_preprocessing_result, params=params) + async def run(self, user: User, text_preprocessing_result: BaseTextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> Optional[List[Command]]: + return await self._action.run(user, text_preprocessing_result, params=params) diff --git a/smart_kit/template/app/basic_entities/actions.py-tpl b/smart_kit/template/app/basic_entities/actions.py-tpl index cceafe97..9a67afad 100644 --- a/smart_kit/template/app/basic_entities/actions.py-tpl +++ b/smart_kit/template/app/basic_entities/actions.py-tpl @@ -18,7 +18,7 @@ class CustomAction(Action): items = items or {} self.test_param = items.get("test_param") - def run(self, user: User, text_preprocessing_result: TextPreprocessingResult, - params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: + async def run(self, user: User, text_preprocessing_result: TextPreprocessingResult, + params: Optional[Dict[str, Union[str, float, int]]] = None) -> None: print("Test Action") return None diff --git a/smart_kit/template/app/basic_entities/fillers.py-tpl b/smart_kit/template/app/basic_entities/fillers.py-tpl index 05234d32..2f3152e2 100644 --- a/smart_kit/template/app/basic_entities/fillers.py-tpl +++ b/smart_kit/template/app/basic_entities/fillers.py-tpl @@ -20,5 +20,5 @@ class CustomFieldFiller(FieldFillerDescription): self.test_item = items.get("test_item") @exc_handler(on_error_obj_method_name="on_extract_error") - def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, params) -> Optional[str]: + async def extract(self, text_preprocessing_result: TextPreprocessingResult, user: User, params) -> Optional[str]: return None diff --git a/smart_kit/template/app/basic_entities/requirements.py-tpl b/smart_kit/template/app/basic_entities/requirements.py-tpl index 14d0c77c..153743f7 100644 --- a/smart_kit/template/app/basic_entities/requirements.py-tpl +++ b/smart_kit/template/app/basic_entities/requirements.py-tpl @@ -18,6 +18,6 @@ class CustomRequirement(Requirement): items = items or {} self.test_param = items.get("test_param") - def check(self, text_preprocessing_result: TextPreprocessingResult, - user: User, params: Dict[str, Any] = None) -> bool: + async def check(self, text_preprocessing_result: TextPreprocessingResult, + user: User, params: Dict[str, Any] = None) -> bool: return False diff --git a/smart_kit/template/app/handlers/handlers.py-tpl b/smart_kit/template/app/handlers/handlers.py-tpl index e5d4a6f3..faaca554 100644 --- a/smart_kit/template/app/handlers/handlers.py-tpl +++ b/smart_kit/template/app/handlers/handlers.py-tpl @@ -8,5 +8,5 @@ class CustomHandler(HandlerBase): Тут создаются Handlers, которые используются для запуска логики в зависимости от типа входяшего сообщения """ - def run(self, payload, user): + async def run(self, payload, user): return [] diff --git a/smart_kit/template/app/models/dialogue_manager.py-tpl b/smart_kit/template/app/models/dialogue_manager.py-tpl index 73864d6d..cc68f113 100644 --- a/smart_kit/template/app/models/dialogue_manager.py-tpl +++ b/smart_kit/template/app/models/dialogue_manager.py-tpl @@ -8,4 +8,4 @@ class CustomDialogueManager(DialogueManager): """ def __init__(self, scenario_descriptions, app_name, **kwargs): - super(CustomDialogueManager, self).__init__(scenario_descriptions, app_name, **kwargs) \ No newline at end of file + super(CustomDialogueManager, self).__init__(scenario_descriptions, app_name, **kwargs) diff --git a/smart_kit/template/app/models/model.py-tpl b/smart_kit/template/app/models/model.py-tpl index 3d7d03c2..3b9fdea8 100644 --- a/smart_kit/template/app/models/model.py-tpl +++ b/smart_kit/template/app/models/model.py-tpl @@ -10,4 +10,3 @@ class CustomModel(SmartAppModel): def __init__(self, resources, dialogue_manager_cls, custom_settings, **kwargs): super(CustomModel, self).__init__(resources, dialogue_manager_cls, custom_settings, **kwargs) self._handlers.update({}) - diff --git a/smart_kit/template/app/resources/custom_app_resourses.py-tpl b/smart_kit/template/app/resources/custom_app_resourses.py-tpl index bd3ec8bc..7641efcf 100644 --- a/smart_kit/template/app/resources/custom_app_resourses.py-tpl +++ b/smart_kit/template/app/resources/custom_app_resourses.py-tpl @@ -39,4 +39,3 @@ class CustomAppResourses(SmartAppResources): def init_db_adapters(self): super(CustomAppResourses, self).init_db_adapters() db_adapters["custom_db_adapter"] = CustomDBAdapter - diff --git a/smart_kit/template/static/configs/template_config.yml b/smart_kit/template/static/configs/template_config.yml index 08dfc541..cdef6480 100644 --- a/smart_kit/template/static/configs/template_config.yml +++ b/smart_kit/template/static/configs/template_config.yml @@ -19,4 +19,5 @@ masking_fields: user_save_collisions_tries: 2 self_service_with_state_save_messages: true project_id: template-app-id -consumer_topic: "app" \ No newline at end of file +consumer_topic: "app" +max_concurrent_messages: 1 diff --git a/smart_kit/testing/local.py b/smart_kit/testing/local.py index b623d26e..d9cdbddf 100644 --- a/smart_kit/testing/local.py +++ b/smart_kit/testing/local.py @@ -1,3 +1,4 @@ +import asyncio import cmd import json import os @@ -148,7 +149,7 @@ def process_message(self, raw_message: str, headers: tuple = ()) -> typing.Tuple user = self.__user_cls(self.environment.user_id, message, self.user_data, self.settings, self.app_model.scenario_descriptions, self.__parametrizer_cls, load_error=False) - answers = self.app_model.answer(message, user) + answers = asyncio.get_event_loop().run_until_complete(self.app_model.answer(message, user)) return user, answers or [] def default(self, _input: str): diff --git a/smart_kit/testing/suite.py b/smart_kit/testing/suite.py index 59ed1866..2cd9e5ef 100644 --- a/smart_kit/testing/suite.py +++ b/smart_kit/testing/suite.py @@ -1,3 +1,4 @@ +import asyncio import json import os from csv import DictWriter, QUOTE_MINIMAL @@ -34,17 +35,17 @@ def run_testfile(path: AnyStr, file: AnyStr, app_model: SmartAppModel, settings: csv_case_callback = csv_file_callback(test_case) else: csv_case_callback = None - if test_case_cls( - app_model, - settings, - user_cls, - parametrizer_cls, - from_msg_cls, - **test_params, - storaged_predefined_fields=storaged_predefined_fields, - interactive=interactive, - csv_case_callback=csv_case_callback, - ).run(): + if asyncio.get_event_loop().run_until_complete(TestCase( + app_model, + settings, + user_cls, + parametrizer_cls, + from_msg_cls, + **test_params, + storaged_predefined_fields=storaged_predefined_fields, + interactive=interactive, + csv_case_callback=csv_case_callback, + ).run()): print(f"[+] {test_case} OK") success += 1 print(f"[+] {file} {success}/{len(json_obj)}") @@ -146,22 +147,22 @@ def __init__(self, app_model: SmartAppModel, settings: Settings, user_cls: type, self.__user_cls = user_cls self.__from_msg_cls = from_msg_cls - def run(self) -> bool: + async def run(self) -> bool: success = True app_callback_id = None - for index, message_ in enumerate(self.messages): + for index, message in enumerate(self.messages): print('Шаг', index) if index and self.interactive: print("Нажмите ENTER, чтобы продолжить...") input() - request = message_["request"] - response = message_["response"] + request = message["request"] + response = message["response"] # Если использован флаг linkPreviousByCallbackId и после предыдущего сообщения был сохранен app_callback_id, # сообщению добавляются заголовки. Таким образом, сработает behavior, созданный предыдущим запросом - if message_.get(LINK_BEHAVIOR_FLAG) and app_callback_id: + if message.get(LINK_BEHAVIOR_FLAG) and app_callback_id: headers = [(self.__from_msg_cls.CALLBACK_ID_HEADER_NAME, app_callback_id.encode())] else: headers = [('kafka_correlationId', 'test_123')] @@ -175,7 +176,7 @@ def run(self) -> bool: self.post_setup_user(user) - commands = self.app_model.answer(message, user) or [] + commands = await self.app_model.answer(message, user) or [] answers = self._generate_answers( user=user, commands=commands, message=message diff --git a/smart_kit/utils/monitoring.py b/smart_kit/utils/monitoring.py index 2b981fd1..4099ff25 100644 --- a/smart_kit/utils/monitoring.py +++ b/smart_kit/utils/monitoring.py @@ -1,6 +1,6 @@ from core.logging.logger_constants import KEY_NAME from core.logging.logger_utils import log -from core.monitoring.monitoring import monitoring +from core.monitoring import monitoring def _filter_monitoring_msg(msg): @@ -39,7 +39,7 @@ def init_metrics(self, app_name): "Incoming message validation error.") def _get_or_create_counter(self, monitoring_msg, descr, labels=()): - counter = monitoring.get_counter(monitoring_msg, descr, labels) + counter = monitoring.monitoring.get_counter(monitoring_msg, descr, labels) if counter is None: raise MetricDisabled('counter disabled') return counter @@ -182,22 +182,22 @@ def counter_mq_long_waiting(self, app_name): @silence_it def sampling_load_time(self, app_name, value): monitoring_msg = "{}_load_time".format(app_name) - monitoring.got_histogram_observe(_filter_monitoring_msg(monitoring_msg), value) + monitoring.monitoring.got_histogram_observe(_filter_monitoring_msg(monitoring_msg), value) @silence_it def sampling_script_time(self, app_name, value): monitoring_msg = "{}_script_time".format(app_name) - monitoring.got_histogram_observe(_filter_monitoring_msg(monitoring_msg), value) + monitoring.monitoring.got_histogram_observe(_filter_monitoring_msg(monitoring_msg), value) @silence_it def sampling_save_time(self, app_name, value): monitoring_msg = "{}_save_time".format(app_name) - monitoring.got_histogram_observe(_filter_monitoring_msg(monitoring_msg), value) + monitoring.monitoring.got_histogram_observe(_filter_monitoring_msg(monitoring_msg), value) @silence_it def sampling_mq_waiting_time(self, app_name, value): monitoring_msg = "{}_mq_waiting_time".format(app_name) - monitoring.got_histogram_observe(_filter_monitoring_msg(monitoring_msg), value) + monitoring.monitoring.got_histogram_observe(_filter_monitoring_msg(monitoring_msg), value) smart_kit_metrics = Metrics() diff --git a/smart_kit/utils/picklable_mock.py b/smart_kit/utils/picklable_mock.py index 000be1ee..539c3e93 100644 --- a/smart_kit/utils/picklable_mock.py +++ b/smart_kit/utils/picklable_mock.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, MagicMock +from unittest.mock import Mock, MagicMock, AsyncMock class PicklableMock(Mock): @@ -6,6 +6,11 @@ def __reduce__(self): return Mock, () +class AsyncPicklableMock(AsyncMock): + def __reduce__(self): + return AsyncMock, () + + class PicklableMagicMock(MagicMock): def __reduce__(self): return MagicMock, () diff --git a/tests/core_tests/basic_scenario_models_test/action_test/test_action.py b/tests/core_tests/basic_scenario_models_test/action_test/test_action.py index 31f811f7..0decb946 100644 --- a/tests/core_tests/basic_scenario_models_test/action_test/test_action.py +++ b/tests/core_tests/basic_scenario_models_test/action_test/test_action.py @@ -44,7 +44,7 @@ def __init__(self, items=None): items = items or {} self.result = items.get("result") - def run(self, user, text_preprocessing_result, params=None): + async def run(self, user, text_preprocessing_result, params=None): return self.result or ["test action run"] @@ -54,7 +54,7 @@ def __init__(self, items=None): self.result = items.get("result") self.done = False - def run(self, user, text_preprocessing_result, params=None): + async def run(self, user, text_preprocessing_result, params=None): self.done = True @@ -62,7 +62,7 @@ class MockRequirement: def __init__(self, items): self.result = items.get("result") - def check(self, text_preprocessing_result, user, params): + async def check(self, text_preprocessing_result, user, params): return self.result @@ -75,7 +75,7 @@ def collect(self, text_preprocessing_result, filter_params=None): return self.data -class ActionTest(unittest.TestCase): +class ActionTest(unittest.IsolatedAsyncioTestCase): def test_nodes_1(self): items = {"nodes": {"answer": "test"}} action = NodeAction(items) @@ -90,34 +90,34 @@ def test_nodes_2(self): nodes = action.nodes self.assertEqual(nodes, {}) - def test_base(self): + async def test_base(self): items = {"nodes": "test"} action = Action(items) try: - action.run(None, None) + await action.run(None, None) result = False except NotImplementedError: result = True self.assertEqual(result, True) - def test_external(self): + async def test_external(self): items = {"action": "test_action_key"} action = ExternalAction(items) user = PicklableMock() user.descriptions = {"external_actions": {"test_action_key": MockAction()}} - self.assertEqual(action.run(user, None), ["test action run"]) + self.assertEqual(await action.run(user, None), ["test action run"]) - def test_doing_nothing_action(self): + async def test_doing_nothing_action(self): items = {"nodes": {"answer": "test"}, "command": "test_name"} action = DoingNothingAction(items) - result = action.run(None, None) + result = await action.run(None, None) self.assertIsInstance(result, list) command = result[0] self.assertIsInstance(command, Command) self.assertEqual(command.name, "test_name") self.assertEqual(command.payload, {"answer": "test"}) - def test_requirement_action(self): + async def test_requirement_action(self): registered_factories[Requirement] = requirement_factory requirements["test"] = MockRequirement registered_factories[Action] = action_factory @@ -126,13 +126,13 @@ def test_requirement_action(self): action = RequirementAction(items) self.assertIsInstance(action.requirement, MockRequirement) self.assertIsInstance(action.internal_item, MockAction) - self.assertEqual(action.run(None, None), ["test action run"]) + self.assertEqual(await action.run(None, None), ["test action run"]) items = {"requirement": {"type": "test", "result": False}, "action": {"type": "test"}} action = RequirementAction(items) - result = action.run(None, None) + result = await action.run(None, None) self.assertIsNone(result) - def test_requirement_choice(self): + async def test_requirement_choice(self): items = {"requirement_actions": [ {"requirement": {"type": "test", "result": False}, "action": {"type": "test", "result": "action1"}}, {"requirement": {"type": "test", "result": True}, "action": {"type": "test", "result": "action2"}} @@ -140,10 +140,10 @@ def test_requirement_choice(self): choice_action = ChoiceAction(items) self.assertIsInstance(choice_action.items, list) self.assertIsInstance(choice_action.items[0], RequirementAction) - result = choice_action.run(None, None) + result = await choice_action.run(None, None) self.assertEqual(result, "action2") - def test_requirement_choice_else(self): + async def test_requirement_choice_else(self): items = { "requirement_actions": [ {"requirement": {"type": "test", "result": False}, "action": {"type": "test", "result": "action1"}}, @@ -154,10 +154,10 @@ def test_requirement_choice_else(self): choice_action = ChoiceAction(items) self.assertIsInstance(choice_action.items, list) self.assertIsInstance(choice_action.items[0], RequirementAction) - result = choice_action.run(None, None) + result = await choice_action.run(None, None) self.assertEqual(result, "action3") - def test_string_action(self): + async def test_string_action(self): expected = [Command("cmd_id", {"item": "template", "params": "params"})] user = PicklableMagicMock() template = PicklableMock() @@ -169,11 +169,11 @@ def test_string_action(self): "nodes": {"item": "template", "params": "{{params}}"}} action = StringAction(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(expected[0].name, result[0].name) self.assertEqual(expected[0].payload, result[0].payload) - def test_else_action_if(self): + async def test_else_action_if(self): registered_factories[Requirement] = requirement_factory requirements["test"] = MockRequirement registered_factories[Action] = action_factory @@ -185,9 +185,9 @@ def test_else_action_if(self): "else_action": {"type": "test", "result": "else_action"} } action = ElseAction(items) - self.assertEqual(action.run(user, None), "main_action") + self.assertEqual(await action.run(user, None), "main_action") - def test_else_action_else(self): + async def test_else_action_else(self): registered_factories[Requirement] = requirement_factory requirements["test"] = MockRequirement registered_factories[Action] = action_factory @@ -199,9 +199,9 @@ def test_else_action_else(self): "else_action": {"type": "test", "result": "else_action"} } action = ElseAction(items) - self.assertEqual(action.run(user, None), "else_action") + self.assertEqual(await action.run(user, None), "else_action") - def test_else_action_no_else_if(self): + async def test_else_action_no_else_if(self): registered_factories[Requirement] = requirement_factory requirements["test"] = MockRequirement registered_factories[Action] = action_factory @@ -212,9 +212,9 @@ def test_else_action_no_else_if(self): "action": {"type": "test", "result": "main_action"}, } action = ElseAction(items) - self.assertEqual(action.run(user, None), "main_action") + self.assertEqual(await action.run(user, None), "main_action") - def test_else_action_no_else_else(self): + async def test_else_action_no_else_else(self): registered_factories[Requirement] = requirement_factory requirements["test"] = MockRequirement registered_factories[Action] = action_factory @@ -225,10 +225,10 @@ def test_else_action_no_else_else(self): "action": {"type": "test", "result": "main_action"}, } action = ElseAction(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertIsNone(result) - def test_composite_action(self): + async def test_composite_action(self): registered_factories[Action] = action_factory actions["action_mock"] = MockAction user = PicklableMock() @@ -239,10 +239,10 @@ def test_composite_action(self): ] } action = CompositeAction(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(['test action run', 'test action run'], result) - def test_node_action_support_templates(self): + async def test_node_action_support_templates(self): params = { "markup": "italic", "email": "heyho@sberbank.ru", @@ -266,10 +266,10 @@ def test_node_action_support_templates(self): self.assertIsInstance(template, UnifiedTemplate) user = PicklableMagicMock() user.parametrizer = MockSimpleParametrizer(user, {"data": params}) - output = action.run(user=user, text_preprocessing_result=None)[0].payload["answer"] + output = (await action.run(user=user, text_preprocessing_result=None))[0].payload["answer"] self.assertEqual(output, expected) - def test_string_action_support_templates(self): + async def test_string_action_support_templates(self): params = { "answer_text": "some_text", "buttons_number": 3 @@ -291,10 +291,10 @@ def test_string_action_support_templates(self): action = StringAction(items) user = PicklableMagicMock() user.parametrizer = MockSimpleParametrizer(user, {"data": params}) - output = action.run(user=user, text_preprocessing_result=None)[0].payload + output = (await action.run(user=user, text_preprocessing_result=None))[0].payload self.assertEqual(output, expected) - def test_push_action(self): + async def test_push_action(self): params = { "day_time": "morning", "deep_link_url": "some_url", @@ -329,7 +329,7 @@ def test_push_action(self): user = PicklableMagicMock() user.parametrizer = MockSimpleParametrizer(user, {"data": params}) user.settings = settings - command = action.run(user=user, text_preprocessing_result=None)[0] + command = (await action.run(user=user, text_preprocessing_result=None))[0] self.assertEqual(command.payload, expected) # проверяем наличие кастомных хэдеров для сервиса пушей self.assertTrue(SmartKitKafkaRequest.KAFKA_EXTRA_HEADERS in command.request_data) @@ -341,7 +341,7 @@ def test_push_action(self): self.assertEqual(command.name, "PUSH_NOTIFY") -class NonRepeatingActionTest(unittest.TestCase): +class NonRepeatingActionTest(unittest.IsolatedAsyncioTestCase): def setUp(self): self.expected = PicklableMock() self.expected1 = PicklableMock() @@ -355,55 +355,55 @@ def setUp(self): registered_factories[Action] = action_factory actions["action_mock"] = MockAction - def test_run_available_indexes(self): + async def test_run_available_indexes(self): self.user.last_action_ids["last_action_ids_storage"].get_list.side_effect = [[0]] - result = self.action.run(self.user, None) + result = await self.action.run(self.user, None) self.user.last_action_ids["last_action_ids_storage"].add.assert_called_once() self.assertEqual(result, self.expected1) - def test_run_no_available_indexes(self): + async def test_run_no_available_indexes(self): self.user.last_action_ids["last_action_ids_storage"].get_list.side_effect = [[0, 1]] - result = self.action.run(self.user, None) + result = await self.action.run(self.user, None) self.assertEqual(result, self.expected) -class CounterIncrementActionTest(unittest.TestCase): - def test_run(self): +class CounterIncrementActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): user = PicklableMock() counter = PicklableMock() counter.inc = PicklableMock() user.counters = {"test": counter} items = {"key": "test"} action = CounterIncrementAction(items) - action.run(user, None) + await action.run(user, None) user.counters["test"].inc.assert_called_once() -class CounterDecrementActionTest(unittest.TestCase): - def test_run(self): +class CounterDecrementActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): user = PicklableMock() counter = PicklableMock() counter.dec = PicklableMock() user.counters = {"test": counter} items = {"key": "test"} action = CounterDecrementAction(items) - action.run(user, None) + await action.run(user, None) user.counters["test"].dec.assert_called_once() -class CounterClearActionTest(unittest.TestCase): - def test_run(self): +class CounterClearActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): user = PicklableMock() user.counters = PicklableMock() user.counters.inc = PicklableMock() items = {"key": "test"} action = CounterClearAction(items) - action.run(user, None) + await action.run(user, None) user.counters.clear.assert_called_once() -class CounterSetActionTest(unittest.TestCase): - def test_run(self): +class CounterSetActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): user = PicklableMock() counter = PicklableMock() counter.inc = PicklableMock() @@ -411,12 +411,12 @@ def test_run(self): user.counters = counters items = {"key": "test"} action = CounterSetAction(items) - action.run(user, None) + await action.run(user, None) user.counters["test"].set.assert_called_once() -class CounterCopyActionTest(unittest.TestCase): - def test_run(self): +class CounterCopyActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): user = PicklableMock() counter_src = PicklableMock() counter_src.value = 10 @@ -424,13 +424,13 @@ def test_run(self): user.counters = {"src": counter_src, "dst": counter_dst} items = {"source": "src", "destination": "dst"} action = CounterCopyAction(items) - action.run(user, None) + await action.run(user, None) user.counters["dst"].set.assert_called_once_with(user.counters["src"].value, action.reset_time, action.time_shift) -class AfinaAnswerActionTest(unittest.TestCase): - def test_typical_answer(self): +class AfinaAnswerActionTest(unittest.IsolatedAsyncioTestCase): + async def test_typical_answer(self): user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) expected = [MagicMock(_name="ANSWER_TO_USER", raw={'messageName': 'ANSWER_TO_USER', @@ -441,12 +441,11 @@ def test_typical_answer(self): } } action = AfinaAnswerAction(items) - - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(expected[0]._name, result[0].name) self.assertEqual(expected[0].raw, result[0].raw) - def test_typical_answer_with_other(self): + async def test_typical_answer_with_other(self): user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) expected = [MagicMock(_name="ANSWER_TO_USER", raw={'messageName': 'ANSWER_TO_USER', @@ -461,12 +460,11 @@ def test_typical_answer_with_other(self): } } action = AfinaAnswerAction(items) - - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(expected[0]._name, result[0].name) self.assertEqual(expected[0].raw, result[0].raw) - def test_typical_answer_with_pers_info(self): + async def test_typical_answer_with_pers_info(self): expected = [MagicMock(_name="ANSWER_TO_USER", raw={'messageName': 'ANSWER_TO_USER', 'payload': {'answer': 'Ivan Ivanov'}})] user = PicklableMock() @@ -475,11 +473,11 @@ def test_typical_answer_with_pers_info(self): user.message.payload = {"personInfo": {"name": "Ivan Ivanov"}} items = {"nodes": {"answer": ["{{payload.personInfo.name}}"]}} action = AfinaAnswerAction(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(expected[0]._name, result[0].name) self.assertEqual(expected[0].raw, result[0].raw) - def test_items_empty(self): + async def test_items_empty(self): user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) template = PicklableMock() @@ -487,10 +485,10 @@ def test_items_empty(self): user.descriptions = {"render_templates": template} items = None action = AfinaAnswerAction(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(result, []) - def test__items_empty_dict(self): + async def test__items_empty_dict(self): user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) template = PicklableMock() @@ -498,12 +496,12 @@ def test__items_empty_dict(self): user.descriptions = {"render_templates": template} items = {} action = AfinaAnswerAction(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(result, []) -class CardAnswerActionTest(unittest.TestCase): - def test_typical_answer(self): +class CardAnswerActionTest(unittest.IsolatedAsyncioTestCase): + async def test_typical_answer(self): user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) user.message = PicklableMock() @@ -554,12 +552,11 @@ def test_typical_answer(self): expect_arr = [exp1, exp2, exp3, exp4] for i in range(10): action = SDKAnswer(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual("ANSWER_TO_USER", result[0].name) self.assertTrue(str(result[0].raw) in expect_arr) - - def test_typical_answer_without_items(self): + async def test_typical_answer_without_items(self): user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) user.message = PicklableMock() @@ -577,11 +574,11 @@ def test_typical_answer_without_items(self): exp_list = [exp1, exp2, exp3, exp4] for i in range(10): action = SDKAnswer(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual("ANSWER_TO_USER", result[0].name) self.assertTrue(str(result[0].raw) in exp_list) - def test_typical_answer_without_nodes(self): + async def test_typical_answer_without_nodes(self): user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) user.message = PicklableMock() @@ -615,13 +612,13 @@ def test_typical_answer_without_nodes(self): expect_arr = [exp1, exp2, exp3, exp4] for i in range(10): action = SDKAnswer(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual("ANSWER_TO_USER", result[0].name) self.assertTrue(str(result[0].raw) in expect_arr) -class SDKRandomAnswer(unittest.TestCase): - def test_SDKItemAnswer_full(self): +class SDKRandomAnswer(unittest.IsolatedAsyncioTestCase): + async def test_SDKItemAnswer_full(self): registered_factories[SdkAnswerItem] = items_factory answer_items["bubble_text"] = BubbleText @@ -707,10 +704,10 @@ def test_SDKItemAnswer_full(self): action = SDKAnswerToUser(items) for i in range(3): - result = action.run(user, None) + result = await action.run(user, None) self.assertTrue(str(result[0].raw) in [exp1, exp2]) - def test_SDKItemAnswer_root(self): + async def test_SDKItemAnswer_root(self): registered_factories[SdkAnswerItem] = items_factory answer_items["bubble_text"] = BubbleText @@ -756,10 +753,10 @@ def test_SDKItemAnswer_root(self): action = SDKAnswerToUser(items) for i in range(3): - result = action.run(user, None) + result = await action.run(user, None) self.assertTrue(str(result[0].raw) in [exp1, exp2]) - def test_SDKItemAnswer_simple(self): + async def test_SDKItemAnswer_simple(self): registered_factories[SdkAnswerItem] = items_factory answer_items["bubble_text"] = BubbleText @@ -777,10 +774,10 @@ def test_SDKItemAnswer_simple(self): ] } action = SDKAnswerToUser(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertDictEqual(result[0].raw, {'messageName': 'ANSWER_TO_USER', 'payload': {'items': [{'bubble': {'text': '42', 'markdown': True}}]}}) - def test_SDKItemAnswer_suggestions_template(self): + async def test_SDKItemAnswer_suggestions_template(self): registered_factories[SdkAnswerItem] = items_factory answer_items["bubble_text"] = BubbleText @@ -799,7 +796,7 @@ def test_SDKItemAnswer_suggestions_template(self): } } action = SDKAnswerToUser(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertDictEqual( result[0].raw, { @@ -811,11 +808,12 @@ def test_SDKItemAnswer_suggestions_template(self): ] } } - }) + } + ) -class GiveMeMemoryActionTest(unittest.TestCase): - def test_run(self): +class GiveMeMemoryActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): expected = [ Command("GIVE_ME_MEMORY", { @@ -873,13 +871,13 @@ def test_run(self): } } action = GiveMeMemoryAction(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(expected[0].name, result[0].name) self.assertEqual(expected[0].payload, result[0].payload) -class RememberThisActionTest(unittest.TestCase): - def test_run(self): +class RememberThisActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): expected = [ Command("REMEMBER_THIS", { @@ -1001,6 +999,6 @@ def test_run(self): } } action = RememberThisAction(items) - result = action.run(user, None) + result = await action.run(user, None) self.assertEqual(expected[0].name, result[0].name) self.assertEqual(expected[0].payload, result[0].payload) diff --git a/tests/core_tests/basic_scenario_models_test/action_test/test_random_action.py b/tests/core_tests/basic_scenario_models_test/action_test/test_random_action.py index 22d03a5e..2449c9a0 100644 --- a/tests/core_tests/basic_scenario_models_test/action_test/test_random_action.py +++ b/tests/core_tests/basic_scenario_models_test/action_test/test_random_action.py @@ -1,17 +1,16 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from core.basic_models.actions.basic_actions import Action, action_factory, actions, DoingNothingAction, RandomAction from core.model.registered import registered_factories -class TestRandomAction(TestCase): +class TestRandomAction(IsolatedAsyncioTestCase): - @classmethod - def setUpClass(cls) -> None: + def setUp(self) -> None: registered_factories[Action] = action_factory actions["do_nothing"] = DoingNothingAction - def test_1(self): + async def test_1(self): items = { "actions": [ @@ -32,10 +31,10 @@ def test_1(self): ] } action = RandomAction(items, 5) - result = action.run(None, None) + result = await action.run(None, None) self.assertIsNotNone(result) - def test_2(self): + async def test_2(self): items = { "actions": [ { @@ -48,5 +47,5 @@ def test_2(self): ] } action = RandomAction(items, 5) - result = action.run(None, None) + result = await action.run(None, None) self.assertIsNotNone(result) diff --git a/tests/core_tests/basic_scenario_models_test/test_parametrizer.py b/tests/core_tests/basic_scenario_models_test/test_parametrizer.py index 23337ae9..56ad1ada 100644 --- a/tests/core_tests/basic_scenario_models_test/test_parametrizer.py +++ b/tests/core_tests/basic_scenario_models_test/test_parametrizer.py @@ -6,7 +6,6 @@ class ParametrizerTest(unittest.TestCase): - @classmethod def setUpClass(cls): cls.user = Mock(message=PicklableMock()) diff --git a/tests/core_tests/requirements_test/test_requirements.py b/tests/core_tests/requirements_test/test_requirements.py index 9462db1e..7677f4e5 100644 --- a/tests/core_tests/requirements_test/test_requirements.py +++ b/tests/core_tests/requirements_test/test_requirements.py @@ -1,3 +1,4 @@ +import asyncio import os import unittest from time import time @@ -20,6 +21,10 @@ from smart_kit.utils.picklable_mock import PicklableMock +def _run(coro): + return asyncio.get_event_loop().run_until_complete(coro) + + def patch_get_app_config(mock_get_app_config): result = PicklableMock() sk_path = os.path.dirname(smart_kit.__file__) @@ -35,7 +40,7 @@ def __init__(self, items=None): items = items or {} self.cond = items.get("cond") or False - def check(self, text_preprocessing_result, user, params): + async def check(self, text_preprocessing_result, user, params): return self.cond @@ -78,113 +83,112 @@ def compare(self, value): return value == self.amount -class RequirementTest(unittest.TestCase): - def test_base(self): +class RequirementTest(unittest.IsolatedAsyncioTestCase): + async def test_base(self): requirement = Requirement(None) - assert requirement.check(None, None) + assert await requirement.check(None, None) - def test_composite(self): + async def test_composite(self): registered_factories[Requirement] = MockRequirement requirement = CompositeRequirement({"requirements": [ {"cond": True}, {"cond": True} ]}) - self.assertEqual(len(requirement.requirements), 2) - self.assertTrue(requirement.check(None, None)) + self.assertTrue(await requirement.check(None, None)) - def test_and_success(self): + async def test_and_success(self): registered_factories[Requirement] = MockRequirement requirement = AndRequirement({"requirements": [ {"cond": True}, {"cond": True} ]}) - self.assertTrue(requirement.check(None, None)) + self.assertTrue(await requirement.check(None, None)) - def test_and_fail(self): + async def test_and_fail(self): registered_factories[Requirement] = MockRequirement requirement = AndRequirement({"requirements": [ {"cond": True}, {"cond": False} ]}) - self.assertFalse(requirement.check(None, None)) + self.assertFalse(await requirement.check(None, None)) - def test_or_success(self): + async def test_or_success(self): registered_factories[Requirement] = MockRequirement requirement = OrRequirement({"requirements": [ {"cond": True}, {"cond": False} ]}) - self.assertTrue(requirement.check(None, None)) + self.assertTrue(await requirement.check(None, None)) - def test_or_fail(self): + async def test_or_fail(self): registered_factories[Requirement] = MockRequirement requirement = OrRequirement({"requirements": [ {"cond": False}, {"cond": False} ]}) - self.assertFalse(requirement.check(None, None)) + self.assertFalse(await requirement.check(None, None)) - def test_not_success(self): + async def test_not_success(self): registered_factories[Requirement] = MockRequirement requirement = NotRequirement({"requirement": {"cond": False}}) - self.assertTrue(requirement.check(None, None)) + self.assertTrue(await requirement.check(None, None)) - def test_not_fail(self): + async def test_not_fail(self): registered_factories[Requirement] = MockRequirement requirement = NotRequirement({"requirement": {"cond": True}}) - self.assertFalse(requirement.check(None, None)) + self.assertFalse(await requirement.check(None, None)) - def test_channel_success(self): + async def test_channel_success(self): user = PicklableMock() message = Mock(channel="ch1") user.message = message requirement = ChannelRequirement({"channels": ["ch1"]}) text_normalization_result = None - self.assertTrue(requirement.check(text_normalization_result, user)) + self.assertTrue(await requirement.check(text_normalization_result, user)) - def test_channel_fail(self): + async def test_channel_fail(self): user = PicklableMock() message = Mock(channel="ch2") user.message = message requirement = ChannelRequirement({"channels": ["ch1"]}) text_normalization_result = None - self.assertFalse(requirement.check(text_normalization_result, user)) + self.assertFalse(await requirement.check(text_normalization_result, user)) - def test_random_requirement_true(self): + async def test_random_requirement_true(self): requirement = RandomRequirement({"percent": 100}) - self.assertTrue(requirement.check(None, None)) + self.assertTrue(await requirement.check(None, None)) - def test_random_requirement_false(self): + async def test_random_requirement_false(self): requirement = RandomRequirement({"percent": 0}) - self.assertFalse(requirement.check(None, None)) + self.assertFalse(await requirement.check(None, None)) - def test_topic_requirement(self): + async def test_topic_requirement(self): requirement = TopicRequirement({"topics": ["test"]}) user = PicklableMock() message = PicklableMock() message.topic_key = "test" user.message = message - self.assertTrue(requirement.check(None, user)) + self.assertTrue(await requirement.check(None, user)) - def test_counter_value_requirement(self): + async def test_counter_value_requirement(self): registered_factories[Operator] = MockAmountOperator user = PicklableMock() counter = PicklableMock() counter.__gt__ = Mock(return_value=True) user.counters = {"test": counter} requirement = CounterValueRequirement({"operator": {"type": "equal", "amount": 2}, "key": "test"}) - self.assertTrue(requirement.check(None, user)) + self.assertTrue(await requirement.check(None, user)) - def test_counter_time_requirement(self): + async def test_counter_time_requirement(self): registered_factories[Operator] = MockAmountOperator user = PicklableMock() counter = PicklableMock() counter.update_time = int(time()) - 10 user.counters = {"test": counter} requirement = CounterUpdateTimeRequirement({"operator": {"type": "more_or_equal", "amount": 5}, "key": "test"}) - self.assertTrue(requirement.check(None, user)) + self.assertTrue(await requirement.check(None, user)) - def test_template_req_true(self): + async def test_template_req_true(self): items = { "template": "{{ payload.message.strip() in payload.murexIds }}" } @@ -197,9 +201,9 @@ def test_template_req_true(self): user = PicklableMock() user.parametrizer = PicklableMock() user.parametrizer.collect = Mock(return_value=params) - self.assertTrue(requirement.check(None, user)) + self.assertTrue(await requirement.check(None, user)) - def test_template_req_false(self): + async def test_template_req_false(self): items = { "template": "{{ payload.groupCode == 'BROKER' }}" } @@ -208,9 +212,9 @@ def test_template_req_false(self): user = PicklableMock() user.parametrizer = PicklableMock() user.parametrizer.collect = Mock(return_value=params) - self.assertFalse(requirement.check(None, user)) + self.assertFalse(await requirement.check(None, user)) - def test_template_req_raise(self): + async def test_template_req_raise(self): items = { "template": "{{ payload.groupCode }}" } @@ -219,23 +223,23 @@ def test_template_req_raise(self): user = PicklableMock() user.parametrizer = PicklableMock() user.parametrizer.collect = Mock(return_value=params) - self.assertRaises(TypeError, requirement.check, None, user) + self.assertRaises(TypeError, _run, requirement.check, None, user) - def test_rolling_requirement_true(self): + async def test_rolling_requirement_true(self): user = PicklableMock() user.id = "353454" requirement = RollingRequirement({"percent": 100}) text_normalization_result = None - self.assertTrue(requirement.check(text_normalization_result, user)) + self.assertTrue(await requirement.check(text_normalization_result, user)) - def test_rolling_requirement_false(self): + async def test_rolling_requirement_false(self): user = PicklableMock() user.id = "353454" requirement = RollingRequirement({"percent": 0}) text_normalization_result = None - self.assertFalse(requirement.check(text_normalization_result, user)) + self.assertFalse(await requirement.check(text_normalization_result, user)) - def test_time_requirement_true(self): + async def test_time_requirement_true(self): user = PicklableMock() user.id = "353454" user.message.payload = { @@ -255,9 +259,9 @@ def test_time_requirement_true(self): } ) text_normalization_result = None - self.assertTrue(requirement.check(text_normalization_result, user)) + self.assertTrue(await requirement.check(text_normalization_result, user)) - def test_time_requirement_false(self): + async def test_time_requirement_false(self): user = PicklableMock() user.id = "353454" user.message.payload = { @@ -277,9 +281,9 @@ def test_time_requirement_false(self): } ) text_normalization_result = None - self.assertFalse(requirement.check(text_normalization_result, user)) + self.assertFalse(await requirement.check(text_normalization_result, user)) - def test_datetime_requirement_true(self): + async def test_datetime_requirement_true(self): user = PicklableMock() user.id = "353454" user.message.payload = { @@ -296,9 +300,9 @@ def test_datetime_requirement_true(self): } ) text_normalization_result = None - self.assertTrue(requirement.check(text_normalization_result, user)) + self.assertTrue(await requirement.check(text_normalization_result, user)) - def test_datetime_requirement_false(self): + async def test_datetime_requirement_false(self): user = PicklableMock() user.id = "353454" user.message.payload = { @@ -315,10 +319,10 @@ def test_datetime_requirement_false(self): } ) text_normalization_result = None - self.assertFalse(requirement.check(text_normalization_result, user)) + self.assertFalse(await requirement.check(text_normalization_result, user)) @patch('smart_kit.configs.get_app_config') - def test_intersection_requirement_true(self, mock_get_app_config): + async def test_intersection_requirement_true(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) user = PicklableMock() requirement = IntersectionRequirement( @@ -335,10 +339,10 @@ def test_intersection_requirement_true(self, mock_get_app_config): {'lemma': 'я'}, {'lemma': 'хотеть'}, ] - self.assertTrue(requirement.check(text_normalization_result, user)) + self.assertTrue(await requirement.check(text_normalization_result, user)) @patch('smart_kit.configs.get_app_config') - def test_intersection_requirement_false(self, mock_get_app_config): + async def test_intersection_requirement_false(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) user = PicklableMock() requirement = IntersectionRequirement( @@ -356,10 +360,10 @@ def test_intersection_requirement_false(self, mock_get_app_config): {'lemma': 'за'}, {'lemma': 'что'}, ] - self.assertFalse(requirement.check(text_normalization_result, user)) + self.assertFalse(await requirement.check(text_normalization_result, user)) @patch.object(ExternalClassifier, "find_best_answer", return_value=[{"answer": "нет", "score": 1.0, "other": False}]) - def test_classifier_requirement_true(self, mock_classifier_model): + async def test_classifier_requirement_true(self, mock_classifier_model): """Тест кейз проверяет что условие возвращает True, если результат классификации запроса относится к одной из указанных категорий, прошедших порог, но не равной классу other. """ @@ -367,30 +371,30 @@ def test_classifier_requirement_true(self, mock_classifier_model): classifier_requirement = ClassifierRequirement(test_items) mock_user = PicklableMock() mock_user.descriptions = {"external_classifiers": ["read_book_or_not_classifier", "hello_scenario_classifier"]} - result = classifier_requirement.check(PicklableMock(), mock_user) + result = await classifier_requirement.check(PicklableMock(), mock_user) self.assertTrue(result) @patch.object(ExternalClassifier, "find_best_answer", return_value=[]) - def test_classifier_requirement_false(self, mock_classifier_model): + async def test_classifier_requirement_false(self, mock_classifier_model): """Тест кейз проверяет что условие возвращает False, если модель классификации не вернула ответ.""" test_items = {"type": "classifier", "classifier": {"type": "external", "classifier": "hello_scenario_classifier"}} classifier_requirement = ClassifierRequirement(test_items) mock_user = PicklableMock() mock_user.descriptions = {"external_classifiers": ["read_book_or_not_classifier", "hello_scenario_classifier"]} - result = classifier_requirement.check(PicklableMock(), mock_user) + result = await classifier_requirement.check(PicklableMock(), mock_user) self.assertFalse(result) @patch.object(ExternalClassifier, "find_best_answer", return_value=[{"answer": "other", "score": 1.0, "other": True}]) - def test_classifier_requirement_false_if_class_other(self, mock_classifier_model): + async def test_classifier_requirement_false_if_class_other(self, mock_classifier_model): """Тест кейз проверяет что условие возвращает False, если наиболее вероятный вариант есть класс other.""" test_items = {"type": "classifier", "classifier": {"type": "external", "classifier": "hello_scenario_classifier"}} classifier_requirement = ClassifierRequirement(test_items) mock_user = PicklableMock() mock_user.descriptions = {"external_classifiers": ["read_book_or_not_classifier", "hello_scenario_classifier"]} - result = classifier_requirement.check(PicklableMock(), mock_user) + result = await classifier_requirement.check(PicklableMock(), mock_user) self.assertFalse(result) - def test_form_field_value_requirement_true(self): + async def test_form_field_value_requirement_true(self): """Тест кейз проверяет что условие возвращает True, т.к в форме form_name в поле form_field значение совпадает с переданным field_value. """ @@ -406,10 +410,10 @@ def test_form_field_value_requirement_true(self): user.forms[form_name].fields = {form_field: PicklableMock(), "value": field_value} user.forms[form_name].fields[form_field].value = field_value - result = req_form_field_value.check(PicklableMock(), user) + result = await req_form_field_value.check(PicklableMock(), user) self.assertTrue(result) - def test_form_field_value_requirement_false(self): + async def test_form_field_value_requirement_false(self): """Тест кейз проверяет что условие возвращает False, т.к в форме form_name в поле form_field значение НЕ совпадает с переданным field_value. """ @@ -425,73 +429,73 @@ def test_form_field_value_requirement_false(self): user.forms[form_name].fields = {form_field: PicklableMock(), "value": "OTHER_TEST_VAL"} user.forms[form_name].fields[form_field].value = "OTHER_TEST_VAL" - result = req_form_field_value.check(PicklableMock(), user) + result = await req_form_field_value.check(PicklableMock(), user) self.assertFalse(result) @patch("smart_kit.configs.get_app_config") - def test_environment_requirement_true(self, mock_get_app_config): + async def test_environment_requirement_true(self, mock_get_app_config): """Тест кейз проверяет что условие возвращает True, т.к среда исполнения из числа values.""" patch_get_app_config(mock_get_app_config) environment_req = EnvironmentRequirement({"values": ["ift", "uat"]}) - self.assertTrue(environment_req.check(PicklableMock(), PicklableMock())) + self.assertTrue(await environment_req.check(PicklableMock(), PicklableMock())) @patch("smart_kit.configs.get_app_config") - def test_environment_requirement_false(self, mock_get_app_config): + async def test_environment_requirement_false(self, mock_get_app_config): """Тест кейз проверяет что условие возвращает False, т.к среда исполнения НЕ из числа values.""" patch_get_app_config(mock_get_app_config) environment_req = EnvironmentRequirement({"values": ["uat", "pt"]}) - self.assertFalse(environment_req.check(PicklableMock(), PicklableMock())) + self.assertFalse(await environment_req.check(PicklableMock(), PicklableMock())) - def test_any_substring_in_lowered_text_requirement_true(self): + async def test_any_substring_in_lowered_text_requirement_true(self): """Тест кейз проверяет что условие возвращает True, т.к нашлась подстрока из списка substrings, которая встречается в оригинальном тексте в нижнем регистре. """ req = AnySubstringInLoweredTextRequirement({"substrings": ["искомая подстрока", "другое знанчение"]}) text_preprocessing_result = PicklableMock() text_preprocessing_result.raw = {"original_text": "КАКОЙ-ТО ТЕКСТ С ИСКОМАЯ ПОДСТРОКА"} - result = req.check(text_preprocessing_result, PicklableMock()) + result = await req.check(text_preprocessing_result, PicklableMock()) self.assertTrue(result) - def test_any_substring_in_lowered_text_requirement_false(self): + async def test_any_substring_in_lowered_text_requirement_false(self): """Тест кейз проверяет что условие возвращает False, т.к НЕ нашлась ни одна подстрока из списка substrings, которая бы встречалась в оригинальном тексте в нижнем регистре. """ req = AnySubstringInLoweredTextRequirement({"substrings": ["искомая подстрока", "другая подстрока"]}) text_preprocessing_result = PicklableMock() text_preprocessing_result.raw = {"original_text": "КАКОЙ-ТО ТЕКСТ"} - result = req.check(text_preprocessing_result, PicklableMock()) + result = await req.check(text_preprocessing_result, PicklableMock()) self.assertFalse(result) - def test_num_in_range_requirement_true(self): + async def test_num_in_range_requirement_true(self): """Тест кейз проверяет что условие возвращает True, т.к число находится в заданном диапазоне.""" req = NumInRangeRequirement({"min_num": "5", "max_num": "10"}) text_preprocessing_result = PicklableMock() text_preprocessing_result.num_token_values = 7 - self.assertTrue(req.check(text_preprocessing_result, PicklableMock())) + self.assertTrue(await req.check(text_preprocessing_result, PicklableMock())) - def test_num_in_range_requirement_false(self): + async def test_num_in_range_requirement_false(self): """Тест кейз проверяет что условие возвращает False, т.к число НЕ находится в заданном диапазоне.""" req = NumInRangeRequirement({"min_num": "5", "max_num": "10"}) text_preprocessing_result = PicklableMock() text_preprocessing_result.num_token_values = 20 - self.assertFalse(req.check(text_preprocessing_result, PicklableMock())) + self.assertFalse(await req.check(text_preprocessing_result, PicklableMock())) - def test_phone_number_number_requirement_true(self): + async def test_phone_number_number_requirement_true(self): """Тест кейз проверяет что условие возвращает True, т.к кол-во номеров телефонов больше заданного.""" req = PhoneNumberNumberRequirement({"operator": {"type": "more", "amount": 1}}) text_preprocessing_result = PicklableMock() text_preprocessing_result.get_token_values_by_type.return_value = ["89030478799", "89092534523"] - self.assertTrue(req.check(text_preprocessing_result, PicklableMock())) + self.assertTrue(await req.check(text_preprocessing_result, PicklableMock())) - def test_phone_number_number_requirement_false(self): + async def test_phone_number_number_requirement_false(self): """Тест кейз проверяет что условие возвращает False, т.к кол-во номеров телефонов НЕ больше заданного.""" req = PhoneNumberNumberRequirement({"operator": {"type": "more", "amount": 10}}) text_preprocessing_result = PicklableMock() text_preprocessing_result.get_token_values_by_type.return_value = ["89030478799"] - self.assertFalse(req.check(text_preprocessing_result, PicklableMock())) + self.assertFalse(await req.check(text_preprocessing_result, PicklableMock())) @patch("smart_kit.configs.get_app_config") - def test_intersection_with_tokens_requirement_true(self, mock_get_app_config): + async def test_intersection_with_tokens_requirement_true(self, mock_get_app_config): """Тест кейз проверяет что условие возвращает True, т.к хотя бы одно слово из нормализованного вида запроса входит в список слов input_words. """ @@ -510,10 +514,10 @@ def test_intersection_with_tokens_requirement_true(self, mock_get_app_config): "part_of_speech": "NOUN"}, "lemma": "погода"} ]} - self.assertTrue(req.check(text_preprocessing_result, PicklableMock())) + self.assertTrue(await req.check(text_preprocessing_result, PicklableMock())) @patch("smart_kit.configs.get_app_config") - def test_intersection_with_tokens_requirement_false(self, mock_get_app_config): + async def test_intersection_with_tokens_requirement_false(self, mock_get_app_config): """Тест кейз проверяет что условие возвращает False, т.к ни одно слово из нормализованного вида запроса не входит в список слов input_words. """ @@ -532,10 +536,10 @@ def test_intersection_with_tokens_requirement_false(self, mock_get_app_config): "part_of_speech": "NOUN"}, "lemma": "погода"} ]} - self.assertFalse(req.check(text_preprocessing_result, PicklableMock())) + self.assertFalse(await req.check(text_preprocessing_result, PicklableMock())) @patch("smart_kit.configs.get_app_config") - def test_normalized_text_in_set_requirement_true(self, mock_get_app_config): + async def test_normalized_text_in_set_requirement_true(self, mock_get_app_config): """Тест кейз проверяет что условие возвращает True, т.к в нормализованном представлении запрос полностью совпадает с одной из нормализованных строк из input_words. """ @@ -546,10 +550,10 @@ def test_normalized_text_in_set_requirement_true(self, mock_get_app_config): text_preprocessing_result = PicklableMock() text_preprocessing_result.raw = {"normalized_text": "погода ."} - self.assertTrue(req.check(text_preprocessing_result, PicklableMock())) + self.assertTrue(await req.check(text_preprocessing_result, PicklableMock())) @patch("smart_kit.configs.get_app_config") - def test_normalized_text_in_set_requirement_false(self, mock_get_app_config): + async def test_normalized_text_in_set_requirement_false(self, mock_get_app_config): """Тест кейз проверяет что условие возвращает False, т.к в нормализованном представлении запрос НЕ совпадает ни с одной из нормализованных строк из input_words. """ @@ -560,33 +564,33 @@ def test_normalized_text_in_set_requirement_false(self, mock_get_app_config): text_preprocessing_result = PicklableMock() text_preprocessing_result.raw = {"normalized_text": "хотеть узнать ."} - self.assertFalse(req.check(text_preprocessing_result, PicklableMock())) + self.assertFalse(await req.check(text_preprocessing_result, PicklableMock())) - def test_character_id_requirement_true(self): + async def test_character_id_requirement_true(self): req = CharacterIdRequirement({"values": ["sber", "afina"]}) user = Mock() user.message = Mock() user.message.payload = {"character": {"id": "sber", "name": "Сбер", "gender": "male"}} - self.assertTrue(req.check(Mock(), user)) + self.assertTrue(await req.check(Mock(), user)) - def test_character_id_requirement_false(self): + async def test_character_id_requirement_false(self): req = CharacterIdRequirement({"values": ["afina"]}) user = Mock() user.message = Mock() user.message.payload = {"character": {"id": "sber", "name": "Сбер", "gender": "male"}} - self.assertFalse(req.check(Mock(), user)) + self.assertFalse(await req.check(Mock(), user)) - def test_feature_toggle_check_requirement_true(self): + async def test_feature_toggle_check_requirement_true(self): req = FeatureToggleRequirement({"toggle_name": "test_true_toggle_name"}) mock_user = Mock() mock_user.settings = {"template_settings": {"test_true_toggle_name": True}} - self.assertTrue(req.check(Mock(), mock_user)) + self.assertTrue(await req.check(Mock(), mock_user)) - def test_feature_toggle_check_requirement_false(self): + async def test_feature_toggle_check_requirement_false(self): req = FeatureToggleRequirement({"toggle_name": "test_false_toggle_name"}) mock_user = Mock() mock_user.settings = {"template_settings": {"test_false_toggle_name": False}} - self.assertFalse(req.check(Mock(), mock_user)) + self.assertFalse(await req.check(Mock(), mock_user)) if __name__ == '__main__': diff --git a/tests/scenarios_tests/actions_test/test_action.py b/tests/scenarios_tests/actions_test/test_action.py index 474ebb22..3b403717 100644 --- a/tests/scenarios_tests/actions_test/test_action.py +++ b/tests/scenarios_tests/actions_test/test_action.py @@ -1,6 +1,6 @@ import unittest from typing import Dict, Any, Union, Optional -from unittest.mock import Mock, ANY +from unittest.mock import Mock, ANY, AsyncMock from core.basic_models.actions.basic_actions import Action, action_factory, actions from core.model.registered import registered_factories @@ -16,15 +16,14 @@ SetVariableAction, SelfServiceActionWithState, SaveBehaviorAction, - ResetCurrentNodeAction, RunScenarioAction, RunLastScenarioAction, - AddHistoryEventAction + AddHistoryEventAction, ResetCurrentNodeAction ) from scenarios.actions.action import ClearFormAction, ClearInnerFormAction, BreakScenarioAction, \ RemoveFormFieldAction, RemoveCompositeFormFieldAction from scenarios.scenario_models.history import Event -from smart_kit.utils.picklable_mock import PicklableMock, PicklableMagicMock +from smart_kit.utils.picklable_mock import PicklableMock, PicklableMagicMock, AsyncPicklableMock class MockAction: @@ -32,7 +31,7 @@ def __init__(self, items=None): self.items = items or {} self.result = items.get("result") - def run(self, user, text_preprocessing_result, params=None): + async def run(self, user, text_preprocessing_result, params=None): return self.result or ["test action run"] @@ -53,35 +52,35 @@ def collect(self, text_preprocessing_result=None, filter_params=None): return data -class ClearFormIdActionTest(unittest.TestCase): - def test_run(self): +class ClearFormIdActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): action = ClearFormAction({"form": "form"}) user = PicklableMagicMock() - action.run(user, None) + await action.run(user, None) user.forms.remove_item.assert_called_once_with("form") -class RemoveCompositeFormFieldActionTest(unittest.TestCase): - def test_run(self): +class ClearInnerFormActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): action = ClearInnerFormAction({"form": "form", "inner_form": "inner_form"}) user, form = PicklableMagicMock(), PicklableMagicMock() user.forms.__getitem__.return_value = form - action.run(user, None) + await action.run(user, None) form.forms.remove_item.assert_called_once_with("inner_form") -class BreakScenarioTest(unittest.TestCase): - def test_run_1(self): +class BreakScenarioTest(unittest.IsolatedAsyncioTestCase): + async def test_run_1(self): scenario_id = "test_id" action = BreakScenarioAction({"scenario_id": scenario_id}) user = PicklableMock() scenario_model = PicklableMagicMock() scenario_model.set_break = Mock(return_value=None) user.scenario_models = {scenario_id: scenario_model} - action.run(user, None) + await action.run(user, None) user.scenario_models[scenario_id].set_break.assert_called_once() - def test_run_2(self): + async def test_run_2(self): scenario_id = "test_id" action = BreakScenarioAction({}) user = PicklableMock() @@ -89,30 +88,30 @@ def test_run_2(self): scenario_model = PicklableMagicMock() scenario_model.set_break = Mock(return_value=None) user.scenario_models = {scenario_id: scenario_model} - action.run(user, None) + await action.run(user, None) user.scenario_models[scenario_id].set_break.assert_called_once() -class RemoveFormFieldActionTest(unittest.TestCase): - def test_run(self): +class RemoveFormFieldActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): action = RemoveFormFieldAction({"form": "form", "field": "field"}) user, form = PicklableMagicMock(), PicklableMagicMock() user.forms.__getitem__.return_value = form - action.run(user, None) + await action.run(user, None) form.fields.remove_item.assert_called_once_with("field") -class RemoveCompositeFormFieldActionTest(unittest.TestCase): - def test_run(self): +class RemoveCompositeFormFieldActionTest(unittest.IsolatedAsyncioTestCase): + async def test_run(self): action = RemoveCompositeFormFieldAction({"form": "form", "inner_form": "form", "field": "field"}) user, inner_form, form = PicklableMagicMock(), PicklableMagicMock(), PicklableMagicMock() form.forms.__getitem__.return_value = inner_form user.forms.__getitem__.return_value = form - action.run(user, None) + await action.run(user, None) inner_form.fields.remove_item.assert_called_once_with("field") -class SaveBehaviorActionTest(unittest.TestCase): +class SaveBehaviorActionTest(unittest.IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): user = PicklableMock() @@ -123,7 +122,7 @@ def setUpClass(cls): user.message.incremental_id = test_incremental_id cls.user = user - def test_save_behavior_scenario_name(self): + async def test_save_behavior_scenario_name(self): data = {"behavior": "test"} behavior = PicklableMock() behavior.add = PicklableMock() @@ -131,12 +130,12 @@ def test_save_behavior_scenario_name(self): action = SaveBehaviorAction(data) tpr = PicklableMock() tpr_raw = tpr.raw - action.run(self.user, tpr) + await action.run(self.user, tpr) self.user.behaviors.add.assert_called_once_with(self.user.message.generate_new_callback_id(), "test", self.user.last_scenarios.last_scenario_name, tpr_raw, action_params=None) - def test_save_behavior_without_scenario_name(self): + async def test_save_behavior_without_scenario_name(self): data = {"behavior": "test", "check_scenario": False} behavior = PicklableMock() behavior.add = PicklableMock() @@ -144,17 +143,17 @@ def test_save_behavior_without_scenario_name(self): action = SaveBehaviorAction(data) text_preprocessing_result_raw = PicklableMock() text_preprocessing_result = Mock(raw=text_preprocessing_result_raw) - action.run(self.user, text_preprocessing_result, None) + await action.run(self.user, text_preprocessing_result, None) self.user.behaviors.add.assert_called_once_with(self.user.message.generate_new_callback_id(), "test", None, text_preprocessing_result_raw, action_params=None) -class SelfServiceActionWithStateTest(unittest.TestCase): +class SelfServiceActionWithStateTest(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.user = PicklableMock() self.user.settings = {"template_settings": {"self_service_with_state_save_messages": True}} - def test_action_1(self): + async def test_action_1(self): data = {"behavior": "test", "check_scenario": False, "command_action": {"command": "cmd_id", "nodes": {}, "request_data": {}}} registered_factories[Action] = action_factory @@ -172,13 +171,13 @@ def test_action_1(self): action = SelfServiceActionWithState(data) text_preprocessing_result_raw = PicklableMock() text_preprocessing_result = Mock(raw=text_preprocessing_result_raw) - result = action.run(self.user, text_preprocessing_result, None) + result = await action.run(self.user, text_preprocessing_result, None) behavior.check_got_saved_id.assert_called_once() behavior.add.assert_called_once() self.assertEqual(result[0].name, "cmd_id") self.assertEqual(result[0].raw, {'messageName': 'cmd_id', 'payload': {}}) - def test_action_2(self): + async def test_action_2(self): data = {"behavior": "test", "check_scenario": False, "command_action": {"command": "cmd_id", "nodes": {}}} self.user.parametrizer = MockParametrizer(self.user, {}) self.user.message = PicklableMock() @@ -188,11 +187,11 @@ def test_action_2(self): self.user.behaviors = behavior behavior.check_got_saved_id = Mock(return_value=True) action = SelfServiceActionWithState(data) - result = action.run(self.user, None) + result = await action.run(self.user, None) behavior.add.assert_not_called() self.assertIsNone(result) - def test_action_3(self): + async def test_action_3(self): data = {"behavior": "test", "command_action": {"command": "cmd_id", "nodes": {}, "request_data": {}}} registered_factories[Action] = action_factory actions["action_mock"] = MockAction @@ -217,7 +216,7 @@ def test_action_3(self): action = SelfServiceActionWithState(data) text_preprocessing_result_raw = PicklableMock() text_preprocessing_result = Mock(raw=text_preprocessing_result_raw) - result = action.run(self.user, text_preprocessing_result, None) + result = await action.run(self.user, text_preprocessing_result, None) behavior.check_got_saved_id.assert_called_once() behavior.add.assert_called_once() self.assertEqual(result[0].name, "cmd_id") @@ -227,7 +226,7 @@ def test_action_3(self): ) -class SetVariableActionTest(unittest.TestCase): +class SetVariableActionTest(unittest.IsolatedAsyncioTestCase): def setUp(self): template = PicklableMock() @@ -242,25 +241,25 @@ def setUp(self): user.variables.set = PicklableMock() self.user = user - def test_action(self): + async def test_action(self): action = SetVariableAction({"key": "some_key", "value": "some_value"}) - action.run(self.user, None) + await action.run(self.user, None) self.user.variables.set.assert_called_with("some_key", "some_value", None) - def test_action_jinja_key_default(self): + async def test_action_jinja_key_default(self): self.user.message.payload = {"some_value": "some_value_test"} action = SetVariableAction({"key": "some_key", "value": "{{payload.some_value}}"}) - action.run(self.user, None) + await action.run(self.user, None) self.user.variables.set.assert_called_with("some_key", "some_value_test", None) - def test_action_jinja_no_key(self): + async def test_action_jinja_no_key(self): self.user.message.payload = {"some_value": "some_value_test"} action = SetVariableAction({"key": "some_key", "value": "{{payload.no_key}}"}) - action.run(self.user, None) + await action.run(self.user, None) self.user.variables.set.assert_called_with("some_key", "", None) -class DeleteVariableActionTest(unittest.TestCase): +class DeleteVariableActionTest(unittest.IsolatedAsyncioTestCase): def setUp(self): user = PicklableMock() @@ -272,13 +271,13 @@ def setUp(self): user.variables.delete = PicklableMock() self.user = user - def test_action(self): + async def test_action(self): action = DeleteVariableAction({"key": "some_key_1"}) - action.run(self.user, None) + await action.run(self.user, None) self.user.variables.delete.assert_called_with("some_key_1") -class ClearVariablesActionTest(unittest.TestCase): +class ClearVariablesActionTest(unittest.IsolatedAsyncioTestCase): def setUp(self): self.var_value = { @@ -294,15 +293,15 @@ def setUp(self): user.variables.clear = PicklableMock() self.user = user - def test_action(self): + async def test_action(self): action = ClearVariablesAction() - action.run(self.user, None) + await action.run(self.user, None) self.user.variables.clear.assert_called_with() -class FillFieldActionTest(unittest.TestCase): +class FillFieldActionTest(unittest.IsolatedAsyncioTestCase): - def test_fill_field(self): + async def test_fill_field(self): params = {"test_field": "test_data"} data = {"form": "test_form", "field": "test_field", "data_path": "{{test_field}}"} action = FillFieldAction(data) @@ -312,13 +311,13 @@ def test_fill_field(self): field = PicklableMock() field.fill = PicklableMock() user.forms["test_form"].fields = {"test_field": field} - action.run(user, None) + await action.run(user, None) field.fill.assert_called_once_with(params["test_field"]) -class CompositeFillFieldActionTest(unittest.TestCase): +class CompositeFillFieldActionTest(unittest.IsolatedAsyncioTestCase): - def test_fill_field(self): + async def test_fill_field(self): params = {"test_field": "test_data"} data = {"form": "test_form", "field": "test_field", "internal_form": "test_internal_form", "data_path": "{{test_field}}", "parametrizer": {"data": params}} @@ -332,37 +331,37 @@ def test_fill_field(self): field = PicklableMock() field.fill = PicklableMock() user.forms["test_form"].forms["test_internal_form"].fields = {"test_field": field} - action.run(user, None) + await action.run(user, None) field.fill.assert_called_once_with(params["test_field"]) -class ScenarioActionTest(unittest.TestCase): - def test_scenario_action(self): +class ScenarioActionTest(unittest.IsolatedAsyncioTestCase): + async def test_scenario_action(self): action = RunScenarioAction({"scenario": "test"}) user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) - scen = PicklableMock() + scen = AsyncPicklableMock() scen_result = 'done' scen.run.return_value = scen_result user.descriptions = {"scenarios": {"test": scen}} - result = action.run(user, PicklableMock()) + result = await action.run(user, PicklableMock()) self.assertEqual(result, scen_result) - def test_scenario_action_with_jinja_good(self): + async def test_scenario_action_with_jinja_good(self): params = {'next_scenario': 'ANNA.pipeline.scenario'} items = {"scenario": "{{next_scenario}}"} action = RunScenarioAction(items) user = PicklableMock() user.parametrizer = MockParametrizer(user, {"data": params}) - scen = PicklableMock() + scen = AsyncPicklableMock() scen_result = 'done' scen.run.return_value = scen_result user.descriptions = {"scenarios": {"ANNA.pipeline.scenario": scen}} - result = action.run(user, PicklableMock()) + result = await action.run(user, PicklableMock()) self.assertEqual(result, scen_result) - def test_scenario_action_no_scenario(self): + async def test_scenario_action_no_scenario(self): action = RunScenarioAction({"scenario": "{{next_scenario}}"}) user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) @@ -370,27 +369,27 @@ def test_scenario_action_no_scenario(self): scen_result = 'done' scen.run.return_value = scen_result user.descriptions = {"scenarios": {"next_scenario": scen}} - result = action.run(user, PicklableMock()) + result = await action.run(user, PicklableMock()) self.assertEqual(result, None) - def test_scenario_action_without_jinja(self): + async def test_scenario_action_without_jinja(self): action = RunScenarioAction({"scenario": "next_scenario"}) user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) - scen = PicklableMock() + scen = AsyncPicklableMock() scen_result = 'done' scen.run.return_value = scen_result user.descriptions = {"scenarios": {"next_scenario": scen}} - result = action.run(user, PicklableMock()) + result = await action.run(user, PicklableMock()) self.assertEqual(result, scen_result) -class RunLastScenarioActionTest(unittest.TestCase): +class RunLastScenarioActionTest(unittest.IsolatedAsyncioTestCase): - def test_scenario_action(self): + async def test_scenario_action(self): action = RunLastScenarioAction({}) user = PicklableMock() - scen = PicklableMock() + scen = AsyncPicklableMock() scen_result = 'done' scen.run.return_value = scen_result user.descriptions = {"scenarios": {"test": scen}} @@ -398,25 +397,25 @@ def test_scenario_action(self): last_scenario_name = "test" user.last_scenarios.scenarios_names = [last_scenario_name] user.last_scenarios.last_scenario_name = last_scenario_name - result = action.run(user, PicklableMock()) + result = await action.run(user, PicklableMock()) self.assertEqual(result, scen_result) -class ChoiceScenarioActionTest(unittest.TestCase): +class ChoiceScenarioActionTest(unittest.IsolatedAsyncioTestCase): @staticmethod - def mock_and_perform_action(test_items: Dict[str, Any], expected_result: Optional[str] = None, - expected_scen: Optional[str] = None) -> Union[str, None]: + async def mock_and_perform_action(test_items: Dict[str, Any], expected_result: Optional[str] = None, + expected_scen: Optional[str] = None) -> Union[str, None]: action = ChoiceScenarioAction(test_items) user = PicklableMock() user.parametrizer = MockParametrizer(user, {}) - scen = PicklableMock() + scen = AsyncPicklableMock() scen.run.return_value = expected_result if expected_scen: user.descriptions = {"scenarios": {expected_scen: scen}} - return action.run(user, PicklableMock()) + return await action.run(user, PicklableMock()) - def test_choice_scenario_action(self): + async def test_choice_scenario_action(self): # Проверяем, что запустили нужный сценарий, в случае если выполнился его requirement test_items = { "scenarios": [ @@ -436,11 +435,12 @@ def test_choice_scenario_action(self): "else_action": {"type": "test", "result": "ELSE ACTION IS DONE"} } expected_scen_result = "test_N_done" - real_scen_result = self.mock_and_perform_action( - test_items, expected_result=expected_scen_result, expected_scen="test_N") + real_scen_result = await self.mock_and_perform_action( + test_items, expected_result=expected_scen_result, expected_scen="test_N" + ) self.assertEqual(real_scen_result, expected_scen_result) - def test_choice_scenario_action_no_else_action(self): + async def test_choice_scenario_action_no_else_action(self): # Проверяем, что вернули None в случае если ни один сценарий не запустился (requirement=False) и else_action нет test_items = { "scenarios": [ @@ -454,10 +454,10 @@ def test_choice_scenario_action_no_else_action(self): } ] } - real_scen_result = self.mock_and_perform_action(test_items) + real_scen_result = await self.mock_and_perform_action(test_items) self.assertIsNone(real_scen_result) - def test_choice_scenario_action_with_else_action(self): + async def test_choice_scenario_action_with_else_action(self): # Проверяем, что выполняется else_action в случае если ни один сценарий не запустился т.к их requirement=False test_items = { "scenarios": [ @@ -473,13 +473,13 @@ def test_choice_scenario_action_with_else_action(self): "else_action": {"type": "test", "result": "ELSE ACTION IS DONE"} } expected_scen_result = "ELSE ACTION IS DONE" - real_scen_result = self.mock_and_perform_action(test_items, expected_result=expected_scen_result) + real_scen_result = await self.mock_and_perform_action(test_items, expected_result=expected_scen_result) self.assertEqual(real_scen_result, expected_scen_result) -class ClearCurrentScenarioActionTest(unittest.TestCase): +class ClearCurrentScenarioActionTest(unittest.IsolatedAsyncioTestCase): - def test_action(self): + async def test_action(self): scenario_name = "test_scenario" user = PicklableMock() user.forms.remove_item = PicklableMock() @@ -491,12 +491,12 @@ def test_action(self): user.descriptions = {"scenarios": {scenario_name: scenario}} action = ClearCurrentScenarioAction({}) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) user.last_scenarios.delete.assert_called_once() user.forms.remove_item.assert_called_once() - def test_action_with_empty_scenarios_names(self): + async def test_action_with_empty_scenarios_names(self): user = PicklableMock() user.forms.remove_item = PicklableMock() @@ -504,15 +504,15 @@ def test_action_with_empty_scenarios_names(self): user.last_scenarios.delete = PicklableMock() action = ClearCurrentScenarioAction({}) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) user.last_scenarios.delete.assert_not_called() user.forms.remove_item.assert_not_called() -class ClearScenarioByIdActionTest(unittest.TestCase): +class ClearScenarioByIdActionTest(unittest.IsolatedAsyncioTestCase): - def test_action(self): + async def test_action(self): scenario_name = "test_scenario" user = PicklableMock() user.forms = PicklableMock() @@ -523,26 +523,26 @@ def test_action(self): user.descriptions = {"scenarios": {scenario_name: scenario}} action = ClearScenarioByIdAction({"scenario_id": scenario_name}) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) user.last_scenarios.delete.assert_called_once() user.forms.remove_item.assert_called_once() - def test_action_with_empty_scenarios_names(self): + async def test_action_with_empty_scenarios_names(self): user = PicklableMock() user.forms = PicklableMock() user.last_scenarios.last_scenario_name = "test_scenario" action = ClearScenarioByIdAction({}) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) user.last_scenarios.delete.assert_not_called() user.forms.remove_item.assert_not_called() -class ClearCurrentScenarioFormActionTest(unittest.TestCase): - def test_action(self): +class ClearCurrentScenarioFormActionTest(unittest.IsolatedAsyncioTestCase): + async def test_action(self): scenario_name = "test_scenario" user = PicklableMock() user.forms = PicklableMock() @@ -555,11 +555,11 @@ def test_action(self): user.descriptions = {"scenarios": {scenario_name: scenario}} action = ClearCurrentScenarioFormAction({}) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) user.forms.clear_form.assert_called_once() - def test_action_with_empty_last_scenario(self): + async def test_action_with_empty_last_scenario(self): scenario_name = "test_scenario" user = PicklableMock() user.forms = PicklableMock() @@ -572,13 +572,13 @@ def test_action_with_empty_last_scenario(self): user.descriptions = {"scenarios": {scenario_name: scenario}} action = ClearCurrentScenarioFormAction({}) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) user.forms.remove_item.assert_not_called() -class ResetCurrentNodeActionTest(unittest.TestCase): - def test_action(self): +class ResetCurrentNodeActionTest(unittest.IsolatedAsyncioTestCase): + async def test_action(self): user = PicklableMock() user.forms = PicklableMock() user.last_scenarios.last_scenario_name = 'test_scenario' @@ -587,11 +587,11 @@ def test_action(self): user.scenario_models = {'test_scenario': scenario_model} action = ResetCurrentNodeAction({}) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) self.assertIsNone(user.scenario_models['test_scenario'].current_node) - def test_action_with_empty_last_scenario(self): + async def test_action_with_empty_last_scenario(self): user = PicklableMock() user.forms = PicklableMock() user.last_scenarios.last_scenario_name = None @@ -600,11 +600,11 @@ def test_action_with_empty_last_scenario(self): user.scenario_models = {'test_scenario': scenario_model} action = ResetCurrentNodeAction({}) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) self.assertEqual('some_node', user.scenario_models['test_scenario'].current_node) - def test_specific_target(self): + async def test_specific_target(self): user = PicklableMock() user.forms = PicklableMock() user.last_scenarios.last_scenario_name = 'test_scenario' @@ -616,12 +616,12 @@ def test_specific_target(self): 'node_id': 'another_node' } action = ResetCurrentNodeAction(items) - result = action.run(user, {}, {}) + result = await action.run(user, {}, {}) self.assertIsNone(result) self.assertEqual('another_node', user.scenario_models['test_scenario'].current_node) -class AddHistoryEventActionTest(unittest.TestCase): +class AddHistoryEventActionTest(unittest.IsolatedAsyncioTestCase): def setUp(self): main_form = PicklableMock() @@ -636,7 +636,7 @@ def setUp(self): self.user.history.add_event = PicklableMock() self.user.last_scenarios.last_scenario_name = 'test_scenario' - def test_action_with_non_empty_scenario(self): + async def test_action_with_non_empty_scenario(self): scenario = PicklableMock() scenario.id = 'name' scenario.version = '1.0' @@ -655,12 +655,12 @@ def test_action_with_non_empty_scenario(self): ) action = AddHistoryEventAction(items) - action.run(self.user, None, None) + await action.run(self.user, None, None) self.user.history.add_event.assert_called_once() self.user.history.add_event.assert_called_once_with(expected) - def test_action_with_empty_scenario(self): + async def test_action_with_empty_scenario(self): self.user.descriptions = {'scenarios': {}} items = { 'event_type': 'type', @@ -669,11 +669,11 @@ def test_action_with_empty_scenario(self): } action = AddHistoryEventAction(items) - action.run(self.user, None, None) + await action.run(self.user, None, None) self.user.history.add_event.assert_not_called() - def test_action_with_jinja(self): + async def test_action_with_jinja(self): scenario = PicklableMock() scenario.id = 'name' scenario.version = '1.0' @@ -692,7 +692,7 @@ def test_action_with_jinja(self): ) action = AddHistoryEventAction(items) - action.run(self.user, None, None) + await action.run(self.user, None, None) self.user.history.add_event.assert_called_once() self.user.history.add_event.assert_called_once_with(expected) diff --git a/tests/scenarios_tests/behaviors_test/test_behavior_description.py b/tests/scenarios_tests/behaviors_test/test_behavior_description.py index 6ccd20cc..8c63ac73 100644 --- a/tests/scenarios_tests/behaviors_test/test_behavior_description.py +++ b/tests/scenarios_tests/behaviors_test/test_behavior_description.py @@ -8,7 +8,7 @@ class MockAction(Action): - def run(self, user, text_preprocessing_result, params=None): + async def run(self, user, text_preprocessing_result, params=None): return [] @@ -17,7 +17,7 @@ def __init__(self): self.id = '123-456-789' self.version = 21 # какая-то версия - def run(self, a, b, c): + async def run(self, a, b, c): return 123 # некий результат diff --git a/tests/scenarios_tests/behaviors_test/test_behavior_model.py b/tests/scenarios_tests/behaviors_test/test_behavior_model.py index 79195d46..d331f673 100644 --- a/tests/scenarios_tests/behaviors_test/test_behavior_model.py +++ b/tests/scenarios_tests/behaviors_test/test_behavior_model.py @@ -6,10 +6,10 @@ from unittest.mock import Mock import scenarios.behaviors.behaviors -from smart_kit.utils.picklable_mock import PicklableMock +from smart_kit.utils.picklable_mock import PicklableMock, AsyncPicklableMock -class BehaviorsTest(unittest.TestCase): +class BehaviorsTest(unittest.IsolatedAsyncioTestCase): def setUp(self): self.user = PicklableMock() self.user.settings = PicklableMock() @@ -17,10 +17,10 @@ def setUp(self): self.user.local_vars.values = {"test_local_var_key": "test_local_var_value"} self.description = PicklableMock() self.description.timeout = Mock(return_value=10) - self.success_action = PicklableMock() - self.success_action.run = PicklableMock() - self.fail_action = PicklableMock() - self.timeout_action = PicklableMock() + self.success_action = AsyncPicklableMock() + self.success_action.run = AsyncPicklableMock() + self.fail_action = AsyncPicklableMock() + self.timeout_action = AsyncPicklableMock() self.description.success_action = self.success_action self.description.fail_action = self.fail_action @@ -28,7 +28,7 @@ def setUp(self): self.descriptions = {"test": self.description} self._callback = namedtuple('Callback', 'behavior_id expire_time scenario_id') - def test_success(self): + async def test_success(self): callback_id = "123" behavior_id = "test" item = {"behavior_id": behavior_id, "expire_time": 2554416000, "scenario_id": None, @@ -36,38 +36,38 @@ def test_success(self): items = {str(callback_id): item} behaviors = scenarios.behaviors.behaviors.Behaviors(items, self.descriptions, self.user) behaviors.initialize() - behaviors.success(callback_id) + await behaviors.success(callback_id) # self.success_action.run.assert_called_once_with(self.user, TextPreprocessingResult({})) self.success_action.run.assert_called_once() self.assertDictEqual(behaviors.raw, {}) - def test_success_2(self): + async def test_success_2(self): callback_id = "123" items = {} behaviors = scenarios.behaviors.behaviors.Behaviors(items, self.descriptions, self.user) behaviors.initialize() - behaviors.success(callback_id) + await behaviors.success(callback_id) self.success_action.run.assert_not_called() - def test_fail(self): + async def test_fail(self): callback_id = "123" behavior_id = "test" item = {"behavior_id": behavior_id, "expire_time": 2554416000, "scenario_id": None} items = {str(callback_id): item} behaviors = scenarios.behaviors.behaviors.Behaviors(items, self.descriptions, self.user) behaviors.initialize() - behaviors.fail(callback_id) + await behaviors.fail(callback_id) self.fail_action.run.assert_called_once() self.assertDictEqual(behaviors.raw, {}) - def test_timeout(self): + async def test_timeout(self): callback_id = "123" behavior_id = "test" item = {"behavior_id": behavior_id, "expire_time": 2554416000, "scenario_id": None} items = {str(callback_id): item} behaviors = scenarios.behaviors.behaviors.Behaviors(items, self.descriptions, self.user) behaviors.initialize() - behaviors.timeout(callback_id) + await behaviors.timeout(callback_id) self.timeout_action.run.assert_called_once() self.assertDictEqual(behaviors.raw, {}) diff --git a/tests/scenarios_tests/fillers/_trial_temp/_trial_marker b/tests/scenarios_tests/fillers/_trial_temp/_trial_marker deleted file mode 100755 index e69de29b..00000000 diff --git a/tests/scenarios_tests/fillers/test_approve.py b/tests/scenarios_tests/fillers/test_approve.py index 82e97a2f..da669d11 100644 --- a/tests/scenarios_tests/fillers/test_approve.py +++ b/tests/scenarios_tests/fillers/test_approve.py @@ -1,5 +1,5 @@ import os -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from unittest.mock import Mock, patch import smart_kit @@ -20,10 +20,10 @@ def patch_get_app_config(mock_get_app_config): mock_get_app_config.return_value = result -class TestApproveFiller(TestCase): +class TestApproveFiller(IsolatedAsyncioTestCase): @patch('smart_kit.configs.get_app_config') - def test_1(self, mock_get_app_config): + async def test_1(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) items = { 'yes_words': [ @@ -49,24 +49,24 @@ def test_1(self, mock_get_app_config): user_phrase = 'даю' text_pre_result = TextPreprocessingResult(normalizer(user_phrase)) - result = filler.extract(text_pre_result, None) + result = await filler.extract(text_pre_result, None) self.assertTrue(result) user_phrase = 'да нет' text_pre_result = TextPreprocessingResult(normalizer(user_phrase)) - result = filler.extract(text_pre_result, None) + result = await filler.extract(text_pre_result, None) self.assertFalse(result) user_phrase = 'даю добро' text_pre_result = TextPreprocessingResult(normalizer(user_phrase)) - result = filler.extract(text_pre_result, None) + result = await filler.extract(text_pre_result, None) self.assertIsNone(result) -class TestApproveRawTextFiller(TestCase): +class TestApproveRawTextFiller(IsolatedAsyncioTestCase): @patch('smart_kit.configs.get_app_config') - def test_1(self, mock_get_app_config): + async def test_1(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) items = { 'yes_words': [ @@ -92,15 +92,15 @@ def test_1(self, mock_get_app_config): user_phrase = 'конечно' text_pre_result = TextPreprocessingResult(normalizer(user_phrase)) - result = filler.extract(text_pre_result, None) + result = await filler.extract(text_pre_result, None) self.assertTrue(result) user_phrase = 'да нет' text_pre_result = TextPreprocessingResult(normalizer(user_phrase)) - result = filler.extract(text_pre_result, None) + result = await filler.extract(text_pre_result, None) self.assertFalse(result) user_phrase = 'даю' text_pre_result = TextPreprocessingResult(normalizer(user_phrase)) - result = filler.extract(text_pre_result, None) + result = await filler.extract(text_pre_result, None) self.assertIsNone(result) diff --git a/tests/scenarios_tests/fillers/test_available_info_filler.py b/tests/scenarios_tests/fillers/test_available_info_filler.py index 58de5ba5..ee1b408c 100644 --- a/tests/scenarios_tests/fillers/test_available_info_filler.py +++ b/tests/scenarios_tests/fillers/test_available_info_filler.py @@ -1,4 +1,4 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from unittest.mock import Mock from scenarios.scenario_models.field.field_filler_description import AvailableInfoFiller @@ -23,7 +23,7 @@ def collect(self, text_preprocessing_result=None, filter_params=None): return data -class TestAvailableInfoFiller(TestCase): +class TestAvailableInfoFiller(IsolatedAsyncioTestCase): @classmethod def setUpClass(cls): cls.address = "Address!" @@ -40,7 +40,7 @@ def setUp(self): user.descriptions = {"render_templates": template} self.user = user - def test_getting_person_info_value(self): + async def test_getting_person_info_value(self): name = "Name!" surname = "Surname!" self.user.person_info.raw = PicklableMock() @@ -48,34 +48,34 @@ def test_getting_person_info_value(self): person_info_items = {'value': '{{person_info.full_name.surname}}'} person_info_filler = AvailableInfoFiller(person_info_items) - result = person_info_filler.extract(None, self.user) + result = await person_info_filler.extract(None, self.user) self.assertEqual(result, surname) - def test_getting_payload_value(self): + async def test_getting_payload_value(self): self.user.message.payload = {"sf_answer": {"address": self.address}} - result = self.payload_filler.extract(None, self.user) + result = await self.payload_filler.extract(None, self.user) self.assertEqual(result, self.address) - def test_getting_uuid_value(self): + async def test_getting_uuid_value(self): uuid = "15" self.user.message.uuid = {"chatId": uuid} uuid_items = {'value': '{{uuid.chatId}}'} uuid_filler = AvailableInfoFiller(uuid_items) - result = uuid_filler.extract(None, self.user) + result = await uuid_filler.extract(None, self.user) self.assertEqual(result, uuid) - def test_not_failing_on_wrong_path(self): + async def test_not_failing_on_wrong_path(self): self.user.message.payload = {"other_answer": {"address": self.address}} - result = self.payload_filler.extract(None, self.user) + result = await self.payload_filler.extract(None, self.user) self.assertIsNone(result) - def test_return_empty_value(self): + async def test_return_empty_value(self): self.user.message.payload = {"sf_answer": '1'} - result = self.payload_filler.extract(None, self.user) + result = await self.payload_filler.extract(None, self.user) self.assertEqual("", result) - def test_filter(self): + async def test_filter(self): template = PicklableMock() template.get_template = Mock(return_value=["payload.personInfo.identityCard"]) self.user.parametrizer = MockParametrizer(self.user, {"filter": True}) @@ -83,10 +83,10 @@ def test_filter(self): self.user.descriptions = {"render_templates": template} payload_items = {'value': '{{filter}}'} filler = AvailableInfoFiller(payload_items) - result = filler.extract(None, self.user) + result = await filler.extract(None, self.user) self.assertEqual("filter_out", result) - def test_getting_payload_parsed_value(self): + async def test_getting_payload_parsed_value(self): data = [ { "id": 1, @@ -144,5 +144,5 @@ def test_getting_payload_parsed_value(self): "cacheGuid": "FHGDDASDHDAKSGFLAK", "data": data } - result = payload_filler.extract(None, self.user) + result = await payload_filler.extract(None, self.user) self.assertEqual(result, data) diff --git a/tests/scenarios_tests/fillers/test_classifier_filler.py b/tests/scenarios_tests/fillers/test_classifier_filler.py index c4506b73..7c36cac4 100644 --- a/tests/scenarios_tests/fillers/test_classifier_filler.py +++ b/tests/scenarios_tests/fillers/test_classifier_filler.py @@ -1,4 +1,4 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from unittest.mock import patch from core.basic_models.classifiers.basic_classifiers import ExternalClassifier @@ -6,7 +6,7 @@ from smart_kit.utils.picklable_mock import PicklableMock -class TestClassifierFiller(TestCase): +class TestClassifierFiller(IsolatedAsyncioTestCase): def setUp(self): test_items = { @@ -26,20 +26,20 @@ def setUp(self): "find_best_answer", return_value=[{"answer": "нет", "score": 0.7, "other": False}, {"answer": "да", "score": 0.3, "other": False}] ) - def test_filler_extract(self, mock_classifier_model): + async def test_filler_extract(self, mock_classifier_model): """Тест кейз проверяет что поле заполнено наиболее вероятным значением, которое вернула модель.""" expected_res = "нет" - actual_res = self.filler.extract(self.mock_text_preprocessing_result, self.mock_user) + actual_res = await self.filler.extract(self.mock_text_preprocessing_result, self.mock_user) self.assertEqual(expected_res, actual_res) @patch.object(ExternalClassifier, "find_best_answer", return_value=[]) - def test_filler_extract_if_no_model_answer(self, mock_classifier_model): + async def test_filler_extract_if_no_model_answer(self, mock_classifier_model): """Тест кейз проверяет что поле осталось не заполненным те результат None, если модель не выдала ответ.""" - actual_res = self.filler.extract(self.mock_text_preprocessing_result, self.mock_user) + actual_res = await self.filler.extract(self.mock_text_preprocessing_result, self.mock_user) self.assertIsNone(actual_res) -class TestClassifierFillerMeta(TestCase): +class TestClassifierFillerMeta(IsolatedAsyncioTestCase): def setUp(self): test_items = { @@ -55,14 +55,14 @@ def setUp(self): "external_classifiers": ["read_book_or_not_classifier", "hello_scenario_classifier"]} @patch.object(ExternalClassifier, "find_best_answer", return_value=[{"answer": "нет", "score": 1.0, "other": False}]) - def test_filler_extract(self, mock_classifier_model): + async def test_filler_extract(self, mock_classifier_model): """Тест кейз проверяет что мы получаем тот же самый ответ, что вернула модель.""" expected_res = [{"answer": "нет", "score": 1.0, "other": False}] - actual_res = self.filler_meta.extract(self.mock_text_preprocessing_result, self.mock_user) + actual_res = await self.filler_meta.extract(self.mock_text_preprocessing_result, self.mock_user) self.assertEqual(expected_res, actual_res) @patch.object(ExternalClassifier, "find_best_answer", return_value=[]) - def test_filler_extract_if_no_model_answer(self, mock_classifier_model): + async def test_filler_extract_if_no_model_answer(self, mock_classifier_model): """Тест кейз проверяет результат None, если модель не выдала ответ.""" - actual_res = self.filler_meta.extract(self.mock_text_preprocessing_result, self.mock_user) + actual_res = await self.filler_meta.extract(self.mock_text_preprocessing_result, self.mock_user) self.assertIsNone(actual_res) diff --git a/tests/scenarios_tests/fillers/test_composite_filler.py b/tests/scenarios_tests/fillers/test_composite_filler.py index 2a9c46a2..3fefdf36 100644 --- a/tests/scenarios_tests/fillers/test_composite_filler.py +++ b/tests/scenarios_tests/fillers/test_composite_filler.py @@ -1,4 +1,4 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from core.model.registered import registered_factories from scenarios.scenario_models.field.field_filler_description import FieldFillerDescription, CompositeFiller @@ -11,19 +11,18 @@ def __init__(self, items=None): items = items or {} self.result = items.get("result") - def extract(self, text_preprocessing_result, user, params): + async def extract(self, text_preprocessing_result, user, params): return self.result -class TestCompositeFiller(TestCase): +class TestCompositeFiller(IsolatedAsyncioTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): registered_factories[FieldFillerDescription] = field_filler_factory field_filler_description["mock_filler"] = MockFiller TestCompositeFiller.user = PicklableMock() - def test_first_filler(self): + async def test_first_filler(self): expected = "first" items = { "fillers": [ @@ -32,10 +31,10 @@ def test_first_filler(self): ] } filler = CompositeFiller(items) - result = filler.extract(None, self.user) + result = await filler.extract(None, self.user) self.assertEqual(expected, result) - def test_second_filler(self): + async def test_second_filler(self): expected = "second" items = { "fillers": [ @@ -44,10 +43,10 @@ def test_second_filler(self): ] } filler = CompositeFiller(items) - result = filler.extract(None, self.user) + result = await filler.extract(None, self.user) self.assertEqual(expected, result) - def test_not_fit(self): + async def test_not_fit(self): items = { "fillers": [ {"type": "mock_filler"}, @@ -55,5 +54,5 @@ def test_not_fit(self): ] } filler = CompositeFiller(items) - result = filler.extract(None, self.user) - self.assertIsNone(result) \ No newline at end of file + result = await filler.extract(None, self.user) + self.assertIsNone(result) diff --git a/tests/scenarios_tests/fillers/test_external_filler.py b/tests/scenarios_tests/fillers/test_external_filler.py index 76dd22e6..b540592c 100644 --- a/tests/scenarios_tests/fillers/test_external_filler.py +++ b/tests/scenarios_tests/fillers/test_external_filler.py @@ -1,22 +1,22 @@ -from unittest import TestCase -from unittest.mock import Mock +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock from scenarios.scenario_models.field.field_filler_description import ExternalFieldFillerDescription from smart_kit.utils.picklable_mock import PicklableMock -class TestExternalFieldFillerDescription(TestCase): +class TestExternalFieldFillerDescription(IsolatedAsyncioTestCase): - def test_1(self): + async def test_1(self): expected = 5 items = {"filler": "my_key"} mock_filler = PicklableMock() - mock_filler.run = Mock(return_value=expected) + mock_filler.run = AsyncMock(return_value=expected) mock_user = PicklableMock() mock_user.descriptions = {"external_field_fillers": {"my_key": mock_filler}} filler = ExternalFieldFillerDescription(items) - result = filler.extract(None, mock_user) + result = await filler.extract(None, mock_user) self.assertEqual(expected, result) diff --git a/tests/scenarios_tests/fillers/test_first_meeting.py b/tests/scenarios_tests/fillers/test_first_meeting.py index 9247487a..9d1b37b8 100644 --- a/tests/scenarios_tests/fillers/test_first_meeting.py +++ b/tests/scenarios_tests/fillers/test_first_meeting.py @@ -1,51 +1,50 @@ -from unittest import TestCase - +from unittest import IsolatedAsyncioTestCase from scenarios.scenario_models.field.field_filler_description import FirstNumberFiller, \ FirstCurrencyFiller from smart_kit.utils.picklable_mock import PicklableMock -class TestFirstNumberFiller(TestCase): - def test_1(self): +class TestFirstNumberFiller(IsolatedAsyncioTestCase): + async def test_1(self): expected = "5" items = {} text_preprocessing_result = PicklableMock() text_preprocessing_result.num_token_values = [expected] filler = FirstNumberFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) - def test_2(self): + async def test_2(self): items = {} text_preprocessing_result = PicklableMock() text_preprocessing_result.num_token_values = [] filler = FirstNumberFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertIsNone(result) -class TestFirstCurrencyFiller(TestCase): - def test_1(self): +class TestFirstCurrencyFiller(IsolatedAsyncioTestCase): + async def test_1(self): expected = "ru" items = {} text_preprocessing_result = PicklableMock() text_preprocessing_result.ccy_token_values = [expected] filler = FirstCurrencyFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) - def test_2(self): + async def test_2(self): items = {} text_preprocessing_result = PicklableMock() text_preprocessing_result.ccy_token_values = [] filler = FirstCurrencyFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertIsNone(result) diff --git a/tests/scenarios_tests/fillers/test_geo_token_filler.py b/tests/scenarios_tests/fillers/test_geo_token_filler.py index 4ef874b7..2288217c 100644 --- a/tests/scenarios_tests/fillers/test_geo_token_filler.py +++ b/tests/scenarios_tests/fillers/test_geo_token_filler.py @@ -1,36 +1,36 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from scenarios.scenario_models.field.field_filler_description import FirstGeoFiller from smart_kit.utils.picklable_mock import PicklableMock -class TestFirstGeoFiller(TestCase): +class TestFirstGeoFiller(IsolatedAsyncioTestCase): def setUp(self): items = {} self.filler = FirstGeoFiller(items) - def test_1(self): + async def test_1(self): expected = "москва" text_preprocessing_result = PicklableMock() text_preprocessing_result.geo_token_values = ["москва"] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) - def test_2(self): + async def test_2(self): expected = "москва" text_preprocessing_result = PicklableMock() text_preprocessing_result.geo_token_values = ["москва", "питер", "казань"] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) - def test_3(self): + async def test_3(self): text_preprocessing_result = PicklableMock() text_preprocessing_result.geo_token_values = [] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertIsNone(result) diff --git a/tests/scenarios_tests/fillers/test_intersection.py b/tests/scenarios_tests/fillers/test_intersection.py index d9906b49..bd9fb5fd 100644 --- a/tests/scenarios_tests/fillers/test_intersection.py +++ b/tests/scenarios_tests/fillers/test_intersection.py @@ -1,5 +1,5 @@ import os -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from unittest.mock import patch import smart_kit @@ -20,10 +20,10 @@ def patch_get_app_config(mock_get_app_config): mock_get_app_config.return_value = result -class TestIntersectionFieldFiller(TestCase): +class TestIntersectionFieldFiller(IsolatedAsyncioTestCase): @patch('smart_kit.configs.get_app_config') - def test_1(self, mock_get_app_config): + async def test_1(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) expected = 'лосось' items = { @@ -47,12 +47,12 @@ def test_1(self, mock_get_app_config): ] filler = IntersectionFieldFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) @patch('smart_kit.configs.get_app_config') - def test_2(self, mock_get_app_config): + async def test_2(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) items = { 'strict': True, @@ -76,12 +76,12 @@ def test_2(self, mock_get_app_config): ] filler = IntersectionFieldFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertIsNone(result) @patch('smart_kit.configs.get_app_config') - def test_3(self, mock_get_app_config): + async def test_3(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) expected = 'лосось' items = { @@ -102,24 +102,24 @@ def test_3(self, mock_get_app_config): ] filler = IntersectionFieldFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) @patch('smart_kit.configs.get_app_config') - def test_4(self, mock_get_app_config): + async def test_4(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) items = {} text_preprocessing_result = PicklableMock() text_preprocessing_result.tokenized_elements_list_pymorphy = [] filler = IntersectionFieldFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertIsNone(result) @patch('smart_kit.configs.get_app_config') - def test_5(self, mock_get_app_config): + async def test_5(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) expected = 'дефолтный тунец' items = { @@ -143,14 +143,14 @@ def test_5(self, mock_get_app_config): ] filler = IntersectionFieldFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) -class TestIntersectionOriginalTextFiller(TestCase): +class TestIntersectionOriginalTextFiller(IsolatedAsyncioTestCase): @patch('smart_kit.configs.get_app_config') - def test_1(self, mock_get_app_config): + async def test_1(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) items = { 'cases': { @@ -166,12 +166,12 @@ def test_1(self, mock_get_app_config): text_preprocessing_result.original_text = 'всего хорошего и спасибо за рыбу' filler = IntersectionOriginalTextFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertIsNone(result) @patch('smart_kit.configs.get_app_config') - def test_2(self, mock_get_app_config): + async def test_2(self, mock_get_app_config): expected = 'лосось' patch_get_app_config(mock_get_app_config) items = { @@ -188,12 +188,12 @@ def test_2(self, mock_get_app_config): text_preprocessing_result.original_text = 'всего хорошая и спасибо за рыба' filler = IntersectionOriginalTextFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) @patch('smart_kit.configs.get_app_config') - def test_3(self, mock_get_app_config): + async def test_3(self, mock_get_app_config): patch_get_app_config(mock_get_app_config) items = { 'cases': { @@ -212,6 +212,6 @@ def test_3(self, mock_get_app_config): text_preprocessing_result.original_text = 'не это хорошая рыба' filler = IntersectionOriginalTextFiller(items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertIsNone(result) diff --git a/tests/scenarios_tests/fillers/test_org_token_filler.py b/tests/scenarios_tests/fillers/test_org_token_filler.py index 78c52525..541c13c6 100644 --- a/tests/scenarios_tests/fillers/test_org_token_filler.py +++ b/tests/scenarios_tests/fillers/test_org_token_filler.py @@ -1,36 +1,36 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from scenarios.scenario_models.field.field_filler_description import FirstOrgFiller from smart_kit.utils.picklable_mock import PicklableMock -class TestFirstOrgFiller(TestCase): +class TestFirstOrgFiller(IsolatedAsyncioTestCase): def setUp(self): items = {} self.filler = FirstOrgFiller(items) - def test_1(self): + async def test_1(self): expected = "тинькофф" text_preprocessing_result = PicklableMock() text_preprocessing_result.org_token_values = ["тинькофф"] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) - def test_2(self): + async def test_2(self): expected = "тинькофф" text_preprocessing_result = PicklableMock() text_preprocessing_result.org_token_values = ["тинькофф", "втб", "мегафон"] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertEqual(expected, result) - def test_3(self): + async def test_3(self): text_preprocessing_result = PicklableMock() text_preprocessing_result.org_token_values = [] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertIsNone(result) diff --git a/tests/scenarios_tests/fillers/test_person_filler.py b/tests/scenarios_tests/fillers/test_person_filler.py index 13059991..19639a03 100644 --- a/tests/scenarios_tests/fillers/test_person_filler.py +++ b/tests/scenarios_tests/fillers/test_person_filler.py @@ -1,36 +1,36 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from scenarios.scenario_models.field.field_filler_description import FirstPersonFiller from smart_kit.utils.picklable_mock import PicklableMock -class TestFirstPersonFiller(TestCase): +class TestFirstPersonFiller(IsolatedAsyncioTestCase): def setUp(self): items = {} self.filler = FirstPersonFiller(items) - def test_1(self): + async def test_1(self): expected = {"name": "иван"} text_preprocessing_result = PicklableMock() text_preprocessing_result.person_token_values = [{"name": "иван"}] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertDictEqual(expected, result) - def test_2(self): + async def test_2(self): expected = {"name": "иван"} text_preprocessing_result = PicklableMock() text_preprocessing_result.person_token_values = [{"name": "иван"}, {"name": "иван", "patronymic": "иванович"}] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertDictEqual(expected, result) - def test_3(self): + async def test_3(self): text_preprocessing_result = PicklableMock() text_preprocessing_result.person_token_values = [] - result = self.filler.extract(text_preprocessing_result, None) + result = await self.filler.extract(text_preprocessing_result, None) self.assertIsNone(result) \ No newline at end of file diff --git a/tests/scenarios_tests/fillers/test_previous_messages_filler.py b/tests/scenarios_tests/fillers/test_previous_messages_filler.py index 60b750ee..f04f6b73 100644 --- a/tests/scenarios_tests/fillers/test_previous_messages_filler.py +++ b/tests/scenarios_tests/fillers/test_previous_messages_filler.py @@ -10,12 +10,12 @@ class MockFiller: def __init__(self, items=None): self.count = 0 - def extract(self, text_preprocessing_result, user, params): + async def extract(self, text_preprocessing_result, user, params): self.count += 1 -class PreviousMessagesFillerTest(unittest.TestCase): - def test_fill_1(self): +class PreviousMessagesFillerTest(unittest.IsolatedAsyncioTestCase): + async def test_fill_1(self): registered_factories[FieldFillerDescription] = field_filler_factory field_filler_description["mock_filler"] = MockFiller expected = "first" @@ -24,10 +24,10 @@ def test_fill_1(self): user.preprocessing_messages_for_scenarios = PicklableMock() user.preprocessing_messages_for_scenarios.processed_items = [{}, {}, {}] filler = PreviousMessagesFiller(items) - filler.extract(None, user) + await filler.extract(None, user) self.assertEqual(filler.filler.count, 4) - def test_fill_2(self): + async def test_fill_2(self): registered_factories[FieldFillerDescription] = field_filler_factory field_filler_description["mock_filler"] = MockFiller expected = "first" @@ -36,5 +36,5 @@ def test_fill_2(self): user.preprocessing_messages_for_scenarios = PicklableMock() user.preprocessing_messages_for_scenarios.processed_items = [{}, {}, {}] filler = PreviousMessagesFiller(items) - filler.extract(None, user) + await filler.extract(None, user) self.assertEqual(filler.filler.count, 2) diff --git a/tests/scenarios_tests/fillers/test_regexp_and_string_operations_filler.py b/tests/scenarios_tests/fillers/test_regexp_and_string_operations_filler.py index 3e0c5cb6..79ca5a2d 100644 --- a/tests/scenarios_tests/fillers/test_regexp_and_string_operations_filler.py +++ b/tests/scenarios_tests/fillers/test_regexp_and_string_operations_filler.py @@ -1,25 +1,16 @@ -from unittest import TestCase -from unittest.mock import Mock +from unittest import IsolatedAsyncioTestCase from scenarios.scenario_models.field.field_filler_description import RegexpAndStringOperationsFieldFiller +from smart_kit.utils.picklable_mock import PicklableMock -class PickableMock(Mock): - def __reduce__(self): - return (Mock, ()) - - -class PickablePicklableMock: - pass - - -class TestRegexpStringOperationsFiller(TestCase): +class TestRegexpStringOperationsFiller(IsolatedAsyncioTestCase): def setUp(self): self.items = {"exp": "1-[0-9A-Z]{7}"} def _test_operation(self, field_value, type_op, amount): self.items["operations"] = [] - text_preprocessing_result = PickablePicklableMock() + text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = field_value filler = RegexpAndStringOperationsFieldFiller(self.items) @@ -40,38 +31,38 @@ def test_operation_amount(self): result = self._test_operation(field_value, type_op, amount) self.assertEqual(field_value.lstrip(amount), result) - def _test_extract(self, field_value): - text_preprocessing_result = PickablePicklableMock() + async def _test_extract(self, field_value): + text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = field_value filler = RegexpAndStringOperationsFieldFiller(self.items) - return filler.extract(text_preprocessing_result, None) + return await filler.extract(text_preprocessing_result, None) - def test_extract_upper(self): + async def test_extract_upper(self): field_value = "1-rsar09a" self.items["operations"] = [{"type":"upper"}] - result = self._test_extract(field_value) + result = await self._test_extract(field_value) self.assertEqual(field_value.upper(), result) - def test_extract_rstrip(self): + async def test_extract_rstrip(self): field_value = "1-RSAR09A !)" self.items["operations"] = [{"type":"rstrip", "amount": "!) "}] - result = self._test_extract(field_value) + result = await self._test_extract(field_value) self.assertEqual(field_value.rstrip("!) "), result) - def test_extract_upper_rstrip(self): + async def test_extract_upper_rstrip(self): field_value = "1-rsar09a !)" self.items["operations"] = [ {"type":"upper"}, {"type":"rstrip", "amount": "!) "} ] - result = self._test_extract(field_value) + result = await self._test_extract(field_value) self.assertEqual(field_value.upper().rstrip("!) "), result) - def test_extract_no_operations(self): + async def test_extract_no_operations(self): field_value = "1-rsar09a !)" self.items["operations"] = [] - result = self._test_extract(field_value) + result = await self._test_extract(field_value) self.assertIsNone(result) diff --git a/tests/scenarios_tests/fillers/test_regexp_filler.py b/tests/scenarios_tests/fillers/test_regexp_filler.py index f26add9e..b272b51e 100644 --- a/tests/scenarios_tests/fillers/test_regexp_filler.py +++ b/tests/scenarios_tests/fillers/test_regexp_filler.py @@ -1,10 +1,9 @@ -from unittest import TestCase - +from unittest import IsolatedAsyncioTestCase from scenarios.scenario_models.field.field_filler_description import RegexpFieldFiller from smart_kit.utils.picklable_mock import PicklableMock -class TestRegexpFiller(TestCase): +class TestRegexpFiller(IsolatedAsyncioTestCase): def setUp(self): self.items = {"exp": "1-[0-9A-Z]{7}"} self.user = PicklableMock() @@ -14,46 +13,46 @@ def setUp(self): def test_no_exp_init(self): self.assertRaises(KeyError, RegexpFieldFiller, {}) - def test_no_exp(self): + async def test_no_exp(self): field_value = "1-RSAR09A" text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = field_value filler = RegexpFieldFiller(self.items) filler.regexp = None - self.assertIsNone(filler.extract(text_preprocessing_result, self.user)) + self.assertIsNone(await filler.extract(text_preprocessing_result, self.user)) - def test_extract(self): + async def test_extract(self): field_value = "1-RSAR09A" text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = field_value filler = RegexpFieldFiller(self.items) - result = filler.extract(text_preprocessing_result, self.user) + result = await filler.extract(text_preprocessing_result, self.user) self.assertEqual(field_value, result) - def test_extract_no_match(self): + async def test_extract_no_match(self): text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = "text" filler = RegexpFieldFiller(self.items) - result = filler.extract(text_preprocessing_result, self.user) + result = await filler.extract(text_preprocessing_result, self.user) self.assertIsNone(result) - def test_extract_mult_match_default_delimiter(self): + async def test_extract_mult_match_default_delimiter(self): field_value = "1-RSAR09A пустой тест 1-RSAR02A" res = ",".join(['1-RSAR09A', '1-RSAR02A']) text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = field_value filler = RegexpFieldFiller(self.items) - result = filler.extract(text_preprocessing_result, self.user) + result = await filler.extract(text_preprocessing_result, self.user) self.assertEqual(res, result) - def test_extract_mult_match_custom_delimiter(self): + async def test_extract_mult_match_custom_delimiter(self): field_value = "1-RSAR09A пустой тест 1-RSAR02B" self.items["delimiter"] = ";" res = self.items["delimiter"].join(['1-RSAR09A', '1-RSAR02B']) @@ -61,6 +60,6 @@ def test_extract_mult_match_custom_delimiter(self): text_preprocessing_result.original_text = field_value filler = RegexpFieldFiller(self.items) - result = filler.extract(text_preprocessing_result, self.user) + result = await filler.extract(text_preprocessing_result, self.user) self.assertEqual(res, result) diff --git a/tests/scenarios_tests/fillers/test_regexps_filler.py b/tests/scenarios_tests/fillers/test_regexps_filler.py index 428de51b..2df532c0 100644 --- a/tests/scenarios_tests/fillers/test_regexps_filler.py +++ b/tests/scenarios_tests/fillers/test_regexps_filler.py @@ -1,44 +1,43 @@ -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from scenarios.scenario_models.field.field_filler_description import AllRegexpsFieldFiller from smart_kit.utils.picklable_mock import PicklableMock -class Test_regexps_filler(TestCase): - @classmethod - def setUpClass(cls): - cls.items = {} - cls.items["exps"] = ["номер[а-я]*\.?\s?(\d+)", "n\.?\s?(\d+)", "nn\.?\s?(\d+)", "#\.?\s?(\d+)", +class Test_regexps_filler(IsolatedAsyncioTestCase): + def setUp(self): + self.items = {} + self.items["exps"] = ["номер[а-я]*\.?\s?(\d+)", "n\.?\s?(\d+)", "nn\.?\s?(\d+)", "#\.?\s?(\d+)", "##\.?\s?(\d+)", "№\.?\s?(\d+)", "№№\.?\s?(\d+)", "платеж[а-я]+\.?\s?(\d+)", "поручен[а-я]+\.?\s?(\d+)", "п\\s?,\\s?п\.?\s?(\d+)", "п\\s?\\/\\s?п\.?\s?(\d+)"] - cls.items["delimiter"] = "|" + self.items["delimiter"] = "|" - cls.filler = AllRegexpsFieldFiller(cls.items) + self.filler = AllRegexpsFieldFiller(self.items) - def test_extract_1(self): + async def test_extract_1(self): field_value = "Просим отозвать платежное поручение 14 от 23.01.19 на сумму 3500 и вернуть деньги на расчетный счет." text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = field_value filler = AllRegexpsFieldFiller(self.items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertEqual('14', result) - def test_extract_2(self): + async def test_extract_2(self): field_value = "поручение12 поручение14 #1 n3 п/п70 n33" text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = field_value filler = AllRegexpsFieldFiller(self.items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertEqual("3|33|1|12|14|70", result) - def test_extract_no_match(self): + async def test_extract_no_match(self): text_preprocessing_result = PicklableMock() text_preprocessing_result.original_text = "текст без искомых номеров" filler = AllRegexpsFieldFiller(self.items) - result = filler.extract(text_preprocessing_result, None) + result = await filler.extract(text_preprocessing_result, None) self.assertIsNone(result) diff --git a/tests/scenarios_tests/requirements_test/test_requirements.py b/tests/scenarios_tests/requirements_test/test_requirements.py index e8d2ae28..7133c32f 100644 --- a/tests/scenarios_tests/requirements_test/test_requirements.py +++ b/tests/scenarios_tests/requirements_test/test_requirements.py @@ -12,7 +12,7 @@ def __init__(self, items=None): items = items or {} self.cond = items.get("cond") or False - def check(self, text_preprocessing_result, user): + async def check(self, text_preprocessing_result, user): return self.cond @@ -55,9 +55,9 @@ def compare(self, value): return value == self.amount -class RequirementTest(unittest.TestCase): +class RequirementTest(unittest.IsolatedAsyncioTestCase): - def test_template_in_array_req_true(self): + async def test_template_in_array_req_true(self): items = { "template": "{{ payload.userInfo.tbcode }}", "items": ["32", "33"] @@ -72,144 +72,137 @@ def test_template_in_array_req_true(self): user = PicklableMock() user.parametrizer = PicklableMock() user.parametrizer.collect = Mock(return_value=params) - self.assertTrue(requirement.check(None, user)) - - def test_template_in_array_req_true2(self): - items = { - "template": "{{ payload.message.strip() }}", - "items": ["AAA", "BBB", "CCC"] - } - requirement = TemplateInArrayRequirement(items) - params = {"payload": { - "userInfo": { - "tbcode": "32" - }, - "message": " BBB " - }} - user = PicklableMock() - user.parametrizer = PicklableMock() - user.parametrizer.collect = Mock(return_value=params) - self.assertTrue(requirement.check(None, user)) - - def test_template_in_array_req_false(self): - items = { - "template": "{{ payload.message.strip() }}", - "items": ["AAA", "CCC"] - } - requirement = TemplateInArrayRequirement(items) - params = {"payload": { - "userInfo": { - "tbcode": "32", - }, - "message": " BBB " - }} - user = PicklableMock() - user.parametrizer = PicklableMock() - user.parametrizer.collect = Mock(return_value=params) - self.assertFalse(requirement.check(None, user)) - - - def test_array_in_template_req_true(self): - items = { - "template": { - "type": "unified_template", - "template": "{{ payload.userInfo.departcode.split('/')|tojson }}", - "loader": "json" - }, - "items": ["111", "456"] - } - requirement = ArrayItemInTemplateRequirement(items) - params = {"payload": { - "userInfo": { - "tbcode": "32", - "departcode": "123/2345/456" - }, - "message": " BBB " - }} - user = PicklableMock() - user.parametrizer = PicklableMock() - user.parametrizer.collect = Mock(return_value=params) - self.assertTrue(requirement.check(None, user)) - - - def test_array_in_template_req_true2(self): - items = { - "template": "{{ payload.message.strip() }}", - "items": ["AAA", "BBB"] - } - requirement = ArrayItemInTemplateRequirement(items) - params = {"payload": { - "userInfo": { - "tbcode": "32", - "departcode": "123/2345/456" - }, - "message": " BBB " - }} - user = PicklableMock() - user.parametrizer = PicklableMock() - user.parametrizer.collect = Mock(return_value=params) - self.assertTrue(requirement.check(None, user)) - - - def test_array_in_template_req_false(self): - items = { - "template": { - "type": "unified_template", - "template": "{{ payload.userInfo.departcode.split('/')|tojson }}", - "loader": "json" - }, - "items": ["111", "222"] - } - requirement = ArrayItemInTemplateRequirement(items) - params = {"payload": { - "userInfo": { - "tbcode": "32", - "departcode": "123/2345/456" - }, - "message": " BBB " - }} - user = PicklableMock() - user.parametrizer = PicklableMock() - user.parametrizer.collect = Mock(return_value=params) - self.assertFalse(requirement.check(None, user)) - - - def test_regexp_in_template_req_true(self): - items = { - "template": "{{ payload.message.strip() }}", - "regexp": "(^|\s)[Фф](\.|-)?1(\-)?(у|У)?($|\s)" - } - requirement = RegexpInTemplateRequirement(items) - params = {"payload": { - "userInfo": { - "tbcode": "32", - }, - "message": "карточки ф1у" - }} - user = PicklableMock() - user.parametrizer = PicklableMock() - user.parametrizer.collect = Mock(return_value=params) - self.assertTrue(requirement.check(None, user)) - - - def test_regexp_in_template_req_false(self): - items = { - "template": "{{ payload.message.strip() }}", - "regexp": "(^|\s)[Фф](\.|-)?1(\-)?(у|У)?($|\s)" - } - requirement = RegexpInTemplateRequirement(items) - params = {"payload": { - "userInfo": { - "tbcode": "32", - }, - "message": "карточки конг фу 1" - }} - user = PicklableMock() - user.parametrizer = PicklableMock() - user.parametrizer.collect = Mock(return_value=params) - self.assertFalse(requirement.check(None, user)) + self.assertTrue(await requirement.check(None, user)) + async def test_template_in_array_req_true2(self): + items = { + "template": "{{ payload.message.strip() }}", + "items": ["AAA", "BBB", "CCC"] + } + requirement = TemplateInArrayRequirement(items) + params = {"payload": { + "userInfo": { + "tbcode": "32" + }, + "message": " BBB " + }} + user = PicklableMock() + user.parametrizer = PicklableMock() + user.parametrizer.collect = Mock(return_value=params) + self.assertTrue(await requirement.check(None, user)) + async def test_template_in_array_req_false(self): + items = { + "template": "{{ payload.message.strip() }}", + "items": ["AAA", "CCC"] + } + requirement = TemplateInArrayRequirement(items) + params = {"payload": { + "userInfo": { + "tbcode": "32", + }, + "message": " BBB " + }} + user = PicklableMock() + user.parametrizer = PicklableMock() + user.parametrizer.collect = Mock(return_value=params) + self.assertFalse(await requirement.check(None, user)) + + async def test_array_in_template_req_true(self): + items = { + "template": { + "type": "unified_template", + "template": "{{ payload.userInfo.departcode.split('/')|tojson }}", + "loader": "json" + }, + "items": ["111", "456"] + } + requirement = ArrayItemInTemplateRequirement(items) + params = {"payload": { + "userInfo": { + "tbcode": "32", + "departcode": "123/2345/456" + }, + "message": " BBB " + }} + user = PicklableMock() + user.parametrizer = PicklableMock() + user.parametrizer.collect = Mock(return_value=params) + self.assertTrue(await requirement.check(None, user)) + + async def test_array_in_template_req_true2(self): + items = { + "template": "{{ payload.message.strip() }}", + "items": ["AAA", "BBB"] + } + requirement = ArrayItemInTemplateRequirement(items) + params = {"payload": { + "userInfo": { + "tbcode": "32", + "departcode": "123/2345/456" + }, + "message": " BBB " + }} + user = PicklableMock() + user.parametrizer = PicklableMock() + user.parametrizer.collect = Mock(return_value=params) + self.assertTrue(await requirement.check(None, user)) + + async def test_array_in_template_req_false(self): + items = { + "template": { + "type": "unified_template", + "template": "{{ payload.userInfo.departcode.split('/')|tojson }}", + "loader": "json" + }, + "items": ["111", "222"] + } + requirement = ArrayItemInTemplateRequirement(items) + params = {"payload": { + "userInfo": { + "tbcode": "32", + "departcode": "123/2345/456" + }, + "message": " BBB " + }} + user = PicklableMock() + user.parametrizer = PicklableMock() + user.parametrizer.collect = Mock(return_value=params) + self.assertFalse(await requirement.check(None, user)) + + async def test_regexp_in_template_req_true(self): + items = { + "template": "{{ payload.message.strip() }}", + "regexp": "(^|\s)[Фф](\.|-)?1(\-)?(у|У)?($|\s)" + } + requirement = RegexpInTemplateRequirement(items) + params = {"payload": { + "userInfo": { + "tbcode": "32", + }, + "message": "карточки ф1у" + }} + user = PicklableMock() + user.parametrizer = PicklableMock() + user.parametrizer.collect = Mock(return_value=params) + self.assertTrue(await requirement.check(None, user)) + + async def test_regexp_in_template_req_false(self): + items = { + "template": "{{ payload.message.strip() }}", + "regexp": "(^|\s)[Фф](\.|-)?1(\-)?(у|У)?($|\s)" + } + requirement = RegexpInTemplateRequirement(items) + params = {"payload": { + "userInfo": { + "tbcode": "32", + }, + "message": "карточки конг фу 1" + }} + user = PicklableMock() + user.parametrizer = PicklableMock() + user.parametrizer.collect = Mock(return_value=params) + self.assertFalse(await requirement.check(None, user)) if __name__ == '__main__': diff --git a/tests/scenarios_tests/scenarios_test/test_tree_scenario.py b/tests/scenarios_tests/scenarios_test/test_tree_scenario.py index c025d25a..76fd8043 100644 --- a/tests/scenarios_tests/scenarios_test/test_tree_scenario.py +++ b/tests/scenarios_tests/scenarios_test/test_tree_scenario.py @@ -1,11 +1,11 @@ -from unittest import TestCase -from unittest.mock import Mock, MagicMock +from unittest import IsolatedAsyncioTestCase +from unittest.mock import Mock, MagicMock, AsyncMock from core.basic_models.actions.basic_actions import Action, action_factory, actions from core.basic_models.actions.command import Command from core.model.registered import registered_factories from scenarios.scenario_descriptions.tree_scenario.tree_scenario import TreeScenario -from smart_kit.utils.picklable_mock import PicklableMock, PicklableMagicMock +from smart_kit.utils.picklable_mock import PicklableMock, PicklableMagicMock, AsyncPicklableMock class MockAction: @@ -13,7 +13,7 @@ def __init__(self, items=None, command_name=None): self.called = False self.command_name = command_name - def run(self, user, text_preprocessing_result, params): + async def run(self, user, text_preprocessing_result, params): self.called = True if self.command_name: return [Command(self.command_name)] @@ -23,14 +23,14 @@ class BreakAction: def __init__(self, items=None): pass - def run(self, user, text_preprocessing_result, params): + async def run(self, user, text_preprocessing_result, params): user.scenario_models["some_id"].break_scenario = True return [] -class TestTreeScenario(TestCase): +class TestTreeScenario(IsolatedAsyncioTestCase): - def test_1(self): + async def test_1(self): """ Тест проверяет сценарий из одного узла. Предполагается идеальный случай, когда одно поле и мы смогли его заполнить. @@ -51,13 +51,14 @@ def test_1(self): "scenario_nodes": {"node_1": node_mock}} field_descriptor = PicklableMock(name="field_descriptor_mock") - field_descriptor.filler.extract = PicklableMock(name="my_field_value_1", return_value=61) + field_descriptor.filler.run = AsyncMock(name="my_field_value_1", return_value=61) field_descriptor.fill_other = False field_descriptor.field_validator.actions = [] + field_descriptor.field_validator.requirement.check = AsyncMock(return_value=True) internal_form = PicklableMock(name="internal_form_mock") internal_form.description.fields.items = PicklableMock(return_value=[("age", field_descriptor)]) - internal_form.field.field_validator.requirement.check = PicklableMock(return_value=True) + internal_form.field.field_validator.requirement.check = AsyncPicklableMock(return_value=True) internal_form.fields = PicklableMagicMock() internal_form.fields.values.items = PicklableMock(return_value={"age": 61}) internal_form.is_valid = PicklableMock(return_value=True) @@ -85,11 +86,11 @@ def test_1(self): scenario = TreeScenario(items, 1) - result = scenario.run(text_preprocessing_result, user) + await scenario.run(text_preprocessing_result, user) self.assertIsNone(current_node_mock.current_node) context_forms.new.assert_called_once_with(form_type) - def test_breake(self): + async def test_break(self): """ Тест проверяет выход из сценария если сработает флаг break_scenario """ @@ -98,6 +99,7 @@ def test_breake(self): actions["test"] = MockAction actions["break"] = MockAction actions["success"] = MockAction + actions["external"] = MockAction form_type = "form for doing smth" internal_form_key = "my form key" @@ -108,15 +110,16 @@ def test_breake(self): "scenario_nodes": {"node_1": node_mock}, "actions": [{"type": "success"}]} field_descriptor = PicklableMock(name="field_descriptor_mock") - field_descriptor.filler.extract = PicklableMock(name="my_field_value_1", return_value=61) + field_descriptor.filler.run = AsyncMock(name="my_field_value_1", return_value=61) field_descriptor.fill_other = False field_descriptor.field_validator.actions = [] + field_descriptor.field_validator.requirement.check = AsyncMock(return_value=True) field_descriptor.on_filled_actions = [BreakAction(), MockAction(command_name="break action result")] field_descriptor.id = "age" internal_form = PicklableMock(name="internal_form_mock") internal_form.description.fields.items = PicklableMock(return_value=[("age", field_descriptor)]) - internal_form.field.field_validator.requirement.check = PicklableMock(return_value=True) + internal_form.field.field_validator.requirement.check = AsyncPicklableMock(return_value=True) field = PicklableMock() field.description = field_descriptor field.value = 61 @@ -146,7 +149,7 @@ def test_breake(self): scenario = TreeScenario(items, 1) - result = scenario.run(text_preprocessing_result, user) + result = await scenario.run(text_preprocessing_result, user) self.assertFalse(scenario.actions[0].called) self.assertEqual(result[0].name, "break action result") diff --git a/tests/scenarios_tests/user_models/test_is_int_value.py b/tests/scenarios_tests/user_models/test_is_int_value.py index 9636a6e1..df08405c 100644 --- a/tests/scenarios_tests/user_models/test_is_int_value.py +++ b/tests/scenarios_tests/user_models/test_is_int_value.py @@ -1,28 +1,27 @@ # coding: utf-8 -from unittest import TestCase +from unittest import IsolatedAsyncioTestCase from scenarios.scenario_models.field_requirements.field_requirements import IsIntFieldRequirement -class IsIntFieldRequirementTest(TestCase): +class IsIntFieldRequirementTest(IsolatedAsyncioTestCase): - @classmethod - def setUpClass(cls): + def setUp(self): items = {} - cls.requirement = IsIntFieldRequirement(items) + self.requirement = IsIntFieldRequirement(items) - def test_is_int_number_string(self): + async def test_is_int_number_string(self): text = "123" - self.assertTrue(self.requirement.check(text)) + self.assertTrue(await self.requirement.check(text)) - def test_is_int_float_string(self): + async def test_is_int_float_string(self): text = "1.23" - self.assertFalse(self.requirement.check(text)) + self.assertFalse(await self.requirement.check(text)) - def test_is_int_text_string(self): + async def test_is_int_text_string(self): text = "test" - self.assertFalse(self.requirement.check(text)) + self.assertFalse(await self.requirement.check(text)) - def test_is_int_empty_string(self): + async def test_is_int_empty_string(self): text = "" - self.assertFalse(self.requirement.check(text)) + self.assertFalse(await self.requirement.check(text)) diff --git a/tests/scenarios_tests/user_models/test_token_part_in_set_requirement.py b/tests/scenarios_tests/user_models/test_token_part_in_set_requirement.py index aa35f05c..f2f6120b 100644 --- a/tests/scenarios_tests/user_models/test_token_part_in_set_requirement.py +++ b/tests/scenarios_tests/user_models/test_token_part_in_set_requirement.py @@ -4,170 +4,188 @@ from scenarios.scenario_models.field_requirements.field_requirements import TokenPartInSet -class RequirementTest(unittest.TestCase): +class RequirementTest(unittest.IsolatedAsyncioTestCase): - def test_token_part_in_set_requirement_equal_false(self): - requirement_items ={ - "type": "token_part_in_set", - "part": "locality_type", - "values": ["DISTRICT", "REGION"] - } - token_val = {'value': 'Амадора', - 'locality_type': 'CITY', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False} + async def test_token_part_in_set_requirement_equal_false(self): + requirement_items = { + "type": "token_part_in_set", + "part": "locality_type", + "values": ["DISTRICT", "REGION"] + } + token_val = { + 'value': 'Амадора', + 'locality_type': 'CITY', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False + } requirement = TokenPartInSet(requirement_items) - self.assertFalse(requirement.check(token_val)) + self.assertFalse(await requirement.check(token_val)) - def test_token_part_in_set_requirement_equal_true(self): - requirement_items ={ - "type": "token_part_in_set", - "part": "locality_type", - "values": ["DISTRICT", "CITY"] - } - token_val = {'value': 'Амадора', - 'locality_type': 'CITY', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False} + async def test_token_part_in_set_requirement_equal_true(self): + requirement_items = { + "type": "token_part_in_set", + "part": "locality_type", + "values": ["DISTRICT", "CITY"] + } + token_val = { + 'value': 'Амадора', + 'locality_type': 'CITY', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False + } requirement = TokenPartInSet(requirement_items) - self.assertTrue(requirement.check(token_val)) + self.assertTrue(await requirement.check(token_val)) - def test_token_part_in_set_requirement_equal_False_double_empty(self): - requirement_items ={ - "type": "token_part_in_set", - "part": "locality_type", - "values": [] - } - token_val = {'value': 'Амадора', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False, - 'locality_type': []} + async def test_token_part_in_set_requirement_equal_False_double_empty(self): + requirement_items = { + "type": "token_part_in_set", + "part": "locality_type", + "values": [] + } + token_val = { + 'value': 'Амадора', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False, + 'locality_type': [] + } requirement = TokenPartInSet(requirement_items) - self.assertFalse(requirement.check(token_val)) + self.assertFalse(await requirement.check(token_val)) - def test_token_part_in_set_requirement_equal_False_empty_val_none(self): - requirement_items ={ - "type": "token_part_in_set", - "part": "locality_type", - "values": [] - } - token_val = {'value': 'Амадора', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False, - 'locality_type': None} + async def test_token_part_in_set_requirement_equal_False_empty_val_none(self): + requirement_items = { + "type": "token_part_in_set", + "part": "locality_type", + "values": [] + } + token_val = { + 'value': 'Амадора', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False, + 'locality_type': None + } requirement = TokenPartInSet(requirement_items) - self.assertFalse(requirement.check(token_val)) + self.assertFalse(await requirement.check(token_val)) - def test_token_part_in_set_requirement_equal_False_string(self): - requirement_items ={ - "type": "token_part_in_set", - "part": "value", - "values": 'cba' - } - token_val = {'value': 'abc', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False} + async def test_token_part_in_set_requirement_equal_False_string(self): + requirement_items = { + "type": "token_part_in_set", + "part": "value", + "values": 'cba' + } + token_val = { + 'value': 'abc', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False + } requirement = TokenPartInSet(requirement_items) - self.assertFalse(requirement.check(token_val)) + self.assertFalse(await requirement.check(token_val)) - def test_token_part_in_set_requirement_equal_False(self): - requirement_items ={ - "type": "token_part_in_set", - "part": "country_hidden", - "values": [1, 2, 3] - } - token_val = {'value': 'Амадора', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False} + async def test_token_part_in_set_requirement_equal_False(self): + requirement_items = { + "type": "token_part_in_set", + "part": "country_hidden", + "values": [1, 2, 3] + } + token_val = { + 'value': 'Амадора', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False + } requirement = TokenPartInSet(requirement_items) - self.assertFalse(requirement.check(token_val)) + self.assertFalse(await requirement.check(token_val)) - def test_token_part_in_set_requirement_equal_False_val_int(self): - requirement_items ={ - "type": "token_part_in_set", - "part": 'capital', - "values": [-9.23083] - } - token_val = {'value': 'Амадора', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False} + async def test_token_part_in_set_requirement_equal_False_val_int(self): + requirement_items = { + "type": "token_part_in_set", + "part": 'capital', + "values": [-9.23083] + } + token_val = { + 'value': 'Амадора', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False + } requirement = TokenPartInSet(requirement_items) - self.assertFalse(requirement.check(token_val)) + self.assertFalse(await requirement.check(token_val)) - def test_token_part_in_set_requirement_equal_True_arr(self): - requirement_items ={ - "type": "token_part_in_set", - "part": 'timezone', - "values": [[[None, 1.0]]] - } - token_val = {'value': 'Амадора', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False} + async def test_token_part_in_set_requirement_equal_True_arr(self): + requirement_items = { + "type": "token_part_in_set", + "part": 'timezone', + "values": [[[None, 1.0]]] + } + token_val = { + 'value': 'Амадора', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False + } requirement = TokenPartInSet(requirement_items) - self.assertTrue(requirement.check(token_val)) + self.assertTrue(await requirement.check(token_val)) - def test_token_part_in_set_requirement_equal_False_arr(self): - requirement_items ={ - "type": "token_part_in_set", - "part": 'timezone', - "values": [[[1.0, None]]] - } - token_val = {'value': 'Амадора', - 'latitude': 38.75382, - 'longitude': -9.23083, - 'capital': None, - 'locative_value': None, - 'timezone': [[None, 1.0]], - 'currency': ['EUR', 'евро'], - 'country': 'Португалия', - 'country_hidden': False} + async def test_token_part_in_set_requirement_equal_False_arr(self): + requirement_items = { + "type": "token_part_in_set", + "part": 'timezone', + "values": [[[1.0, None]]] + } + token_val = { + 'value': 'Амадора', + 'latitude': 38.75382, + 'longitude': -9.23083, + 'capital': None, + 'locative_value': None, + 'timezone': [[None, 1.0]], + 'currency': ['EUR', 'евро'], + 'country': 'Португалия', + 'country_hidden': False + } requirement = TokenPartInSet(requirement_items) - self.assertFalse(requirement.check(token_val)) + self.assertFalse(await requirement.check(token_val)) diff --git a/tests/smart_kit_tests/action/test_base_http_action.py b/tests/smart_kit_tests/action/test_base_http_action.py deleted file mode 100644 index cab465ad..00000000 --- a/tests/smart_kit_tests/action/test_base_http_action.py +++ /dev/null @@ -1,71 +0,0 @@ -import unittest -from unittest.mock import Mock, patch - -from smart_kit.action.base_http import BaseHttpRequestAction - - -class BaseHttpRequestActionTest(unittest.TestCase): - def setUp(self): - self.user = Mock(parametrizer=Mock(collect=lambda *args, **kwargs: {})) - - @staticmethod - def set_request_mock_attribute(request_mock, return_value=None): - return_value = return_value or {} - request_mock.return_value = Mock( - __enter__=Mock(return_value=Mock( - json=Mock(return_value=return_value), - cookies={}, - headers={}, - ),), - __exit__=Mock() - ) - - @patch('requests.request') - def test_simple_request(self, request_mock: Mock): - self.set_request_mock_attribute(request_mock) - items = { - "method": "POST", - "url": "https://my.url.com", - } - result = BaseHttpRequestAction(items).run(self.user, None, {}) - request_mock.assert_called_with(url="https://my.url.com", method='POST') - self.assertEqual(result, {}) - - @patch('requests.request') - def test_render_params(self, request_mock: Mock): - self.set_request_mock_attribute(request_mock) - items = { - "method": "POST", - "url": "https://{{url}}", - "timeout": 3, - "json": { - "param": "{{value}}" - } - } - params = { - "url": "my.url.com", - "value": "my_value" - } - result = BaseHttpRequestAction(items).run(self.user, None, params) - request_mock.assert_called_with(url="https://my.url.com", method='POST', timeout=3, json={"param": "my_value"}) - self.assertEqual(result, {}) - - @patch('requests.request') - def test_headers_fix(self, request_mock): - self.set_request_mock_attribute(request_mock) - items = { - "headers": { - "header_1": 32, - "header_2": 32.03, - "header_3": b"d32", - "header_4": None, - "header_5": {"data": "value"}, - }, - } - result = BaseHttpRequestAction(items).run(self.user, None, {}) - request_mock.assert_called_with(headers={ - "header_1": "32", - "header_2": "32.03", - "header_3": b"d32" - }) - self.assertEqual(result, {}) diff --git a/tests/smart_kit_tests/action/test_http.py b/tests/smart_kit_tests/action/test_http.py index d418b7cd..ee61417c 100644 --- a/tests/smart_kit_tests/action/test_http.py +++ b/tests/smart_kit_tests/action/test_http.py @@ -1,24 +1,39 @@ import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, AsyncMock + +from aiohttp import ClientTimeout from smart_kit.action.http import HTTPRequestAction -from tests.smart_kit_tests.action.test_base_http_action import BaseHttpRequestActionTest -class HttpRequestActionTest(unittest.TestCase): +class HttpRequestActionTest(unittest.IsolatedAsyncioTestCase): + TIMEOUT = 3 + def setUp(self): self.user = Mock( parametrizer=Mock(collect=lambda *args, **kwargs: {}), descriptions={ "behaviors": { - "my_behavior": Mock(timeout=Mock(return_value=3)) + "my_behavior": AsyncMock(timeout=Mock(return_value=3)) } } ) - @patch('requests.request') - def test_simple_request(self, request_mock: Mock): - BaseHttpRequestActionTest.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) + def set_request_mock_attribute(self, request_mock, return_value=None): + return_value = return_value or {} + request_mock.return_value = Mock( + __aenter__=AsyncMock(return_value=Mock( + # response + json=AsyncMock(return_value=return_value), + cookies={}, + headers={}, + ), ), + __aexit__=AsyncMock() + ) + + @patch('aiohttp.request') + async def test_simple_request(self, request_mock: Mock): + self.set_request_mock_attribute(request_mock, return_value={'data': 'value'}) items = { "params": { "method": "POST", @@ -27,8 +42,53 @@ def test_simple_request(self, request_mock: Mock): "store": "user_variable", "behavior": "my_behavior", } - HTTPRequestAction(items).run(self.user, None, {}) - request_mock.assert_called_with(url="https://my.url.com", method='POST', timeout=3) + await HTTPRequestAction(items).run(self.user, None, {}) + request_mock.assert_called_with(url="https://my.url.com", method='POST', timeout=ClientTimeout(3)) self.assertTrue(self.user.descriptions["behaviors"]["my_behavior"].success_action.run.called) self.assertTrue(self.user.variables.set.called) self.user.variables.set.assert_called_with("user_variable", {'data': 'value'}) + + @patch('aiohttp.request') + async def test_render_params(self, request_mock: Mock): + self.set_request_mock_attribute(request_mock) + items = { + "params": { + "method": "POST", + "url": "https://{{url}}", + "timeout": 3, + "json": { + "param": "{{value}}" + } + }, + "store": "user_variable", + "behavior": "my_behavior", + } + params = { + "url": "my.url.com", + "value": "my_value" + } + await HTTPRequestAction(items).run(self.user, None, params) + request_mock.assert_called_with(url="https://my.url.com", method='POST', timeout=ClientTimeout(3), json={"param": "my_value"}) + + @patch('aiohttp.request') + async def test_headers_fix(self, request_mock): + self.set_request_mock_attribute(request_mock) + items = { + "params": { + "headers": { + "header_1": 32, + "header_2": 32.03, + "header_3": b"d32", + "header_4": None, + "header_5": {"data": "value"}, + }, + }, + "store": "user_variable", + "behavior": "my_behavior", + } + await HTTPRequestAction(items).run(self.user, None, {}) + request_mock.assert_called_with(headers={ + "header_1": "32", + "header_2": "32.03", + "header_3": b"d32" + }, method=HTTPRequestAction.DEFAULT_METHOD, timeout=ClientTimeout(self.TIMEOUT)) diff --git a/tests/smart_kit_tests/action/test_run_scenario_by_project_name.py b/tests/smart_kit_tests/action/test_run_scenario_by_project_name.py index 616ac188..8a2c6bf6 100644 --- a/tests/smart_kit_tests/action/test_run_scenario_by_project_name.py +++ b/tests/smart_kit_tests/action/test_run_scenario_by_project_name.py @@ -7,11 +7,11 @@ class TestScenarioDesc(dict): - def run(self, argv1, argv2, params): + async def run(self, argv1, argv2, params): return 'result to run scenario' -class RunScenarioByProjectNameActionTest1(unittest.TestCase): +class RunScenarioByProjectNameActionTest1(unittest.IsolatedAsyncioTestCase): def setUp(self): self.test_text_preprocessing_result = Mock('text_preprocessing_result') self.test_user1 = Mock('User') @@ -32,13 +32,13 @@ def setUp(self): self.test_text_preprocessing_result = Mock('TextPreprocessingResult') self.items = {"any_key": "any value"} - def test_run_scenario_by_project_name_run(self): + async def test_run_scenario_by_project_name_run(self): obj1 = RunScenarioByProjectNameAction(self.items) # без оглядки на аннотации из PEP 484 - self.assertTrue(obj1.run(self.test_user1, self.test_text_preprocessing_result, {'any_attr': {'any_data'}}) == + self.assertTrue(await obj1.run(self.test_user1, self.test_text_preprocessing_result, {'any_attr': {'any_data'}}) == 'result to run scenario') obj2 = RunScenarioByProjectNameAction(self.items) - self.assertIsNone(obj2.run(self.test_user2, self.test_text_preprocessing_result)) + self.assertIsNone(await obj2.run(self.test_user2, self.test_text_preprocessing_result)) def test_run_scenario_by_project_name_log_vars(self): obj = RunScenarioByProjectNameAction(self.items) diff --git a/tests/smart_kit_tests/adapters/test_memory_adapter.py b/tests/smart_kit_tests/adapters/test_memory_adapter.py index a3c10772..ece4090d 100644 --- a/tests/smart_kit_tests/adapters/test_memory_adapter.py +++ b/tests/smart_kit_tests/adapters/test_memory_adapter.py @@ -4,50 +4,50 @@ from core.db_adapter import memory_adapter -class AdapterTest1(unittest.TestCase): +class AdapterTest1(unittest.IsolatedAsyncioTestCase): - def test_memory_adapter_init(self): + async def test_memory_adapter_init(self): obj1 = memory_adapter.MemoryAdapter() obj2 = memory_adapter.MemoryAdapter({'try_count': 3}) self.assertTrue(hasattr(obj1, 'open')) self.assertTrue(hasattr(obj2, 'open')) - self.assertTrue(obj1.memory_storage == {}) - self.assertTrue(obj2.memory_storage == {}) - self.assertTrue(obj1.try_count == 5) # взято из исходников - self.assertTrue(obj2.try_count == 3) + self.assertEqual(obj1.memory_storage, {}) + self.assertEqual(obj2.memory_storage, {}) + self.assertEqual(obj1.try_count, 5) # взято из исходников + self.assertEqual(obj2.try_count, 3) - def test_memory_adapter_connect(self): + async def test_memory_adapter_connect(self): obj = memory_adapter.MemoryAdapter() self.assertTrue(hasattr(obj, 'connect')) - def test_memory_adapter_open(self): + async def test_memory_adapter_open(self): obj = memory_adapter.MemoryAdapter() self.assertTrue(hasattr(obj, '_open')) with self.assertRaises(TypeError): - obj._open() + await obj._open() - def test_memory_adapter_save(self): + async def test_memory_adapter_save(self): obj1 = memory_adapter.MemoryAdapter() obj2 = memory_adapter.MemoryAdapter() - obj1._save(10, {'any_data'}) - obj1._save(11, {'any_data'}) - obj2._save(10, {'any_data'}) - self.assertTrue(obj1.memory_storage == {10: {'any_data'}, 11: {'any_data'}}) - self.assertTrue(obj2.memory_storage == {10: {'any_data'}}) + await obj1._save(10, {'any_data'}) + await obj1._save(11, {'any_data'}) + await obj2._save(10, {'any_data'}) + self.assertEqual(obj1.memory_storage, {10: {'any_data'}, 11: {'any_data'}}) + self.assertEqual(obj2.memory_storage, {10: {'any_data'}}) # метод переписывает значения - obj2._save(10, 'any_data') - self.assertTrue(obj2.memory_storage == {10: 'any_data'}) + await obj2._save(10, 'any_data') + self.assertEqual(obj2.memory_storage, {10: 'any_data'}) - def test_memory_adapter_get(self): + async def test_memory_adapter_get(self): obj1 = memory_adapter.MemoryAdapter() - obj1._save(10, {'any_data'}) - obj1._save(11, 'any_data') - self.assertTrue(obj1._get(10) == {'any_data'}) - self.assertTrue(obj1._get(11) == 'any_data') - self.assertIsNone(obj1._get(12)) + await obj1._save(10, {'any_data'}) + await obj1._save(11, 'any_data') + self.assertEqual(await obj1._get(10), {'any_data'}) + self.assertEqual(await obj1._get(11), 'any_data') + self.assertIsNone(await obj1._get(12)) - def test_memory_adapter_list_dir(self): + async def test_memory_adapter_list_dir(self): obj = memory_adapter.MemoryAdapter() self.assertTrue(hasattr(obj, '_list_dir')) with self.assertRaises(TypeError): - obj._open() + await obj._open() diff --git a/tests/smart_kit_tests/handlers/test_handle_close_app.py b/tests/smart_kit_tests/handlers/test_handle_close_app.py index 1b30d9db..b220fd15 100644 --- a/tests/smart_kit_tests/handlers/test_handle_close_app.py +++ b/tests/smart_kit_tests/handlers/test_handle_close_app.py @@ -11,7 +11,7 @@ def form_type(self): return 'type' -class HandlerTest6(unittest.TestCase): +class HandlerTest6(unittest.IsolatedAsyncioTestCase): def setUp(self): self.test_text_preprocessing_result = Mock('text_preprocessing_result') self.test_user = PicklableMock() @@ -37,7 +37,7 @@ def test_handler_close_app_init(self): self.assertIsNotNone(obj.KAFKA_KEY) self.assertIsNotNone(obj._clear_current_scenario) - def test_handler_close_app_run(self): + async def test_handler_close_app_run(self): self.assertIsNotNone(handle_close_app.log_const.KEY_NAME) obj = handle_close_app.HandlerCloseApp(app_name=self.app_name) - self.assertIsNone(obj.run(self.test_payload, self.test_user)) + self.assertIsNone(await obj.run(self.test_payload, self.test_user)) diff --git a/tests/smart_kit_tests/handlers/test_handle_respond.py b/tests/smart_kit_tests/handlers/test_handle_respond.py index cd8007e6..8ac619e4 100644 --- a/tests/smart_kit_tests/handlers/test_handle_respond.py +++ b/tests/smart_kit_tests/handlers/test_handle_respond.py @@ -6,7 +6,11 @@ from smart_kit.utils.picklable_mock import PicklableMock, PicklableMagicMock -class HandlerTest4(unittest.TestCase): +async def mock_test_action_run(x, y, z): + return 10 + + +class HandlerTest4(unittest.IsolatedAsyncioTestCase): def setUp(self): self.app_name = "TestAppName" self.test_user1 = Mock('user') @@ -24,7 +28,7 @@ def setUp(self): self.test_user1.behaviors = PicklableMagicMock() self.test_action = Mock('action') - self.test_action.run = lambda x, y, z: 10 # пусть что то возвращает + self.test_action.run = mock_test_action_run # пусть что то возвращает. self.test_user2 = MagicMock('user') self.test_user2.id = '123-345-678' # пусть чему-то равняется self.test_user2.descriptions = {'external_actions': {'any action name': self.test_action}} @@ -62,7 +66,7 @@ def test_handler_respond_get_action_params(self): self.assertTrue(obj.get_action_params("any data", self.test_user2) == self.callback11_action_params) self.assertTrue(obj.get_action_params(None, self.test_user2) == self.callback11_action_params) - def test_handler_respond_run(self): + async def test_handler_respond_run(self): self.assertIsNotNone(handle_respond.TextPreprocessingResult(self.test_payload.get("message", {}))) self.assertIsNotNone(handle_respond.log_const.KEY_NAME) self.assertIsNotNone(handle_respond.log_const.NORMALIZED_TEXT_VALUE) @@ -70,5 +74,5 @@ def test_handler_respond_run(self): obj1 = handle_respond.HandlerRespond(app_name=self.app_name) obj2 = handle_respond.HandlerRespond(self.app_name, "any action name") with self.assertRaises(KeyError): - obj1.run(self.test_payload, self.test_user1) - self.assertTrue(obj2.run(self.test_payload, self.test_user2) == 10) + await obj1.run(self.test_payload, self.test_user1) + self.assertTrue(await obj2.run(self.test_payload, self.test_user2) == 10) diff --git a/tests/smart_kit_tests/handlers/test_handler_text.py b/tests/smart_kit_tests/handlers/test_handler_text.py index 692480d8..0a68b61b 100644 --- a/tests/smart_kit_tests/handlers/test_handler_text.py +++ b/tests/smart_kit_tests/handlers/test_handler_text.py @@ -6,13 +6,21 @@ from smart_kit.utils.picklable_mock import PicklableMock -class HandlerTest5(unittest.TestCase): +async def mock_dialogue_manager1_run(x, y): + return "TestAnswer", True + + +async def mock_dialogue_manager2_run(x, y): + return "", False + + +class HandlerTest5(unittest.IsolatedAsyncioTestCase): def setUp(self): self.app_name = "TestAppName" self.test_dialog_manager1 = Mock('dialog_manager') - self.test_dialog_manager1.run = lambda x, y: ("TestAnswer", True) + self.test_dialog_manager1.run = mock_dialogue_manager1_run self.test_dialog_manager2 = Mock('dialog_manager') - self.test_dialog_manager2.run = lambda x, y: ("", False) + self.test_dialog_manager2.run = mock_dialogue_manager2_run self.test_text_preprocessing_result = Mock('text_preprocessing_result') self.test_text_preprocessing_result.raw = 'any raw' self.test_user = Mock('User') @@ -39,15 +47,15 @@ def test_handler_text_init(self): self.assertIsNotNone(handler_text.log_const.STARTUP_VALUE) self.assertIsNotNone(obj1.__class__.__name__) - def test_handler_text_handle_base(self): + async def test_handler_text_handle_base(self): obj1 = handler_text.HandlerText(self.app_name, self.test_dialog_manager1) obj2 = handler_text.HandlerText(self.app_name, self.test_dialog_manager2) - self.assertTrue(obj1._handle_base(self.test_text_preprocessing_result, self.test_user) == "TestAnswer") - self.assertTrue(obj2._handle_base(self.test_text_preprocessing_result, self.test_user) == []) + self.assertTrue(await obj1._handle_base(self.test_text_preprocessing_result, self.test_user) == "TestAnswer") + self.assertTrue(await obj2._handle_base(self.test_text_preprocessing_result, self.test_user) == []) - def test_handler_text_run(self): + async def test_handler_text_run(self): self.assertIsNotNone(handler_text.log_const.NORMALIZED_TEXT_VALUE) obj1 = handler_text.HandlerText(self.app_name, self.test_dialog_manager1) obj2 = handler_text.HandlerText(self.app_name, self.test_dialog_manager2) - self.assertTrue(obj1.run(self.test_payload, self.test_user) == "TestAnswer") - self.assertTrue(obj2.run(self.test_payload, self.test_user) == []) + self.assertTrue(await obj1.run(self.test_payload, self.test_user) == "TestAnswer") + self.assertTrue(await obj2.run(self.test_payload, self.test_user) == []) diff --git a/tests/smart_kit_tests/handlers/test_handler_timeout.py b/tests/smart_kit_tests/handlers/test_handler_timeout.py index 0238a3ff..a4b3ee91 100644 --- a/tests/smart_kit_tests/handlers/test_handler_timeout.py +++ b/tests/smart_kit_tests/handlers/test_handler_timeout.py @@ -6,7 +6,11 @@ from smart_kit.utils.picklable_mock import PicklableMock, PicklableMagicMock -class HandlerTest2(unittest.TestCase): +async def mock_behaviors_timeout(x): + return 120 + + +class HandlerTest2(unittest.IsolatedAsyncioTestCase): def setUp(self): self.app_name = "TastAppName" self.test_user = Mock('user') @@ -22,13 +26,13 @@ def setUp(self): self.test_user.message.device.surface = "surface" self.test_user.behaviors = Mock('behaviors') - self.test_user.behaviors.timeout = lambda x: 120 + self.test_user.behaviors.timeout = mock_behaviors_timeout self.test_user.behaviors.has_callback = lambda *x, **y: PicklableMagicMock() self.test_user.behaviors.get_callback_action_params = lambda *x, **y: {} self.test_payload = Mock('payload') - def test_handler_timeout(self): + async def test_handler_timeout(self): obj = handler_timeout.HandlerTimeout(self.app_name) self.assertIsNotNone(obj.KAFKA_KEY) self.assertIsNotNone(handler_timeout.log_const.KEY_NAME) - self.assertTrue(obj.run(self.test_payload, self.test_user) == 120) + self.assertTrue(await obj.run(self.test_payload, self.test_user) == 120) diff --git a/tests/smart_kit_tests/models/test_dialogue_manager.py b/tests/smart_kit_tests/models/test_dialogue_manager.py index 5af86bf2..0e0d6c86 100644 --- a/tests/smart_kit_tests/models/test_dialogue_manager.py +++ b/tests/smart_kit_tests/models/test_dialogue_manager.py @@ -10,7 +10,27 @@ def get_keys(self): return self.keys() -class ModelsTest1(unittest.TestCase): +async def mock_two_parameters_return_false(x, y): + return False + + +async def mock_scenario1_text_fits(): + return False + + +async def mock_scenario2_text_fits(): + return True + + +async def mock_scenario1_run(x, y): + return x.name + y.name + + +async def mock_scenario2_run(x, y): + return y.name + x.name + + +class ModelsTest1(unittest.IsolatedAsyncioTestCase): def setUp(self): self.test_user1 = PicklableMock() self.test_user1.name = "TestName" @@ -40,16 +60,17 @@ def setUp(self): self.test_text_preprocessing_result.name = "Result" self.test_scenario1 = PicklableMock() self.test_scenario1.scenario_description = "This is test scenario 1 desc" - self.test_scenario1.text_fits = lambda x, y: False - self.test_scenario1.run = lambda x, y: x.name + y.name + self.test_scenario1.text_fits = mock_scenario1_text_fits + self.test_scenario1.text_fits = mock_two_parameters_return_false + self.test_scenario1.run = mock_scenario1_run self.test_scenario2 = PicklableMock() self.test_scenario2.scenario_description = "This is test scenario 2 desc" - self.test_scenario2.text_fits = lambda x, y: True - self.test_scenario2.run = lambda x, y: y.name + x.name + self.test_scenario2.text_fits = mock_scenario2_text_fits + self.test_scenario2.run = mock_scenario2_run self.test_scenarios = TestScenarioDesc({1: self.test_scenario1, 2: self.test_scenario2}) self.TestAction = PicklableMock() self.TestAction.description = "test_function" - self.TestAction.run = lambda x, y: x.name + y.name + self.TestAction.run = mock_scenario1_run self.app_name = "test" def test_log_const(self): @@ -79,7 +100,7 @@ def test_dialogue_manager_found_action(self): with self.assertRaises(TypeError): obj2._nothing_found_action() - def test_dialogue_manager_run(self): + async def test_dialogue_manager_run(self): obj1 = dialogue_manager.DialogueManager({'scenarios': self.test_scenarios, 'external_actions': {'nothing_found_action': self.TestAction}}, self.app_name) @@ -87,22 +108,29 @@ def test_dialogue_manager_run(self): 'external_actions': {}}, self.app_name) # путь по умолчанию без выполнения условий - self.assertTrue(obj1.run(self.test_text_preprocessing_result, self.test_user1) == ("TestNameResult", True)) - self.assertTrue(obj2.run(self.test_text_preprocessing_result, self.test_user1) == ("TestNameResult", True)) + self.assertTrue( + await obj1.run(self.test_text_preprocessing_result, self.test_user1) == ("TestNameResult", True) + ) + self.assertTrue( + await obj2.run(self.test_text_preprocessing_result, self.test_user1) == ("TestNameResult", True) + ) # случай когда срабатоли оба условия - self.assertTrue(obj1.run(self.test_text_preprocessing_result, self.test_user2) == ("TestNameResult", True)) + self.assertTrue( + await obj1.run(self.test_text_preprocessing_result, self.test_user2) == ("TestNameResult", True) + ) # случай, когда 2-е условие не выполнено - self.assertTrue(obj2.run(self.test_text_preprocessing_result, self.test_user3) == ('TestNameResult', True)) - - # проверка на вызов before_action, если такой задан в external_actions - obj2.run(self.test_text_preprocessing_result, self.test_user4) - self.assertTrue(self.test_user4.descriptions['external_actions']['before_action'].run.called) + self.assertTrue( + await obj2.run(self.test_text_preprocessing_result, self.test_user3) == ('TestNameResult', True) + ) - def test_dialogue_manager_run_scenario(self): + async def test_dialogue_manager_run_scenario(self): obj = dialogue_manager.DialogueManager({'scenarios': self.test_scenarios, 'external_actions': {'nothing_found_action': self.TestAction}}, self.app_name) - self.assertTrue(obj.run_scenario(1, self.test_text_preprocessing_result, self.test_user1) == "ResultTestName") - self.assertTrue(obj.run_scenario(2, self.test_text_preprocessing_result, self.test_user1) == "TestNameResult") - + self.assertTrue( + await obj.run_scenario(1, self.test_text_preprocessing_result, self.test_user1) == "ResultTestName" + ) + self.assertTrue( + await obj.run_scenario(2, self.test_text_preprocessing_result, self.test_user1) == "TestNameResult" + ) diff --git a/tests/smart_kit_tests/requirement/test_device_requirements.py b/tests/smart_kit_tests/requirement/test_device_requirements.py index 9e2c8ead..373995ba 100644 --- a/tests/smart_kit_tests/requirement/test_device_requirements.py +++ b/tests/smart_kit_tests/requirement/test_device_requirements.py @@ -1,11 +1,12 @@ # coding: utf-8 +import asyncio import unittest from unittest.mock import Mock from core.basic_models.requirement import device_requirements -class RequirementTest1(unittest.TestCase): +class RequirementTest1(unittest.IsolatedAsyncioTestCase): def setUp(self): self.test_items1 = {"platfrom_type": "any platform"} # PLATFROM - так задумано? self.test_items2 = {"platform_type": "any platform 2"} @@ -30,10 +31,10 @@ def test_platform_type_requirement_init(self): with self.assertRaises(KeyError): obj3 = device_requirements.PlatformTypeRequirement(self.test_items2, self.test_id) - def test_platform_type_requirement_check(self): + async def test_platform_type_requirement_check(self): obj1 = device_requirements.PlatformTypeRequirement(self.test_items1, self.test_id) - self.assertTrue(obj1.check(self.test_text_processing_result, self.test_user1)) - self.assertTrue(not obj1.check(self.test_text_processing_result, self.test_user2)) + self.assertTrue(await obj1.check(self.test_text_processing_result, self.test_user1)) + self.assertTrue(not await obj1.check(self.test_text_processing_result, self.test_user2)) class RequirementTest2(unittest.TestCase): diff --git a/tests/smart_kit_tests/system_answers/test_nothing_found_action.py b/tests/smart_kit_tests/system_answers/test_nothing_found_action.py index 7f004f11..da201d07 100644 --- a/tests/smart_kit_tests/system_answers/test_nothing_found_action.py +++ b/tests/smart_kit_tests/system_answers/test_nothing_found_action.py @@ -8,7 +8,7 @@ from smart_kit.system_answers import nothing_found_action -class SystemAnswersTest1(unittest.TestCase): +class SystemAnswersTest1(unittest.IsolatedAsyncioTestCase): def setUp(self): self.test_command_1 = Mock('Command') self.test_id = '123-345-678' # пусть чему-то равняется @@ -26,8 +26,8 @@ def test_system_answers_nothing_found_action_init(self): self.assertTrue(isinstance(obj1._action, StringAction)) self.assertTrue(obj1._action.command == NOTHING_FOUND) - def test_system_answer_nothing_found_action_run(self): + async def test_system_answer_nothing_found_action_run(self): obj1 = nothing_found_action.NothingFoundAction() obj2 = nothing_found_action.NothingFoundAction(self.test_items1, self.test_id) - self.assertTrue(isinstance(obj1.run(self.test_user1, self.test_text_preprocessing_result).pop(), Command)) - self.assertTrue(isinstance(obj2.run(self.test_user1, self.test_text_preprocessing_result).pop(), Command)) + self.assertTrue(isinstance((await obj1.run(self.test_user1, self.test_text_preprocessing_result)).pop(), Command)) + self.assertTrue(isinstance((await obj2.run(self.test_user1, self.test_text_preprocessing_result)).pop(), Command))