diff --git a/.gitignore b/.gitignore index 63ec466..0ff8ab0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ syntax: glob .settings .classpath .pydevproject -.coverage +.coverage* .pytest_cache .env* htmlcov diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 169cab9..17f2224 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -72,3 +72,4 @@ repos: entry: make test language: system always_run: True + pass_filenames: false diff --git a/Makefile b/Makefile index 865c747..b74a08c 100644 --- a/Makefile +++ b/Makefile @@ -25,16 +25,16 @@ $(eval POETRY_VERSION_NEW=$(POETRY_VERSION_MAIN)$(POETRY_VERSION_NAME)$(POETRY_V .PHONY: test test: - pytest --cov=maxbot --cov-report html --cov-fail-under=95 + pytest -p no:maxbot_stories --cov=maxbot --cov-report html --cov-fail-under=95 stories: - maxbot stories -B examples/hello-world - maxbot stories -B examples/echo - maxbot stories -B examples/restaurant - maxbot stories -B examples/reservation-basic - maxbot stories -B examples/reservation - maxbot stories -B examples/digression-showcase - maxbot stories -B examples/rpc-showcase + pytest --bot examples/hello-world examples/hello-world/stories.yaml + pytest --bot examples/echo examples/echo/stories.yaml + pytest --bot examples/restaurant examples/restaurant/stories.yaml + pytest --bot examples/reservation-basic examples/reservation-basic/stories.yaml + pytest --bot examples/reservation examples/reservation/stories.yaml + pytest --bot examples/digression-showcase examples/digression-showcase/stories.yaml + pytest --bot examples/rpc-showcase examples/rpc-showcase/stories.yaml clean: rm -f dist/maxbot-*.*.*-py3-none-any.whl diff --git a/README.md b/README.md index ea68e1e..36b140c 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,14 @@ Press `Ctrl-C` to exit MaxBot CLI app. Congratulations! You have successfully created and launched a simple bot and chatted with it. +### Advanced examples + +There are several examples of services built on Maxbot. They show the advanced features of Maxbot, such as custom messanger controls, integration with different REST services, databases and so on. You can also check the implementation details of these features in the examples below. + +- [Bank Bot example](https://github.com/maxbot-ai/bank_bot). +- [Taxi Bot example](https://github.com/maxbot-ai/taxi_bot). +- [Transport Bot example](https://github.com/maxbot-ai/transport_bot). + ## Where to ask questions The **Maxbot** project is maintained by the [Maxbot team](https://maxbot.ai). diff --git a/maxbot/bot.py b/maxbot/bot.py index 1b9c3a5..060d6a4 100644 --- a/maxbot/bot.py +++ b/maxbot/bot.py @@ -1,13 +1,12 @@ """Create and run conversations applications.""" import asyncio import logging -import os from .channels import ChannelsCollection from .dialog_manager import DialogManager from .errors import BotError from .resources import Resources -from .user_locks import AsyncioLocks +from .user_locks import AsyncioLocks, UnixSocketStreams logger = logging.getLogger(__name__) @@ -20,22 +19,28 @@ def __init__( dialog_manager=None, channels=None, user_locks=None, - state_store=None, + persistence_manager=None, resources=None, + history_tracked=False, ): """Create new class instance. :param DialogManager dialog_manager: Dialog manager. :param ChannelsCollection channels: Channels for communication with users. - :param StateStore state_store: State store. + :param PersistenceManager persistence_manager: Persistence manager. :param Resources resources: Resources for tracking and reloading changes. """ self.dialog_manager = dialog_manager or DialogManager() self.channels = channels or ChannelsCollection.empty() - self._state_store = state_store # the default value is initialized lazily - self.user_locks = user_locks or AsyncioLocks() + self._persistence_manager = persistence_manager # the default value is initialized lazily + self._history_tracked = history_tracked + self._user_locks = user_locks self.resources = resources or Resources.empty() + SocketStreams = UnixSocketStreams + SUFFIX_LOCKS = "-locks.sock" + SUFFIX_DB = ".db" + @classmethod def builder(cls, **kwargs): """Create a :class:`~BotBuilder` in a convenient way. @@ -84,6 +89,23 @@ def from_directory(cls, bot_dir, **kwargs): builder.use_directory_resources(bot_dir) return builder.build() + @property + def user_locks(self): + """Get user locks implementation.""" + if self._user_locks is None: + self._user_locks = AsyncioLocks() + return self._user_locks + + def setdefault_user_locks(self, value): + """Set .user_locks field value if it is not set. + + :param AsyncioLocks value: User locks object. + :return AsyncioLocks: .user_locks field value + """ + if self._user_locks is None: + self._user_locks = value + return self._user_locks + @property def rpc(self): """Get RPC manager used by the bot. @@ -93,14 +115,24 @@ def rpc(self): return self.dialog_manager.rpc @property - def state_store(self): - """State store used to maintain state variables.""" - if self._state_store is None: + def persistence_manager(self): + """Return persistence manager.""" + if self._persistence_manager is None: # lazy import to speed up load time - from .state_store import SQLAlchemyStateStore + from .persistence_manager import SQLAlchemyManager + + self._persistence_manager = SQLAlchemyManager() + return self._persistence_manager - self._state_store = SQLAlchemyStateStore() - return self._state_store + def setdefault_persistence_manager(self, factory): + """Set .persistence_manager field value if it is not set. + + :param callable factory: Persistence manager factory. + :return SQLAlchemyStateStore: .persistence_manager field value. + """ + if self._persistence_manager is None: + self._persistence_manager = factory() + return self._persistence_manager def process_message(self, message, dialog=None): """Process user message. @@ -115,8 +147,10 @@ def process_message(self, message, dialog=None): """ if dialog is None: dialog = self._default_dialog() - with self.state_store(dialog) as state: - return asyncio.run(self.dialog_manager.process_message(message, dialog, state)) + with self.persistence_manager(dialog) as tracker: + return asyncio.run( + self.dialog_manager.process_message(message, dialog, tracker.get_state()) + ) def process_rpc(self, request, dialog=None): """Process RPC request. @@ -131,8 +165,10 @@ def process_rpc(self, request, dialog=None): """ if dialog is None: dialog = self._default_dialog() - with self.state_store(dialog) as state: - return asyncio.run(self.dialog_manager.process_rpc(request, dialog, state)) + with self.persistence_manager(dialog) as tracker: + return asyncio.run( + self.dialog_manager.process_rpc(request, dialog, tracker.get_state()) + ) def _default_dialog(self): return {"channel_name": "builtin", "user_id": "1"} @@ -148,10 +184,14 @@ async def default_channel_adapter(self, data, channel): message = await channel.call_receivers(data) if message is None: return - with self.state_store(dialog) as state: - commands = await self.dialog_manager.process_message(message, dialog, state) + with self.persistence_manager(dialog) as tracker: + commands = await self.dialog_manager.process_message( + message, dialog, tracker.get_state() + ) for command in commands: await channel.call_senders(command, dialog) + if self._history_tracked: + tracker.set_message_history(message, commands) async def default_rpc_adapter(self, request, channel, user_id): """Handle RPC request for specific channel. @@ -162,66 +202,14 @@ async def default_rpc_adapter(self, request, channel, user_id): """ dialog = {"channel_name": channel.name, "user_id": str(user_id)} async with self.user_locks(dialog): - with self.state_store(dialog) as state: - commands = await self.dialog_manager.process_rpc(request, dialog, state) + with self.persistence_manager(dialog) as tracker: + commands = await self.dialog_manager.process_rpc( + request, dialog, tracker.get_state() + ) for command in commands: await channel.call_senders(command, dialog) - - def run_webapp(self, host="localhost", port="8080", *, public_url=None, autoreload=False): - """Run web application. - - :param str host: Hostname or IP address on which to listen. - :param int port: TCP port on which to listen. - :param str public_url: Base url to register webhook. - :param bool autoreload: Enable tracking and reloading bot resource changes. - """ - # lazy import to speed up load time - import sanic - - self._validate_at_least_one_channel() - - app = sanic.Sanic("maxbot", configure_logging=False) - app.config.FALLBACK_ERROR_FORMAT = "text" - - for channel in self.channels: - if public_url is None: - logger.warning( - "Make sure you have a public URL that is forwarded to -> " - f"http://{host}:{port}/{channel.name} and register webhook for it." - ) - - app.blueprint( - channel.blueprint( - self.default_channel_adapter, - public_url=public_url, - webhook_path=f"/{channel.name}", - ) - ) - - if self.rpc: - app.blueprint(self.rpc.blueprint(self.channels, self.default_rpc_adapter)) - - if autoreload: - - @app.after_server_start - async def start_autoreloader(app, loop): - app.add_task(self.autoreloader, name="autoreloader") - - @app.before_server_stop - async def stop_autoreloader(app, loop): - await app.cancel_task("autoreloader") - - @app.after_server_start - async def report_started(app, loop): - logger.info( - f"Started webhooks updater on http://{host}:{port}. Press 'Ctrl-C' to exit." - ) - - if sanic.__version__.startswith("21."): - app.run(host, port, motd=False, workers=1) - else: - os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true" - app.run(host, port, motd=False, single_process=True) + if self._history_tracked: + tracker.set_rpc_history(request, commands) def run_polling(self, autoreload=False): """Run polling application. @@ -229,13 +217,17 @@ def run_polling(self, autoreload=False): :param bool autoreload: Enable tracking and reloading bot resource changes. """ # lazy import to speed up load time - from telegram.ext import ApplicationBuilder, MessageHandler, filters + from telegram.ext import ApplicationBuilder, CallbackQueryHandler, MessageHandler, filters - self._validate_at_least_one_channel() + self.validate_at_least_one_channel() self._validate_polling_support() builder = ApplicationBuilder() builder.token(self.channels.telegram.config["api_token"]) + + builder.request(self.channels.telegram.create_request()) + builder.get_updates_request(self.channels.telegram.create_request()) + background_tasks = [] @builder.post_init @@ -263,6 +255,7 @@ async def error_handler(update, context): app = builder.build() app.add_handler(MessageHandler(filters.ALL, callback)) + app.add_handler(CallbackQueryHandler(callback=callback, pattern=None)) app.add_error_handler(error_handler) app.run_polling() @@ -311,7 +304,8 @@ def _exclude_unsupported_changes(self, changes): ) return changes - unsupported - def _validate_at_least_one_channel(self): + def validate_at_least_one_channel(self): + """Raise BotError if at least one channel is missing.""" if not self.channels: raise BotError( "At least one channel is required to run a bot. " diff --git a/maxbot/builder.py b/maxbot/builder.py index e3ab050..26f5a8b 100644 --- a/maxbot/builder.py +++ b/maxbot/builder.py @@ -39,13 +39,14 @@ def __init__(self, *, available_extensions=None): self._bot_created = False self.resources = Resources.empty() self._user_locks = None - self._state_store = None + self._persistence_manager = None self._nlu = None self._message_schemas = {} self._command_schemas = {} self._before_turn_hooks = [] self._after_turn_hooks = [] self._middlewares = [] + self._history_tracked = False def add_message(self, schema, name): """Register a custom message. @@ -152,35 +153,37 @@ def user_locks(self, value): self._user_locks = value @property - def state_store(self): - """State store used to maintain state variables. + def persistence_manager(self): + """Return persistence manager. - See default implementation :class:`~maxbot.state_store.SQLAlchemyStateStore` for more information. + Used, for example, to save-restore state variables. + + See default implementation :class:`~maxbot.persistence_manager.SQLAlchemyManager` for more information. You can use this property to configure default state tracker:: - builder.state_store.engine = sqlalchemy.create_engine(...) + builder.persistence_manager.engine = sqlalchemy.create_engine(...) or set your own implementation:: - class CustomStateStore: + class CustomPersistenceManager: @contextmanager def __call__(self, dialog): # load variables... yield StateVariables(...) # save variables... - builder.state_store = CustomStateStore() + builder.persistence_manager = CustomPersistenceManager() """ - if self._state_store is None: + if self._persistence_manager is None: # lazy import to speed up load time - from .state_store import SQLAlchemyStateStore + from .persistence_manager import SQLAlchemyManager - self._state_store = SQLAlchemyStateStore() - return self._state_store + self._persistence_manager = SQLAlchemyManager() + return self._persistence_manager - @state_store.setter - def state_store(self, value): - self._state_store = value + @persistence_manager.setter + def persistence_manager(self, value): + self._persistence_manager = value @property def nlu(self): @@ -407,6 +410,10 @@ def use_resources(self, resources): """ self.resources = resources + def track_history(self, value=True): + """Set/reset flag that controls history recording.""" + self._history_tracked = value + def _create_dialog_manager(self): message_schema = self._create_message_schema() command_schema = self._create_command_schema() @@ -449,7 +456,8 @@ def build(self): return MaxBot( self._create_dialog_manager(), channels, - self.user_locks, - self._state_store, + self._user_locks, + self._persistence_manager, self.resources, + self._history_tracked, ) diff --git a/maxbot/channels/facebook.py b/maxbot/channels/facebook.py index 10c0071..3893e94 100644 --- a/maxbot/channels/facebook.py +++ b/maxbot/channels/facebook.py @@ -6,7 +6,7 @@ import httpx -from ..maxml import Schema, fields +from ..maxml import PoolLimitSchema, Schema, TimeoutSchema, fields logger = logging.getLogger(__name__) @@ -14,18 +14,18 @@ class Gateway: """Facebook sender and verifier incoming messages.""" - # Facebook graph api version with which this gateway is well tested. - # @See https://developers.facebook.com/docs/graph-api/guides/versioning - httpx_client = httpx.AsyncClient(base_url="https://graph.facebook.com/v15.0", timeout=3) - - def __init__(self, app_secret, access_token): + def __init__(self, app_secret, access_token, **kwargs): """Create a new class instance. :param str app_secret: Facebook application secret :param str access_token: Facebook access_token + :param dict kwargs: Arguments for creating HTTPX asynchronous client. """ self.app_secret = app_secret self.access_token = access_token + self.httpx_client = httpx.AsyncClient( + base_url="https://graph.facebook.com/v15.0", **kwargs + ) async def send_request(self, json_data): """Send request to Facebook. @@ -131,6 +131,12 @@ class ConfigSchema(Schema): # Facebook access_token # @See https://developers.facebook.com/docs/facebook-login/security/#appsecret access_token = fields.Str(required=True) + # Default HTTP request timeouts + # @See https://www.python-httpx.org/advanced/#timeout-configuration + timeout = fields.Nested(TimeoutSchema()) + # Pool limit configuration + # @See https://www.python-httpx.org/advanced/#pool-limit-configuration + limits = fields.Nested(PoolLimitSchema()) @cached_property def gateway(self): @@ -138,7 +144,12 @@ def gateway(self): :return Gateway: """ - return Gateway(self.config["app_secret"], self.config["access_token"]) + return Gateway( + self.config["app_secret"], + self.config["access_token"], + timeout=self.config.get("timeout", TimeoutSchema.DEFAULT), + limits=self.config.get("limits", PoolLimitSchema.DEFAULT), + ) @cached_property def _api(self): @@ -210,10 +221,11 @@ async def receive_image(self, messaging: dict): return {"image": {"url": images[0]}} return None - def blueprint(self, callback, public_url=None, webhook_path=None): + def blueprint(self, callback, execute_once, public_url=None, webhook_path=None): """Create web application blueprint to receive incoming updates. :param callable callback: a callback for received messages. + :param callable execute_once: Execute only for first WEB application worker. :param string public_url: Base url to register webhook. :param string webhook_path: An url path to receive incoming updates. :return Blueprint: Blueprint for sanic app. @@ -229,6 +241,7 @@ def blueprint(self, callback, public_url=None, webhook_path=None): @bp.post(webhook_path) async def webhook(request): # @See https://developers.facebook.com/docs/messenger-platform/webhooks#event-notifications + logger.debug("%s", request.json) http_code = self.gateway.verify_token(request.body, request.headers) if http_code != 200: diff --git a/maxbot/channels/telegram.py b/maxbot/channels/telegram.py index c5855fb..e82d366 100644 --- a/maxbot/channels/telegram.py +++ b/maxbot/channels/telegram.py @@ -6,8 +6,9 @@ import httpx from telegram import Bot, Update +from telegram.request import HTTPXRequest -from ..maxml import Schema, fields +from ..maxml import PoolLimitSchema, Schema, TimeoutSchema, fields TG_FILE_URL = "https://api.telegram.org/file" @@ -31,7 +32,51 @@ class ConfigSchema(Schema): # @see https://core.telegram.org/bots#6-botfather. api_token = fields.Str(required=True) - httpx_client = httpx.AsyncClient(timeout=3) + # Default HTTP request timeouts + # @see https://www.python-httpx.org/advanced/#timeout-configuration + timeout = fields.Nested(TimeoutSchema()) + + # Pool limit configuration + # @see https://www.python-httpx.org/advanced/#pool-limit-configuration + limits = fields.Nested(PoolLimitSchema()) + + class Request(HTTPXRequest): + """Local implementation of telegram.request.HTTPXRequest.""" + + def __init__(self, timeout, limits): # pylint: disable=super-init-not-called + """Create new instance. + + :param httpx.Timeout timeout: HTTPX client timeout. + :param httpx.Limits limits: HTTPX client limits. + """ + self._http_version = "1.1" + self._client_kwargs = { + "timeout": timeout, + "proxies": None, + "limits": limits, + "http1": True, + "http2": False, + } + self._client = self._build_client() + + def create_request(self): + """Create new instance of `telegram.request.BaseRequest` implementation.""" + return self.Request(self.timeout, self.limits) + + @cached_property + def timeout(self): + """Create `httpx.Timeout` from channel configuration.""" + return self.config.get("timeout", TimeoutSchema.DEFAULT) + + @cached_property + def limits(self): + """Create `httpx.Limits` from channel configuration.""" + return self.config.get("limits", PoolLimitSchema.DEFAULT) + + @cached_property + def httpx_client(self): + """Create HTTPX asynchronous client.""" + return httpx.AsyncClient(timeout=self.timeout, limits=self.limits) @cached_property def bot(self): @@ -41,7 +86,11 @@ def bot(self): :return Bot: """ - return Bot(self.config["api_token"]) + return Bot( + self.config["api_token"], + get_updates_request=self.create_request(), + request=self.create_request(), + ) async def create_dialog(self, update: Update): """Create a dialog object from the incomming update. @@ -124,10 +173,11 @@ async def send_image(self, command: dict, dialog: dict): dialog["user_id"], image["url"], None if caption is None else caption.render() ) - def blueprint(self, callback, public_url=None, webhook_path=None): + def blueprint(self, callback, execute_once, public_url=None, webhook_path=None): """Create web application blueprint to receive incoming updates. :param callable callback: a callback for received messages. + :param callable execute_once: Execute only for first WEB application worker. :param string public_url: Base url to register webhook. :param string webhook_path: An url path to receive incoming updates. :return Blueprint: Blueprint for sanic app. @@ -142,6 +192,7 @@ def blueprint(self, callback, public_url=None, webhook_path=None): @bp.post(webhook_path) async def webhook(request): + logger.debug("%s", request.json) update = Update.de_json(data=request.json, bot=self.bot) await callback(update, self) return empty() @@ -150,8 +201,11 @@ async def webhook(request): @bp.after_server_start async def register_webhook(app, loop): - webhook_url = urljoin(public_url, webhook_path) - await self.bot.setWebhook(webhook_url) - logger.info(f"Registered webhook {webhook_url}.") + async def _impl(): + webhook_url = urljoin(public_url, webhook_path) + await self.bot.setWebhook(webhook_url) + logger.info(f"Registered webhook {webhook_url}.") + + await execute_once(app, _impl) return bp diff --git a/maxbot/channels/viber.py b/maxbot/channels/viber.py index 77e1052..eeb6e4e 100644 --- a/maxbot/channels/viber.py +++ b/maxbot/channels/viber.py @@ -9,7 +9,7 @@ from viberbot.api.viber_requests import ViberMessageRequest, create_request from viberbot.api.viber_requests.viber_request import ViberRequest -from ..maxml import Schema, fields +from ..maxml import PoolLimitSchema, Schema, TimeoutSchema, fields logger = logging.getLogger(__name__) @@ -17,14 +17,14 @@ class Gateway: """Viber Gateway.""" - httpx_client = httpx.AsyncClient(base_url="https://chatapi.viber.com/pa", timeout=3) - - def __init__(self, api_token): + def __init__(self, api_token, **kwargs): """Create a new class instance. :param str auth_token: Viber auth token. + :param dict kwargs: Arguments for creating HTTPX asynchronous client. """ self.api_token = api_token + self.httpx_client = httpx.AsyncClient(base_url="https://chatapi.viber.com/pa", **kwargs) async def send_request(self, method, payload=None): """Send request to Facebook. @@ -120,6 +120,14 @@ class ConfigSchema(Schema): # https://developers.viber.com/docs/api/python-bot-api/#userprofile-object avatar = fields.Str() + # Default HTTP request timeouts + # https://www.python-httpx.org/advanced/#timeout-configuration + timeout = fields.Nested(TimeoutSchema()) + + # Pool limit configuration + # https://www.python-httpx.org/advanced/#pool-limit-configuration + limits = fields.Nested(PoolLimitSchema()) + @cached_property def _api(self): """Return viber api connected to your bot. @@ -129,7 +137,13 @@ def _api(self): :return Api: """ return _Api( - Gateway(self.config["api_token"]), self.config.get("name"), self.config.get("avatar") + Gateway( + self.config["api_token"], + timeout=self.config.get("timeout", TimeoutSchema.DEFAULT), + limits=self.config.get("limits", PoolLimitSchema.DEFAULT), + ), + self.config.get("name"), + self.config.get("avatar"), ) async def create_dialog(self, request: ViberRequest): @@ -204,10 +218,11 @@ async def receive_image(self, request: ViberRequest): return content return None - def blueprint(self, callback, public_url=None, webhook_path=None): + def blueprint(self, callback, execute_once, public_url=None, webhook_path=None): """Create web application blueprint to receive incoming updates. :param callable callback: a callback for received messages. + :param callable execute_once: Execute only for first WEB application worker. :param string public_url: Base url to register webhook. :param string webhook_path: An url path to receive incoming updates. :return Blueprint: Blueprint for sanic app. @@ -222,6 +237,7 @@ def blueprint(self, callback, public_url=None, webhook_path=None): @bp.post(webhook_path) async def webhook(request): + logger.debug("%s", request.json) request_data = create_request(request.json) if request_data.event_type == "message": await callback(request_data, self) @@ -231,8 +247,11 @@ async def webhook(request): @bp.after_server_start async def register_webhook(app, loop): - webhook_url = urljoin(public_url, webhook_path) - await self._api.set_webhook(webhook_url) - logger.info(f"Registered webhook {webhook_url}.") + async def _impl(): + webhook_url = urljoin(public_url, webhook_path) + await self._api.set_webhook(webhook_url) + logger.info(f"Registered webhook {webhook_url}.") + + await execute_once(app, _impl) return bp diff --git a/maxbot/channels/vk.py b/maxbot/channels/vk.py index 7fc45ee..fa6c9f4 100644 --- a/maxbot/channels/vk.py +++ b/maxbot/channels/vk.py @@ -7,7 +7,8 @@ import httpx from .._download import download_to_tempfile -from ..maxml import Schema, fields +from ..errors import BotError +from ..maxml import PoolLimitSchema, Schema, TimeoutSchema, fields logger = logging.getLogger(__name__) @@ -27,14 +28,25 @@ class Gateway: # @See https://dev.vk.com/reference/versions API_VERSION = "5.131" - httpx_client = httpx.AsyncClient(base_url="https://api.vk.com/method", timeout=3) - - def __init__(self, access_token): + def __init__(self, access_token, group_id, **kwargs): """Create a new class instance. :param int access_token: VK access_token + :param int group_id: VK group_id, may be None + :param dict kwargs: Arguments for creating HTTPX asynchronous client. """ self.access_token = access_token + self.group_id = group_id + self.httpx_client = httpx.AsyncClient(base_url="https://api.vk.com/method", **kwargs) + + async def _send_request(self, method, common_params, payload, error_on_empty_response=True): + params = {**common_params, **(payload or {})} + response = await self.httpx_client.post(method, data=params) + response.raise_for_status() + result = response.json().get("response") + if error_on_empty_response and result is None: + raise RuntimeError(f"empty response: {response.json()!r}") + return result async def send_request(self, method, payload=None): """Send VK method with payload. @@ -44,21 +56,32 @@ async def send_request(self, method, payload=None): :raise RuntimeError: unexpected response. :return dict response: response data """ - response = await self.httpx_client.post( - method, - json={ - **(payload or {}), - "v": self.API_VERSION, - "access_token": self.access_token, - # https://dev.vk.com/method/messages.send: random_id < max(int32) - "random_id": secrets.randbelow(0x7FFFFFFF), - }, - ) - response.raise_for_status() - result = response.json().get("response") - if not result: - raise RuntimeError(f"empty result: {result!r}") - return result + common_params = { + "v": self.API_VERSION, + "access_token": self.access_token, + # https://dev.vk.com/method/messages.send: random_id < max(int32) + "random_id": secrets.randbelow(0x7FFFFFFF), + } + return await self._send_request(method, common_params, payload) + + async def send_callback_api_request(self, method, payload=None, error_on_empty_response=True): + """Send VK method with payload for configuration callback API. + + :param str method: example 'messages.send' + :param dict payload: additional payload + :param true error_on_empty_response: raise error on empty response + :raise RuntimeError: unexpected response. + :raise BotError: `group_id` is not set. + :return dict response: response data + """ + if not self.group_id: + raise BotError("`group_id` is not set") + common_params = { + "v": self.API_VERSION, + "access_token": self.access_token, + "group_id": self.group_id, + } + return await self._send_request(method, common_params, payload, error_on_empty_response) class _Api: @@ -71,14 +94,14 @@ class _Api: * upload photo (upload_media_file) """ - upload_client = httpx.AsyncClient(timeout=3) - - def __init__(self, gateway): + def __init__(self, gateway, **kwargs): """Create a new class instance. :param Gateway gateway: send VK command + :param dict kwargs: Arguments for creating HTTPX asynchronous client. """ self.gateway = gateway + self.upload_client = httpx.AsyncClient(**kwargs) async def send_text(self, user_id, text): """Send text message. @@ -147,6 +170,77 @@ async def save_photo(self, server, photo, hash_param): _response_validate(result[0], ["owner_id", "id"]) return f"photo{result[0]['owner_id']}_{result[0]['id']}" + async def get_callback_confirmation_code(self): + """Get confirmation code for VK callback API. + + @See https://dev.vk.com/method/groups.getCallbackConfirmationCode + + :return str: confirmation code, will use in answer on VK request with type=confirmation + """ + result = await self.gateway.send_callback_api_request("groups.getCallbackConfirmationCode") + _response_validate(result, ["code"]) + return result["code"] + + async def get_callback_servers(self): + """Get callback API servers. + + @See https://dev.vk.com/method/groups.getCallbackServers + + :return list: list of server id + """ + result = await self.gateway.send_callback_api_request("groups.getCallbackServers") + return result.get("items", []) + + async def delete_callback_servers(self, server_id): + """Delete callback API server. + + @See https://dev.vk.com/method/groups.deleteCallbackServer + + :param int server_id: server id + """ + payload = {"server_id": server_id} + result = await self.gateway.send_callback_api_request( + "groups.deleteCallbackServer", payload, error_on_empty_response=False + ) + return result + + async def add_callback_server(self, webhook_url, secret_key, title): + """Add callback API server. + + @See https://dev.vk.com/method/groups.addCallbackServer + + :param str webhook_url: webhook url for incoming updates + :param str secret_key: webhook secret key + :param str title: Server title + :return int: server id + """ + payload = {"url": webhook_url, "secret_key": secret_key, "title": title} + result = await self.gateway.send_callback_api_request("groups.addCallbackServer", payload) + _response_validate(result, ["server_id"]) + return result["server_id"] + + async def set_callback_settings(self, server_id): + """Set callback API server settings. + + @See https://dev.vk.com/method/groups.setCallbackSettings + + :param int server_id: server id + :raise RuntimeError: unexpected response. + """ + payload = { + "api_version": self.gateway.API_VERSION, + "server_id": server_id, + "message_new": 1, + "message_reply": 1, + "message_allow": 1, + "message_deny": 1, + } + result = await self.gateway.send_callback_api_request( + "groups.setCallbackSettings", payload + ) + if result != 1: + raise RuntimeError(f"setCallbackSettings result error: {result!r}") + class VkChannel: """Channel for VK Bots. @@ -166,11 +260,29 @@ class ConfigSchema(Schema): access_token = fields.Str(required=True) # Group_id for VK page, if present, the incoming messages will be checked against it + # And use for set webhook group_id = fields.Integer() - # Secret string for confirmation answer - # @See https://dev.vk.com/api/callback/getting-started - confirm_secret = fields.Str() + # Secret key, use for set webhook + # @See https://dev.vk.com/method/groups.addCallbackServer + secret_key = fields.Str() + + # Server title, use for set webhook + # @See https://dev.vk.com/method/groups.addCallbackServer + server_title = fields.Str(load_default="MAXBOT") + + # Default HTTP request timeouts + # @See https://www.python-httpx.org/advanced/#timeout-configuration + timeout = fields.Nested(TimeoutSchema()) + + # Pool limit configuration + # @See https://www.python-httpx.org/advanced/#pool-limit-configuration + limits = fields.Nested(PoolLimitSchema()) + + @cached_property + def timeout(self): + """Create `httpx.Timeout` from channel configuration.""" + return self.config.get("timeout", TimeoutSchema.DEFAULT) @cached_property def gateway(self): @@ -178,11 +290,20 @@ def gateway(self): :return Gateway: """ - return Gateway(self.config["access_token"]) + return Gateway( + self.config["access_token"], + self.config.get("group_id"), + timeout=self.timeout, + limits=self.config.get("limits", PoolLimitSchema.DEFAULT), + ) @cached_property def _api(self): - return _Api(self.gateway) + return _Api( + self.gateway, + timeout=self.timeout, + limits=self.config.get("limits", PoolLimitSchema.DEFAULT), + ) async def create_dialog(self, incoming_message: dict): """ @@ -261,10 +382,33 @@ async def receive_image(self, incoming_message: dict): return {"image": payload} return None - def blueprint(self, callback, public_url=None, webhook_path=None): + async def set_webhook(self, webhook_url): + """Set webhook url for receive incoming updates. + + See https://dev.vk.com/api/callback/getting-started + + :param str webhook_url: An url to receive incoming updates. + :raise BotError: `secret_key` is not set. + :return str: confirmation code, will use in answer on VK request with type=confirmation + """ + if not self.config.get("secret_key"): + raise BotError("`secret_key` is not set") + servers = await self._api.get_callback_servers() + for server in servers: + result = await self._api.delete_callback_servers(server["id"]) + if result != 1: + logger.error("deleteCallbackServer error: %s (%s)", server, result) + server_id = await self._api.add_callback_server( + webhook_url, self.config["secret_key"], self.config["server_title"] + ) + await self._api.set_callback_settings(server_id) + return await self._api.get_callback_confirmation_code() + + def blueprint(self, callback, execute_once, public_url=None, webhook_path=None): """Create web application blueprint to receive incoming updates. :param callable callback: a callback for received messages. + :param callable execute_once: Execute only for first WEB application worker. :param string public_url: Base url to register webhook. :param string webhook_path: An url path to receive incoming updates. :return Blueprint: Blueprint for sanic app. @@ -277,25 +421,38 @@ def blueprint(self, callback, public_url=None, webhook_path=None): if webhook_path is None: webhook_path = f"/{self.name}" + bp.ctx.confirmation_code = None + @bp.post(webhook_path) async def endpoint(request): + logger.debug("%s", request.json) data = request.json if "group_id" in self.config and data.get("group_id") != self.config["group_id"]: return text_response("not my group", status=400) if data.get("type") == "message_new": await callback(data["object"]["message"], self) elif data.get("type") == "confirmation": - if self.config.get("confirm_secret"): - return text_response(self.config["confirm_secret"]) - raise RuntimeError("No confirm_secret configuration") + if bp.ctx.confirmation_code: + return text_response(bp.ctx.confirmation_code) + raise RuntimeError("No confirmation_code") # @See https://dev.vk.com/api/callback/getting-started return text_response("ok") if public_url: - # FIXME: actually there is an API - webhook_url = urljoin(public_url, webhook_path) - logger.warning( - f"The {self.name} platform has no suitable api, register a webhook yourself {webhook_url}." - ) + if self.config.get("group_id") and self.config.get("secret_key"): + + @bp.after_server_start + async def register_webhook(app, loop): + async def _impl(): + webhook_url = urljoin(public_url, webhook_path) + bp.ctx.confirmation_code = await self.set_webhook(webhook_url) + logger.info(f"Registered webhook {webhook_url}.") + + await execute_once(app, _impl) + + else: + logger.warning( + "Skip register webhook, set secret_key and group_id for register new webhook" + ) return bp diff --git a/maxbot/cli/__init__.py b/maxbot/cli/__init__.py index ed220b3..9620d74 100644 --- a/maxbot/cli/__init__.py +++ b/maxbot/cli/__init__.py @@ -3,23 +3,21 @@ import dotenv from .info import info as info_command -from .run import run -from .stories import stories as stories_command +from .run import run as run_command @click.group() def main(): """Execute the cli script for MaxBot applications. - Provides commands to run bots, test them with stories etc. + Provides commands to run bots and etc. """ path = dotenv.find_dotenv(".env", usecwd=True) if path: dotenv.load_dotenv(path, encoding="utf-8") -main.add_command(run) -main.add_command(stories_command) +main.add_command(run_command) main.add_command(info_command) __all__ = ("main",) diff --git a/maxbot/cli/_bot.py b/maxbot/cli/_bot.py index 859eafd..1eacaf5 100644 --- a/maxbot/cli/_bot.py +++ b/maxbot/cli/_bot.py @@ -1,15 +1,25 @@ """Prepare bot for CLI.""" -import logging -import pkgutil -from pathlib import Path -from types import ModuleType - import click -from ..bot import MaxBot -from ..errors import BotError +from ..resolver import BotResolver + + +class _CliBotResolver(BotResolver): + def on_bot_error(self, exc): + raise click.Abort() + + def on_error(self, exc): + raise click.Abort() from exc + + def on_unknown_source(self, pkg_error): + raise click.BadParameter( + f'file or directory not found, import causes error "{pkg_error}".', param_hint="--bot" + ) -logger = logging.getLogger(__name__) + def on_invalid_type(self): + raise click.BadParameter( + f"a valid MaxBot instance was not obtained from {self.spec!r}.", param_hint="--bot" + ) def resolve_bot(spec): @@ -20,43 +30,4 @@ def resolve_bot(spec): :raise click.Abort: An error occured while creating the object. :return MaxBot: Created object. """ - pkg_error = None - try: - rv = pkgutil.resolve_name(spec) - except ( - ValueError, # invalid spec - ModuleNotFoundError, # module from spec is not found - ) as exc: - # we have to leave "except" block - pkg_error = exc - except BotError as exc: - logger.critical("Bot Error %s", exc) - raise click.Abort() - except Exception as exc: - logger.exception(f"While loading {spec!r}, an exception was raised") - raise click.Abort() from exc - if pkg_error: - # fallback to paths - path = Path(spec) - try: - if path.is_file(): - builder = MaxBot.builder() - builder.use_directory_resources(path.parent, path.name) - return builder.build() - if path.is_dir(): - return MaxBot.from_directory(path) - except BotError as exc: - logger.critical("Bot Error %s", exc) - raise click.Abort() - raise click.BadParameter( - f'file or directory not found, import causes error "{pkg_error}".', param_hint="--bot" - ) - - # if attribute name is not provided, use the default one. - if isinstance(rv, ModuleType) and hasattr(rv, "bot"): - rv = rv.bot - if not isinstance(rv, MaxBot): - raise click.BadParameter( - f"a valid MaxBot instance was not obtained from {spec!r}.", param_hint="--bot" - ) - return rv + return _CliBotResolver(spec)() diff --git a/maxbot/cli/_journal.py b/maxbot/cli/_journal.py index bb58b6a..bc25944 100644 --- a/maxbot/cli/_journal.py +++ b/maxbot/cli/_journal.py @@ -1,49 +1,63 @@ """Bot journals.""" import json import os +import sys from dataclasses import asdict from ..maxml import pretty from ._yaml_dumper import yaml_frendly_dumps -def create_journal(verbose, quiet, journal_file, journal_output): +def create_journal(verbosity, journal_file, journal_output): """Create bot journal from provided specification. - :param int verbose: Verbose level. - :param bool quiet: Do not log to console. + :param int verbosity: Verbose level. :param file journal_file: Logging to the file. :param str journal_output: File journal output format ("json" or "yaml"). """ - journal = JournalChain() - if not quiet: - from ._rich import PrettyJournal # speed up loading time + if verbosity < -1: + return no_journal + + if not journal_file and _stdout_is_non_interactive(): + journal_file = sys.stdout - journal.chain.append(PrettyJournal(verbose)) if journal_file: - journal.chain.append( - FileJournal( - journal_file, - {"json": Dumper.json_line, "yaml": Dumper.yaml_triple_dash}[journal_output], - ) + journal_class = FileQuietJournal if verbosity < 0 else FileJournal + return journal_class( + journal_file, + {"json": Dumper.json_line, "yaml": Dumper.yaml_triple_dash}[journal_output], ) - return journal + + from ._rich import PrettyJournal # speed up loading time + + return PrettyJournal(verbosity) -class JournalChain: - """Sequentially call several journals.""" +def no_journal(ctx): + """Silent journal.""" + return - def __init__(self, chain=None): - """Create class instance.""" - self.chain = chain or [] + +class FileQuietJournal: + """Error logging only.""" + + def __init__(self, f, dumps): + """Create class instance. + + :param file f: Target file to write. + :dump callable dumps: Dump object to string. + """ + self.f = f + self.dumps = dumps def __call__(self, ctx): - """Process turn context. + """Write turn context. :param TurnContext ctx: Context of the dialog turn. """ - for journal in self.chain: - journal(ctx) + if ctx.error: + self.f.write(self.dumps({"error": {"message": ctx.error.message}})) + self.f.flush() class FileJournal: @@ -102,4 +116,8 @@ def default(o): @staticmethod def yaml_triple_dash(data): """Dump object to YAML with three dashes (`---`) at end.""" - return yaml_frendly_dumps(data) + "---" + os.linesep + return yaml_frendly_dumps(data, aliases_allowed=False) + "---" + os.linesep + + +def _stdout_is_non_interactive(): + return not sys.stdout.isatty() diff --git a/maxbot/cli/_logging.py b/maxbot/cli/_logging.py index ab01b02..545da8e 100644 --- a/maxbot/cli/_logging.py +++ b/maxbot/cli/_logging.py @@ -1,6 +1,7 @@ """CLI logging output.""" import logging import logging.handlers +import sys import click @@ -11,18 +12,25 @@ def configure_logging(target, verbosity): :param str target: The name of the logger and its arguments. :param int verbosity: Output verbosity. """ + if verbosity < -1: + logging.disable() + return + _configure_libs(verbosity) if target == "console": - from ._rich import ConsoleLogHandler # speed up loading time + if _stderr_is_non_interactive(): + handler = _create_stderr_handler() + else: + from ._rich import ConsoleLogHandler # speed up loading time - handler = ConsoleLogHandler() + handler = ConsoleLogHandler() elif target.startswith("file:"): handler = _create_file_handler(filename=target.removeprefix("file:")) else: raise click.BadParameter(f"unknown logger {target!r}.", param_hint="--logger") - loglevel = [logging.INFO, logging.DEBUG][min(verbosity, 1)] + loglevel = logging.DEBUG if verbosity >= 1 else logging.INFO handler.setLevel(loglevel) logging.basicConfig(level="NOTSET", handlers=[handler]) @@ -30,7 +38,7 @@ def configure_logging(target, verbosity): def _configure_libs(verbosity): - loglevel = [logging.WARNING, logging.INFO, logging.DEBUG][min(verbosity, 2)] + loglevel = [logging.ERROR, logging.WARNING, logging.INFO, logging.DEBUG][min(verbosity + 1, 3)] libs = [ "asyncio", "httpx", @@ -54,3 +62,17 @@ def _create_file_handler(filename): ) handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) return handler + + +def _create_stderr_handler(): + handler = logging.StreamHandler() + handler.setFormatter( + logging.Formatter( + "%(asctime)s - %(processName)s - %(name)s - %(levelname)s - %(message)s" + ) + ) + return handler + + +def _stderr_is_non_interactive(): + return not sys.stderr.isatty() diff --git a/maxbot/cli/_rich.py b/maxbot/cli/_rich.py index 5084a09..f7f4790 100644 --- a/maxbot/cli/_rich.py +++ b/maxbot/cli/_rich.py @@ -27,7 +27,9 @@ STDOUT = Console() STDERR = Console(stderr=True) # Make sure the STDERR output displayed above the progress display. -Progress = functools.partial(_Progress, console=STDERR) +Progress = functools.partial( + _Progress, console=STDERR, redirect_stdout=False, redirect_stderr=False +) class PrettyJournal: @@ -36,16 +38,19 @@ class PrettyJournal: class VerbosityLevel(IntEnum): """Level of verbosity.""" + INFO = 0 NLU = 1 + VAR_DIFF = 1 JOURNAL = 2 + VAR_FULL = 2 - def __init__(self, verbose=0, console=None): + def __init__(self, verbosity=0, console=None): """Create class instance. - :param int verbose: Output verbosity. + :param int verbosity: Output verbosity. :param Console|None console: Console to write. """ - self.verbose = verbose + self.verbosity = verbosity self.console = console or STDOUT def __call__(self, ctx): @@ -53,21 +58,29 @@ def __call__(self, ctx): :param TurnContext ctx: Context of the dialog turn. """ - self.console.line() - self.print_dialog(ctx.dialog) - if ctx.message: - self.print_message(ctx.message) - if ctx.rpc: - self.print_rpc(asdict(ctx.rpc.request)) - if ctx.commands: - self.print_commands(ctx.commands, ctx.command_schema) - if self.verbose >= self.VerbosityLevel.NLU: - self.print_intents(ctx.intents) - if ctx.entities.all_objects: - self.print_entities(ctx.entities) - if ctx.journal_events: - self.print_journal_events(ctx.journal_events) - if ctx.error: + if self.verbosity >= self.VerbosityLevel.INFO: + self.console.line() + self.print_dialog(ctx.dialog) + if ctx.message: + self.print_message(ctx.message) + if ctx.rpc: + self.print_rpc(asdict(ctx.rpc.request)) + if ctx.commands: + self.print_commands(ctx.commands, ctx.command_schema) + if self.verbosity >= self.VerbosityLevel.NLU: + self.print_intents(ctx.intents) + if ctx.entities.all_objects: + self.print_entities(ctx.entities) + if ctx.journal_events: + self.print_journal_events(ctx.journal_events) + if ctx.error: + self.print_error(ctx.error) + if self.verbosity >= self.VerbosityLevel.VAR_FULL or ( + ctx.error and self.verbosity >= self.VerbosityLevel.VAR_DIFF + ): + self.print_var_full(ctx.state) + elif ctx.error: + self.console.line() self.print_error(ctx.error) def print_dialog(self, dialog): @@ -170,25 +183,33 @@ def print_journal_events(self, journal_events): :param list[dict] journal_events: Journal. """ - verbose_journal = self.verbose >= self.VerbosityLevel.JOURNAL - output = Table.grid(padding=(0, 1)) - output.title = "journal_events" if verbose_journal else "logs" - output.title_justify = "left" - output.expand = True - output.add_column() - output.add_column(ratio=1) + extractors = [ + self._extract_log_event, + ] + if self.verbosity >= self.VerbosityLevel.VAR_DIFF: + extractors.append(self._extract_user_diff) + extractors.append(self._extract_slots_diff) + if self.verbosity >= self.VerbosityLevel.JOURNAL: + extractors.append(self._extract_journal_event) + + table = Table.grid(padding=(0, 1)) + table.title = "journal_events" if len(extractors) > 1 else "logs" + table.title_justify = "left" + table.expand = True + table.add_column() + table.add_column(ratio=1) + + table_is_empty = True for event in journal_events: - level, message = self._extract_log_event(event) - if level: - if not isinstance(message, str): - message = Pretty(message, indent_size=2, indent_guides=True) - output.add_row(self._LOG_LEVELS[level.upper()], message) - elif verbose_journal: - output.add_row( - Text.styled(event.get("type", "?"), "yellow"), - _yaml_syntax(event.get("payload")), - ) - self.console.print(output) + for e in extractors: + first, second = e(event) + if first: + table.add_row(first, second) + table_is_empty = False + break + + if not table_is_empty: + self.console.print(table) def print_error(self, exc): """Print bot error. @@ -201,6 +222,29 @@ def print_error(self, exc): message.extend(bot_error_snippet(snippet)) self._print_log("ERROR", message) + def print_var_full(self, state): + """Print snapshot of slots.* and user.*. + + :param StateVariables state: Container of variables. + """ + self._print_full(state, "slots") + self._print_full(state, "user") + + def _print_full(self, state, field_name): + d = getattr(state, field_name) + if d: + table = Table.grid(padding=(0, 1)) + table.title = field_name + table.title_justify = "left" + table.expand = True + table.add_column() + table.add_column(ratio=1) + + for name, value in d.items(): + table.add_row(Text.styled(f".{name}", "bold"), _yaml_syntax(value)) + + self.console.print(table) + def _print_speech(self, speaker, data, lexer=None): output = Table.grid(padding=(0, 1)) output.expand = True @@ -231,7 +275,34 @@ def _print_log(self, level, message): def _extract_log_event(self, event): level, message = TurnContext.extract_log_event(event) - return (level, message) if level in self._LOG_LEVELS else (None, None) + level = self._LOG_LEVELS.get(level.upper()) if isinstance(level, str) else None + if not isinstance(message, str): + message = Pretty(message, indent_size=2, indent_guides=True) + return (level, message) if level else (None, None) + + def _extract_user_diff(self, event): + return self._extract_diff(event, "user") + + def _extract_slots_diff(self, event): + return self._extract_diff(event, "slots") + + def _extract_diff(self, event, kind): + t, p = event.get("type"), event.get("payload") + if isinstance(p, dict): + if t == "assign": + name, value = p.get(kind), p.get("value") + if isinstance(name, str): + return Text.styled(f"{kind}.{name} =", "green"), Pretty( + value, indent_size=2, indent_guides=True + ) + elif t == "delete": + name = p.get(kind) + if isinstance(name, str): + return Text.styled("❌ delete", "red"), Text.styled(f"{kind}.{name}", "bold") + return None, None + + def _extract_journal_event(self, event): + return Text.styled(event.get("type", "?"), "yellow"), _yaml_syntax(event.get("payload")) class ConsoleLogHandler(RichHandler): @@ -256,7 +327,7 @@ class ConsoleLogHandler(RichHandler): def __init__(self): """Create new class instance.""" super().__init__(console=STDERR, show_time=False, show_path=False) - self.setFormatter(logging.Formatter("%(message)s")) + self.setFormatter(logging.Formatter("%(asctime)s - %(processName)s - %(message)s")) def get_level_text(self, record): """Render the level prefix for log record. diff --git a/maxbot/cli/_yaml_dumper.py b/maxbot/cli/_yaml_dumper.py index cfcae3f..f5e0bec 100644 --- a/maxbot/cli/_yaml_dumper.py +++ b/maxbot/cli/_yaml_dumper.py @@ -3,32 +3,48 @@ from yaml import SafeDumper, dump -class YamlFrendlyDumper(SafeDumper): - """Human friendly dumps.""" +def _create_dumper(aliases_allowed): + class _Dumper(SafeDumper): + """Human friendly dumps.""" - @staticmethod - def represent_str_literal(dumper, data): - """Represent multiline strings using literal style.""" - data = str(data) - if "\n" in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") - return dumper.represent_str(data) + @staticmethod + def represent_str_literal(dumper, data): + """Represent multiline strings using literal style.""" + data = str(data) + if "\n" in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + return dumper.represent_str(data) - @staticmethod - def represent_undefined_to_repr(dumper, data): - """Represetn all unknown objects as repr-string.""" - return dumper.represent_str(repr(data)) + @staticmethod + def represent_undefined_to_repr(dumper, data): + """Represetn all unknown objects as repr-string.""" + return dumper.represent_str(repr(data)) + def ignore_aliases(self, data): + """Check for aliase allowed.""" + return True if not aliases_allowed else super().ignore_aliases(data) -YamlFrendlyDumper.add_representer(str, YamlFrendlyDumper.represent_str_literal) -YamlFrendlyDumper.add_representer(None, YamlFrendlyDumper.represent_undefined_to_repr) + _Dumper.add_representer(str, _Dumper.represent_str_literal) + _Dumper.add_representer(None, _Dumper.represent_undefined_to_repr) + return _Dumper -def yaml_frendly_dumps(data): +Dumper = _create_dumper(aliases_allowed=True) +DumperNoAliases = _create_dumper(aliases_allowed=False) + + +def yaml_frendly_dumps(data, aliases_allowed=True): """Dump object to YAML string (human-friendly). Dump all unknown objects as repr-string. :param any data: Object to dump. + :param bool aliases_allowed: Enable/disable anchors and aliases usage. + :return str: Dumped YAML string. """ - return dump(data, Dumper=YamlFrendlyDumper, sort_keys=False) + return dump( + data, + Dumper=Dumper if aliases_allowed else DumperNoAliases, + sort_keys=False, + allow_unicode=True, + ) diff --git a/maxbot/cli/run.py b/maxbot/cli/run.py index af05ad0..8a710a8 100644 --- a/maxbot/cli/run.py +++ b/maxbot/cli/run.py @@ -1,6 +1,9 @@ """Command Line Interface for a Bot.""" +from functools import partial + import click +from ..webapp import run_webapp from ._bot import resolve_bot from ._journal import create_journal from ._logging import configure_logging @@ -77,7 +80,8 @@ default=True, help="Watch bot files and reload on changes.", ) -@click.option("-v", "--verbose", count=True, help="Set the verbosity level.") +@click.option("-v", "--verbose", count=True, help="Increasing the level of verbosity.") +@click.option("-q", "--quiet", count=True, help="Decreasing the level of verbosity.") @click.option( "--logger", type=str, @@ -88,13 +92,6 @@ "Use the --journal-file option to redirect the journal." ), ) -@click.option( - "--quiet", - "-q", - is_flag=True, - default=False, - help="Do not log to console.", -) @click.option( "--journal-file", type=click.File(mode="a", encoding="utf8"), @@ -107,6 +104,26 @@ show_default=True, help="Journal file format", ) +@click.option( + "--workers", + type=int, + default=1, + show_default=True, + help="Number of web application worker processes to spawn. Cannot be used with `fast`.", +) +@click.option( + "--fast", + is_flag=True, + default=False, + help="Set the number of web application workers to max allowed.", +) +@click.option( + "--single-process", + "single_process", + is_flag=True, + default=False, + help="Run web application in a single process.", +) @click.pass_context def run( ctx, @@ -123,6 +140,9 @@ def run( quiet, journal_file, journal_output, + workers, + fast, + single_process, ): """ Run the bot. @@ -158,31 +178,41 @@ def run( maxbot run --bot bot.yaml --ngrok \b - # log to file - maxbot run --bot bot.yaml --logger file:/var/log/maxbot.log - - \b - # journal to file and to console - maxbot run --bot bot.yaml --journal-file /var/log/maxbot.jsonl + # verbose journal output and discard logger messages + maxbot run --bot bot.yaml -vv --logger file:/dev/null \b - # journal to file only - maxbot run --bot bot.yaml -q --journal-file /var/log/maxbot.jsonl + # print to console (journal and logger) errors only + maxbot run --bot bot.yaml -q """ - if not quiet: - configure_logging(logger, verbose) + if quiet and verbose: + raise click.UsageError("Options -q and -v are mutually exclusive.") + + verbosity = verbose if verbose else (0 - quiet) from ._rich import Progress + init_logging = partial(configure_logging, logger, verbosity) + init_logging() + + bot_factory = partial(create_bot, bot_spec, logger, verbosity, journal_file, journal_output) with Progress(transient=True) as progress: progress.add_task("Loading resources", total=None) - bot = resolve_bot(bot_spec) - bot.dialog_manager.journal(create_journal(verbose, quiet, journal_file, journal_output)) + bot = bot_factory() polling_conflicts = [ next(p.get_error_hint(ctx) for p in ctx.command.params if p.name == name) - for name in ("host", "port", "public_url", "ngrok", "ngrok_url") + for name in ( + "host", + "port", + "public_url", + "ngrok", + "ngrok_url", + "workers", + "fast", + "single_process", + ) if ctx.get_parameter_source(name) != click.core.ParameterSource.DEFAULT ] @@ -214,6 +244,25 @@ def run( f"Option '--ngrok'/'--ngrok-url' conflicts with {', '.join(ngrok_conflicts)}." ) host, port, public_url = ask_ngrok(ngrok_url) - bot.run_webapp(host, port, public_url=public_url, autoreload=autoreload) + + run_webapp( + bot, + bot_factory, + host, + port, + init_logging=init_logging, + public_url=public_url, + autoreload=autoreload, + workers=workers, + fast=fast, + single_process=single_process, + ) else: raise AssertionError(f"Unexpected updater {updater}.") # pragma: no cover + + +def create_bot(bot_spec, logger, verbosity, journal_file, journal_output): + """Create new instance of MaxBot.""" + bot = resolve_bot(bot_spec) + bot.dialog_manager.journal(create_journal(verbosity, journal_file, journal_output)) + return bot diff --git a/maxbot/cli/stories.py b/maxbot/cli/stories.py index fa84cec..cf960cd 100644 --- a/maxbot/cli/stories.py +++ b/maxbot/cli/stories.py @@ -1,4 +1,5 @@ """Command `stories` of bots.""" +import asyncio import pprint from datetime import timedelta, timezone from pathlib import Path @@ -41,13 +42,22 @@ def stories(bot_spec, stories_file): try: markup.Value.COMPARATOR = markup_value_rendered_comparator if stories_file: - return _stories_impl(bot, stories_file) - return _stories_impl(bot, Path(bot.resources.base_directory) / "stories.yaml") + asyncio.run(_stories_impl(bot, stories_file)) + elif hasattr(bot.resources, "base_directory"): + asyncio.run(_stories_impl(bot, Path(bot.resources.base_directory) / "stories.yaml")) + else: + raise click.BadParameter( + ( + "stories file cannot be defined, " + "explicit specification of parameter --stories/-S is required." + ), + param_hint="--bot", + ) finally: markup.Value.COMPARATOR = original_comparator -def _stories_impl(bot, stories_file): +async def _stories_impl(bot, stories_file): console = Console() command_schema = bot.dialog_manager.CommandSchema(many=True) bot.dialog_manager.utc_time_provider = StoryUtcTimeProvider() @@ -60,12 +70,17 @@ def _stories_impl(bot, stories_file): for i, turn in enumerate(story["turns"]): bot.dialog_manager.utc_time_provider.tick(turn.get("utc_time")) - if "message" in turn: - response = bot.process_message(turn["message"], dialog) - elif "rpc" in turn: - response = bot.process_rpc(turn["rpc"], dialog) - else: - raise AssertionError("Either message or rpc must be provided.") + with bot.persistence_manager(dialog) as tracker: + if "message" in turn: + response = await bot.dialog_manager.process_message( + turn["message"], dialog, tracker.get_state() + ) + elif "rpc" in turn: + response = await bot.dialog_manager.process_rpc( + turn["rpc"], dialog, tracker.get_state() + ) + else: + raise AssertionError("Either message or rpc must be provided.") for expected in turn["response"]: if command_schema.loads(expected) == response: diff --git a/maxbot/context.py b/maxbot/context.py index 4dcfddb..3ac691b 100644 --- a/maxbot/context.py +++ b/maxbot/context.py @@ -1,5 +1,6 @@ """The context of the dialog turns.""" import logging +from collections.abc import MutableMapping from dataclasses import dataclass, field, fields from datetime import datetime, timezone from operator import attrgetter @@ -312,15 +313,61 @@ def __rich_repr__(self): yield name, _ReprAsIs("EntitiesProxy(...)") +class JournalledDict(MutableMapping): + """Wraps a dictionary of state variables to journal their changes.""" + + def __init__(self): + """Create a class instance. + + You cannot change the object before calling `set_journal_event_function` method. + """ + self.__impl = {} + self.__journal_event = None + self.__journal_event_name = None + + def set_journal_event_function(self, journal_event, name): + """Start of journaling. + + :param callable journal_event: Journaling function. + :param str name: Kind name. + """ + self.__journal_event = journal_event + self.__journal_event_name = name + + def __getitem__(self, name): + """Abstract method of MutableMapping implementation.""" + return self.__impl[name] + + def __setitem__(self, name, value): + """Abstract method of MutableMapping implementation.""" + self.__impl[name] = value + if self.__journal_event: + self.__journal_event("assign", {self.__journal_event_name: name, "value": value}) + + def __delitem__(self, name): + """Abstract method of MutableMapping implementation.""" + del self.__impl[name] + if self.__journal_event: + self.__journal_event("delete", {self.__journal_event_name: name}) + + def __iter__(self): + """Abstract method of MutableMapping implementation.""" + return iter(self.__impl) + + def __len__(self): + """Abstract method of MutableMapping implementation.""" + return len(self.__impl) + + @dataclass(frozen=True) class StateVariables: """A container for state variables loaded by state tracker.""" # User variables that live forever. - user: dict = field(default_factory=dict) + user: JournalledDict = field(default_factory=JournalledDict) # Skill variables that live during discussing a topic. - slots: dict = field(default_factory=dict) + slots: JournalledDict = field(default_factory=JournalledDict) # Private variables used by **maxbot** internal components. components: dict = field(default_factory=dict) @@ -351,7 +398,7 @@ def from_kv_pairs(cls, kv_pairs): :return StateVariables: :raise ValueError: Meet unknown namespace. """ - data = {f.name: {} for f in fields(cls)} + data = {f.name: f.default_factory() for f in fields(cls)} for name, value in kv_pairs: ns, name = name.split(".", 1) if ns not in data: @@ -484,9 +531,15 @@ class TurnContext: error: Optional[BotError] = field(default=None, init=False) def __post_init__(self): - """Make sure that turn is either foreground or background which is mutually exclusive.""" + """Make sure that turn is either foreground or background which is mutually exclusive. + + And pre-initialization `.state` field. + """ assert bool(self.message) != bool(self.rpc) + self.state.slots.set_journal_event_function(self.journal_event, "slots") + self.state.user.set_journal_event_function(self.journal_event, "user") + def get_state_variable(self, key): """Get state variable for the component. diff --git a/maxbot/extensions/rest.py b/maxbot/extensions/rest.py index 6088c14..87f3571 100644 --- a/maxbot/extensions/rest.py +++ b/maxbot/extensions/rest.py @@ -1,34 +1,29 @@ """Builtin MaxBot extension: REST calls from jinja scenarios.""" +import json import logging +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone from urllib.parse import urljoin import httpx from jinja2 import nodes from jinja2.ext import Extension -from ..errors import BotError -from ..maxml import Schema, fields, validate +from ..errors import BotError, YamlSnippet +from ..maxml import PoolLimitSchema, Schema, TimeDeltaSchema, TimeoutSchema, fields, validate logger = logging.getLogger(__name__) -class _ServiceAuth(Schema): - user = fields.Str(required=True) - password = fields.Str(required=True) - - -class _Service(Schema): - name = fields.Str(required=True) - method = fields.Str(validate=validate.OneOf(["get", "post", "put", "delete"])) - auth = fields.Nested(_ServiceAuth()) - headers = fields.Dict(keys=fields.Str(), values=fields.Str(), load_default=dict) - parameters = fields.Dict(keys=fields.Str(), values=fields.Str(), load_default=dict) - timeout = fields.Int() - base_url = fields.Url() - - class _JinjaExtension(Extension): - tags = {"GET", "POST", "PUT", "DELETE"} + tags = { + "GET", + "POST", + "PUT", + "DELETE", + "PATCH", + } def parse(self, parser): method = parser.stream.current.value.lower() @@ -51,6 +46,37 @@ def parse(self, parser): return nodes.Assign(target, restcall).set_lineno(lineno) +def _now(): + return datetime.now(timezone.utc) + + +class _ServiceAuth(Schema): + user = fields.Str(required=True) + password = fields.Str(required=True) + + +class _Service(Schema): + name = fields.Str(required=True) + method = fields.Str( + validate=validate.OneOf( + [ + "get", + "post", + "put", + "delete", + "patch", + ] + ) + ) + auth = fields.Nested(_ServiceAuth()) + headers = fields.Dict(keys=fields.Str(), values=fields.Str(), load_default=dict) + parameters = fields.Dict(keys=fields.Str(), values=fields.Str(), load_default=dict) + timeout = fields.Nested(TimeoutSchema()) + base_url = fields.Url() + limits = fields.Nested(PoolLimitSchema()) + cache = fields.Nested(TimeDeltaSchema()) + + class RestExtension: """Extension class.""" @@ -58,6 +84,16 @@ class ConfigSchema(Schema): """Extension configuration schema.""" services = fields.List(fields.Nested(_Service()), load_default=list) + timeout = fields.Nested(TimeoutSchema()) + limits = fields.Nested(PoolLimitSchema()) + garbage_collector_timeout = fields.Nested( + TimeDeltaSchema(), load_default=TimeDeltaSchema.VALUE_TYPE(hours=1) + ) + + @dataclass(frozen=True) + class _CacheValue: + return_value: dict + created: datetime = field(default_factory=_now, init=False) def __init__(self, builder, config): """Extension entry point. @@ -65,9 +101,24 @@ def __init__(self, builder, config): :param BotBuilder builder: MaxBot builder. :param dict config: Extension configuration. """ - self.services = {s["name"]: s for s in config.get("services", [])} + self.services = {} + for service in config.get("services", []): + service_name = service["name"].lower() + if service_name in self.services: + raise BotError( + f"Duplicate REST service names: {service['name']!r} and " + f"{self.services[service_name]['name']!r}", + YamlSnippet.from_data(service["name"]), + ) + self.services[service_name] = service + self.allowed_schemes = frozenset({"http", "https"} | set(self.services.keys())) + self.timeout = config.get("timeout", TimeoutSchema.DEFAULT) + self.limits = config.get("limits", PoolLimitSchema.DEFAULT) + self.garbage_collector_timeout = config["garbage_collector_timeout"] builder.add_template_global(self._rest_call, "rest_call") builder.jinja_env.add_extension(_JinjaExtension) + self.cache_container = {} + self.cache_container_garbage_collector_ts = _now() async def _rest_call(self, **args): args = dict(args) @@ -77,32 +128,74 @@ async def _rest_call(self, **args): url = _prepare_url(args, service) logger.debug("%s %s", method.upper(), url) + headers = _prepare_headers(args, service) + params = _prepare_params(args, service) + body = args.get("body") + auth = _prepare_auth(args, service) + + cache_key, return_value = None, None + cache_timeout = _prepare_cache_timeout(args, service) + if cache_timeout: + cache_key = _create_cache_key(method, url, headers, params, body, auth) + + now = _now() + if now > (self.cache_container_garbage_collector_ts + self.garbage_collector_timeout): + # ⚔️ garbage collection time + self.cache_container_garbage_collector_ts = now + garbage = [] + for key, value in self.cache_container.items(): + if cache_key == key and now <= (value.created + cache_timeout): + return_value = value.return_value + elif now >= (value.created + self.garbage_collector_timeout): + garbage.append(key) + for key in garbage: + self.cache_container.pop(key, None) + else: + value = self.cache_container.get(cache_key) + if value and now <= (value.created + cache_timeout): + return_value = value.return_value + + if return_value: + logger.debug("cache hit: %s", return_value) + return return_value + on_error = _prepare_on_error(args) session = service["session"] - resp = await session.request( - method, - url, - headers=_prepare_headers(args, service), - params=_prepare_params(args, service), - timeout=_prepare_timeout(args, service), - json=args.get("body"), - auth=_prepare_auth(args, service), - ) + kwargs = {} + headers = _prepare_headers(args, service) + body = args.get("body") + if isinstance(body, Mapping): + _, content_type = next( + iter([(k, v) for k, v in headers.items() if k.lower() == "content-type"]), + (None, "application/json"), + ) + if content_type.lower().strip().startswith("application/x-www-form-urlencoded"): + kwargs.update(data=body) + else: + kwargs.update(json=body) + else: + kwargs.update(content=body) + try: + resp = await session.request( + method, + url, + headers=headers, + params=params, + auth=auth, + timeout=self._prepare_timeout(args, service), + **kwargs, + ) resp.raise_for_status() + except httpx.HTTPStatusError as error: + result = {"ok": False, "status_code": resp.status_code, "json": {}} + return _on_request_exception(on_error, error, result) except httpx.HTTPError as error: - message = "REST call failed: " + str(error) - - logger.exception(message) - if on_error == "continue": - return {"ok": False, "status_code": resp.status_code, "json": {}} - - assert on_error == "break_flow" - _raise(message) + return _on_request_exception(on_error, error, result={"ok": False}) try: - return {"ok": True, "status_code": resp.status_code, "json": resp.json()} + return_value = {"ok": True, "status_code": resp.status_code, "json": resp.json()} except ValueError as error: logger.exception( "\n".join( @@ -114,25 +207,44 @@ async def _rest_call(self, **args): ] ) ) - return {"ok": True, "status_code": resp.status_code, "json": {}} + return {"ok": True, "status_code": resp.status_code, "json": {}} + if cache_timeout: + self.cache_container[cache_key] = self._CacheValue(return_value) + return return_value def _prepare_service(self, args): service_name = args.get("service") if service_name: - service = self.services.get(service_name) + service = self.services.get(service_name.lower()) if service is None: _raise(f'Unknown REST service "{service_name}"') else: - splitted = args.get("url", "").split("://") - service = self.services.get(splitted[0], {}) if len(splitted) == 2 else {} - if service: - args["url"] = splitted[1] + splitted = args.get("url", "").split("://", 1) + service = {} + if len(splitted) == 2: + scheme = splitted[0].lower() + service = self.services.get(scheme, {}) + if service: + args["url"] = splitted[1] + elif scheme not in self.allowed_schemes: + one_of = ["http", "https"] + list(self.services.keys()) + raise BotError( + ( + f"Unknown schema ({splitted[0]!r}) in URL {args['url']!r}\n" + f"Must be one of: {', '.join(one_of)}" + ) + ) + else: + service = {} if "session" not in service: - service["session"] = httpx.AsyncClient() + service["session"] = httpx.AsyncClient(limits=service.get("limits", self.limits)) return service + def _prepare_timeout(self, args, service): + return args.get("timeout") or service.get("timeout") or self.timeout + def _prepare_on_error(args): on_error = args.get("on_error", "break_flow") @@ -162,14 +274,34 @@ def _prepare_params(args, service): return {**service.get("parameters", {}), **args.get("parameters", {})} -def _prepare_timeout(args, service): - return args.get("timeout") or service.get("timeout") or 5 - - def _prepare_auth(args, service): auth = args.get("auth") or service.get("auth") return (auth["user"], auth["password"]) if auth else None +def _prepare_cache_timeout(args, service): + cache_timeout = args.get("cache") + if cache_timeout is None: + cache_timeout = service.get("cache") + if isinstance(cache_timeout, timedelta): + return cache_timeout + return None if cache_timeout is None else timedelta(seconds=cache_timeout) + + +def _create_cache_key(method, url, headers, params, body, auth): + return json.dumps( + {"m": method, "u": url, "h": headers, "p": params, "b": body, "a": auth}, sort_keys=True + ) + + def _raise(message): raise BotError(message) + + +def _on_request_exception(on_error, error, result): + message = "REST call failed: " + str(error) + logger.exception(message) + if on_error == "continue": + return result + assert on_error == "break_flow" + return _raise(message) diff --git a/maxbot/flows/dialog_flow.py b/maxbot/flows/dialog_flow.py index 0fe5e32..b766e56 100644 --- a/maxbot/flows/dialog_flow.py +++ b/maxbot/flows/dialog_flow.py @@ -53,6 +53,9 @@ async def turn(self, ctx): try: result = await self._root_component(ctx) except BotError as exc: + ctx.warning( + "An error has occurred. The dialog will be reset (including the slot values)." + ) ctx.set_error(exc) result = FlowResult.DONE diff --git a/maxbot/flows/dialog_tree.py b/maxbot/flows/dialog_tree.py index 37325a4..9d78c32 100644 --- a/maxbot/flows/dialog_tree.py +++ b/maxbot/flows/dialog_tree.py @@ -380,12 +380,12 @@ async def digression(self, digressed_node): :param Node digressed_node: Node to switch from. :return FlowResult: The result of the turn of the flow. """ - logger.debug("digression from %s", digressed_node) + self.journal_event("digression_from", digressed_node) for node in self.tree.root_nodes(self.ctx): if node == digressed_node: continue if node.condition(self.ctx, digressing=True): - return await self.trigger_maybe_digressed(node) + return await self.trigger_maybe_digressed(node, digressing=True) if self.ctx.rpc: return FlowResult.LISTEN if digressed_node.followup_allow_return: @@ -416,21 +416,23 @@ async def return_after_digression(self, result=DigressionResult.FOUND): # nowere to return return None - async def trigger_maybe_digressed(self, node): + async def trigger_maybe_digressed(self, node, digressing=False): """Trigger the node or return to the node after digression. :param Node node: Triggered node. + :param bool digressing: True if digression occurs. :return FlowResult: The result of the turn of the flow. """ if self.stack.remove(node): - return await self.trigger(node, DigressionResult.FOUND) - return await self.trigger(node) + return await self.trigger(node, DigressionResult.FOUND, digressing=digressing) + return await self.trigger(node, digressing=digressing) - async def trigger(self, node, digression_result=None): + async def trigger(self, node, digression_result=None, digressing=False): """Go through the node's slot filling flow (if any) and/or execute its response scenario. :param Node node: Triggered node. :param DigressionResult digression_result: The result with which we return from digression. + :param bool digressing: True if digression occurs. :raise ValueError: Unknown slot filling result. :return FlowResult: The result of the turn of the flow. """ @@ -439,35 +441,44 @@ async def trigger(self, node, digression_result=None): result = await node.slot_filling(self.ctx, digression_result) if result == FlowResult.DONE: self.stack.remove(node) - return await self.response(node, digression_result) + return await self.response(node, digression_result, digressing=digressing) if result == FlowResult.LISTEN: self.stack.push(node, "slot_filling") return FlowResult.LISTEN if result == FlowResult.DIGRESS: return await self.digression(node) raise ValueError(f"Unknown flow result {result!r}") - return await self.response(node, digression_result) + return await self.response(node, digression_result, digressing=digressing) - async def response(self, node, digression_result): + async def response(self, node, digression_result, digressing=False): """Execute the response scenario of the node. :param Node node: Triggered node. :param DigressionResult digression_result: The result with which we return from digression. + :param bool digressing: True if digression occurs. :return FlowResult: The result of the turn of the flow. """ payload = self.journal_event("response", node) - for command in await node.response(self.ctx, returning=digression_result is not None): + params = {"returning": digression_result is not None, "digressing": digressing} + for command in await node.response(self.ctx, **params): if "jump_to" in command: - payload.update(control_command="jump_to") + payload.update( + control_command={ + "jump_to": { + "node": command["jump_to"]["node"], + "transition": command["jump_to"]["transition"], + } + } + ) return await self.command_jump_to(node, command["jump_to"]) if "listen" in command: - payload.update(control_command="listen") + payload.update(control_command={"listen": {}}) return await self.command_listen(node) if "end" in command: - payload.update(control_command="end") + payload.update(control_command={"end": {}}) return self.command_end() if "followup" in command: - payload.update(control_command="followup") + payload.update(control_command={"followup": {}}) return await self.command_followup(node) self.ctx.commands.append(command) if node.followup: diff --git a/maxbot/flows/slot_filling.py b/maxbot/flows/slot_filling.py index 4ddd455..3bd98a8 100644 --- a/maxbot/flows/slot_filling.py +++ b/maxbot/flows/slot_filling.py @@ -233,6 +233,7 @@ def elicit(self, slot): value = True logger.debug("elicit slot %r value %r", slot.name, value) previous_value = self.ctx.state.slots.get(slot.name) + self.ctx.journal_event("slot_filling", {"slot": slot.name}) self.ctx.state.slots[slot.name] = value self.found_slots.append((slot, {"previous_value": previous_value, "current_value": value})) @@ -343,9 +344,6 @@ async def __call__(self): self.elicit(slot) # found for slot, params in self.found_slots: - self.ctx.journal_event( - "slot_filling", {"slot": slot.name, "value": params["current_value"]} - ) if slot.found: await self.found(slot, params) if self.state.get("slot_in_focus") and not self.found_slots: diff --git a/maxbot/maxml/__init__.py b/maxbot/maxml/__init__.py index b81e56c..3c5bea7 100644 --- a/maxbot/maxml/__init__.py +++ b/maxbot/maxml/__init__.py @@ -7,4 +7,8 @@ post_load, pre_load, validate, + validates_schema, ) + +from .http import PoolLimitSchema, TimeoutSchema # noqa: F401 +from .timedelta import TimeDeltaSchema # noqa: F401 diff --git a/maxbot/maxml/http.py b/maxbot/maxml/http.py new file mode 100644 index 0000000..c8cea26 --- /dev/null +++ b/maxbot/maxml/http.py @@ -0,0 +1,71 @@ +"""HTTPX library (https://www.python-httpx.org/) configuration.""" +import httpx +from marshmallow import Schema, fields, post_load, pre_load # noqa: F401 + + +class TimeoutSchema(Schema): + """HTTP request timeout schema. + + @see https://www.python-httpx.org/advanced/#timeout-configuration + """ + + default = fields.Float(load_default=5.0) + connect = fields.Float() + read = fields.Float() + write = fields.Float() + pool = fields.Float() + + @pre_load + def short_syntax(self, data, **kwargs): + """Short syntax is available. + + timeout: 1.2 + -> httpx.Timeout(connect=1.2, read=1.2, write=1.2, pool=1.2) + """ + if isinstance( + data, + ( + float, + int, + str, + ), + ): + try: + return {"default": float(data)} + except ValueError: + pass + return data + + @post_load + def return_httpx_timeout(self, data, **kwargs): + """Create and return httpx.Timeout by loaded data.""" + return httpx.Timeout( + connect=data.get("connect", data["default"]), + read=data.get("read", data["default"]), + write=data.get("write", data["default"]), + pool=data.get("pool", data["default"]), + ) + + DEFAULT = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=5.0) + + +class PoolLimitSchema(Schema): + """HTTP pool limit schema. + + @see https://www.python-httpx.org/advanced/#pool-limit-configuration + """ + + max_keepalive_connections = fields.Int(load_default=20) + max_connections = fields.Int(load_default=100) + keepalive_expiry = fields.Float(load_default=5.0) + + @post_load + def return_httpx_limits(self, data, **kwargs): + """Create and return httpx.Timeout by loaded data.""" + return httpx.Limits( + max_keepalive_connections=data["max_keepalive_connections"], + max_connections=data["max_connections"], + keepalive_expiry=data["keepalive_expiry"], + ) + + DEFAULT = httpx.Limits(max_keepalive_connections=20, max_connections=100, keepalive_expiry=5.0) diff --git a/maxbot/maxml/pretty.py b/maxbot/maxml/pretty.py index 8493232..4506ddc 100644 --- a/maxbot/maxml/pretty.py +++ b/maxbot/maxml/pretty.py @@ -29,7 +29,7 @@ def _write_element_markup(self, level, name, value): self._write_element_markup_headless(level, name, value) def _write_element_markup_headless(self, level, name, value): - lines = _markup_to_lines(value) + lines = markup_to_lines(value) if len(lines) > 1: self._result.write(self._newline) for line in lines: @@ -135,7 +135,12 @@ def new_line(self): self._new_line = True -def _markup_to_lines(value): +def markup_to_lines(value): + """Print markup value to list of strings. + + :param markup.Value value: Markup value. + :return list[str]: XML-escaped lines. + """ lines = _Lines() for i, item in enumerate(value.items): if item.kind == markup.START_TAG: diff --git a/maxbot/maxml/timedelta.py b/maxbot/maxml/timedelta.py new file mode 100644 index 0000000..2570207 --- /dev/null +++ b/maxbot/maxml/timedelta.py @@ -0,0 +1,48 @@ +"""`datetime.timedelta` representation.""" +from datetime import timedelta + +from marshmallow import Schema, fields, post_load, pre_load # noqa: F401 + + +class TimeDeltaSchema(Schema): + """Time shift value representation. + + @see https://docs.python.org/3/library/datetime.html#timedelta-objects + """ + + VALUE_TYPE = timedelta + + days = fields.Int(load_default=0) + seconds = fields.Int(load_default=0) + microseconds = fields.Int(load_default=0) + milliseconds = fields.Int(load_default=0) + minutes = fields.Int(load_default=0) + hours = fields.Int(load_default=0) + weeks = fields.Int(load_default=0) + + @pre_load + def short_syntax(self, data, **kwargs): + """Short syntax is available. + + my_time: 5 + -> timedelta(seconds=5) + """ + if isinstance(data, (int, str)): + try: + return {"seconds": int(data)} + except ValueError: + pass + return data + + @post_load + def return_httpx_timeout(self, data, **kwargs): + """Create and return datetime.timedelta by loaded data.""" + return self.VALUE_TYPE( + days=data["days"], + seconds=data["seconds"], + microseconds=data["microseconds"], + milliseconds=data["milliseconds"], + minutes=data["minutes"], + hours=data["hours"], + weeks=data["weeks"], + ) diff --git a/maxbot/nlu.py b/maxbot/nlu.py index 7e0e31c..ef658c8 100644 --- a/maxbot/nlu.py +++ b/maxbot/nlu.py @@ -373,14 +373,17 @@ def __call__(self, doc, utc_time=None): with warnings.catch_warnings(): warnings.simplefilter("ignore") - search_dates_settings = {"PREFER_DATES_FROM": "future"} + shift = 0 + settings = {"PREFER_DATES_FROM": "future", "STRICT_PARSING": True} if utc_time: - search_dates_settings["RELATIVE_BASE"] = utc_time + settings["RELATIVE_BASE"] = utc_time - shift = 0 - results = search_dates(doc.text, languages=["en"], settings=search_dates_settings) + results = search_dates(doc.text, languages=["en"], settings=settings) if results is None: - return + del settings["STRICT_PARSING"] + results = search_dates(doc.text, languages=["en"], settings=settings) + if results is None: + return for literal, dt in results: start_char = doc.text.index(literal, shift) end_char = start_char + len(literal) @@ -389,15 +392,17 @@ def __call__(self, doc, utc_time=None): dd = self.ddp.get_date_data(literal) if dd.period == "time": if parse(literal, languages=["en"], settings={"REQUIRE_PARTS": ["day"]}): + name = "date" if settings.get("STRICT_PARSING") else "latent_date" yield RecognizedEntity( - "date", dt.date().isoformat(), literal, start_char, end_char + name, dt.date().isoformat(), literal, start_char, end_char ) yield RecognizedEntity( "time", dt.time().isoformat(), literal, start_char, end_char ) else: + name = "date" if settings.get("STRICT_PARSING") else "latent_date" yield RecognizedEntity( - "date", dt.date().isoformat(), literal, start_char, end_char + name, dt.date().isoformat(), literal, start_char, end_char ) diff --git a/maxbot/persistence_manager.py b/maxbot/persistence_manager.py new file mode 100644 index 0000000..20f2dc7 --- /dev/null +++ b/maxbot/persistence_manager.py @@ -0,0 +1,253 @@ +"""Default implementation for state tracker based on sqlalchemy.""" +import copy +import enum +import json +import logging +from contextlib import contextmanager +from datetime import datetime, timezone + +from sqlalchemy import ( + JSON, + Column, + DateTime, + Enum, + ForeignKey, + Integer, + String, + UniqueConstraint, + create_engine, + select, +) +from sqlalchemy.orm import Session, declarative_base, relationship +from sqlalchemy.pool import StaticPool + +from .context import StateVariables +from .maxml import markup, pretty + +logger = logging.getLogger(__name__) +Base = declarative_base() + + +def create_json_serializer(default_serializers=None): + """Create JSON serializer. + + https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.JSON + (see "Customizing the JSON Serializer") + + :param list[tuple] default_serializers: Pairs: object type, object serializier. + :return callable: Value for json_serializer argument of create_engine. + """ + default_serializers = list(default_serializers) if default_serializers else [] + default_serializers.append((markup.Value, lambda o: "\n".join(pretty.markup_to_lines(o)))) + + def _default(o): + for type_, serializer in default_serializers: + if isinstance(o, type_): + return serializer(o) + raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") + + return lambda o: json.dumps(o, default=_default) + + +class DialogTable(Base): + """Stores the information about a conversation channel and user.""" + + __tablename__ = "dialog" + + dialog_id = Column(Integer, primary_key=True) + channel_name = Column(String, nullable=False) + user_id = Column(String, nullable=False) + + variables = relationship( + lambda: VariableTable, + backref="user", + cascade="all, delete-orphan", + order_by=lambda: VariableTable.variable_id, + ) + + history = relationship( + lambda: HistoryTable, + cascade="all, delete-orphan", + order_by=lambda: HistoryTable.history_id, + ) + + __table_args__ = (UniqueConstraint("channel_name", "user_id"),) + + +class VariableTable(Base): + """Stores the state of a conversation.""" + + __tablename__ = "variable" + + variable_id = Column(Integer, primary_key=True) + dialog_id = Column(Integer, ForeignKey("dialog.dialog_id", ondelete="CASCADE"), nullable=False) + name = Column(String, nullable=False) + value = Column(JSON, nullable=False) + + __table_args__ = (UniqueConstraint("dialog_id", "name"),) + + +class RequestType(enum.Enum): + """Type of request source.""" + + # Message from user + message = 1 + + # RPC message + rpc = 2 + + +class HistoryTable(Base): + """Stores the dialog turn history.""" + + __tablename__ = "history" + + history_id = Column(Integer, primary_key=True) + dialog_id = Column(Integer, ForeignKey("dialog.dialog_id", ondelete="CASCADE"), nullable=False) + request_date = Column(DateTime(timezone=True), nullable=False) + request_type = Column(Enum(RequestType), nullable=False) + request = Column(JSON, nullable=False) + response = Column(JSON, nullable=False) + + +class PersistenceTracker: + """Dialog turn persistence tracker.""" + + def __init__(self, user): + """Create new class instance. + + :param DialogTable user: Current dialog with user. + """ + self.user = user + + # make a deep copy to allow sqlalchemy track changes by + # comparing with original values + kv_pairs = [(v.name, copy.deepcopy(v.value)) for v in self.user.variables] + self.variables = StateVariables.from_kv_pairs(kv_pairs) + + def get_state(self): + """Return state variables.""" + return self.variables + + def set_message_history(self, message, commands): + """Track incoming user message. + + :param any message: JSON-serializable object of user message. + :param list commands: List of JSON-serializable objects of response commands. + """ + self.user.history.append( + HistoryTable( + request_date=datetime.now(timezone.utc), + request_type=RequestType.message, + request=message, + response=commands, + ) + ) + + def set_rpc_history(self, rpc, commands): + """Track RPC. + + :param any rpc: JSON-serializable object of RPC. + :param list commands: List of JSON-serializable objects of response commands. + """ + self.user.history.append( + HistoryTable( + request_date=datetime.now(timezone.utc), + request_type=RequestType.rpc, + request=rpc, + response=commands, + ) + ) + + @contextmanager + def __call__(self): + """Wrap dialog turn.""" + yield self + + existing = {v.name: v for v in self.user.variables} + for name, value in self.variables.to_kv_pairs(): + if name in existing: + var = existing.pop(name) + var.value = value + else: + var = VariableTable(name=name, value=value) + self.user.variables.append(var) + # variables that not in kv_pairs anymore must be deleted + for var in existing.values(): + self.user.variables.remove(var) + + +class SQLAlchemyManager: + """Load and save the state during the conversation.""" + + def __init__(self): + """Create new class instance.""" + self._engine = None + + @property + def engine(self): + """Sqlalchemy engine. + + You can set your own engine, for example, + + from sqlalchemy import create_engine + + state_tracker.engine = create_engine("postgresql://scott:tiger@localhost:5432/mydatabase") + + All necessary tables are created immediately after your set engine. + + Default: in-memory sqlite db that supports multithreading. + """ + if self._engine is None: + self._engine = self._create_default_engine() + self._create_tables(self._engine) + return self._engine + + @engine.setter + def engine(self, value): + self._engine = value + + def create_tables(self): + """Create storage schema.""" + if self._engine is None: + self._engine = self._create_default_engine() + self._create_tables(self.engine) + + @staticmethod + def _create_default_engine(): + return create_engine( + "sqlite://", + future=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + json_serializer=create_json_serializer(), + ) + + @staticmethod + def _create_tables(engine): + Base.metadata.create_all(engine, checkfirst=True) + + @contextmanager + def __call__(self, dialog): + """Load and save persistence state. + + :param dict dialog: A dialog for which the state is being loaded, with the schema :class:`~maxbot.schemas.DialogSchema`. + :return PersistenceTracker: Persistence tracker of current dialog turn. + """ + with Session(self.engine) as session: + stmt = ( + select(DialogTable) + .where(DialogTable.channel_name == dialog["channel_name"]) + .where(DialogTable.user_id == str(dialog["user_id"])) + .with_for_update() + ) + user = session.scalars(stmt).one_or_none() + if user is None: + user = DialogTable(channel_name=dialog["channel_name"], user_id=dialog["user_id"]) + session.add(user) + + tracker = PersistenceTracker(user) + with tracker(): + yield tracker + + session.commit() diff --git a/maxbot/resolver.py b/maxbot/resolver.py new file mode 100644 index 0000000..7361e00 --- /dev/null +++ b/maxbot/resolver.py @@ -0,0 +1,85 @@ +"""Resolve bot from command line argument.""" +import logging +import pkgutil +from pathlib import Path +from types import ModuleType + +from .bot import MaxBot +from .errors import BotError + +logger = logging.getLogger(__name__) + + +class BotResolver: + """Create a bot object from the given specification. + + You can override .on_* handlers to generate your own errors. + """ + + def __init__(self, spec): + """Create new instance. + + :param str spec: Path for bot file or directory or a name to an object. + """ + self.spec = spec + + def __call__(self): + """Create a bot object from the given specification. + + :return MaxBot: Created object. + """ + pkg_error = None + try: + rv = pkgutil.resolve_name(self.spec) + except ( + ValueError, # invalid spec + ModuleNotFoundError, # module from spec is not found + ) as exc: + # we have to leave "except" block + pkg_error = exc + except BotError as exc: + logger.critical("Bot Error %s", exc) + self.on_bot_error(exc) + raise exc + except Exception as exc: + logger.exception(f"While loading {self.spec!r}, an exception was raised") + self.on_error(exc) + raise exc + if pkg_error: + # fallback to paths + path = Path(self.spec) + try: + if path.is_file(): + builder = MaxBot.builder() + builder.use_directory_resources(path.parent, path.name) + return builder.build() + if path.is_dir(): + return MaxBot.from_directory(path) + except BotError as exc: + logger.critical("Bot Error %s", exc) + self.on_bot_error(exc) + raise exc + self.on_unknown_source(pkg_error) + raise RuntimeError( + f"{self.spec!r} file or directory not found, import causes error {pkg_error!r}" + ) + + # if attribute name is not provided, use the default one. + if isinstance(rv, ModuleType) and hasattr(rv, "bot"): + rv = rv.bot + if not isinstance(rv, MaxBot): + self.on_invalid_type() + raise RuntimeError(f"A valid MaxBot instance was not obtained from {self.spec!r}") + return rv + + def on_bot_error(self, exc): + """Handle a BotError that occurred when the bot was loaded.""" + + def on_error(self, exc): + """Raise exception (not a BotError) that occurred when the bot was loaded.""" + + def on_unknown_source(self, pkg_error): + """Handle a situation where `spec` is not a file, directory, or package.""" + + def on_invalid_type(self): + """Handle a situation where the resolved bot is of type other than MaxBot.""" diff --git a/maxbot/state_store.py b/maxbot/state_store.py deleted file mode 100644 index d6223fb..0000000 --- a/maxbot/state_store.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Default implementation for state tracker based on sqlalchemy.""" -import copy -from contextlib import contextmanager - -from sqlalchemy import ( - JSON, - Column, - ForeignKey, - Integer, - String, - UniqueConstraint, - create_engine, - select, -) -from sqlalchemy.orm import Session, declarative_base, relationship -from sqlalchemy.pool import StaticPool - -from .context import StateVariables - -Base = declarative_base() - - -class DialogTable(Base): - """Stores the information about a conversation channel and user.""" - - __tablename__ = "dialog" - - dialog_id = Column(Integer, primary_key=True) - channel_name = Column(String, nullable=False) - user_id = Column(String, nullable=False) - - variables = relationship( - lambda: VariableTable, - backref="user", - cascade="all, delete-orphan", - order_by=lambda: VariableTable.variable_id, - ) - - __table_args__ = (UniqueConstraint("channel_name", "user_id"),) - - -class VariableTable(Base): - """Stores the state of a conversation.""" - - __tablename__ = "variable" - - variable_id = Column(Integer, primary_key=True) - dialog_id = Column(Integer, ForeignKey("dialog.dialog_id", ondelete="CASCADE"), nullable=False) - name = Column(String, nullable=False) - value = Column(JSON, nullable=False) - - __table_args__ = (UniqueConstraint("dialog_id", "name"),) - - -class SQLAlchemyStateStore: - """Load and save the state during the conversation.""" - - def __init__(self): - """Create new class instance.""" - self._engine = None - - @property - def engine(self): - """Sqlalchemy engine. - - You can set your own engine, for example, - - from sqlalchemy import create_engine - - state_tracker.engine = create_engine("postgresql://scott:tiger@localhost:5432/mydatabase") - - All necessary tables are created immediately after your set engine. - - Default: in-memory sqlite db that supports multithreading. - """ - if self._engine is None: - self._set_engine( - create_engine( - "sqlite://", - future=True, - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) - ) - return self._engine - - @engine.setter - def engine(self, value): - self._set_engine(value) - - def _set_engine(self, value): - self._engine = value - Base.metadata.create_all(self._engine, checkfirst=True) - - @contextmanager - def __call__(self, dialog): - """Load and save state variables. - - :param dict dialog: A dialog for which the state is being loaded, with the schema :class:`~maxbot.schemas.DialogSchema`. - :return StateVariables: A container for state variables. - """ - with Session(self.engine) as session: - stmt = ( - select(DialogTable) - .where(DialogTable.channel_name == dialog["channel_name"]) - .where(DialogTable.user_id == str(dialog["user_id"])) - .with_for_update() - ) - user = session.scalars(stmt).one_or_none() - if user is None: - user = DialogTable(channel_name=dialog["channel_name"], user_id=dialog["user_id"]) - session.add(user) - - # make a deep copy to allow sqlalchemy track changes by - # comparing with original values - kv_pairs = [(v.name, copy.deepcopy(v.value)) for v in user.variables] - variables = StateVariables.from_kv_pairs(kv_pairs) - yield variables - existing = {v.name: v for v in user.variables} - for name, value in variables.to_kv_pairs(): - if name in existing: - var = existing.pop(name) - var.value = value - else: - var = VariableTable(name=name, value=value) - user.variables.append(var) - # variables that not in kv_pairs anymore must be deleted - for var in existing.values(): - user.variables.remove(var) - session.commit() diff --git a/maxbot/stories/__init__.py b/maxbot/stories/__init__.py new file mode 100644 index 0000000..12e59a6 --- /dev/null +++ b/maxbot/stories/__init__.py @@ -0,0 +1,159 @@ +"""MaxBot stories engine.""" +import pprint +from asyncio import new_event_loop +from datetime import timedelta, timezone +from uuid import uuid4 + +from ..context import get_utc_time_default +from ..maxml import Schema, ValidationError, fields, markup, pre_load, pretty, validates_schema +from ..rpc import RpcRequestSchema +from ..schemas import ResourceSchema + + +class Stories: + """Load and run stories.""" + + class MismatchError(Exception): + """Run story mismatch error.""" + + def __init__(self, message): + """Create new instance.""" + super().__init__(message) + + @property + def message(self): + """Mismatch formatted message.""" + return self.args[0] + + def __init__(self, bot): + """Create new instance. + + :param MaxBot bot: Bot instance. + """ + self.bot = bot + self.command_schema = self.bot.dialog_manager.CommandSchema(many=True) + self.loop = new_event_loop() + + class _RpcRequestSchemaWithDesc(RpcRequestSchema): + @validates_schema + def validates_schema(self, data, **kwargs): + params_schema = bot.rpc.get_params_schema(data["method"]) + if params_schema is None: + raise ValidationError("Method not found", field_name="method") + errors = params_schema().validate(data.get("params", {})) + if errors: + raise ValidationError(pprint.pformat(errors), field_name="params") + + class _TurnSchema(Schema): + utc_time = fields.DateTime() + message = fields.Nested(bot.dialog_manager.MessageSchema) + rpc = fields.Nested(_RpcRequestSchemaWithDesc) + response = fields.List(fields.Str(), required=True) + + @pre_load + def response_short_syntax(self, data, **kwargs): + response = data.get("response") + if isinstance(response, str): + data.update(response=[response]) + return data + + @validates_schema + def validates_schema(self, data, **kwargs): + if ("message" in data) == ("rpc" in data): + raise ValidationError("Exactly one of 'message' or 'rpc' is required.") + + class _StorySchema(ResourceSchema): + name = fields.Str(required=True) + turns = fields.Nested(_TurnSchema, many=True, required=True) + markers = fields.List(fields.Str(), load_default=list) + + self.schema = _StorySchema(many=True) + + def load(self, fspath): + """Load stories from file. + + :param str fspath: Stories file path. + :return list[dict]: Loaded stories. + """ + return self.schema.load_file(fspath) + + def run(self, story): + """Run one story.""" + self.loop.run_until_complete(self.arun(story)) + + async def arun(self, story): + """Run one story asynchronously.""" + original_comparator = markup.Value.COMPARATOR + markup.Value.COMPARATOR = markup_value_rendered_comparator + try: + self.bot.dialog_manager.utc_time_provider = StoryUtcTimeProvider() + dialog = {"channel_name": "stories", "user_id": str(uuid4())} + + for i, turn in enumerate(story["turns"]): + self.bot.dialog_manager.utc_time_provider.tick(turn.get("utc_time")) + + with self.bot.persistence_manager(dialog) as tracker: + if "message" in turn: + response = await self.bot.dialog_manager.process_message( + turn["message"], dialog, tracker.get_state() + ) + elif "rpc" in turn: + response = await self.bot.dialog_manager.process_rpc( + turn["rpc"], dialog, tracker.get_state() + ) + else: + raise AssertionError() # pragma: not covered + + for expected in turn["response"]: + if self.command_schema.loads(expected) == response: + break + else: + expected = [self.command_schema.loads(r) for r in turn["response"]] + raise self.MismatchError( + _format_mismatch(i, expected, response, self.command_schema) + ) + finally: + markup.Value.COMPARATOR = original_comparator + + +class StoryUtcTimeProvider: + """Stories datetime.""" + + def __init__(self): + """Create new class instance.""" + self.value = None + + def tick(self, dt=None): + """Calculate datetime for next step of story. + + :param datetime dt: Datetime from stry turn (optional). + """ + if dt: + if dt.tzinfo: + self.value = dt.astimezone(timezone.utc) + else: + self.value = dt.replace(tzinfo=timezone.utc) + elif self.value is not None: + self.value += timedelta(seconds=10) + + def __call__(self): + """Get current datetime.""" + return get_utc_time_default() if self.value is None else self.value + + +def markup_value_rendered_comparator(lhs, rhs): + """Compare `markup.Value` by rendered value.""" + if isinstance(lhs, (markup.Value, str)) and isinstance(rhs, (markup.Value, str)): + lhs_rendered = lhs if isinstance(lhs, str) else lhs.render() + return lhs_rendered == (rhs if isinstance(rhs, str) else rhs.render()) + return False + + +def _format_mismatch(turn_index, expected, actual, command_schema): + expected_str = "\n-or-\n".join(pretty.print_xml(e, command_schema) for e in expected) + actual_str = pretty.print_xml(actual, command_schema) + + def _shift(s): + return "\n".join(f" {line}" for line in s.splitlines()) + + return f"Mismatch at step [{turn_index}]\nExpected:\n{_shift(expected_str)}\nActual:\n{_shift(actual_str)}" diff --git a/maxbot/stories/pytest.py b/maxbot/stories/pytest.py new file mode 100644 index 0000000..7f86156 --- /dev/null +++ b/maxbot/stories/pytest.py @@ -0,0 +1,87 @@ +"""MaxBot stories pytest plugin. + +See: +- https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html +- https://docs.pytest.org/en/7.1.x/reference/reference.html +- https://docs.pytest.org/en/7.1.x/example/nonpython.html +""" +import pytest +from dotenv import find_dotenv, load_dotenv + +from ..resolver import BotResolver +from . import Stories + +_STORIES = None +_CONFIG = None + + +def pytest_addoption(parser): + """Add options to pytest command line.""" + group = parser.getgroup("maxbot") + group.addoption( + "--bot", + action="store", + dest="bot_spec", + help=( + "Path for bot file or directory or the Maxbot instance to load. " + "The instance can be in the form 'module:name'. " + "Module can be a dotted import. Name is not required if it is 'bot'." + ), + ) + + +def pytest_configure(config): + """Perform initial configuration.""" + bot_spec = config.getoption("bot_spec") + if bot_spec: + path = find_dotenv(".env", usecwd=True) + if path: + load_dotenv(path, encoding="utf-8") + + global _STORIES, _CONFIG # pylint: disable=W0603 + _STORIES = Stories(BotResolver(bot_spec)()) + _CONFIG = config + + +def pytest_collect_file(file_path, parent): + """Create collector of stories.""" + return StoriesFile.from_parent(parent, path=file_path) if _STORIES else None + + +class StoriesFile(pytest.File): + """Stories YAML file.""" + + def collect(self): + """Collect stories from current file.""" + for story in _STORIES.load(self.fspath): + item = StoryItem.from_parent(self, name=story["name"]) + item.user_properties.append(("maxbot_story", story)) + for mark in story["markers"]: + _CONFIG.addinivalue_line("markers", mark) + item.add_marker(mark) + yield item + + +class StoryItem(pytest.Item): + """One story pytest representation.""" + + def runtest(self): + """Run story.""" + _, story = next(pair for pair in self.user_properties if pair[0] == "maxbot_story") + _STORIES.run(story) + + def repr_failure(self, excinfo, style=None): + """Return a representation of a test failure.""" + if isinstance(excinfo.value, _STORIES.MismatchError): + return excinfo.value.message + return excinfo.getrepr(style="long" if _CONFIG.getoption("verbose") > 0 else "short") + + def reportinfo(self): + """Get location information for this item for test reports.""" + return self.fspath, None, f"{self.name}" + + +def pytest_sessionfinish(session, exitstatus): + """Cleanup pytest session.""" + if _STORIES: + _STORIES.loop.close() diff --git a/maxbot/user_locks/__init__.py b/maxbot/user_locks/__init__.py new file mode 100644 index 0000000..22deb75 --- /dev/null +++ b/maxbot/user_locks/__init__.py @@ -0,0 +1,4 @@ +"""User Locks allow user requests to be processed in FIFO order.""" + +from .asyncio import AsyncioLocks # noqa: F401 +from .mp import MultiProcessLocks, MultiProcessLocksServer, UnixSocketStreams # noqa: F401 diff --git a/maxbot/user_locks.py b/maxbot/user_locks/asyncio.py similarity index 93% rename from maxbot/user_locks.py rename to maxbot/user_locks/asyncio.py index 1cfb8be..824b46f 100644 --- a/maxbot/user_locks.py +++ b/maxbot/user_locks/asyncio.py @@ -1,4 +1,4 @@ -"""User Locks allow user requests to be processed in FIFO order.""" +"""An `asyncio.Lock` -based implementation of user lock.""" import asyncio from contextlib import asynccontextmanager diff --git a/maxbot/user_locks/mp.py b/maxbot/user_locks/mp.py new file mode 100644 index 0000000..1059e7d --- /dev/null +++ b/maxbot/user_locks/mp.py @@ -0,0 +1,277 @@ +"""Multiprocess implementation of user lock.""" +import asyncio +import logging +import os +from base64 import b64encode +from contextlib import asynccontextmanager +from multiprocessing import current_process +from signal import SIGINT, signal + +from .asyncio import AsyncioLocks + +logger = logging.getLogger(__name__) + +_OP_ACQUIRE = b"a" +_ACQ_ACQUIRED = b"\x01" +_OP_RELEASE = b"r" +_EOF = b"\0" + + +class ServerClosedConnectionError(RuntimeError): + """Exception from `MultiProcessLocks.__call__`: server closed the connection.""" + + +class UnixSocketStreams: + """Asyncronus streams implemented on UNIX sockets.""" + + def __init__(self, path): + """Create new class instance. + + :param str|bytes|os.PathLike path: UNIX socket path. + """ + self.path = path + + @asynccontextmanager + async def start_server(self, client_connected_cb): + """Start new server.""" + try: + yield await asyncio.start_unix_server(client_connected_cb, path=self.path) + finally: + os.unlink(self.path) + + async def open_connection(self): + """Open client connection.""" + return await asyncio.open_unix_connection(path=self.path) + + +class MultiProcessLocks: + """Multiprocess implementation of user lock.""" + + def __init__(self, open_connection): + """Create new class instance. + + * The implementation in not thread-safe. + * The FIFO order is guaranteed by underlying asyncio.Lock. + + Usage example: + + from multiprocessing import Process, Event + + # socket streams implementation (IPC transport) + streams = UnixSocketStreams("/tmp/maxbot-locks.sock") + # server ready and server stop events + ready_event, stop_event = Event(), Event() + # server in dedicated process + Process( + target=MultiProcessLocksServer(streams.start_server, ready_event, stop_event), + daemon=True + ).start() + # waiting for server to be ready to accept connections + ready_event.wait() + + # asynchronous user locks + locks = MultiProcessLocks(streams.open_connection) + + # lock-unlock user1 (from dialog1) + async with locks(dialog1): + pass + + # lock-unlock user2 (from dialog2) + async with locks(dialog2): + pass + + :param callable open_connection: Open connection to server. + """ + self._open_connection = open_connection + self._streams_lock = asyncio.Lock() + self._reader, self._writer = None, None + self._for_current_process = AsyncioLocks() + + @asynccontextmanager + async def __call__(self, dialog): + """Acquire and release a lock on the given dialog. + + :param dict dialog: The dialog that needs to be locked. + """ + async with self._for_current_process(dialog): + key = ( + b64encode(str(dialog["channel_name"]).encode()) + + b"|" + + b64encode(str(dialog["user_id"]).encode()) + ) + try: + await self._connect() + await self._acquire(key) + try: + yield + finally: + await self._release(key) + except (ConnectionResetError, BrokenPipeError) as exc: + raise ServerClosedConnectionError() from exc # pragma: not covered + + async def _connect(self): + async with self._streams_lock: + if not self._reader: + assert not self._writer + self._reader, self._writer = await self._open_connection() + self._writer.write(current_process().name.encode() + _EOF) + await self._writer.drain() + + async def _acquire(self, key): + async with self._streams_lock: + self._writer.write(_OP_ACQUIRE + key + _EOF) + await self._writer.drain() + + acq = await self._reader.read(len(_ACQ_ACQUIRED)) + if acq != _ACQ_ACQUIRED: + if acq: + raise AssertionError(f"Unexpected server answer: {acq}") + raise ServerClosedConnectionError() + + async def _release(self, key): + async with self._streams_lock: + self._writer.write(_OP_RELEASE + key + _EOF) + await self._writer.drain() + + async def disconnect(self): + """Disconnect from server process.""" + if self._writer: + self._writer.close() + await self._writer.wait_closed() + + self._reader, self._writer = None, None + + +class MultiProcessLocksServer: + """Dedicated process of multiprocess user locks.""" + + def __init__(self, start_server, ready_event, stop_event): + """Create new class instance. + + :param callable start_server: Start new server. + :param Event ready_event: An event that is set when server is accepting new connections. + :param Event stop_event: Stop server event. + """ + self._start_server = start_server + self._ready_event = ready_event + self._stop_event = stop_event + self._connected_clients = 0 + self._locks = {} + + async def _log_exception(self, coro, name): + try: + await coro + except Exception: + logger.exception(f"Unhandled exception in {name}") + + def _create_task(self, coro, name): + return asyncio.create_task(self._log_exception(coro, name), name=name) + + @staticmethod + async def _fatal_error(writer, message): + writer.close() + await writer.wait_closed() + raise AssertionError(message) + + async def _client_conected(self, reader, writer): + await self._log_exception(self._client_conected_impl(reader, writer), "client_connected") + + async def _client_conected_impl(self, reader, writer): + self._connected_clients += 1 + try: + client_name = await reader.readuntil(_EOF) + client_name = client_name[:-1].decode() + logger.info("%s connected", client_name) + + locked_by_client = {} + try: + while True: + try: + buffer = await reader.readuntil(_EOF) + except asyncio.IncompleteReadError as error: + assert not error.partial + return + + op = buffer[0:1] + key = buffer[1:-1] + if op == _OP_ACQUIRE: + logger.debug("%s acquire %s", client_name, key) + if key in locked_by_client: + await self._fatal_error( + writer, f"Recursive lock {key!r}: {locked_by_client[key]}" + ) + + user_lock = self._locks.get(key) + if user_lock is None: + user_lock = asyncio.Lock() + self._locks[key] = user_lock + + await user_lock.acquire() + locked_by_client[key] = user_lock + + writer.write(_ACQ_ACQUIRED) + await writer.drain() + elif op == _OP_RELEASE: + logger.debug("%s release %s", client_name, key) + user_lock = locked_by_client.pop(key, None) + if not user_lock: + await self._fatal_error(writer, f"{key} is not locked") + + user_lock.release() + else: + await self._fatal_error(writer, f"Unexpected: {buffer}") # no cov + finally: + for key, user_lock in locked_by_client.items(): + logger.warning( + "%s %s unlocked on client %s disconnect", key, user_lock, client_name + ) + user_lock.release() + + logger.info("%s disconnected", client_name) + finally: + assert self._connected_clients > 0 + self._connected_clients -= 1 + + async def _serve(self, server): + async with server: + logger.info("Locks server started") + while True: + if server.is_serving(): + self._ready_event.set() + + if self._stop_event.is_set() and self._connected_clients == 0: + logger.info("Locks server stopping...") + return + + await asyncio.sleep(0.1) + continue + + def _sigint_handler(self, signum, frame): + logger.warning( + "Locks server has received an %s signal. Waiting for %s clients to disconnect.", + signum, + self._connected_clients, + ) + self._stop_event.set() + + def __call__(self, initialize_process_fns=None): + """Server process entry point. + + :param iterable[calable] initialize_process_fns: Additional in-process initialization. + """ + signal(SIGINT, self._sigint_handler) + + for fn in initialize_process_fns or []: + fn() + logger.info("Locks server staring...") + + async def main(): + async with self._start_server(self._client_conected) as server: + await asyncio.gather( + self._create_task(self._serve(server), name="serve"), + self._create_task(server.serve_forever(), name="serve_forever"), + return_exceptions=True, + ) + + asyncio.run(main()) + logger.info("Locks server stopped") diff --git a/maxbot/webapp.py b/maxbot/webapp.py new file mode 100644 index 0000000..b92a322 --- /dev/null +++ b/maxbot/webapp.py @@ -0,0 +1,245 @@ +"""Sanic WEB application.""" + +import logging +import os +from multiprocessing import get_context +from tempfile import NamedTemporaryFile + +from .user_locks import MultiProcessLocks, MultiProcessLocksServer + +logger = logging.getLogger(__name__) + + +def run_webapp( + bot, + bot_factory, + host, + port, + *, + init_logging=None, + public_url=None, + autoreload=False, + workers=1, + fast=False, + single_process=False, +): + """Run WEB application. + + Function does not return control. + + :param MaxBot bot: Bot. + :param calable bot_factory: MaxBot factory. + :param str host: Hostname or IP address on which to listen. + :param int port: TCP port on which to listen. + :param calable init_logging: Initialize logging for new processes. + :param str public_url: Base url to register webhook. + :param bool autoreload: Enable tracking and reloading bot resource changes. + :param int workers: Number of worker processes to spawn. + :param bool fast: Whether to maximize worker processes. + :param bool single_process: Single process mode. + """ + factory = Factory( + bot, bot_factory, host, port, init_logging, public_url, autoreload, single_process + ) + + # lazy import to speed up load time + import sanic + + if sanic.__version__.startswith("21."): + factory.single_process = True + factory().run(host, port, motd=False, workers=1) + return + + os.environ["SANIC_IGNORE_PRODUCTION_WARNING"] = "true" + if single_process: + factory().run(host, port, motd=False, single_process=True) + return + + from sanic.worker.loader import AppLoader + + loader = AppLoader(factory=factory) + app = loader.load() + app.prepare( + host, + port, + motd=False, + fast=fast, + workers=workers, + single_process=single_process, + ) + sanic.Sanic.serve(primary=app, app_loader=loader) + + +class Factory: + """WEB application factory. + + Re-create MaxBot object in another process and create WEB application. + """ + + def __init__( + self, bot, bot_factory, host, port, init_logging, public_url, autoreload, single_process + ): + """Create new class instance. + + :param MaxBot bot: Bot. + :param calable bot_factory: MaxBot factory. + :param str host: Hostname or IP address on which to listen. + :param int|str port: TCP port on which to listen. + :param calable init_logging: Initialize logging for new processes. + :param str public_url: Base url to register webhook. + :param bool autoreload: Enable tracking and reloading bot resource changes. + :param bool single_process: Single process mode. + """ + self.bot = bot + self.bot_factory = bot_factory + self.host = host + self.port = port + self.init_logging = init_logging + self.public_url = public_url + self.autoreload = autoreload + self.single_process = single_process + with NamedTemporaryFile(prefix="maxbot-") as f: + self.base_file_name = f.name + + def __getstate__(self): + """Transfer state to another process.""" + state = self.__dict__.copy() + state.update(bot=None) + return state + + def __setstate__(self, state): + """Apply transfered state from another process.""" + self.__dict__.update(state) + if self.init_logging: + self.init_logging() # for new process + self.bot = self.bot_factory() + + def __call__(self): + """Create and return WEB application.""" + # lazy import to speed up load time + import sanic + + self.bot.validate_at_least_one_channel() + + app = sanic.Sanic("maxbot", configure_logging=False) + app.config.FALLBACK_ERROR_FORMAT = "text" + + for channel in self.bot.channels: + app.blueprint( + channel.blueprint( + self.bot.default_channel_adapter, + self.execute_once, + public_url=self.public_url, + webhook_path=f"/{channel.name}", + ) + ) + + if self.bot.rpc: + app.blueprint(self.bot.rpc.blueprint(self.bot.channels, self.bot.default_rpc_adapter)) + + if self.autoreload: + + @app.after_server_start + async def start_autoreloader(app, loop): + app.add_task(self.bot.autoreloader, name="autoreloader") + + @app.before_server_stop + async def stop_autoreloader(app, loop): + await app.cancel_task("autoreloader") + + if not self.single_process: + mp_ctx = { + "locks_file_path": f"{self.base_file_name}{self.bot.SUFFIX_LOCKS}", + "db_file_path": f"{self.base_file_name}{self.bot.SUFFIX_DB}", + } + mp_ctx["locks_streams"] = self.bot.SocketStreams(mp_ctx["locks_file_path"]) + mp_ctx["default_locks"] = MultiProcessLocks(mp_ctx["locks_streams"].open_connection) + + @app.main_process_start + async def main_process_started(app, loop): + logger.info("Sanic multi-process server starting...") + user_locks = self.bot.setdefault_user_locks(mp_ctx["default_locks"]) + if isinstance(user_locks, MultiProcessLocks): + ctx = get_context("spawn") + mp_ctx["locks_server_ready"] = ctx.Event() + mp_ctx["locks_server_stop"] = ctx.Event() + mp_ctx["locks_server"] = ctx.Process( + target=MultiProcessLocksServer( + mp_ctx["locks_streams"].start_server, + mp_ctx["locks_server_ready"], + mp_ctx["locks_server_stop"], + ), + args=( + [ + self.init_logging, + ], + ), + name="MpUserLocks", + ) + mp_ctx["locks_server"].start() + + def _create_default_mp_persistence_manager_and_tables(): + persistence_manager = self._create_default_mp_persistence_manager( + mp_ctx["db_file_path"] + ) + persistence_manager.create_tables() + return persistence_manager + + self.bot.setdefault_persistence_manager( + _create_default_mp_persistence_manager_and_tables + ) + + @app.main_process_ready + async def main_process_ready(app, loop): + if "locks_server" in mp_ctx: + mp_ctx["locks_server_ready"].wait() + + @app.main_process_stop + async def main_process_stopping(app, loop): + logger.info("Sanic multi-process server stopping...") + if "locks_server" in mp_ctx: + mp_ctx["locks_server_stop"].set() + + @app.after_server_start + async def server_started(app, loop): + if not self.single_process: + self.bot.setdefault_user_locks(mp_ctx["default_locks"]) + self.bot.setdefault_persistence_manager( + lambda: self._create_default_mp_persistence_manager(mp_ctx["db_file_path"]) + ) + + async def _log_messages_for_user(): + logger.debug("bot.user_locks = %s", self.bot.user_locks) + logger.debug( + "bot.persistence_manager.engine = %s", self.bot.persistence_manager.engine + ) + + if self.public_url is None: + for channel in self.bot.channels: + logger.warning( + "Make sure you have a public URL that is forwarded to -> " + f"http://{self.host}:{self.port}/{channel.name} and register webhook for it." + ) + logger.info( + f"Started webhooks updater on http://{self.host}:{self.port}. Press 'Ctrl-C' to exit." + ) + + await self.execute_once(app, _log_messages_for_user) + + return app + + @staticmethod + def _create_default_mp_persistence_manager(file_path): + from .persistence_manager import SQLAlchemyManager, create_engine, create_json_serializer + + persistence_manager = SQLAlchemyManager() + persistence_manager.engine = create_engine( + f"sqlite:///{file_path}", json_serializer=create_json_serializer() + ) + return persistence_manager + + async def execute_once(self, app, fn): + """Execute only for first worker of WEB application.""" + if self.single_process or (app.m.name.endswith("-0-0") and app.m.state["starts"] == 1): + # Run in first worker + await fn() diff --git a/pyproject.toml b/pyproject.toml index bef1904..141d069 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "Maxbot" -version = "0.2.0" +version = "0.3.0" description = "Maxbot is an open source library and framework for creating conversational apps." license = "MIT" authors = ["Maxbot team "] @@ -104,3 +104,6 @@ asyncio_mode = "auto" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.poetry.plugins.pytest11] +maxbot_stories = "maxbot.stories.pytest" diff --git a/tests/conftest.py b/tests/conftest.py index c13dc3f..cc54ee7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from maxbot.context import StateVariables from maxbot.errors import YamlSymbols +pytest_plugins = "pytester" Sanic.test_mode = True diff --git a/tests/test_bot/test_adapters.py b/tests/test_bot/test_adapters.py index 840d8ee..c33c5c5 100644 --- a/tests/test_bot/test_adapters.py +++ b/tests/test_bot/test_adapters.py @@ -72,11 +72,53 @@ async def test_default_rpc_adapter(bot, channel): assert channel.sent == ["hello world"] -async def test_state_store(channel): +async def test_persistence_manager(channel): mock = MagicMock() - bot = MaxBot(state_store=Mock(return_value=mock)) + bot = MaxBot(persistence_manager=Mock(return_value=mock)) await bot.default_channel_adapter("hey bot", channel) - bot.state_store.assert_called_once_with({"channel_name": channel.name, "user_id": "23"}) + bot.persistence_manager.assert_called_once_with( + {"channel_name": channel.name, "user_id": "23"} + ) mock.__enter__.assert_called_once() mock.__exit__.assert_called_once() + + +async def test_track_history_channel(channel): + tracker = MagicMock() + mock = MagicMock() + mock.__enter__ = Mock(return_value=tracker) + bot = MaxBot(persistence_manager=Mock(return_value=mock), history_tracked=True) + await bot.default_channel_adapter("hey bot", channel) + + tracker.set_message_history.assert_called_once() + + +async def test_track_history_rpc(channel): + tracker = MagicMock() + mock = MagicMock() + mock.__enter__ = Mock(return_value=tracker) + bot = MaxBot(persistence_manager=Mock(return_value=mock), history_tracked=True) + await bot.default_rpc_adapter({}, channel, "34") + + tracker.set_rpc_history.assert_called_once() + + +async def test_track_history_channel_default(channel): + tracker = MagicMock() + mock = MagicMock() + mock.__enter__ = Mock(return_value=tracker) + bot = MaxBot(persistence_manager=Mock(return_value=mock)) + await bot.default_channel_adapter("hey bot", channel) + + tracker.set_message_history.assert_not_called() + + +async def test_track_history_rpc_default(channel): + tracker = MagicMock() + mock = MagicMock() + mock.__enter__ = Mock(return_value=tracker) + bot = MaxBot(persistence_manager=Mock(return_value=mock)) + await bot.default_rpc_adapter({}, channel, "34") + + tracker.set_rpc_history.assert_not_called() diff --git a/tests/test_bot/test_run_polling.py b/tests/test_bot/test_run_polling.py index 76362f5..ba21b9f 100644 --- a/tests/test_bot/test_run_polling.py +++ b/tests/test_bot/test_run_polling.py @@ -132,3 +132,18 @@ async def test_autoreload(tmp_path, post_init, post_stop): await post_stop() assert task.cancelled() + + +async def test_timeout(bot, monkeypatch): + app = None + + def run_polling(self): + nonlocal app + app = self + + monkeypatch.setattr(Application, "run_polling", run_polling) + + bot.run_polling() + + assert app.bot._request[0]._client_kwargs["timeout"] == bot.channels.telegram.timeout + assert app.bot._request[1]._client_kwargs["timeout"] == bot.channels.telegram.timeout diff --git a/tests/test_bot/test_run_webapp.py b/tests/test_bot/test_run_webapp.py deleted file mode 100644 index 1b78781..0000000 --- a/tests/test_bot/test_run_webapp.py +++ /dev/null @@ -1,131 +0,0 @@ -import asyncio -import logging -from unittest.mock import ANY, AsyncMock, Mock - -import pytest -from sanic import Sanic - -from maxbot.bot import MaxBot -from maxbot.channels import ChannelsCollection -from maxbot.errors import BotError -from maxbot.rpc import RpcManager - - -@pytest.fixture(autouse=True) -def mock_sanic_run(monkeypatch): - monkeypatch.setattr(Sanic, "run", Mock()) - - -@pytest.fixture -def bot(): - # we need at least one channel to run the bot - channel = Mock() - channel.configure_mock(name="my_channel") - bot = MaxBot(channels=ChannelsCollection([channel])) - return bot - - -@pytest.fixture -def after_server_start(monkeypatch): - monkeypatch.setattr(Sanic, "after_server_start", Mock()) - - async def execute(app=None): - for call in Sanic.after_server_start.call_args_list: - (coro,) = call.args - await coro(app or Mock(), loop=Mock()) - - return execute - - -@pytest.fixture -def before_server_stop(monkeypatch): - monkeypatch.setattr(Sanic, "before_server_stop", Mock()) - - async def execute(app=None): - for call in Sanic.before_server_stop.call_args_list: - (coro,) = call.args - await coro(app or Mock(), loop=Mock()) - - return execute - - -def test_run_webapp(bot): - bot.run_webapp("localhost", 8080) - - assert Sanic.run.call_args.args == ("localhost", 8080) - - ch = bot.channels.my_channel - assert ch.blueprint.called - - -async def test_report_started(bot, after_server_start, caplog): - bot.run_webapp("localhost", 8080) - - with caplog.at_level(logging.INFO): - await after_server_start() - assert ( - "Started webhooks updater on http://localhost:8080. Press 'Ctrl-C' to exit." - ) in caplog.text - - -def test_no_channels(): - bot = MaxBot() - with pytest.raises(BotError) as excinfo: - bot.run_webapp("localhost", 8080) - assert excinfo.value.message == ( - "At least one channel is required to run a bot. " - "Please, fill the 'channels' section of your bot.yaml." - ) - - -def test_rpc_enabled(bot, monkeypatch): - monkeypatch.setattr(RpcManager, "blueprint", Mock()) - - bot.dialog_manager.load_inline_resources( - """ - rpc: - - method: say_hello - """ - ) - bot.run_webapp("localhost", 8080) - - assert bot.rpc.blueprint.called - - -def test_rpc_disabled(bot, monkeypatch): - monkeypatch.setattr(RpcManager, "blueprint", Mock()) - - bot.run_webapp("localhost", 8080) - - assert not bot.rpc.blueprint.called - - -async def test_autoreload(bot, after_server_start, before_server_stop): - bot.run_webapp("localhost", 8080, autoreload=True) - - app = Mock() - await after_server_start(app) - app.add_task.assert_called_with(bot.autoreloader, name="autoreloader") - - app = AsyncMock() - await before_server_stop(app) - app.cancel_task.assert_called_with("autoreloader") - - -def test_public_url_missing(bot, caplog): - bot.channels.my_channel.configure_mock(name="my_channel") - - with caplog.at_level(logging.WARNING): - bot.run_webapp("localhost", 8080) - assert ( - "Make sure you have a public URL that is forwarded to -> " - "http://localhost:8080/my_channel and register webhook for it." - ) in caplog.text - - -def test_public_url_present(bot): - bot.run_webapp("localhost", 8080, public_url="https://example.com") - - ch = bot.channels.my_channel - kw = ch.blueprint.call_args.kwargs - assert kw["public_url"] == "https://example.com" diff --git a/tests/test_builder.py b/tests/test_builder.py index eb0cc2e..ed9882a 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -82,11 +82,15 @@ async def test_channel_unknown(builder): ) -def test_state_store(builder): - builder.state_store = sentinel.state_store - assert builder.state_store is sentinel.state_store +def test_persistence_manager(builder): + builder.persistence_manager = sentinel.persistence_manager + assert builder.persistence_manager is sentinel.persistence_manager bot = builder.build() - assert bot.state_store is sentinel.state_store + assert bot.persistence_manager is sentinel.persistence_manager + + +def test_persistence_manager_default(builder): + assert type(builder.persistence_manager).__name__ == "SQLAlchemyManager" def test_user_locks(builder): @@ -96,6 +100,10 @@ def test_user_locks(builder): assert bot.user_locks is sentinel.user_locks +def test_user_locks_default(builder): + assert type(builder.user_locks).__name__ == "AsyncioLocks" + + def test_nlu(builder): nlu = Mock() builder.nlu = nlu @@ -105,6 +113,10 @@ def test_nlu(builder): nlu.load_resources.assert_called_once() +def test_nlu_default(builder): + assert type(builder.nlu).__name__ == "Nlu" + + def test_jinja_options(builder): builder.jinja_options["optimized"] = False assert builder.jinja_env.optimized == False @@ -274,3 +286,12 @@ def test_use_package_resources(builder, tmp_path, monkeypatch): bot = builder.build() commands = bot.process_message("hey bot") assert commands == [{"text": "hello world"}] + + +def test_history_tracked(builder): + builder.track_history() + assert builder.build()._history_tracked == True + + +def test_history_tracked_default(builder): + assert builder.build()._history_tracked == False diff --git a/tests/test_channels/test_facebook.py b/tests/test_channels/test_facebook.py index 67503b6..fa89809 100644 --- a/tests/test_channels/test_facebook.py +++ b/tests/test_channels/test_facebook.py @@ -164,7 +164,7 @@ async def test_sanic_endpoint(bot): callback = AsyncMock() app = Sanic(__name__) - app.blueprint(bot.channels.facebook.blueprint(callback)) + app.blueprint(bot.channels.facebook.blueprint(callback, None)) _, response = await app.asgi_client.post( "/facebook", json=UPDATE_TEXT, @@ -189,8 +189,59 @@ async def test_sanic_endpoint(bot): async def test_sanic_register_webhook(bot, caplog): + async def execute_once(app, fn): + await fn() + with caplog.at_level(logging.WARNING): - bot.channels.facebook.blueprint(AsyncMock(), public_url="https://example.com/") + bot.channels.facebook.blueprint( + AsyncMock(), execute_once, public_url="https://example.com/" + ) assert ( "The facebook platform has no suitable api, register a webhook yourself https://example.com/facebook." ) in caplog.text + + +async def test_timeout_not_specified(bot, dialog, respx_mock): + respx_mock.post(f"{API_URL}/me/messages?access_token={ACCESS_TOKEN}").respond(json={}) + text = Mock() + text.render = Mock(return_value=MESSAGE_TEXT) + await bot.channels.facebook.call_senders({"text": text}, dialog) + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 5.0, "pool": 5.0, "read": 5.0, "write": 5.0}, + ] + + +async def test_timeout(dialog, respx_mock): + bot = MaxBot.inline( + f""" + channels: + facebook: + app_secret: {FB_APP_SECRET} + access_token: {ACCESS_TOKEN} + timeout: + default: 3.1 + connect: 10 + """ + ) + respx_mock.post(f"{API_URL}/me/messages?access_token={ACCESS_TOKEN}").respond(json={}) + text = Mock() + text.render = Mock(return_value=MESSAGE_TEXT) + await bot.channels.facebook.call_senders({"text": text}, dialog) + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 10, "pool": 3.1, "read": 3.1, "write": 3.1}, + ] + + +def test_limits(): + MaxBot.inline( + f""" + channels: + facebook: + app_secret: {FB_APP_SECRET} + access_token: {ACCESS_TOKEN} + limits: + max_keepalive_connections: 1 + max_connections: 2 + keepalive_expiry: 3 + """ + ) diff --git a/tests/test_channels/test_telegram.py b/tests/test_channels/test_telegram.py index bcaa16d..fe2d59d 100644 --- a/tests/test_channels/test_telegram.py +++ b/tests/test_channels/test_telegram.py @@ -187,7 +187,7 @@ async def test_sanic_endpoint(bot, respx_mock, update_text): callback = AsyncMock() app = Sanic(__name__) - app.blueprint(bot.channels.telegram.blueprint(callback)) + app.blueprint(bot.channels.telegram.blueprint(callback, None)) _, response = await app.asgi_client.post("/telegram", json=UPDATE_TEXT) assert response.status_code == 204, response.text assert response.text == "" @@ -198,10 +198,62 @@ async def test_sanic_register_webhook(bot, respx_mock, monkeypatch): respx_mock.post(f"{API_URL}/setWebhook").respond(json={"result": {}}) monkeypatch.setattr(Blueprint, "after_server_start", Mock()) - bp = bot.channels.telegram.blueprint(AsyncMock(), public_url="https://example.com/") + async def execute_once(app, fn): + await fn() + + bp = bot.channels.telegram.blueprint( + AsyncMock(), execute_once, public_url="https://example.com/" + ) for call in bp.after_server_start.call_args_list: (coro,) = call.args await coro(Mock(), Mock()) request = respx_mock.calls.last.request assert request.content == b"url=https%3A%2F%2Fexample.com%2Ftelegram" + + +async def test_timeout_not_specified(bot, respx_mock, dialog): + respx_mock.post(f"{API_URL}/sendMessage").respond(json={"result": {}}) + text = Mock() + text.render = Mock(return_value=MESSAGE_TEXT) + + await bot.channels.telegram.call_senders({"text": text}, dialog) + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 5.0, "read": 5.0, "write": 5.0, "pool": 5.0}, + ] + + +async def test_timeout_send_photo(respx_mock, dialog): + bot = MaxBot.inline( + f""" + channels: + telegram: + api_token: {API_TOKEN} + timeout: + default: 3.5 + connect: 1.0 + """ + ) + respx_mock.post(f"{API_URL}/sendPhoto").respond(json={"result": {}}) + url = "https://api.telegram.org/file/bot110201543/photos/file_21.jpg" + respx_mock.get(url).respond(stream=b"IMAGE_CONTENT") + + await bot.channels.telegram.call_senders({"image": {"url": url}}, dialog) + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 1.0, "read": 3.5, "write": 3.5, "pool": 3.5}, + {"connect": 1.0, "read": 3.5, "write": 20, "pool": 3.5}, # send_photo(write_timeout=20) + ] + + +async def test_limits(): + MaxBot.inline( + f""" + channels: + telegram: + api_token: {API_TOKEN} + limits: + max_keepalive_connections: 1 + max_connections: 2 + keepalive_expiry: 3 + """ + ) diff --git a/tests/test_channels/test_viber.py b/tests/test_channels/test_viber.py index f7b33c7..43c7bb4 100644 --- a/tests/test_channels/test_viber.py +++ b/tests/test_channels/test_viber.py @@ -142,7 +142,7 @@ async def test_sanic_endpoint(bot, respx_mock): callback = AsyncMock() app = Sanic(__name__) - app.blueprint(bot.channels.viber.blueprint(callback)) + app.blueprint(bot.channels.viber.blueprint(callback, None)) _, response = await app.asgi_client.post("/viber", json=UPDATE_TEXT) assert response.status_code == 204, response.text assert response.text == "" @@ -155,7 +155,10 @@ async def test_sanic_register_webhook(bot, respx_mock, monkeypatch): respx_mock.post(f"{API_URL}/set_webhook").respond(json={"status": 0, "event_types": []}) monkeypatch.setattr(Blueprint, "after_server_start", Mock()) - bp = bot.channels.viber.blueprint(AsyncMock(), public_url="https://example.com/") + async def execute_once(app, fn): + await fn() + + bp = bot.channels.viber.blueprint(AsyncMock(), execute_once, public_url="https://example.com/") for call in bp.after_server_start.call_args_list: (coro,) = call.args await coro(Mock(), Mock()) @@ -196,3 +199,55 @@ async def test_get_avatar_and_name(dialog, respx_mock): "receiver": USER_ID, "sender": {"name": name, "avatar": avatar}, } + + +async def test_timeout_not_specified(bot, dialog, respx_mock): + respx_mock.post(f"{API_URL}/send_message").respond(json={"status": 0, "message_token": "11"}) + + text = Mock() + text.render = Mock(return_value=MESSAGE_TEXT) + await bot.channels.viber.call_senders({"text": text}, dialog) + + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 5.0, "pool": 5.0, "read": 5.0, "write": 5.0}, + ] + + +async def test_timeout(dialog, respx_mock): + bot = MaxBot.inline( + f""" + channels: + viber: + api_token: {API_TOKEN} + avatar: {DEFAULT_AVATAR} + name: MAXBOT + timeout: + default: 3.1 + connect: 10 + """ + ) + respx_mock.post(f"{API_URL}/send_message").respond(json={"status": 0, "message_token": "11"}) + + text = Mock() + text.render = Mock(return_value=MESSAGE_TEXT) + await bot.channels.viber.call_senders({"text": text}, dialog) + + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 10, "pool": 3.1, "read": 3.1, "write": 3.1}, + ] + + +async def test_limits(): + MaxBot.inline( + f""" + channels: + viber: + api_token: {API_TOKEN} + avatar: {DEFAULT_AVATAR} + name: MAXBOT + limits: + max_keepalive_connections: 1 + max_connections: 2 + keepalive_expiry: 3 + """ + ) diff --git a/tests/test_channels/test_vk.py b/tests/test_channels/test_vk.py index 01fb0aa..87e1c39 100644 --- a/tests/test_channels/test_vk.py +++ b/tests/test_channels/test_vk.py @@ -1,19 +1,19 @@ -import json import logging +import urllib.parse from unittest.mock import AsyncMock, Mock import pytest -from sanic import Sanic +from sanic import Blueprint, Sanic from maxbot import MaxBot from maxbot.channels.vk import Gateway -from maxbot.errors import BotError from maxbot.schemas import MessageSchema VK_ACCESS_TOKEN = "4fdfac4de0e7e5af-8ea26569db6b60d8-adf115afb5cfe2d0" USER_ID = 12345 VK_GROUP_ID = 12345678 VK_CONFIRM_SECRET = "confirm" +SECRET_KEY = "secret" API_URL = "https://api.vk.com/method" MESSAGE_TEXT = "hello world" IMAGE_URL = "http://example.com/123.jpg" @@ -77,9 +77,9 @@ def builder(): f""" channels: vk: - confirm_secret: {VK_CONFIRM_SECRET} access_token: {VK_ACCESS_TOKEN} group_id: {VK_GROUP_ID} + secret_key: {SECRET_KEY} """ ) return builder @@ -95,6 +95,10 @@ def dialog(): return {"channel_name": "vk", "user_id": str(USER_ID)} +def get_dict(content): + return dict([(k, v[0]) for k, v in urllib.parse.parse_qs(content.decode("utf-8")).items()]) + + async def test_create_dialog(bot, dialog): incoming_message = UPDATE_IMAGE["object"]["message"] assert dialog == await bot.channels.vk.create_dialog(incoming_message) @@ -114,10 +118,10 @@ async def test_send_text(bot, dialog, respx_mock): request = respx_mock.calls.last.request assert ( - json.loads(request.content).items() + get_dict(request.content).items() >= { "message": MESSAGE_TEXT, - "user_id": USER_ID, + "user_id": str(USER_ID), "v": Gateway.API_VERSION, "access_token": VK_ACCESS_TOKEN, }.items() @@ -181,7 +185,7 @@ async def test_send_image(bot, dialog, url, headers, respx_mock): request = respx_mock.calls[1].request assert request.url.path.endswith("photos.getMessagesUploadServer") assert ( - json.loads(request.content).items() + get_dict(request.content).items() > {"v": Gateway.API_VERSION, "access_token": VK_ACCESS_TOKEN}.items() ) @@ -191,7 +195,7 @@ async def test_send_image(bot, dialog, url, headers, respx_mock): request = respx_mock.calls[3].request assert request.url.path.endswith("photos.saveMessagesPhoto") assert ( - json.loads(request.content).items() + get_dict(request.content).items() > { "server": "server_vk", "photo": "file_id", @@ -204,10 +208,10 @@ async def test_send_image(bot, dialog, url, headers, respx_mock): request = respx_mock.calls[4].request assert request.url.path.endswith("messages.send") assert ( - json.loads(request.content).items() + get_dict(request.content).items() > { "attachment": "photo121_212", - "user_id": USER_ID, + "user_id": str(USER_ID), "message": MESSAGE_TEXT, "v": Gateway.API_VERSION, "access_token": VK_ACCESS_TOKEN, @@ -226,30 +230,243 @@ async def test_send_image_error(bot, dialog, respx_mock): ) +def _webhook_mock(bot, respx_mock, monkeypatch, responses, webhook=None): + app = Sanic(__name__) + monkeypatch.setattr(Blueprint, "after_server_start", Mock()) + + async def execute_once(app, fn): + await fn() + + bp = bot.channels.vk.blueprint( + AsyncMock(), execute_once, public_url=webhook or "http://webhook" + ) + app.blueprint(bp) + while len(responses) < 5: + responses.append({}) + respx_mock.post(f"{API_URL}/groups.getCallbackServers").respond(json=responses[0]) + respx_mock.post(f"{API_URL}/groups.deleteCallbackServer").respond(json=responses[1]) + respx_mock.post(f"{API_URL}/groups.addCallbackServer").respond(json=responses[2]) + respx_mock.post(f"{API_URL}/groups.setCallbackSettings").respond(json=responses[3]) + respx_mock.post(f"{API_URL}/groups.getCallbackConfirmationCode").respond(json=responses[4]) + return app, bp + + +async def test_set_webhook(bot, respx_mock, monkeypatch): + confirm_code = "123456" + old_server1, old_server2, new_server = 11, 22, 33 + webhook = "https://webhook.ai/" + responses = [ + {"response": {"items": [{"id": old_server1}, {"id": old_server2}]}}, + {"response": 1}, + {"response": {"server_id": new_server}}, + {"response": 1}, + {"response": {"code": confirm_code}}, + ] + app, bp = _webhook_mock(bot, respx_mock, monkeypatch, responses, webhook) + + for call in bp.after_server_start.call_args_list: + (coro,) = call.args + await coro(Mock(), Mock()) + + calls = respx_mock.calls + assert len(calls) == 6 + expected = {"v": "5.131", "access_token": VK_ACCESS_TOKEN, "group_id": str(VK_GROUP_ID)} + assert get_dict(calls[0].request.content) == expected + assert get_dict(calls[1].request.content) == dict( + **expected, **{"server_id": str(old_server1)} + ) + assert get_dict(calls[2].request.content) == dict( + **expected, **{"server_id": str(old_server2)} + ) + data = {"secret_key": SECRET_KEY, "title": "MAXBOT", "url": f"{webhook}vk"} + assert get_dict(calls[3].request.content) == dict(**expected, **data) + data = { + "api_version": "5.131", + "message_allow": "1", + "message_deny": "1", + "message_new": "1", + "message_reply": "1", + "server_id": str(new_server), + } + assert get_dict(calls[4].request.content) == dict(**expected, **data) + assert get_dict(calls[5].request.content) == expected + + _, response = await app.asgi_client.post("/vk", json=UPDATE_CONFIRM) + assert response.status_code == 200, response.text + assert response.text == confirm_code + + +async def _assert_webhhok_error(bp, respx_mock, count): + with pytest.raises(RuntimeError): + for call in bp.after_server_start.call_args_list: + (coro,) = call.args + await coro(Mock(), Mock()) + calls = respx_mock.calls + assert len(calls) == count + + +async def test_get_callback_error(bot, respx_mock, monkeypatch): + responses = [{"error": {"error_code": 10}}] + app, bp = _webhook_mock(bot, respx_mock, monkeypatch, responses) + await _assert_webhhok_error(bp, respx_mock, count=1) + + +async def test_set_callback_error(bot, respx_mock, monkeypatch): + responses = [ + {"response": {"items": [{"id": 1}, {"id": 2}]}}, + {"response": 0}, + {"error": {"error_code": 10}}, + ] + app, bp = _webhook_mock(bot, respx_mock, monkeypatch, responses) + await _assert_webhhok_error(bp, respx_mock, count=4) + + +async def test_callback_settings_error(bot, respx_mock, monkeypatch): + responses = [ + {"response": {"items": [{"id": 1}, {"id": 2}]}}, + {"response": 1}, + {"response": {"server_id": 3}}, + {"response": 0}, + ] + app, bp = _webhook_mock(bot, respx_mock, monkeypatch, responses) + await _assert_webhhok_error(bp, respx_mock, count=5) + + +async def test_confirm_code_error(bot, respx_mock, monkeypatch): + responses = [ + {"response": {"items": [{"id": 1}, {"id": 2}]}}, + {"response": 1}, + {"response": {"server_id": 3}}, + {"response": 1}, + {"error": {"error_code": 10}}, + ] + app, bp = _webhook_mock(bot, respx_mock, monkeypatch, responses) + await _assert_webhhok_error(bp, respx_mock, count=6) + + async def test_sanic_endpoint(bot): callback = AsyncMock() app = Sanic(__name__) - app.blueprint(bot.channels.vk.blueprint(callback)) + app.blueprint(bot.channels.vk.blueprint(callback, None)) _, response = await app.asgi_client.post("/vk", json=UPDATE_TEXT) assert response.status_code == 200, response.text assert response.text == "ok" callback.assert_called_once_with(UPDATE_TEXT["object"]["message"], bot.channels.vk) - _, response = await app.asgi_client.post("/vk", json=UPDATE_CONFIRM) - assert response.status_code == 200, response.text - assert response.text == "confirm" - _, response = await app.asgi_client.post( "/vk", json={"type": "confirmation", "group_id": 12345} ) assert response.status_code == 400, response.text + _, response = await app.asgi_client.post( + "/vk", json={"type": "confirmation", "group_id": VK_GROUP_ID} + ) + assert response.status_code == 500, response.text + + +async def test_skip_register_webhook(bot, caplog): + builder = MaxBot.builder() + builder.use_inline_resources( + f""" channels: + vk: + access_token: {VK_ACCESS_TOKEN} + """ + ) + bot = builder.build() + + async def execute_once(app, fn): + await fn() -async def test_sanic_register_webhook(bot, caplog): with caplog.at_level(logging.WARNING): - bot.channels.vk.blueprint(AsyncMock(), public_url="https://example.com/") + bot.channels.vk.blueprint(AsyncMock(), execute_once, public_url="https://example.com/") assert ( - "The vk platform has no suitable api, register a webhook yourself https://example.com/vk." + "Skip register webhook, set secret_key and group_id for register new webhook" ) in caplog.text + + +async def test_timeout_not_specified(bot, dialog, respx_mock): + respx_mock.post(f"{API_URL}/messages.send").respond(json={"response": 1}) + + text = Mock() + text.render = Mock(return_value=MESSAGE_TEXT) + await bot.channels.vk.call_senders( + command={"text": text}, + dialog=dialog, + ) + + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 5.0, "pool": 5.0, "read": 5.0, "write": 5.0}, + ] + + +async def test_timeout(dialog, respx_mock): + bot = MaxBot.inline( + f""" + channels: + vk: + access_token: {VK_ACCESS_TOKEN} + group_id: {VK_GROUP_ID} + secret_key: {SECRET_KEY} + timeout: + default: 3.1 + connect: 10 + """ + ) + respx_mock.post(f"{API_URL}/messages.send").respond(json={"response": 1}) + + text = Mock() + text.render = Mock(return_value=MESSAGE_TEXT) + await bot.channels.vk.call_senders( + command={"text": text}, + dialog=dialog, + ) + + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 10, "pool": 3.1, "read": 3.1, "write": 3.1}, + ] + + +async def test_send_image_timeout(bot, dialog, respx_mock): + respx_mock.get(IMAGE_URL).respond( + headers={"content-type": "image/jpeg"}, stream=b"image contents" + ) + respx_mock.post(f"{API_URL}/photos.getMessagesUploadServer").respond( + json={"response": {"upload_url": "http://upload_photo.vk"}} + ) + respx_mock.post("http://upload_photo.vk").respond( + json={"photo": "file_id", "server": "server_vk", "hash": "hash_vk"} + ) + respx_mock.post(f"{API_URL}/photos.saveMessagesPhoto").respond( + json={"response": [{"owner_id": 121, "id": 212}]} + ) + respx_mock.post(f"{API_URL}/messages.send").respond(json={"response": 1}) + + caption = Mock() + caption.render = Mock(return_value=MESSAGE_TEXT) + await bot.channels.vk.call_senders({"image": {"url": IMAGE_URL, "caption": caption}}, dialog) + + assert [c.request.extensions["timeout"] for c in respx_mock.calls] == [ + {"connect": 30, "pool": 30, "read": 30, "write": 30}, # maxbot._download.HTTPX_CLIENT + {"connect": 5.0, "pool": 5.0, "read": 5.0, "write": 5.0}, + {"connect": 5.0, "pool": 5.0, "read": 5.0, "write": 5.0}, + {"connect": 5.0, "pool": 5.0, "read": 5.0, "write": 5.0}, + {"connect": 5.0, "pool": 5.0, "read": 5.0, "write": 5.0}, + ] + + +async def test_limits(): + MaxBot.inline( + f""" + channels: + vk: + access_token: {VK_ACCESS_TOKEN} + group_id: {VK_GROUP_ID} + secret_key: {SECRET_KEY} + limits: + max_keepalive_connections: 1 + max_connections: 2 + keepalive_expiry: 3 + """ + ) diff --git a/tests/test_cli/test_journal.py b/tests/test_cli/test_journal.py index e499255..6abd7cf 100644 --- a/tests/test_cli/test_journal.py +++ b/tests/test_cli/test_journal.py @@ -2,7 +2,7 @@ import pytest -from maxbot.cli._journal import Dumper, FileJournal, JournalChain +from maxbot.cli._journal import Dumper, FileJournal, FileQuietJournal, no_journal from maxbot.context import ( EntitiesResult, IntentsResult, @@ -27,6 +27,16 @@ def yaml_journal(tmp_path): yield from _file_journal(tmp_path, Dumper.yaml_triple_dash) +@pytest.fixture +def jsonl_quiet_journal(tmp_path): + yield from _file_journal(tmp_path, Dumper.json_line, klass=FileQuietJournal) + + +@pytest.fixture +def yaml_quiet_journal(tmp_path): + yield from _file_journal(tmp_path, Dumper.yaml_triple_dash, klass=FileQuietJournal) + + @pytest.fixture def ctx(): ctx = TurnContext( @@ -47,6 +57,11 @@ def test_jsonl_basic(ctx, jsonl_journal): assert '"response": "Good day to you!"' in out, out +def test_jsonl_quiet_basic(ctx, jsonl_quiet_journal): + ctx.commands.append({"text": markup.Value([markup.Item(markup.TEXT, "Good day to you!")])}) + assert not jsonl_quiet_journal(ctx) + + def test_yaml_basic(ctx, yaml_journal): ctx.commands.append({"text": markup.Value([markup.Item(markup.TEXT, "Good day to you!")])}) @@ -60,6 +75,11 @@ def test_yaml_basic(ctx, yaml_journal): assert "response: Good day to you!" in out, out +def test_yaml_quiet_basic(ctx, yaml_quiet_journal): + ctx.commands.append({"text": markup.Value([markup.Item(markup.TEXT, "Good day to you!")])}) + assert not yaml_quiet_journal(ctx) + + def test_jsonl_intents(jsonl_journal): ctx = TurnContext( dialog={"channel_name": "test", "user_id": "123"}, @@ -232,6 +252,15 @@ def test_jsonl_logs(ctx, jsonl_journal): ) in out +def test_jsonl_quiet_error(ctx, jsonl_quiet_journal): + ctx.debug(("what is here?", {"xxx": "yyy"})) + ctx.warning("something wrong") + ctx.set_error(BotError("some error")) + + out = jsonl_quiet_journal(ctx) + assert out.strip() == '{"error": {"message": "some error"}}' + + def test_yaml_logs(ctx, yaml_journal): ctx.debug(("what is here?", {"xxx": "yyy"})) ctx.warning("something wrong") @@ -258,6 +287,26 @@ def test_yaml_logs(ctx, yaml_journal): ) in out +def test_yaml_quiet_error(ctx, yaml_quiet_journal): + ctx.debug(("what is here?", {"xxx": "yyy"})) + ctx.warning("something wrong") + ctx.set_error(BotError("some error")) + + out = yaml_quiet_journal(ctx) + assert out.strip() == "error:\n message: some error\n---" + + +def test_no_journal(ctx, capsys): + ctx.commands.append({"text": markup.Value([markup.Item(markup.TEXT, "Good day to you!")])}) + ctx.debug(("what is here?", {"xxx": "yyy"})) + ctx.warning("something wrong") + ctx.set_error(BotError("some error")) + no_journal(ctx) + captured = capsys.readouterr() + assert captured.out == "" + assert captured.err == "" + + def test_jsonl_error_snippet(ctx, jsonl_journal): class C(ResourceSchema): s = fields.String() @@ -366,17 +415,20 @@ def test_yaml_journal_events(ctx, yaml_journal): ) in out, out -def test_journal_chain(): - history = [] - journal = JournalChain([lambda ctx: history.append(1), lambda ctx: history.append(2)]) - journal(ctx=None) - assert history == [1, 2] +def test_yaml_journal_events_no_aliases(ctx, yaml_journal): + d = {"key1": "value1", "key2": "value2", "key3": "value3", "key4": "value4"} + ctx.journal_event("test", {"d1": d, "d2": d}) + + out = yaml_journal(ctx) + + assert "d1:\n" in out, out + assert "d2:\n" in out, out -def _file_journal(tmp_path, dumps): +def _file_journal(tmp_path, dumps, klass=FileJournal): journal_file = tmp_path / "maxbot.journal" with journal_file.open("a") as f: - journal = FileJournal(f, dumps) + journal = klass(f, dumps) def call(ctx): journal(ctx) diff --git a/tests/test_cli/test_ngrok.py b/tests/test_cli/test_ngrok.py index 9f6574d..af2b9ee 100644 --- a/tests/test_cli/test_ngrok.py +++ b/tests/test_cli/test_ngrok.py @@ -7,11 +7,10 @@ from sanic import Sanic import maxbot.cli -from maxbot import MaxBot def test_ngrok(runner, monkeypatch, respx_mock, botfile): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) respx_mock.get("http://localhost:4040/api/tunnels").respond( json={ @@ -37,13 +36,14 @@ def test_ngrok(runner, monkeypatch, respx_mock, botfile): "--bot", botfile, "--ngrok", + "--single-process", ], catch_exceptions=False, ) assert result.exit_code == 0, result.output - ca = MaxBot.run_webapp.call_args - assert ca.args == ("localhost", 8080) + ca = maxbot.cli.run.run_webapp.call_args + assert ca.args[2:] == ("localhost", 8080) assert ca.kwargs["public_url"] == "https://7ad1-109-172-248-9.ngrok.io" request = respx_mock.calls.last.request @@ -51,7 +51,7 @@ def test_ngrok(runner, monkeypatch, respx_mock, botfile): def test_url(runner, monkeypatch, respx_mock, botfile): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) respx_mock.get("http://localhost:4041/api/tunnels").respond( json={ @@ -72,7 +72,7 @@ def test_url(runner, monkeypatch, respx_mock, botfile): result = runner.invoke( maxbot.cli.main, - ["run", "--bot", botfile, "--ngrok-url", "http://localhost:4041/"], + ["run", "--bot", botfile, "--ngrok-url", "http://localhost:4041/", "--single-process"], catch_exceptions=False, ) diff --git a/tests/test_cli/test_resolve_bot.py b/tests/test_cli/test_resolve_bot.py index 079615b..cbf44f9 100644 --- a/tests/test_cli/test_resolve_bot.py +++ b/tests/test_cli/test_resolve_bot.py @@ -25,7 +25,9 @@ def pkgdir(botfile, monkeypatch): def test_from_file(runner, botfile, monkeypatch): monkeypatch.setattr(Sanic, "run", Mock()) - result = runner.invoke(maxbot.cli.main, ["run", "--bot", botfile], catch_exceptions=False) + result = runner.invoke( + maxbot.cli.main, ["run", "--bot", botfile, "--single-process"], catch_exceptions=False + ) assert result.exit_code == 0, result.output assert Sanic.run.call_count == 1 @@ -35,7 +37,9 @@ def test_custom_filename(runner, botfile, monkeypatch): botfile.rename(customfile) monkeypatch.setattr(Sanic, "run", Mock()) - result = runner.invoke(maxbot.cli.main, ["run", "--bot", customfile], catch_exceptions=False) + result = runner.invoke( + maxbot.cli.main, ["run", "--bot", customfile, "--single-process"], catch_exceptions=False + ) assert result.exit_code == 0, result.output assert Sanic.run.call_count == 1 @@ -50,7 +54,9 @@ def test_from_file_bot_error(runner, botfile, monkeypatch, caplog): ) with caplog.at_level(logging.CRITICAL): - result = runner.invoke(maxbot.cli.main, ["run", "--bot", botfile], catch_exceptions=False) + result = runner.invoke( + maxbot.cli.main, ["run", "--bot", botfile, "--single-process"], catch_exceptions=False + ) assert result.exit_code == 1, result.output assert "Invalid input type" in caplog.text assert " dialog: XXX" in caplog.text @@ -60,7 +66,9 @@ def test_from_directory(runner, botfile, monkeypatch): monkeypatch.setattr(Sanic, "run", Mock()) result = runner.invoke( - maxbot.cli.main, ["run", "--bot", botfile.parent], catch_exceptions=False + maxbot.cli.main, + ["run", "--bot", botfile.parent, "--single-process"], + catch_exceptions=False, ) assert result.exit_code == 0, result.output assert Sanic.run.call_count == 1 @@ -69,7 +77,9 @@ def test_from_directory(runner, botfile, monkeypatch): def test_package(runner, pkgdir, monkeypatch): monkeypatch.setattr(Sanic, "run", Mock()) - result = runner.invoke(maxbot.cli.main, ["run", "--bot", pkgdir.name], catch_exceptions=False) + result = runner.invoke( + maxbot.cli.main, ["run", "--bot", pkgdir.name, "--single-process"], catch_exceptions=False + ) assert result.exit_code == 0, result.output assert Sanic.run.call_count == 1 @@ -86,7 +96,7 @@ def test_package_custom_bot_name(runner, pkgdir, monkeypatch): result = runner.invoke( maxbot.cli.main, - ["run", "--bot", f"{pkgdir.name}:custom_bot"], + ["run", "--bot", f"{pkgdir.name}:custom_bot", "--single-process"], catch_exceptions=False, ) assert result.exit_code == 0, result.output diff --git a/tests/test_cli/test_rich.py b/tests/test_cli/test_rich.py index 1de4055..1f1ae96 100644 --- a/tests/test_cli/test_rich.py +++ b/tests/test_cli/test_rich.py @@ -12,23 +12,29 @@ TurnContext, ) from maxbot.errors import BotError, YamlSnippet +from maxbot.flows.slot_filling import SlotFilling, SlotSchema from maxbot.maxml import fields, markup from maxbot.schemas import CommandSchema, ResourceSchema +@pytest.fixture +def console_journal_q(): + return _create_console_journal(verbosity=-1) + + @pytest.fixture def console_journal(): - return _create_console_journal(verbose=0) + return _create_console_journal(verbosity=0) @pytest.fixture def console_journal_v(): - return _create_console_journal(verbose=1) + return _create_console_journal(verbosity=1) @pytest.fixture def console_journal_vv(): - return _create_console_journal(verbose=2) + return _create_console_journal(verbosity=2) @pytest.fixture @@ -51,7 +57,7 @@ def test_console_basic(ctx, console_journal): assert "Good day to you!" in out, out -def test_console_commands_yaml(ctx, console_journal): +def test_console_commands_xml(ctx, console_journal): ctx.commands.append({"text": markup.Value([markup.Item(markup.TEXT, "Hello, John!")])}) ctx.commands.append( { @@ -69,21 +75,23 @@ def test_console_commands_yaml(ctx, console_journal): def test_console_empty_journal_events(ctx, console_journal): - out = console_journal(ctx) - assert "journal_events" not in out - assert "logs" not in out + _test_console_empty_journal(ctx, console_journal) def test_console_empty_journal_events_v(ctx, console_journal_v): - out = console_journal_v(ctx) - assert "journal_events" not in out - assert "logs" not in out + _test_console_empty_journal(ctx, console_journal_v) def test_console_empty_journal_events_vv(ctx, console_journal_vv): - out = console_journal_vv(ctx) + _test_console_empty_journal(ctx, console_journal_vv) + + +def _test_console_empty_journal(ctx, console_journal_): + out = console_journal_(ctx) assert "journal_events" not in out assert "logs" not in out + assert "user" not in out + assert "slots" not in out def test_console_logs(ctx, console_journal): @@ -244,9 +252,136 @@ def test_console_journal_events_not_serializable(ctx, console_journal_vv): assert f"test {value!r}" in out, out -def _create_console_journal(verbose): +def test_console_journal_events_alias(ctx, console_journal_vv): + d = {"key1": "value1", "key2": "value2", "key3": "value3", "key4": "value4"} + ctx.journal_event("test", {"d1": d, "d2": d}) + + out = console_journal_vv(ctx) + assert "test d1: &" in out, out + assert " d2: *" in out, out + + +def test_console_journal_q_empty(console_journal_q): + ctx = TurnContext( + dialog={"channel_name": "test", "user_id": "321"}, + message={"text": "hello world"}, + intents=IntentsResult.resolve([RecognizedIntent("i1", 1)]), + entities=EntitiesResult.resolve([RecognizedEntity("e1", 1, "one", 2, 34)]), + ) + ctx.commands.append({"text": markup.Value([markup.Item(markup.TEXT, "Hello, John!")])}) + ctx.journal_event("event1", {"data": 1}) + ctx.debug(("what is here?", {"xxx": "yyy"})) + ctx.warning("something wrong") + + assert console_journal_q(ctx) == "" + + +def test_console_journal_q_error(console_journal_q): + ctx = TurnContext( + dialog={"channel_name": "test", "user_id": "321"}, + message={"text": "hello world"}, + intents=IntentsResult.resolve([RecognizedIntent("i1", 1)]), + entities=EntitiesResult.resolve([RecognizedEntity("e1", 1, "one", 2, 34)]), + ) + ctx.commands.append({"text": markup.Value([markup.Item(markup.TEXT, "Hello, John!")])}) + ctx.journal_event("event1", {"data": 1}) + ctx.debug(("what is here?", {"xxx": "yyy"})) + ctx.warning("something wrong") + ctx.set_error(BotError("some error")) + + assert console_journal_q(ctx).strip() == "✗ some error" + + +@pytest.mark.parametrize( + "state_field", + ( + "user", + "slots", + ), +) +def test_console_journal_diff(ctx, console_journal_v, state_field): + getattr(ctx.state, state_field)["test"] = 1 + del getattr(ctx.state, state_field)["test"] + + out = [s.strip().split() for s in console_journal_v(ctx).splitlines()] + assert ["journal_events"] in out, out + assert [f"{state_field}.test", "=", "1"] in out, out + assert ["❌", "delete", f"{state_field}.test"] in out, out + + +def test_console_journal_diff_clear(ctx, console_journal_v): + ctx.state.slots["test1"] = 1 + ctx.state.slots["test2"] = 1 + ctx.clear_state_variables() + + out = [s.strip().split() for s in console_journal_v(ctx).splitlines()] + assert ["journal_events"] in out, out + assert ["❌", "delete", "slots.test1"] in out, out + assert ["❌", "delete", "slots.test2"] in out, out + + +def test_console_journal_diff_clear_empty(ctx, console_journal_v): + ctx.clear_state_variables() + + out = console_journal_v(ctx) + assert f"journal_events" not in out, out + + +@pytest.mark.parametrize( + "state_field", + ( + "user", + "slots", + ), +) +def test_console_journal_full(ctx, console_journal_vv, state_field): + getattr(ctx.state, state_field)["test"] = 1 + + out = console_journal_vv(ctx) + assert f"{state_field}\n" in out, out + assert ".test 1" in out, out + + +@pytest.mark.parametrize( + "state_field", + ( + "user", + "slots", + ), +) +def test_console_journal_full_if_error(ctx, console_journal_v, state_field): + getattr(ctx.state, state_field)["test"] = 1 + ctx.set_error(BotError("some error")) + + out = console_journal_v(ctx) + assert f"{state_field}\n" in out, out + assert ".test 1" in out, out + + +async def test_jourtnal_slot_filling(ctx, console_journal_vv): + model = SlotFilling( + SlotSchema(many=True).loads( + """ + - name: slot1 + check_for: true + found: "" """ + ), + [], + ) + await model(ctx, ctx.state.components.setdefault("xxx", {})) + + out = console_journal_vv(ctx) + assert "journal_events\n" in out, out + assert "slot_filling slot: slot1", out + assert "slots.slot1 = True", out + assert "found slot: slot1", out + assert " control_command: prompt_again", out + assert "❌ delete slots.slot1", out + + +def _create_console_journal(verbosity): console = Console(force_terminal=False, soft_wrap=True) - journal = PrettyJournal(verbose=verbose, console=console) + journal = PrettyJournal(verbosity=verbosity, console=console) def call(ctx): with console.capture() as capture: diff --git a/tests/test_cli/test_run.py b/tests/test_cli/test_run.py index 12d9b5f..786d05a 100644 --- a/tests/test_cli/test_run.py +++ b/tests/test_cli/test_run.py @@ -1,4 +1,5 @@ import logging +import sys from pathlib import Path from unittest.mock import AsyncMock, Mock, sentinel @@ -8,16 +9,30 @@ import maxbot.cli from maxbot import MaxBot -from maxbot.cli._journal import Dumper, FileJournal +from maxbot.cli._journal import Dumper, FileJournal, FileQuietJournal, no_journal from maxbot.cli._rich import PrettyJournal from maxbot.dialog_manager import DialogManager +def test_updater_webhooks_default(runner, botfile, monkeypatch): + monkeypatch.setattr(Sanic, "serve", Mock()) + + result = runner.invoke( + maxbot.cli.main, + ["run", "--bot", botfile], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + assert Sanic.serve.call_count == 1 + + def test_updater_webhooks(runner, botfile, monkeypatch): monkeypatch.setattr(Sanic, "run", Mock()) result = runner.invoke( - maxbot.cli.main, ["run", "--bot", botfile, "--updater", "webhooks"], catch_exceptions=False + maxbot.cli.main, + ["run", "--bot", botfile, "--updater", "webhooks", "--single-process"], + catch_exceptions=False, ) assert result.exit_code == 0, result.output assert Sanic.run.call_count == 1 @@ -61,26 +76,28 @@ def test_updater_for_host(runner, botfile, monkeypatch): monkeypatch.setattr(Sanic, "run", Mock()) result = runner.invoke( - maxbot.cli.main, ["run", "--bot", botfile, "--host", "localhost"], catch_exceptions=False + maxbot.cli.main, + ["run", "--bot", botfile, "--host", "localhost", "--single-process"], + catch_exceptions=False, ) assert result.exit_code == 0, result.output assert Sanic.run.call_count == 1 def test_updater_for_port(runner, botfile, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) result = runner.invoke( maxbot.cli.main, - ["run", "--bot", botfile, "--host", "myhost", "--port", "123"], + ["run", "--bot", botfile, "--host", "myhost", "--port", "123", "--single-process"], catch_exceptions=False, ) assert result.exit_code == 0, result.output - assert MaxBot.run_webapp.call_args.args == ("myhost", 123) + assert maxbot.cli.run.run_webapp.call_args.args[2:] == ("myhost", 123) def test_public_url(runner, botfile, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) result = runner.invoke( maxbot.cli.main, @@ -90,33 +107,28 @@ def test_public_url(runner, botfile, monkeypatch): botfile, "--public-url", "http://example.com", + "--single-process", ], catch_exceptions=False, ) - assert MaxBot.run_webapp.call_args.kwargs["public_url"] == "http://example.com" + assert maxbot.cli.run.run_webapp.call_args.kwargs["public_url"] == "http://example.com" def test_journal_file(runner, botfile, tmp_path, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) monkeypatch.setattr(DialogManager, "journal", Mock()) journal_file = tmp_path / "maxbot.jsonl" result = runner.invoke( maxbot.cli.main, - [ - "run", - "--bot", - botfile, - "--journal-file", - f"{journal_file}", - ], + ["run", "--bot", botfile, "--journal-file", f"{journal_file}", "--single-process"], catch_exceptions=False, ) assert result.exit_code == 0, result.output (journal,) = DialogManager.journal.call_args.args - assert journal.chain[-1].f.name == str(journal_file) + assert journal.f.name == str(journal_file) def test_journal_bad_file(runner, botfile, tmp_path): @@ -142,45 +154,69 @@ def test_journal_bad_file(runner, botfile, tmp_path): "output, dumps", (("json", Dumper.json_line), ("yaml", Dumper.yaml_triple_dash)) ) def test_journal_output(runner, botfile, tmp_path, monkeypatch, output, dumps): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) monkeypatch.setattr(DialogManager, "journal", Mock()) journal_file = tmp_path / "maxbot.jsonl" result = runner.invoke( maxbot.cli.main, - ["run", "--bot", botfile, "--journal-file", f"{journal_file}", "--journal-output", output], + [ + "run", + "--bot", + botfile, + "--journal-file", + f"{journal_file}", + "--journal-output", + output, + "--single-process", + ], catch_exceptions=False, ) assert result.exit_code == 0, result.output (journal,) = DialogManager.journal.call_args.args - assert journal.chain[-1].dumps == dumps + assert journal.dumps == dumps def test_no_journal(runner, botfile, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) monkeypatch.setattr(DialogManager, "journal", Mock()) + try: + result = runner.invoke( + maxbot.cli.main, + ["run", "--bot", botfile, "-qq", "--single-process"], + catch_exceptions=False, + ) + finally: + logging.disable(logging.NOTSET) + assert result.exit_code == 0, result.output + + (journal,) = DialogManager.journal.call_args.args + assert journal == no_journal + + +def test_console_journal(runner, botfile, monkeypatch): + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) + monkeypatch.setattr(DialogManager, "journal", Mock()) + monkeypatch.setattr(maxbot.cli._journal, "_stdout_is_non_interactive", lambda: False) + result = runner.invoke( maxbot.cli.main, - [ - "run", - "--bot", - botfile, - "-q", - ], + ["run", "--bot", botfile, "--single-process"], catch_exceptions=False, ) assert result.exit_code == 0, result.output (journal,) = DialogManager.journal.call_args.args - assert not journal.chain + assert isinstance(journal, PrettyJournal) -def test_console_journal_only(runner, botfile, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) +def test_non_interactive_journal(runner, botfile, monkeypatch): + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) monkeypatch.setattr(DialogManager, "journal", Mock()) + monkeypatch.setattr(maxbot.cli._journal, "_stdout_is_non_interactive", lambda: True) result = runner.invoke( maxbot.cli.main, @@ -188,59 +224,55 @@ def test_console_journal_only(runner, botfile, monkeypatch): "run", "--bot", botfile, + "--single-process", ], catch_exceptions=False, ) assert result.exit_code == 0, result.output (journal,) = DialogManager.journal.call_args.args - (console_journal,) = journal.chain - assert isinstance(console_journal, PrettyJournal) + assert isinstance(journal, FileJournal) -def test_file_journal_only(runner, botfile, tmp_path, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) +def test_file_journal(runner, botfile, tmp_path, monkeypatch): + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) monkeypatch.setattr(DialogManager, "journal", Mock()) journal_file = tmp_path / "maxbot.jsonl" result = runner.invoke( maxbot.cli.main, - [ - "run", - "--bot", - botfile, - "-q", - "--journal-file", - f"{journal_file}", - ], + ["run", "--bot", botfile, "--journal-file", f"{journal_file}", "--single-process"], catch_exceptions=False, ) assert result.exit_code == 0, result.output (journal,) = DialogManager.journal.call_args.args - (file_journal,) = journal.chain - assert isinstance(file_journal, FileJournal) + assert isinstance(journal, FileJournal) -def test_console_and_file_journal(runner, botfile, tmp_path, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) +def test_quiet_file_journal(runner, botfile, tmp_path, monkeypatch): + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) monkeypatch.setattr(DialogManager, "journal", Mock()) journal_file = tmp_path / "maxbot.jsonl" result = runner.invoke( maxbot.cli.main, - [ - "run", - "--bot", - botfile, - "--journal-file", - f"{journal_file}", - ], + ["run", "--bot", botfile, "-q", "--journal-file", f"{journal_file}", "--single-process"], catch_exceptions=False, ) assert result.exit_code == 0, result.output (journal,) = DialogManager.journal.call_args.args - assert len(journal.chain) == 2 + assert isinstance(journal, FileQuietJournal) + + +def test_q_v_mutually_exclusive(runner, telegram_botfile): + result = runner.invoke( + maxbot.cli.main, + ["run", "--bot", telegram_botfile, "-v", "-q"], + catch_exceptions=False, + ) + assert result.exit_code == 2, result.output + assert "Options -q and -v are mutually exclusive." in result.output diff --git a/tests/test_cli/test_run_logging.py b/tests/test_cli/test_run_logging.py index 118cb85..f583bf8 100644 --- a/tests/test_cli/test_run_logging.py +++ b/tests/test_cli/test_run_logging.py @@ -13,20 +13,14 @@ def test_logger_file(runner, botfile, tmp_path, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) monkeypatch.setattr(logging, "basicConfig", Mock()) logfile = tmp_path / "maxbot.log" result = runner.invoke( maxbot.cli.main, - [ - "run", - "--bot", - botfile, - "--logger", - f"file:{logfile}", - ], + ["run", "--bot", botfile, "--logger", f"file:{logfile}", "--single-process"], catch_exceptions=False, ) assert result.exit_code == 0, result.output @@ -36,18 +30,17 @@ def test_logger_file(runner, botfile, tmp_path, monkeypatch): @pytest.fixture def console_handler(runner, botfile, monkeypatch): - monkeypatch.setattr(MaxBot, "run_webapp", Mock()) + return _console_handler(runner, botfile, monkeypatch) + + +def _console_handler(runner, botfile, monkeypatch, non_interactive=False): + monkeypatch.setattr(maxbot.cli.run, "run_webapp", Mock()) monkeypatch.setattr(logging, "basicConfig", Mock()) + monkeypatch.setattr(maxbot.cli._logging, "_stderr_is_non_interactive", lambda: non_interactive) result = runner.invoke( maxbot.cli.main, - [ - "run", - "--bot", - botfile, - "--logger", - "console", - ], + ["run", "--bot", botfile, "--logger", "console", "--single-process"], catch_exceptions=False, ) assert result.exit_code == 0, result.output @@ -62,7 +55,7 @@ def _make_log_record(log_level, message, args=tuple()): def test_logger_console(console_handler, capsys): console_handler.emit(_make_log_record(logging.INFO, "foo bar")) _, err = capsys.readouterr() - assert "✓ foo bar" in err + assert " - MainProcess - foo bar" in err def test_logger_console_bot_error(console_handler, capsys): @@ -79,6 +72,11 @@ class C(ResourceSchema): assert " ❱ 1 s: hello world" in err +def test_logger_non_interactive(runner, botfile, monkeypatch): + handler = _console_handler(runner, botfile, monkeypatch, non_interactive=True) + assert isinstance(handler, logging.StreamHandler) + + def test_logger_bad_file(runner, tmp_path): badfile = tmp_path / "maxbot.log" badfile.mkdir() diff --git a/tests/test_cli/test_stories.py b/tests/test_cli/test_stories.py deleted file mode 100644 index 52b312e..0000000 --- a/tests/test_cli/test_stories.py +++ /dev/null @@ -1,366 +0,0 @@ -from datetime import datetime, timezone -from pathlib import Path - -import pytest - -import maxbot.cli.stories -from maxbot.builder import BotBuilder -from maxbot.cli import main -from maxbot.errors import BotError -from maxbot.maxml import markup - - -def test_minimal(runner, tmp_path): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text("{}") - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text("[]") - - result = runner.invoke(main, ["stories", "--bot", bot_file], catch_exceptions=False) - assert result.exit_code == 0, result.output - - -def test_minimal_explicit(runner, tmp_path): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text("{}") - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text("[]") - - result = runner.invoke( - main, ["stories", "--bot", tmp_path, "--stories", stories_file], catch_exceptions=False - ) - assert result.exit_code == 0, result.output - - -def _iter_examples(): - for dir_path in (Path(__file__).parents[2] / "examples").iterdir(): - if dir_path.is_dir(): - if (dir_path / "stories.yaml").is_file(): - if (dir_path / "bot.yaml").is_file(): - yield str(dir_path) - - -@pytest.mark.parametrize("project_dir", tuple(_iter_examples())) -def test_examples(runner, project_dir): - result = runner.invoke(main, ["stories", "--bot", project_dir], catch_exceptions=False) - assert result.exit_code == 0, result.output - - -@pytest.mark.parametrize( - "utc_time", ("2023-04-10T19:15:58.104144", "2023-04-10T18:15:58.104144-01:00") -) -def test_utc_time_template(runner, tmp_path, utc_time): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text( - """ - extensions: - datetime: {} - dialog: - - condition: true - response: | - {{ utc_time.isoformat() }} - """ - ) - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text( - f""" - - name: test - turns: - - utc_time: "{utc_time}" - message: hello - response: "2023-04-10T19:15:58.104144+00:00" - """ - ) - - result = runner.invoke(main, ["stories", "--bot", tmp_path], catch_exceptions=False) - assert result.exit_code == 0, result.output - assert "test OK" in result.output, result.output - - -def test_utc_time_entitites(runner, tmp_path): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text( - """ - entities: - - name: date - dialog: - - condition: entities.date - response: | - {{ entities.date.value }} - """ - ) - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text( - """ - - name: test - turns: - - utc_time: '2021-01-01T19:15:58.104144' - message: today - response: '2021-01-01' - """ - ) - - result = runner.invoke(main, ["stories", "--bot", tmp_path], catch_exceptions=False) - assert result.exit_code == 0, result.output - assert "test OK" in result.output, result.output - - -def test_fail(runner, tmp_path): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text( - """ - dialog: - - condition: true - response: | - {{ message.text }} - """ - ) - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text( - """ - - name: test - turns: - - message: hello - response: HELLO - """ - ) - - result = runner.invoke(main, ["stories", "--bot", tmp_path], catch_exceptions=False) - assert result.exit_code == 1, result.output - assert result.output.startswith( - ( - "test FAILED at step [0]\n" - "Expected:\n" - " HELLO\n" - "Actual:\n" - " hello\n" - ) - ), result.output - - -def test_fail_list(runner, tmp_path): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text( - """ - dialog: - - condition: true - response: | - {{ message.text }} - """ - ) - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text( - """ - - name: test - turns: - - message: hello - response: - - hello1 - - hello2 - """ - ) - - result = runner.invoke(main, ["stories", "--bot", tmp_path], catch_exceptions=False) - assert result.exit_code == 1, result.output - assert result.output.startswith( - ( - "test FAILED at step [0]\n" - "Expected:\n" - " hello1\n" - " -or-\n" - " hello2\n" - "Actual:\n" - " hello\n" - ) - ), result.output - - -def test_xfail(runner, tmp_path): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text( - """ - dialog: - - condition: true - response: | - {{ message.text }} - """ - ) - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text( - """ - - xfail: true - name: test - turns: - - message: hello - response: HELLO - """ - ) - - result = runner.invoke(main, ["stories", "--bot", tmp_path], catch_exceptions=False) - assert result.exit_code == 0, result.output - assert "test XFAIL" in result.output, result.output - - -def test_assert_no_message_and_no_rpc(runner, tmp_path, monkeypatch): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text("{}") - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text("[]") - - class _StorySchema: - def load_file(self, *args, **kwargs): - return [{"name": "test", "turns": [{"response": ""}]}] - - monkeypatch.setattr(maxbot.cli.stories, "create_story_schema", lambda bot: _StorySchema()) - - with pytest.raises(AssertionError) as excinfo: - runner.invoke(main, ["stories", "--bot", bot_file], catch_exceptions=False) - assert str(excinfo.value) == "Either message or rpc must be provided." - - -def test_utc_time_tick_10sec(): - provider = maxbot.cli.stories.StoryUtcTimeProvider() - provider.tick(datetime(2020, 1, 1, 0, 0)) - provider.tick() - assert datetime(2020, 1, 1, 0, 0, 10, tzinfo=timezone.utc) == provider() - - -def test_rpc_method_validation_error(): - schema = maxbot.cli.stories.create_story_schema(BotBuilder().build()) - with pytest.raises(BotError) as excinfo: - schema.loads( - """ - - name: test - turns: - - rpc: { method: nonexistent } - response: "" - """ - ) - assert str(excinfo.value) == ( - "caused by marshmallow.exceptions.ValidationError: Method not found\n" - ' in "", line 4, column 28:\n' - " turns:\n" - " - rpc: { method: nonexistent }\n" - " ^^^\n" - ' response: ""' - ) - - -def test_rpc_params_validation_error(): - builder = BotBuilder() - builder.use_inline_resources( - """ - rpc: - - method: with_params - params: - - name: required_param - required: true - """ - ) - schema = maxbot.cli.stories.create_story_schema(builder.build()) - with pytest.raises(BotError) as excinfo: - schema.loads( - """ - - name: test - turns: - - rpc: { method: with_params } - response: "" - """ - ) - assert str(excinfo.value) == ( - "caused by marshmallow.exceptions.ValidationError: {'required_param': " - "['Missing data for required field.']}\n" - ' in "", line 4, column 18:\n' - " turns:\n" - " - rpc: { method: with_params }\n" - " ^^^\n" - ' response: ""' - ) - - -def test_turn_no_message_and_no_rpc(): - schema = maxbot.cli.stories.create_story_schema(BotBuilder().build()) - with pytest.raises(BotError) as excinfo: - schema.loads( - """ - - name: test - turns: - - response: "" - """ - ) - assert str(excinfo.value) == ( - "caused by marshmallow.exceptions.ValidationError: " - "Exactly one of 'message' or 'rpc' is required.\n" - ' in "", line 4, column 13:\n' - " turns:\n" - ' - response: ""\n' - " ^^^\n" - ) - - -def test_match_first(runner, tmp_path): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text( - """ - dialog: - - condition: true - response: | - {{ message.text }} - """ - ) - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text( - """ - - name: test - turns: - - message: hello - response: - - hello - - hello2 - """ - ) - - result = runner.invoke(main, ["stories", "--bot", tmp_path], catch_exceptions=False) - assert result.exit_code == 0, result.output - assert "test OK" in result.output, result.output - - -def test_match_second(runner, tmp_path): - bot_file = tmp_path / "bot.yaml" - bot_file.write_text( - """ - dialog: - - condition: true - response: | - {{ message.text }} - """ - ) - - stories_file = tmp_path / "stories.yaml" - stories_file.write_text( - """ - - name: test - turns: - - message: hello - response: - - hello1 - - hello - """ - ) - - result = runner.invoke(main, ["stories", "--bot", tmp_path], catch_exceptions=False) - assert result.exit_code == 0, result.output - assert "test OK" in result.output, result.output - - -def test_markup_value_rendered_comparator_false(): - assert not maxbot.cli.stories.markup_value_rendered_comparator(markup.Value(), 1) diff --git a/tests/test_context.py b/tests/test_context.py index 621148f..fd8913d 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -6,6 +6,7 @@ EntitiesProxy, EntitiesResult, IntentsResult, + JournalledDict, RecognizedEntity, RecognizedIntent, RpcContext, @@ -248,3 +249,16 @@ def test_utc_time(): def test_today(): ctx = TurnContext(dialog=None, message={"text": "hello world"}) assert ctx.create_scenario_context({})["utc_today"] == ctx.utc_time.date() + + +def test_journalled_dict_len(): + d = JournalledDict() + assert len(d) == 0 + d["a"] = 1 + assert len(d) == 1 + d["a"] = 2 + assert len(d) == 1 + d["b"] = 2 + assert len(d) == 2 + del d["a"] + assert len(d) == 1 diff --git a/tests/test_extensions/test_babel.py b/tests/test_extensions/test_babel.py index 900853d..41fb10a 100644 --- a/tests/test_extensions/test_babel.py +++ b/tests/test_extensions/test_babel.py @@ -209,6 +209,16 @@ def test_unknown_tz_config(utc_time, filter_name): ) +def test_format_time_str(utc_time): + assert _babel()["format_time"](utc_time.time().isoformat(), locale="en") == "12:00:00\u202fPM" + + +def test_unknow_tz(utc_time): + with pytest.raises(BotError) as excinfo: + assert _babel()["format_datetime"](utc_time, tz="UnKnown", locale="en") + assert str(excinfo.value) == "caused by builtins.LookupError: unknown timezone 'UnKnown'" + + def _babel(config={}): builder = Mock() BabelExtension(builder, config) diff --git a/tests/test_extensions/test_rest.py b/tests/test_extensions/test_rest.py index 13096ef..788f0b6 100644 --- a/tests/test_extensions/test_rest.py +++ b/tests/test_extensions/test_rest.py @@ -1,13 +1,22 @@ from base64 import b64encode +from datetime import timedelta +import httpx import pytest import respx +import maxbot.extensions.rest from maxbot.bot import MaxBot from maxbot.errors import BotError +_GB_TIMEOUT = ( + maxbot.extensions.rest.RestExtension.ConfigSchema() + .load({})["garbage_collector_timeout"] + .total_seconds() +) -@pytest.mark.parametrize("method", ("get", "post", "put", "delete")) + +@pytest.mark.parametrize("method", ("get", "post", "put", "delete", "patch")) def test_tag(method): bot = MaxBot.inline( """ @@ -15,7 +24,7 @@ def test_tag(method): rest: {} dialog: - condition: true - response: |- + response: | {% """ + method.upper() + """ "http://127.0.0.1/endpoint" %} @@ -25,6 +34,34 @@ def test_tag(method): _test_mock_common(bot, method) +def test_duplicate_services(): + with pytest.raises(BotError) as excinfo: + MaxBot.inline( + """ + extensions: + rest: + services: + - name: my_service + base_url: http://127.0.0.1/ + - name: My_Service + base_url: http://127.0.0.2/ + dialog: + - condition: true + response: |- + {% set _ = rest_call(service='my_service') %} + test + """ + ) + assert str(excinfo.value) == ( + "Duplicate REST service names: 'My_Service' and 'my_service'\n in" + ' "", line 7, column 25:\n' + " base_url: http://127.0.0.1/\n" + " - name: My_Service\n" + " ^^^\n" + " base_url: http://127.0.0.2/" + ) + + def test_service_args(): bot = MaxBot.inline( """ @@ -35,7 +72,7 @@ def test_service_args(): base_url: http://127.0.0.1/endpoint dialog: - condition: true - response: |- + response: | {% set _ = rest_call(service='my_service') %} test """ @@ -43,6 +80,24 @@ def test_service_args(): _test_mock_common(bot, "get") +def test_service_case_insensitive(): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: my_service + base_url: http://127.0.0.1/endpoint + dialog: + - condition: true + response: |- + {% set _ = rest_call(service='mY_seRVICe') %} + test + """ + ) + _test_mock_common(bot, "get") + + def test_service_not_found(): bot = MaxBot.inline( """ @@ -50,7 +105,7 @@ def test_service_not_found(): rest: {} dialog: - condition: true - response: |- + response: | {% set _ = rest_call(service='my_service') %} test """ @@ -70,7 +125,7 @@ def test_service_in_url(): base_url: http://127.0.0.1 dialog: - condition: true - response: |- + response: | {% GET "my_service://endpoint" %} test """ @@ -78,6 +133,24 @@ def test_service_in_url(): _test_mock_common(bot, "get") +def test_service_in_url_case_insensitive(): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: my_service + base_url: http://127.0.0.1 + dialog: + - condition: true + response: |- + {% GET "My_SERVICe://endpoint" %} + test + """ + ) + _test_mock_common(bot, "get") + + def test_method_default_get(): bot = MaxBot.inline( """ @@ -85,7 +158,7 @@ def test_method_default_get(): rest: {} dialog: - condition: true - response: |- + response: | {% set _ = rest_call(url="http://127.0.0.1/endpoint") %} test """ @@ -100,7 +173,7 @@ def test_method_default_post(): rest: {} dialog: - condition: true - response: |- + response: | {% set _ = rest_call(url="http://127.0.0.1/endpoint", body={"a": 1}) %} test """ @@ -119,7 +192,7 @@ def test_method_service(): method: post dialog: - condition: true - response: |- + response: | {% set _ = rest_call(url="my_service://endpoint") %} test """ @@ -138,7 +211,7 @@ def test_method_args(): method: post dialog: - condition: true - response: |- + response: | {% set _ = rest_call(url="my_service://endpoint", method="put") %} test """ @@ -156,7 +229,7 @@ def test_url_and_base_url(): base_url: http://127.0.0.1 dialog: - condition: true - response: |- + response: | {% set _ = rest_call(service="my_service", url="endpoint") %} test """ @@ -174,7 +247,7 @@ def test_base_url(): base_url: http://127.0.0.1/endpoint dialog: - condition: true - response: |- + response: | {% set _ = rest_call(service="my_service") %} test """ @@ -189,7 +262,7 @@ def test_url(): rest: {} dialog: - condition: true - response: |- + response: | {% set _ = rest_call(url="http://127.0.0.1/endpoint") %} test """ @@ -204,7 +277,7 @@ def test_url_not_specified(): rest: {} dialog: - condition: true - response: |- + response: | {% set _ = rest_call() %} test """ @@ -225,7 +298,7 @@ def test_headers(): headers: {"a": "a", "b": "b"} dialog: - condition: true - response: |- + response: | {% GET "my_service://endpoint" headers {"b": "2", "c": "3"} %} test """ @@ -251,7 +324,7 @@ def test_parameters(): parameters: {"a": "a", "b": "b"} dialog: - condition: true - response: |- + response: | {% GET "my_service://endpoint" parameters {"b": "2", "c": "3"} %} test """ @@ -274,7 +347,7 @@ def test_timeout_default(): base_url: http://127.0.0.1 dialog: - condition: true - response: |- + response: | {% GET "my_service://endpoint" %} test """ @@ -287,6 +360,37 @@ def _match(request): _test_mock_common(bot, "get", additional_matcher=_match) +def test_timeout_config(): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: my_service + base_url: http://127.0.0.1 + timeout: + default: 5.1 + pool: 1 + dialog: + - condition: true + response: |- + {% GET "my_service://endpoint" %} + test + """ + ) + + def _match(request): + assert request.extensions["timeout"] == { + "connect": 5.1, + "pool": 1, + "read": 5.1, + "write": 5.1, + } + return True + + _test_mock_common(bot, "get", additional_matcher=_match) + + def test_timeout_service(): bot = MaxBot.inline( """ @@ -296,9 +400,10 @@ def test_timeout_service(): - name: my_service base_url: http://127.0.0.1 timeout: 6 + timeout: 5.5 dialog: - condition: true - response: |- + response: | {% GET "my_service://endpoint" %} test """ @@ -320,9 +425,10 @@ def test_timeout_args(): - name: my_service base_url: http://127.0.0.1 timeout: 6 + timeout: 5.5 dialog: - condition: true - response: |- + response: | {% GET "my_service://endpoint" timeout 7 %} test """ @@ -335,6 +441,78 @@ def _match(request): _test_mock_common(bot, "get", additional_matcher=_match) +def test_limits_config(monkeypatch): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: my_service + base_url: http://127.0.0.1 + limits: + max_keepalive_connections: 1 + max_connections: 2 + keepalive_expiry: 3 + dialog: + - condition: true + response: |- + {% GET "my_service://endpoint" %} + test + """ + ) + limits = [] + httpx_AsyncClient_ctor = httpx.AsyncClient.__init__ + + def hook_AsyncClient_ctor(self, *args, **kwargs): + limits.append(kwargs.get("limits")) + httpx_AsyncClient_ctor(self, *args, **kwargs) + + monkeypatch.setattr(httpx.AsyncClient, "__init__", hook_AsyncClient_ctor) + + _test_mock_common(bot, "get") + assert limits == [ + httpx.Limits(max_connections=2, max_keepalive_connections=1, keepalive_expiry=3.0) + ] + + +def test_limits_service(monkeypatch): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: my_service + base_url: http://127.0.0.1 + limits: + max_keepalive_connections: 4 + max_connections: 5 + keepalive_expiry: 6 + limits: + max_keepalive_connections: 1 + max_connections: 2 + keepalive_expiry: 3 + dialog: + - condition: true + response: |- + {% GET "my_service://endpoint" %} + test + """ + ) + limits = [] + httpx_AsyncClient_ctor = httpx.AsyncClient.__init__ + + def hook_AsyncClient_ctor(self, *args, **kwargs): + limits.append(kwargs.get("limits")) + httpx_AsyncClient_ctor(self, *args, **kwargs) + + monkeypatch.setattr(httpx.AsyncClient, "__init__", hook_AsyncClient_ctor) + + _test_mock_common(bot, "get") + assert limits == [ + httpx.Limits(max_connections=5, max_keepalive_connections=4, keepalive_expiry=6.0) + ] + + def test_auth_service(): bot = MaxBot.inline( """ @@ -348,7 +526,7 @@ def test_auth_service(): password: mypassword dialog: - condition: true - response: |- + response: | {% GET "my_service://endpoint" %} test """ @@ -376,7 +554,7 @@ def test_auth_args(): password: mypassword dialog: - condition: true - response: |- + response: | {% GET "my_service://endpoint" auth {"user": "myuser2", "password": "mypassword2"} %} test """ @@ -398,7 +576,7 @@ def test_200_json(): rest: {} dialog: - condition: true - response: |- + response: | {% GET "http://127.0.0.1/endpoint" %} {{ rest.ok|tojson }}|{{ rest.status_code }}|{{ rest.json.success|tojson }} """ @@ -414,7 +592,7 @@ def test_server_error(on_error, respx_mock): rest: {} dialog: - condition: true - response: |- + response: | {% GET "http://127.0.0.1/endpoint" """ + on_error + """ %} @@ -435,7 +613,7 @@ def test_server_error_continue(): rest: {} dialog: - condition: true - response: |- + response: | {% GET "http://127.0.0.1/endpoint" on_error "continue" %} {{ rest.ok|tojson }}|{{ rest.status_code }} """ @@ -450,7 +628,7 @@ def test_invalid_on_error(): rest: {} dialog: - condition: true - response: |- + response: | {% GET "http://127.0.0.1/endpoint" on_error "try_again" %} {{ rest.ok|tojson }}|{{ rest.status_code }} """ @@ -460,6 +638,451 @@ def test_invalid_on_error(): assert str(excinfo.value) == "on_error invalid value: try_again" +def test_network_error_continue(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: |- + {% GET "http://127.0.0.1/endpoint" on_error "continue" %} + {{ rest.ok }} + """ + ) + assert _test_mock_network_error(bot, error=httpx.ConnectError) == [{"text": "False"}] + + +def test_network_error_break(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: |- + {% GET "http://127.0.0.1/endpoint" %} + {{ rest.ok }} + """ + ) + with pytest.raises(BotError) as excinfo: + _test_mock_network_error(bot, error=httpx.TimeoutException) + assert str(excinfo.value) == "caused by httpx.TimeoutException: REST call failed: Mock Error" + + +def test_cache_args(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 %} + """ + ) + _test_cache(bot) + + +def test_cache_service(): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: localhost + cache: 1 + base_url: http://127.0.0.1/ + dialog: + - condition: true + response: | + {% GET "localhost://endpoint" %} + """ + ) + _test_cache(bot) + + +def test_cache_expired_args(monkeypatch): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 %} + """ + ) + mock = {"now": maxbot.extensions.rest._now()} + monkeypatch.setattr(maxbot.extensions.rest, "_now", lambda: mock["now"]) + + _test_cache( + bot, + cached_successfully=False, + between_calls=lambda: mock.update(now=mock["now"] + timedelta(seconds=2)), + ) + + +def test_cache_expired_service(monkeypatch): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: localhost + cache: 1 + base_url: http://127.0.0.1/ + dialog: + - condition: true + response: | + {% GET "localhost://endpoint" %} + """ + ) + mock = {"now": maxbot.extensions.rest._now()} + monkeypatch.setattr(maxbot.extensions.rest, "_now", lambda: mock["now"]) + + _test_cache( + bot, + cached_successfully=False, + between_calls=lambda: mock.update(now=mock["now"] + timedelta(seconds=2)), + ) + + +def test_cache_garbage_collector_ignored(monkeypatch): + bot = MaxBot.inline( + f""" + extensions: + rest: {{}} + dialog: + - condition: true + response: | + {{% GET "http://127.0.0.1/endpoint" cache {_GB_TIMEOUT + 2} %}} + """ + ) + mock = {"now": maxbot.extensions.rest._now()} + monkeypatch.setattr(maxbot.extensions.rest, "_now", lambda: mock["now"]) + + _test_cache( + bot, + between_calls=lambda: mock.update(now=mock["now"] + timedelta(seconds=_GB_TIMEOUT + 1)), + ) + + +def test_cache_garbage_collector_expired(monkeypatch): + bot = MaxBot.inline( + f""" + extensions: + rest: {{}} + dialog: + - condition: true + response: | + {{% GET "http://127.0.0.1/endpoint" cache 1 %}} + """ + ) + mock = {"now": maxbot.extensions.rest._now()} + monkeypatch.setattr(maxbot.extensions.rest, "_now", lambda: mock["now"]) + + _test_cache( + bot, + cached_successfully=False, + between_calls=lambda: mock.update(now=mock["now"] + timedelta(seconds=_GB_TIMEOUT + 1)), + ) + + +def test_cache_garbage_collector_time(monkeypatch): + mock = {"now": maxbot.extensions.rest._now()} + monkeypatch.setattr(maxbot.extensions.rest, "_now", lambda: mock["now"]) + bot = MaxBot.inline( + f""" + extensions: + rest: {{}} + dialog: + - condition: true + response: | + {{% GET "http://127.0.0.1/endpoint" cache {_GB_TIMEOUT + 1} %}} + """ + ) + + mock.update(now=mock["now"] - timedelta(seconds=_GB_TIMEOUT + 3)) + _test_cache( + bot, + between_calls=lambda: mock.update(now=mock["now"] + timedelta(seconds=_GB_TIMEOUT)), + ) + + +@pytest.mark.parametrize( + "field", + ( + "headers", + "parameters", + "body", + "auth", + ), +) +def test_cache_dict_reorder(field): + bot = MaxBot.inline( + f""" + extensions: + rest: {{}} + dialog: + - condition: message.text == "1" + response: | + {{% GET "http://127.0.0.1/endpoint" cache 1 {field} {{"user": "v1", "password": "v2"}} %}} + - condition: message.text == "2" + response: | + {{% GET "http://127.0.0.1/endpoint" cache 1 {field} {{"password": "v2", "user": "v1"}} %}} + """ + ) + _test_cache(bot) + + +@pytest.mark.parametrize( + "field", + ( + "headers", + "parameters", + "body", + "auth", + ), +) +def test_cache_dict_mismatch(field): + bot = MaxBot.inline( + f""" + extensions: + rest: {{}} + dialog: + - condition: message.text == "1" + response: | + {{% GET "http://127.0.0.1/endpoint" cache 1 {field} {{"user": "v1", "password": "v2"}} %}} + - condition: message.text == "2" + response: | + {{% GET "http://127.0.0.1/endpoint" cache 1 {field} {{"user": "v1", "password": "MISMATCH"}} %}} + """ + ) + _test_cache(bot, cached_successfully=False) + + +def test_cache_url_service_match(): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: localhost + cache: 1 + base_url: http://127.0.0.1/ + dialog: + - condition: message.text == "1" + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 %} + - condition: message.text == "2" + response: | + {% GET "localhost://endpoint" %} + """ + ) + _test_cache(bot) + + +def test_cache_url_mismatch(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: message.text == "1" + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 %} + - condition: message.text == "2" + response: | + {% GET "http://127.0.0.1/endpoinX" cache 1 %} + """ + ) + with respx.mock: + route2 = respx.get("http://127.0.0.1/endpoinX").respond(json={}) + _test_cache(bot, cached_successfully=False, route2=route2) + + +def test_cache_on_error_ignore(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: message.text == "1" + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 on_error "continue" %} + - condition: message.text == "2" + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 on_error "break_flow" %} + """ + ) + _test_cache(bot) + + +def test_cache_method_mismatch(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: message.text == "1" + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 %} + - condition: message.text == "2" + response: | + {% POST "http://127.0.0.1/endpoint" cache 1 %} + """ + ) + with respx.mock: + route2 = respx.post("http://127.0.0.1/endpoint").respond(json={}) + _test_cache(bot, cached_successfully=False, route2=route2) + + +def test_cache_timeout_ignore(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: message.text == "1" + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 timeout 1 %} + - condition: message.text == "2" + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 timeout 2 %} + """ + ) + _test_cache(bot) + + +def test_cache_error(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: | + {% GET "http://127.0.0.1/endpoint" cache 1 on_error "continue" %} + """ + ) + with respx.mock: + route = respx.get("http://127.0.0.1/endpoint").respond(status_code=500) + _test_cache_mocked(bot, route, cached_successfully=False) + + +def test_misprint_service_name(): + bot = MaxBot.inline( + """ + extensions: + rest: + services: + - name: my_service1 + base_url: http://127.0.0.1 + dialog: + - condition: true + response: |- + {% GET "my_service2://endpoint" %} + {{ rest.ok }} + """ + ) + with pytest.raises(BotError) as excinfo: + bot.process_message("hey bot") + assert str(excinfo.value) == ( + "Unknown schema ('my_service2') in URL 'my_service2://endpoint'\n" + "Must be one of: http, https, my_service1" + ) + + +def test_request_body_urlencoded(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: |- + {% POST "http://127.0.0.1/endpoint" + body {"k": "v"}|urlencode + %} + test + """ + ) + + def _match(request): + assert b"".join(b for b in request.stream) == b"k=v" + return True + + _test_mock_common(bot, "post", additional_matcher=_match) + + +def test_request_body_urlencoded_content_type(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: |- + {% POST "http://127.0.0.1/endpoint" + body {"k": "v"} + headers {'Content-Type': 'application/x-www-form-urlencoded'} + %} + test + """ + ) + + def _match(request): + assert b"".join(b for b in request.stream) == b"k=v" + return True + + _test_mock_common(bot, "post", additional_matcher=_match) + + +def test_request_body_json(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: |- + {% POST "http://127.0.0.1/endpoint" + body {"k": "v"} + %} + test + """ + ) + + def _match(request): + assert b"".join(b for b in request.stream) == b'{"k": "v"}' + return True + + _test_mock_common(bot, "post", additional_matcher=_match) + + +def test_request_body_raw(): + bot = MaxBot.inline( + """ + extensions: + rest: {} + dialog: + - condition: true + response: |- + {% POST "http://127.0.0.1/endpoint" + body "k|v" + %} + test + """ + ) + + def _match(request): + assert b"".join(b for b in request.stream) == b"k|v" + return True + + _test_mock_common(bot, "post", additional_matcher=_match) + + @respx.mock def _test_mock_common( bot, @@ -477,3 +1100,28 @@ def _test_mock_common( assert route.call_count == 1 if additional_matcher is not None: additional_matcher(route.calls.last.request) + + +@respx.mock +def _test_cache(bot, **kwargs): + route = respx.get("http://127.0.0.1/endpoint").respond(json={}) + _test_cache_mocked(bot, route, **kwargs) + + +def _test_cache_mocked( + bot, route, cached_successfully=True, between_calls=lambda: None, route2=None +): + bot.process_message("1") + assert route.call_count == 1 + between_calls() + bot.process_message("2") + call_count = route.call_count + (route2.call_count if route2 else 0) + assert call_count == (1 if cached_successfully else 2) + + +@respx.mock +def _test_mock_network_error(bot, error): + route = respx.request("GET", "http://127.0.0.1/endpoint").mock(side_effect=error) + commands = bot.process_message("hey bot") + assert route.call_count == 1 + return commands diff --git a/tests/test_flows/test_base.py b/tests/test_flows/test_base.py new file mode 100644 index 0000000..8dd2e5c --- /dev/null +++ b/tests/test_flows/test_base.py @@ -0,0 +1,9 @@ +from unittest.mock import Mock + +from maxbot.flows._base import FlowComponent, FlowResult + + +def test_reset_state(): + ctx = Mock() + FlowComponent("test", Mock(return_value=FlowResult.DONE))(ctx) + ctx.set_state_variable.assert_called_once_with("test", None) diff --git a/tests/test_flows/test_dialog_tree.py b/tests/test_flows/test_dialog_tree.py index bc3528d..84c4d66 100644 --- a/tests/test_flows/test_dialog_tree.py +++ b/tests/test_flows/test_dialog_tree.py @@ -786,3 +786,69 @@ async def test_node_removed_gc_stack(): assert await model(ctx, state) == FlowResult.DONE assert ctx.commands == [{"text": "root triggered"}] assert state == {"node_stack": []} + + +async def test_response_digressing_false(): + model = DialogTree( + DialogNodeSchema(many=True).loads( + """ + - condition: true + response: "digressing={{ digressing }}" + """ + ) + ) + ctx, state = make_context() + assert await model(ctx, state) == FlowResult.DONE + assert ctx.commands == [{"text": "digressing=False"}] + + +async def test_response_digressing(): + model = DialogTree( + DialogNodeSchema(many=True).loads( + """ + - label: root + condition: true + followup: + - condition: false + response: unexpected + response: root triggered + - condition: true + response: "digressing={{ digressing }}" + """ + ) + ) + ctx, state = make_context(state={"node_stack": [["root", "followup"]]}) + assert await model(ctx, state) == FlowResult.LISTEN + assert ctx.commands == [ + {"text": "digressing=True"}, + {"text": "root triggered"}, + ] + + +async def test_response_digressing_with_slot_filling(): + model = DialogTree( + DialogNodeSchema(many=True).loads( + """ + - label: root + condition: true + followup: + - condition: false + response: unexpected + response: root triggered + - label: node_for_digression + condition: true + slot_filling: + - name: slot1 + check_for: true + found: found + response: "digressing={{ digressing }}" + """ + ) + ) + ctx, state = make_context(state={"node_stack": [["root", "followup"]]}) + assert await model(ctx, state) == FlowResult.LISTEN + assert ctx.commands == [ + {"text": "found"}, + {"text": "digressing=True"}, + {"text": "root triggered"}, + ] diff --git a/tests/test_flows/test_journal_dialog_tree.py b/tests/test_flows/test_journal_dialog_tree.py index 85e7c02..663f0fa 100644 --- a/tests/test_flows/test_journal_dialog_tree.py +++ b/tests/test_flows/test_journal_dialog_tree.py @@ -51,7 +51,10 @@ async def test_journal_response_jump_to(): _, event, _, _ = ctx.journal_events assert event == { "type": "response", - "payload": {"node": {"condition": "true"}, "control_command": "jump_to"}, + "payload": { + "node": {"condition": "true"}, + "control_command": {"jump_to": {"node": "label1", "transition": "condition"}}, + }, } @@ -74,7 +77,7 @@ async def test_journal_response_listen(): ) = ctx.journal_events assert event == { "type": "response", - "payload": {"node": {"condition": "true"}, "control_command": "listen"}, + "payload": {"node": {"condition": "true"}, "control_command": {"listen": {}}}, } @@ -97,7 +100,7 @@ async def test_journal_response_end(): ) = ctx.journal_events assert event == { "type": "response", - "payload": {"node": {"condition": "true"}, "control_command": "end"}, + "payload": {"node": {"condition": "true"}, "control_command": {"end": {}}}, } @@ -123,7 +126,7 @@ async def test_journal_response_followup(): "type": "response", "payload": { "node": {"condition": "true", "label": "root1"}, - "control_command": "followup", + "control_command": {"followup": {}}, }, } @@ -173,7 +176,7 @@ async def test_journal_response_default_end(): assert event == {"type": "response", "payload": {"node": {"condition": "true"}, "end": {}}} -async def test_journal_response_return_after_digression(): +async def test_journal_digression(): model = DialogTree( DialogNodeSchema(many=True).loads( """ @@ -194,8 +197,12 @@ async def test_journal_response_return_after_digression(): components_state={"label1": {"slot_in_focus": "slot1"}}, ) assert await model(ctx, state) == FlowResult.LISTEN - _, _, event, _, _ = ctx.journal_events - assert event == { + _, event1, _, event2, _, _ = ctx.journal_events + assert event1 == { + "type": "digression_from", + "payload": {"node": {"condition": "true", "label": "label1"}}, + } + assert event2 == { "type": "response", "payload": {"return_after_digression": {}, "node": {"condition": "true"}}, } diff --git a/tests/test_flows/test_journal_slot_filling.py b/tests/test_flows/test_journal_slot_filling.py index 251ab3c..48efca1 100644 --- a/tests/test_flows/test_journal_slot_filling.py +++ b/tests/test_flows/test_journal_slot_filling.py @@ -10,7 +10,7 @@ def make_context(state=None): dialog=None, message={"text": "hello"}, entities=EntitiesResult(), - state=StateVariables(slots={}, components={"xxx": state} if state else {}), + state=StateVariables(components={"xxx": state} if state else {}), ) return ctx, ctx.state.components.setdefault("xxx", {}) @@ -27,8 +27,9 @@ async def test_journal_slot_filling(): ) ctx, state = make_context() await model(ctx, state) - (event,) = ctx.journal_events - assert event == {"type": "slot_filling", "payload": {"slot": "slot1", "value": True}} + event1, event2 = ctx.journal_events + assert event1 == {"type": "slot_filling", "payload": {"slot": "slot1"}} + assert event2 == {"type": "assign", "payload": {"slots": "slot1", "value": True}} async def test_journal_found(): @@ -45,6 +46,7 @@ async def test_journal_found(): ctx, state = make_context() await model(ctx, state) ( + _, _, event, ) = ctx.journal_events @@ -68,10 +70,7 @@ async def test_journal_found_control_command(control_command): ) ctx, state = make_context() await model(ctx, state) - ( - _, - event, - ) = ctx.journal_events + event = ctx.journal_events[2] # slot_filling, assing, [, delete] assert event == { "type": "found", "payload": {"slot": "slot1", "control_command": control_command}, diff --git a/tests/test_flows/test_journal_state.py b/tests/test_flows/test_journal_state.py new file mode 100644 index 0000000..afb983f --- /dev/null +++ b/tests/test_flows/test_journal_state.py @@ -0,0 +1,132 @@ +import pytest + +from maxbot.context import StateVariables, TurnContext +from maxbot.flows._base import FlowResult +from maxbot.flows.dialog_flow import DialogFlow +from maxbot.flows.dialog_tree import DialogNodeSchema, DialogTree + + +def make_context(state=None, components_state=None): + ctx = TurnContext( + dialog=None, + message={"text": "hello"}, + state=StateVariables(components=components_state or {}), + ) + if state is not None: + ctx.state.components["ROOT"] = state + return ctx, ctx.state.components.setdefault("ROOT", {}) + + +@pytest.mark.parametrize( + "kind", + ( + "slots", + "user", + ), +) +async def test_journal_two_nodes(kind): + model = DialogTree( + DialogNodeSchema(many=True).loads( + """ + - condition: true + response: | + {% set """ + + kind + + """.slot1 = 1 %} + + - condition: false + label: target_node + response: | + {% set """ + + kind + + """.slot1 = 2 %} + + """ + ) + ) + ctx, state = make_context() + assert await model(ctx, state) == FlowResult.DONE + _, _, event1, _, _, event2 = ctx.journal_events + assert event1 == {"type": "assign", "payload": {kind: "slot1", "value": 1}} + assert event2 == {"type": "assign", "payload": {kind: "slot1", "value": 2}} + + +@pytest.mark.parametrize( + "kind", + ( + "slots", + "user", + ), +) +async def test_journal_equal_changes(kind): + model = DialogTree( + DialogNodeSchema(many=True).loads( + """ + - condition: true + response: | + {% set """ + + kind + + """.slot1 = 1 %} + {% set """ + + kind + + """.slot1 = 1 %} + """ + ) + ) + ctx, state = make_context() + assert await model(ctx, state) == FlowResult.DONE + + _, _, event1, event2 = ctx.journal_events + assert event1 == {"type": "assign", "payload": {kind: "slot1", "value": 1}} + assert event2 == {"type": "assign", "payload": {kind: "slot1", "value": 1}} + + +@pytest.mark.parametrize( + "kind", + ( + "slots", + "user", + ), +) +async def test_journal_delete(kind): + model = DialogTree( + DialogNodeSchema(many=True).loads( + """ + - condition: true + response: | + {% set """ + + kind + + """.slot1 = 1 %} + + - condition: false + label: target_node + response: | + {% delete """ + + kind + + """.slot1 %} + + """ + ) + ) + ctx, state = make_context() + assert await model(ctx, state) == FlowResult.DONE + _, _, event1, _, _, event2 = ctx.journal_events + assert event1 == {"type": "assign", "payload": {kind: "slot1", "value": 1}} + assert event2 == {"type": "delete", "payload": {kind: "slot1"}} + + +async def test_journal_clear(): + ctx, state = make_context() + df = DialogFlow() + df.load_inline_resources( + """ + dialog: + - condition: true + response: | + {% set slots.slot1 = 1 %} + """ + ) + await df.turn(ctx) + + _, _, _, event = ctx.journal_events + assert event == {"type": "delete", "payload": {"slots": "slot1"}} diff --git a/tests/test_flows/test_slot_filling.py b/tests/test_flows/test_slot_filling.py index 2e9a6d5..5968ca1 100644 --- a/tests/test_flows/test_slot_filling.py +++ b/tests/test_flows/test_slot_filling.py @@ -19,7 +19,7 @@ def make_context(state=None, intents=None, entities=None): message={"text": "hello"}, intents=IntentsResult.resolve(intents or []), entities=entities or EntitiesResult(), - state=StateVariables(slots={}, components={"xxx": state} if state else {}), + state=StateVariables(components={"xxx": state} if state else {}), ) return ctx, ctx.state.components.setdefault("xxx", {}) diff --git a/tests/test_maxml/test_http.py b/tests/test_maxml/test_http.py new file mode 100644 index 0000000..e102e61 --- /dev/null +++ b/tests/test_maxml/test_http.py @@ -0,0 +1,118 @@ +import pytest + +from maxbot.errors import BotError +from maxbot.maxml import PoolLimitSchema, TimeoutSchema, fields +from maxbot.schemas import ResourceSchema + + +class Config(ResourceSchema): + timeout = fields.Nested(TimeoutSchema()) + limits = fields.Nested(PoolLimitSchema()) + + +def test_timeout_empty(): + data = Config().loads( + """ + timeout: {} + """ + ) + assert data["timeout"].connect == 5.0 + assert data["timeout"].read == 5.0 + assert data["timeout"].write == 5.0 + assert data["timeout"].pool == 5.0 + + +def test_timeout_short_syntax(): + data = Config().loads( + """ + timeout: 1.2 + """ + ) + assert data["timeout"].connect == 1.2 + assert data["timeout"].read == 1.2 + assert data["timeout"].write == 1.2 + assert data["timeout"].pool == 1.2 + + +def test_timeout_error(): + with pytest.raises(BotError) as excinfo: + Config().loads( + """ + timeout: abc + """ + ) + assert str(excinfo.value) == ( + "caused by marshmallow.exceptions.ValidationError: Invalid input type.\n" + ' in "", line 2, column 20:\n' + " timeout: abc\n" + " ^^^\n" + ) + + +def test_timeout_default(): + data = Config().loads( + """ + timeout: + default: 3.6 + """ + ) + assert data["timeout"].connect == 3.6 + assert data["timeout"].read == 3.6 + assert data["timeout"].write == 3.6 + assert data["timeout"].pool == 3.6 + + +def test_timeout_default_connect_pool(): + data = Config().loads( + """ + timeout: + default: 3.6 + connect: 10.0 + pool: 1.0 + """ + ) + assert data["timeout"].connect == 10.0 + assert data["timeout"].read == 3.6 + assert data["timeout"].write == 3.6 + assert data["timeout"].pool == 1.0 + + +def test_timeout_connect_read_write_pool(): + data = Config().loads( + """ + timeout: + connect: 1.0 + read: 2.0 + write: 3.0 + pool: 4.0 + """ + ) + assert data["timeout"].connect == 1.0 + assert data["timeout"].read == 2.0 + assert data["timeout"].write == 3.0 + assert data["timeout"].pool == 4.0 + + +def test_limits_empty(): + data = Config().loads( + """ + limits: {} + """ + ) + assert data["limits"].max_keepalive_connections == 20 + assert data["limits"].max_connections == 100 + assert data["limits"].keepalive_expiry == 5.0 + + +def test_limits(): + data = Config().loads( + """ + limits: + max_keepalive_connections: 1 + max_connections: 2 + keepalive_expiry: 3 + """ + ) + assert data["limits"].max_keepalive_connections == 1 + assert data["limits"].max_connections == 2 + assert data["limits"].keepalive_expiry == 3.0 diff --git a/tests/test_maxml/test_timedelta.py b/tests/test_maxml/test_timedelta.py new file mode 100644 index 0000000..307888e --- /dev/null +++ b/tests/test_maxml/test_timedelta.py @@ -0,0 +1,56 @@ +import pytest + +from maxbot.errors import BotError +from maxbot.maxml import TimeDeltaSchema, fields +from maxbot.schemas import ResourceSchema + + +class Config(ResourceSchema): + timedelta = fields.Nested(TimeDeltaSchema()) + + +def test_empty(): + data = Config().loads("timedelta: {}") + assert data["timedelta"].total_seconds() == 0 + + +def test_short(): + data = Config().loads("timedelta: 5") + assert data["timedelta"].total_seconds() == 5 + + +def test_short_error(): + with pytest.raises(BotError) as excinfo: + Config().loads("timedelta: x") + assert str(excinfo.value) == ( + "caused by marshmallow.exceptions.ValidationError: Invalid input type.\n" + ' in "", line 1, column 12:\n' + " timedelta: x\n" + " ^^^\n" + ) + + +def test_all(): + data = Config().loads( + """ + timedelta: + weeks: 1 + days: 2 + hours: 3 + minutes: 4 + seconds: 5 + microseconds: 6 + milliseconds: 7 + """ + ) + assert data["timedelta"].total_seconds() == 788645.007006 + + +def test_default(): + class Config(ResourceSchema): + timedelta = fields.Nested( + TimeDeltaSchema(), load_default=TimeDeltaSchema.VALUE_TYPE(seconds=5) + ) + + data = Config().loads("{}") + assert data["timedelta"].total_seconds() == 5 diff --git a/tests/test_nlu.py b/tests/test_nlu.py index 4c8fc8a..5322a27 100644 --- a/tests/test_nlu.py +++ b/tests/test_nlu.py @@ -166,6 +166,11 @@ def test_dateparser_entities(spacy_nlp): assert time.value == "18:00:00" assert time.literal == "February 22, 2022 at 6pm" + (date,) = entity_recognizer(spacy_nlp("1984")) + assert date.name == "latent_date" + assert date.value.startswith("1984-") + assert date.literal == "1984" + @freeze_time("2023-04-08") def test_dateparser_entities_prefer_future(spacy_nlp): diff --git a/tests/test_persistence_manager.py b/tests/test_persistence_manager.py new file mode 100644 index 0000000..bca8748 --- /dev/null +++ b/tests/test_persistence_manager.py @@ -0,0 +1,163 @@ +import pytest +from sqlalchemy import select +from sqlalchemy.exc import StatementError +from sqlalchemy.orm import Session + +from maxbot.maxml import markup +from maxbot.persistence_manager import DialogTable, RequestType, SQLAlchemyManager +from maxbot.webapp import Factory + + +@pytest.fixture +def event(): + return {"channel_name": "test", "user_id": 123} + + +def test_state_create(event): + persistence_manager = SQLAlchemyManager() + with persistence_manager(event) as tracker: + tracker.get_state().user["user1"] = "value1" + tracker.get_state().slots["slot1"] = "value2" + tracker.get_state().components["flow1"] = "value3" + + with Session(persistence_manager.engine) as session: + user = session.scalars(select(DialogTable)).one() + assert user.dialog_id + assert user.channel_name == "test" + assert user.user_id == "123" + v1, v2, v3 = user.variables + assert v1.name == "user.user1" + assert v1.value == "value1" + assert v2.name == "slots.slot1" + assert v2.value == "value2" + assert v3.name == "components.flow1" + assert v3.value == "value3" + + +def test_state_update(event): + persistence_manager = SQLAlchemyManager() + with persistence_manager(event) as tracker: + tracker.get_state().user["user1"] = "value1" + + with persistence_manager(event) as tracker: + tracker.get_state().user["user1"] = "value2" + + with Session(persistence_manager.engine) as session: + user = session.scalars(select(DialogTable)).one() + (v,) = user.variables + assert v.name == "user.user1" + assert v.value == "value2" + + +def test_state_update_inplace(event): + persistence_manager = SQLAlchemyManager() + with persistence_manager(event) as tracker: + tracker.get_state().user["user1"] = {"key": "value1"} + + with persistence_manager(event) as tracker: + tracker.get_state().user["user1"]["key"] = "value2" + + with Session(persistence_manager.engine) as session: + user = session.scalars(select(DialogTable)).one() + (v,) = user.variables + assert v.name == "user.user1" + assert v.value == {"key": "value2"} + + +def test_state_delete_using_del(event): + persistence_manager = SQLAlchemyManager() + with persistence_manager(event) as tracker: + tracker.get_state().user["user1"] = "value1" + + with persistence_manager(event) as tracker: + del tracker.get_state().user["user1"] + + with Session(persistence_manager.engine) as session: + user = session.scalars(select(DialogTable)).one() + assert len(user.variables) == 0 + + +def test_state_keep_none(event): + persistence_manager = SQLAlchemyManager() + with persistence_manager(event) as tracker: + tracker.get_state().user["user1"] = "value1" + + with persistence_manager(event) as tracker: + tracker.get_state().user["user1"] = None + + with Session(persistence_manager.engine) as session: + user = session.scalars(select(DialogTable)).one() + (v,) = user.variables + assert v.name == "user.user1" + assert v.value is None + + +def test_history_message(event): + persistence_manager = SQLAlchemyManager() + with persistence_manager(event) as tracker: + tracker.set_message_history({}, []) + + with persistence_manager(event) as tracker: + (turn,) = tracker.user.history + assert turn.request_date + assert turn.request_type == RequestType.message + assert turn.request == {} + assert turn.response == [] + + +def test_history_rpc(event): + persistence_manager = SQLAlchemyManager() + with persistence_manager(event) as tracker: + tracker.set_rpc_history({}, []) + + with persistence_manager(event) as tracker: + (turn,) = tracker.user.history + assert turn.request_date + assert turn.request_type == RequestType.rpc + assert turn.request == {} + assert turn.response == [] + + +def _default(tmp_path): + return SQLAlchemyManager() + + +def _default_mp(tmp_path): + persistence_manager = Factory._create_default_mp_persistence_manager(tmp_path / "pytest.db") + persistence_manager.create_tables() + return persistence_manager + + +@pytest.mark.parametrize("factory", [_default, _default_mp]) +def test_history_maxml(event, factory, tmp_path): + v = markup.Value( + [ + markup.Item(markup.TEXT, "line 1"), + markup.Item(markup.START_TAG, "br"), + markup.Item(markup.END_TAG, "br"), + markup.Item(markup.TEXT, "line 2"), + ] + ) + persistence_manager = factory(tmp_path) + with persistence_manager(event) as tracker: + tracker.set_rpc_history({}, [{"text": v}]) + + with persistence_manager(event) as tracker: + (turn,) = tracker.user.history + assert turn.request_date + assert turn.request_type == RequestType.rpc + assert turn.request == {} + assert turn.response == [{"text": "line 1
line 2"}] + + +@pytest.mark.parametrize("factory", [_default, _default_mp]) +def test_history_json_not_serializable(event, factory, tmp_path): + class Value: + pass + + persistence_manager = factory(tmp_path) + with pytest.raises(StatementError) as excinfo: + with persistence_manager(event) as tracker: + tracker.set_rpc_history({}, [{"custom": Value()}]) + + assert "Object of type Value is not JSON serializable" in str(excinfo.value) diff --git a/tests/test_resolver.py b/tests/test_resolver.py new file mode 100644 index 0000000..261ff44 --- /dev/null +++ b/tests/test_resolver.py @@ -0,0 +1,19 @@ +import pytest + +from maxbot.resolver import BotResolver, pkgutil + + +def test_raise_unknown_source(): + with pytest.raises(RuntimeError) as excinfo: + BotResolver("XyZ")() + assert str(excinfo.value) == ( + "'XyZ' file or directory not found, " + """import causes error ModuleNotFoundError("No module named 'XyZ'")""" + ) + + +def test_raise_invalid_type(monkeypatch): + monkeypatch.setattr(pkgutil, "resolve_name", lambda spec: 1) + with pytest.raises(RuntimeError) as excinfo: + BotResolver("XyZ")() + assert str(excinfo.value) == "A valid MaxBot instance was not obtained from 'XyZ'" diff --git a/tests/test_state_store.py b/tests/test_state_store.py deleted file mode 100644 index 8c0d653..0000000 --- a/tests/test_state_store.py +++ /dev/null @@ -1,100 +0,0 @@ -import pytest -from sqlalchemy import select -from sqlalchemy.orm import Session - -from maxbot.state_store import DialogTable, SQLAlchemyStateStore - - -@pytest.fixture -def event(): - return {"channel_name": "test", "user_id": 123} - - -def test_create(event): - state_store = SQLAlchemyStateStore() - with state_store(event) as state: - state.user["user1"] = "value1" - state.slots["slot1"] = "value2" - state.components["flow1"] = "value3" - - with Session(state_store.engine) as session: - user = session.scalars(select(DialogTable)).one() - assert user.dialog_id - assert user.channel_name == "test" - assert user.user_id == "123" - v1, v2, v3 = user.variables - assert v1.name == "user.user1" - assert v1.value == "value1" - assert v2.name == "slots.slot1" - assert v2.value == "value2" - assert v3.name == "components.flow1" - assert v3.value == "value3" - - -def test_update(event): - state_store = SQLAlchemyStateStore() - with state_store(event) as state: - state.user["user1"] = "value1" - - with state_store(event) as state: - state.user["user1"] = "value2" - - with Session(state_store.engine) as session: - user = session.scalars(select(DialogTable)).one() - (v,) = user.variables - assert v.name == "user.user1" - assert v.value == "value2" - - -def test_update_inplace(event): - state_store = SQLAlchemyStateStore() - with state_store(event) as state: - state.user["user1"] = {"key": "value1"} - - with state_store(event) as state: - state.user["user1"]["key"] = "value2" - - with Session(state_store.engine) as session: - user = session.scalars(select(DialogTable)).one() - (v,) = user.variables - assert v.name == "user.user1" - assert v.value == {"key": "value2"} - - -def test_delete_using_del(event): - state_store = SQLAlchemyStateStore() - with state_store(event) as state: - state.user["user1"] = "value1" - - with state_store(event) as state: - del state.user["user1"] - - with Session(state_store.engine) as session: - user = session.scalars(select(DialogTable)).one() - assert len(user.variables) == 0 - - -def test_keep_none(event): - state_store = SQLAlchemyStateStore() - with state_store(event) as state: - state.user["user1"] = "value1" - - with state_store(event) as state: - state.user["user1"] = None - - with Session(state_store.engine) as session: - user = session.scalars(select(DialogTable)).one() - (v,) = user.variables - assert v.name == "user.user1" - assert v.value is None - - -# def test_from_config(): -# services = Services().loads( -# """ -# sqlalchemy: -# url: sqlite:// -# """ -# ) -# state_store = SQLAlchemyStateStore(services.sqlalchemy) -# assert state_store.engine diff --git a/tests/test_stories/test_pytest.py b/tests/test_stories/test_pytest.py new file mode 100644 index 0000000..a8b9070 --- /dev/null +++ b/tests/test_stories/test_pytest.py @@ -0,0 +1,101 @@ +from unittest.mock import Mock + +import pytest + +import maxbot.stories.pytest + + +@pytest.fixture(autouse=True) +def mock_resolve_bot(monkeypatch): + monkeypatch.setattr(maxbot.stories.pytest, "BotResolver", Mock()) + + +def test_file(testdir, monkeypatch): + stories_file = testdir.tmpdir / "stories.yaml" + stories_file.write_text("", encoding="utf8") + + stories = Mock() + stories.load = Mock(return_value=[{"name": f"story{i}", "markers": []} for i in range(2)]) + monkeypatch.setattr(maxbot.stories.pytest, "Stories", Mock(return_value=stories)) + + result = testdir.runpytest("-p", "maxbot_stories", "--bot", "my_bot", stories_file) + result.stdout.fnmatch_lines(["*2 passed*"]) + + +def test_directory(testdir, monkeypatch): + stories_dir = testdir.tmpdir / "stories" + stories_dir.mkdir() + for i in range(3): + (stories_dir / f"{i}.yaml").write_text("", encoding="utf8") + + stories = Mock() + stories.load = Mock(return_value=[{"name": f"story{i}", "markers": []} for i in range(2)]) + monkeypatch.setattr(maxbot.stories.pytest, "Stories", Mock(return_value=stories)) + + result = testdir.runpytest("-p", "maxbot_stories", "--bot", "my_bot", stories_dir) + result.stdout.fnmatch_lines(["*6 passed*"]) + + +def test_fail(testdir, monkeypatch): + stories_file = testdir.tmpdir / "stories.yaml" + stories_file.write_text("", encoding="utf8") + + stories = Mock() + stories.load = Mock(return_value=[{"name": "story1", "markers": []}]) + stories.run = Mock(side_effect=RuntimeError()) + + class _MismatchError(Exception): + pass + + stories.MismatchError = _MismatchError + monkeypatch.setattr(maxbot.stories.pytest, "Stories", Mock(return_value=stories)) + + result = testdir.runpytest("-p", "maxbot_stories", "--bot", "my_bot", stories_file) + result.stdout.fnmatch_lines( + [ + "*FAILED stories.yaml::story1 - RuntimeError*", + "*1 failed*", + ] + ) + + +def test_xfail(testdir, monkeypatch): + stories_file = testdir.tmpdir / "stories.yaml" + stories_file.write_text("", encoding="utf8") + + stories = Mock() + stories.load = Mock(return_value=[{"name": "story1", "markers": ["xfail"]}]) + stories.run = Mock(side_effect=RuntimeError()) + + class _MismatchError(Exception): + pass + + stories.MismatchError = _MismatchError + monkeypatch.setattr(maxbot.stories.pytest, "Stories", Mock(return_value=stories)) + + result = testdir.runpytest("-p", "maxbot_stories", "--bot", "my_bot", stories_file) + result.stdout.no_fnmatch_line("*FAILED stories.yaml::story1 - RuntimeError*") + result.stdout.fnmatch_lines(["*1 xfailed*"]) + + +def test_mismatch(testdir, monkeypatch): + stories_file = testdir.tmpdir / "stories.yaml" + stories_file.write_text("", encoding="utf8") + + stories = Mock() + stories.load = Mock(return_value=[{"name": "story1", "markers": []}]) + + class _MismatchError(Exception): + message = "XyZ" + + stories.MismatchError = _MismatchError + stories.run = Mock(side_effect=_MismatchError()) + monkeypatch.setattr(maxbot.stories.pytest, "Stories", Mock(return_value=stories)) + + result = testdir.runpytest("-p", "maxbot_stories", "--bot", "my_bot", stories_file) + result.stdout.fnmatch_lines( + [ + "*XyZ*", + "*1 failed*", + ] + ) diff --git a/tests/test_stories/test_stories.py b/tests/test_stories/test_stories.py new file mode 100644 index 0000000..99467fb --- /dev/null +++ b/tests/test_stories/test_stories.py @@ -0,0 +1,318 @@ +from datetime import datetime, timezone + +import pytest + +from maxbot import MaxBot +from maxbot.errors import BotError +from maxbot.maxml import markup +from maxbot.stories import Stories, StoryUtcTimeProvider, markup_value_rendered_comparator + + +@pytest.mark.parametrize( + "utc_time", ("2023-04-10T19:15:58.104144", "2023-04-10T18:15:58.104144-01:00") +) +def test_utc_time_template(tmp_path, utc_time): + stories = Stories( + MaxBot.inline( + """ + dialog: + - condition: true + response: "{{ utc_time.isoformat() }}" + """ + ) + ) + + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + f""" + - name: test + turns: + - utc_time: "{utc_time}" + message: hello + response: "2023-04-10T19:15:58.104144+00:00" + """ + ) + (s,) = stories.load(stories_file) + stories.run(s) + + +def test_utc_time_entitites(tmp_path): + stories = Stories( + MaxBot.inline( + """ + entities: + - name: date + dialog: + - condition: entities.date + response: "{{ entities.date.value }}" + """ + ) + ) + + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - utc_time: '2021-01-01T19:15:58.104144' + message: today + response: '2021-01-01' + """ + ) + (s,) = stories.load(stories_file) + stories.run(s) + + +def test_fail(tmp_path): + stories = Stories( + MaxBot.inline( + """ + dialog: + - condition: true + response: "{{ message.text }}" + """ + ) + ) + + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - message: hello + response: HELLO + """ + ) + + (s,) = stories.load(stories_file) + with pytest.raises(stories.MismatchError) as exinfo: + stories.run(s) + + assert ( + "Mismatch at step [0]\n" + "Expected:\n" + " HELLO\n" + "Actual:\n" + " hello" + ) == exinfo.value.message + + +def test_fail_list(tmp_path): + stories = Stories( + MaxBot.inline( + """ + dialog: + - condition: true + response: "{{ message.text }}" + """ + ) + ) + + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - message: hello + response: + - hello1 + - hello2 + """ + ) + + (s,) = stories.load(stories_file) + with pytest.raises(stories.MismatchError) as exinfo: + stories.run(s) + + assert ( + "Mismatch at step [0]\n" + "Expected:\n" + " hello1\n" + " -or-\n" + " hello2\n" + "Actual:\n" + " hello" + ) == exinfo.value.message + + +def test_utc_time_tick_10sec(): + provider = StoryUtcTimeProvider() + provider.tick(datetime(2020, 1, 1, 0, 0)) + provider.tick() + assert datetime(2020, 1, 1, 0, 0, 10, tzinfo=timezone.utc) == provider() + + +def test_rpc_method_validation_error(tmp_path): + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - rpc: { method: nonexistent } + response: "" + """ + ) + + with pytest.raises(BotError) as excinfo: + Stories(MaxBot.builder().build()).load(stories_file) + + lines = str(excinfo.value).splitlines() + assert lines[0] == "caused by marshmallow.exceptions.ValidationError: Method not found" + assert lines[1].endswith(", line 4, column 28:") + assert lines[2:] == [ + " turns:", + " - rpc: { method: nonexistent }", + " ^^^", + ' response: ""', + ] + + +def test_rpc_params_validation_error(tmp_path): + bot = MaxBot.inline( + """ + rpc: + - method: with_params + params: + - name: required_param + required: true + """ + ) + + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - rpc: { method: with_params } + response: "" + """ + ) + + with pytest.raises(BotError) as excinfo: + Stories(bot).load(stories_file) + + lines = str(excinfo.value).splitlines() + assert lines[0] == ( + "caused by marshmallow.exceptions.ValidationError: {'required_param': " + "['Missing data for required field.']}" + ) + assert lines[1].endswith(", line 4, column 18:") + assert lines[2:] == [ + " turns:", + " - rpc: { method: with_params }", + " ^^^", + ' response: ""', + ] + + +def test_turn_no_message_and_no_rpc(tmp_path): + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - response: "" + """ + ) + + with pytest.raises(BotError) as excinfo: + Stories(MaxBot.builder().build()).load(stories_file) + + lines = str(excinfo.value).splitlines() + assert lines[0] == ( + "caused by marshmallow.exceptions.ValidationError: " + "Exactly one of 'message' or 'rpc' is required." + ) + assert lines[1].endswith(", line 4, column 13:") + assert lines[2:] == [ + " turns:", + ' - response: ""', + " ^^^", + ] + + +def test_match_first(tmp_path): + stories = Stories( + MaxBot.inline( + """ + dialog: + - condition: true + response: | + {{ message.text }} + """ + ) + ) + + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - message: hello + response: + - hello + - hello2 + """ + ) + + (s,) = stories.load(stories_file) + stories.run(s) + + +def test_match_second(tmp_path): + stories = Stories( + MaxBot.inline( + """ + dialog: + - condition: true + response: | + {{ message.text }} + """ + ) + ) + + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - message: hello + response: + - hello1 + - hello + """ + ) + + (s,) = stories.load(stories_file) + stories.run(s) + + +def test_rpc(tmp_path): + stories = Stories( + MaxBot.inline( + """ + rpc: + - method: test + dialog: + - condition: rpc.test + response: success + """ + ) + ) + + stories_file = tmp_path / "stories.yaml" + stories_file.write_text( + """ + - name: test + turns: + - rpc: { method: test } + response: success + """ + ) + + (s,) = stories.load(stories_file) + stories.run(s) + + +def test_markup_value_rendered_comparator_false(): + assert not markup_value_rendered_comparator(markup.Value(), 1) diff --git a/tests/test_user_locks.py b/tests/test_user_locks/test_asyncio.py similarity index 100% rename from tests/test_user_locks.py rename to tests/test_user_locks/test_asyncio.py diff --git a/tests/test_user_locks/test_mp.py b/tests/test_user_locks/test_mp.py new file mode 100644 index 0000000..806691a --- /dev/null +++ b/tests/test_user_locks/test_mp.py @@ -0,0 +1,296 @@ +import asyncio +import logging +from contextlib import asynccontextmanager, contextmanager +from multiprocessing import current_process, get_context +from os import unlink +from tempfile import NamedTemporaryFile +from time import sleep +from unittest.mock import MagicMock + +import pytest + +from maxbot.user_locks.mp import ( + _EOF, + MultiProcessLocks, + MultiProcessLocksServer, + ServerClosedConnectionError, + UnixSocketStreams, +) + + +@pytest.fixture +def streams(): + with NamedTemporaryFile(prefix="maxbot-pytest-", suffix="-.sock") as f: + path = f.name + return UnixSocketStreams(path) + + +@pytest.fixture +def spawn_ctx(): + return get_context("spawn") + + +@contextmanager +def server_process(streams, spawn_ctx, args=tuple(), server_stop_=None): + server_ready, server_stop = spawn_ctx.Event(), server_stop_ or spawn_ctx.Event() + server_proc = spawn_ctx.Process( + target=MultiProcessLocksServer(streams.start_server, server_ready, server_stop), + args=args, + ) + server_proc.start() + server_ready.wait() + try: + yield + finally: + server_stop.set() + server_proc.join() + + +@asynccontextmanager +async def mp_locks(open_connection): + locks = MultiProcessLocks(open_connection) + yield locks + await locks.disconnect() + + +def _common_client(open_connection, dialog, results): + async def _impl(): + async with mp_locks(open_connection) as locks: + for _ in range(4): + async with locks(dialog): + results.append(1) + results.append(2) + + asyncio.run(_impl()) + + +def test_concurrent_different_processes(streams, spawn_ctx): + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + results = spawn_ctx.Manager().list() + + with server_process(streams, spawn_ctx): + proc_clients = [] + for i in range(2): + proc_clients.append( + spawn_ctx.Process( + target=_common_client, + args=(streams.open_connection, dialog, results), + name=f"Client {i}", + ) + ) + + for p in proc_clients: + p.start() + + for p in proc_clients: + p.join() + + assert results[:] == [1, 2] * 8 + + +async def test_concurrent_one_process(streams, spawn_ctx): + with server_process(streams, spawn_ctx): + async with mp_locks(streams.open_connection) as locks: + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + + results = [] + + async def _request(f1, f2, f3, f4): + async with locks(dialog): + await f1 + results.append(1) + await f2 + results.append(2) + async with locks(dialog): + await f3 + results.append(1) + await f4 + results.append(2) + + fs = [asyncio.get_event_loop().create_future() for _ in range(12)] + + async def _wake(): + for f in fs: + f.set_result(True) + await asyncio.sleep(0.01) + + await asyncio.gather( + asyncio.create_task(_request(fs[0], fs[3], fs[6], fs[9]), name="r1"), + asyncio.create_task(_request(fs[1], fs[4], fs[7], fs[10]), name="r2"), + asyncio.create_task(_request(fs[2], fs[5], fs[8], fs[11]), name="r3"), + asyncio.create_task(_wake(), name="wake"), + ) + assert results == [1, 2] * 6 + + +async def test_locked_exception(streams, spawn_ctx): + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + with server_process(streams, spawn_ctx): + async with mp_locks(streams.open_connection) as locks: + with pytest.raises(RuntimeError): + async with locks(dialog): + raise RuntimeError() + + async with locks(dialog): + pass + + +def _locked_exit_client(open_connection, dialog, locked_event): + async def _impl(): + async with mp_locks(open_connection) as locks: + async with locks(dialog): + locked_event.set() + Event().wait() + + asyncio.run(_impl()) + + +async def test_locked_kill(streams, spawn_ctx): + with server_process(streams, spawn_ctx): + async with mp_locks(streams.open_connection) as locks: + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + + locked_event = spawn_ctx.Event() + p = spawn_ctx.Process( + target=_locked_exit_client, + args=(streams.open_connection, dialog, locked_event), + name="LockedExit", + ) + p.start() + locked_event.wait() + p.kill() + p.join() + + async with locks(dialog): + pass + + +async def test_recursive_lock(streams, spawn_ctx): + with server_process(streams, spawn_ctx): + try: + async with mp_locks(streams.open_connection) as locks: + locks._for_current_process = MagicMock() + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + try: + async with locks(dialog): + with pytest.raises(ServerClosedConnectionError) as excinfo: + async with locks(dialog): + pass + except ServerClosedConnectionError: + pass + except BrokenPipeError: + pass + + +async def test_brokenpipe2serverclosedconnectoin(streams, spawn_ctx): + with server_process(streams, spawn_ctx): + try: + async with mp_locks(streams.open_connection) as locks: + locks._for_current_process = MagicMock() + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + try: + async with locks(dialog): + with pytest.raises(ServerClosedConnectionError) as excinfo: + async with locks(dialog): + pass + except ServerClosedConnectionError: + pass + with pytest.raises(ServerClosedConnectionError) as excinfo: + async with locks(dialog): + pass + except BrokenPipeError: + pass + + +async def test_is_not_locked(streams, spawn_ctx): + with server_process(streams, spawn_ctx): + try: + async with mp_locks(streams.open_connection) as locks: + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + async with locks(dialog): + pass + + await locks._release(b"") + + with pytest.raises(ServerClosedConnectionError) as excinfo: + async with locks(dialog): + pass + except BrokenPipeError: + pass + + +async def test_unexpected_op(streams, spawn_ctx): + with server_process(streams, spawn_ctx): + try: + async with mp_locks(streams.open_connection) as locks: + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + async with locks(dialog): + pass + + locks._writer.write(b"?" + _EOF) + await locks._writer.drain() + + with pytest.raises(ServerClosedConnectionError) as excinfo: + async with locks(dialog): + pass + except BrokenPipeError: + pass + + +async def test_different_user(streams, spawn_ctx): + with server_process(streams, spawn_ctx): + async with mp_locks(streams.open_connection) as locks: + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + async with locks(dialog): + async with locks({**dialog, **{"user_id": "2"}}): + pass + + +def patch_ACQ_ACQUIRED(): + import maxbot.user_locks.mp + + maxbot.user_locks.mp._ACQ_ACQUIRED = b"?" + + +async def test_unexpected_server_answer(streams, spawn_ctx): + with server_process( + streams, + spawn_ctx, + args=( + [ + patch_ACQ_ACQUIRED, + ], + ), + ): + async with mp_locks(streams.open_connection) as locks: + dialog = {"channel_name": "channel_test", "user_id": "user_test"} + with pytest.raises(AssertionError) as excinfo: + async with locks(dialog): + pass + + assert "Unexpected server answer: b'?'" == str(excinfo.value) + + +def send_sigint(): + from signal import SIGINT, raise_signal + + raise_signal(SIGINT) + + +async def test_sigint(streams, spawn_ctx): + server_stop_ = spawn_ctx.Event() + with server_process( + streams, + spawn_ctx, + args=( + [ + send_sigint, + ], + ), + server_stop_=server_stop_, + ): + for _ in range(10): + if server_stop_.is_set(): + break + sleep(0.1) + assert server_stop_.is_set() diff --git a/tests/test_webapp.py b/tests/test_webapp.py new file mode 100644 index 0000000..0555298 --- /dev/null +++ b/tests/test_webapp.py @@ -0,0 +1,237 @@ +import asyncio +import logging +import os +import pickle +from unittest.mock import ANY, AsyncMock, Mock + +import pytest +import sanic +from sanic import Sanic + +from maxbot.bot import MaxBot +from maxbot.channels import ChannelsCollection +from maxbot.errors import BotError +from maxbot.rpc import RpcManager +from maxbot.webapp import Factory, run_webapp + + +@pytest.fixture(autouse=True) +def mock_sanic_run(monkeypatch): + monkeypatch.setattr(Sanic, "run", Mock()) + monkeypatch.setattr(Sanic, "serve", Mock()) + + +@pytest.fixture +def bot(): + # we need at least one channel to run the bot + channel = Mock() + channel.configure_mock(name="my_channel") + bot = MaxBot(channels=ChannelsCollection([channel])) + return bot + + +@pytest.fixture +def after_server_start(monkeypatch): + monkeypatch.setattr(Sanic, "after_server_start", Mock()) + return _create_listener_execute("after_server_start") + + +@pytest.fixture +def before_server_stop(monkeypatch): + monkeypatch.setattr(Sanic, "before_server_stop", Mock()) + return _create_listener_execute("before_server_stop") + + +@pytest.fixture +def main_process_start(monkeypatch): + monkeypatch.setattr(Sanic, "main_process_start", Mock()) + return _create_listener_execute("main_process_start") + + +@pytest.fixture +def main_process_ready(monkeypatch): + monkeypatch.setattr(Sanic, "main_process_ready", Mock()) + return _create_listener_execute("main_process_ready") + + +@pytest.fixture +def main_process_stop(monkeypatch): + monkeypatch.setattr(Sanic, "main_process_stop", Mock()) + return _create_listener_execute("main_process_stop") + + +def test_run_webapp(bot): + run_webapp(bot, None, "localhost", 8080, single_process=True) + + assert Sanic.run.call_args.args == ("localhost", 8080) + + ch = bot.channels.my_channel + assert ch.blueprint.called + + +async def test_report_started(bot, after_server_start, caplog): + run_webapp(bot, None, "localhost", 8080, single_process=True) + + with caplog.at_level(logging.INFO): + await after_server_start() + assert ( + "Started webhooks updater on http://localhost:8080. Press 'Ctrl-C' to exit." + ) in caplog.text + + +async def test_report_started_mp(bot, after_server_start, caplog): + run_webapp(bot, None, "localhost", 8080, single_process=False, fast=True) + + with caplog.at_level(logging.INFO): + app = Mock() + app.m = Mock() + app.m.name = "SanicServer-0-0" + app.m.state = {"starts": 1} + await after_server_start(app) + assert ( + "Started webhooks updater on http://localhost:8080. Press 'Ctrl-C' to exit." + ) in caplog.text + + +def test_no_channels(): + bot = MaxBot() + with pytest.raises(BotError) as excinfo: + run_webapp(bot, None, "localhost", 8080, single_process=True) + assert excinfo.value.message == ( + "At least one channel is required to run a bot. " + "Please, fill the 'channels' section of your bot.yaml." + ) + + +def test_rpc_enabled(bot, monkeypatch): + monkeypatch.setattr(RpcManager, "blueprint", Mock()) + + bot.dialog_manager.load_inline_resources( + """ + rpc: + - method: say_hello + """ + ) + run_webapp(bot, None, "localhost", 8080, single_process=True) + + assert bot.rpc.blueprint.called + + +def test_rpc_disabled(bot, monkeypatch): + monkeypatch.setattr(RpcManager, "blueprint", Mock()) + + run_webapp(bot, None, "localhost", 8080, single_process=True) + + assert not bot.rpc.blueprint.called + + +async def test_autoreload(bot, after_server_start, before_server_stop): + run_webapp(bot, None, "localhost", 8080, autoreload=True, single_process=True) + + app = Mock() + await after_server_start(app) + app.add_task.assert_called_with(bot.autoreloader, name="autoreloader") + + app = AsyncMock() + await before_server_stop(app) + app.cancel_task.assert_called_with("autoreloader") + + +async def test_public_url_missing(bot, after_server_start, caplog): + bot.channels.my_channel.configure_mock(name="my_channel") + + with caplog.at_level(logging.WARNING): + run_webapp(bot, None, "localhost", 8080, single_process=True) + await after_server_start() + + assert ( + "Make sure you have a public URL that is forwarded to -> " + "http://localhost:8080/my_channel and register webhook for it." + ) in caplog.text + + +def test_public_url_present(bot): + run_webapp(bot, None, "localhost", 8080, public_url="https://example.com", single_process=True) + + ch = bot.channels.my_channel + kw = ch.blueprint.call_args.kwargs + assert kw["public_url"] == "https://example.com" + + +def test_base_file_name(bot): + fname = Factory(bot, None, "localhost", 8080, None, None, None, None).base_file_name + with open(fname, "w") as f: + pass + os.unlink(fname) + + +def _create_mock(): + return Mock() + + +def test_pickle_bot(): + bot = _create_mock() + factory1 = Factory(bot, _create_mock, "localhost", 8080, None, None, None, None) + factory2 = pickle.loads(pickle.dumps(factory1)) + assert factory1.bot is not None + assert factory2.bot is not None + assert id(factory1.bot) != id(factory2.bot) + + +_LOG_INITED = False + + +def _init_logging(): + global _LOG_INITED + _LOG_INITED = True + + +def test_pickle_init_logging(): + global _LOG_INITED + _LOG_INITED = False + factory1 = Factory(None, _create_mock, "localhost", 8080, _init_logging, None, None, None) + factory2 = pickle.loads(pickle.dumps(factory1)) + assert _LOG_INITED + + +def test_sanic_21_force_single_process(bot, monkeypatch): + monkeypatch.setattr(sanic, "__version__", "21.0.0") + run_webapp(bot, None, "localhost", 8080, fast=True, single_process=False) + assert Sanic.run.call_args.kwargs == {"motd": False, "workers": 1} + + +def test_single_process(bot): + run_webapp(bot, None, "localhost", 8080, single_process=True) + assert Sanic.run.call_args.kwargs == {"motd": False, "single_process": True} + + +def test_multi_processes_fast(bot): + run_webapp(bot, None, "localhost", 8080, fast=True, single_process=False) + Sanic.serve.assert_called_once() + + +def test_multi_processes_workers(bot): + run_webapp(bot, None, "localhost", 8080, workers=16, single_process=False) + Sanic.serve.assert_called_once() + + +async def test_main_process( + bot, main_process_start, main_process_ready, main_process_stop, caplog +): + run_webapp(bot, None, "localhost", 8080, init_logging=_init_logging, single_process=False) + with caplog.at_level(logging.INFO): + await main_process_start() + await main_process_ready() + await main_process_stop() + + assert "Sanic multi-process server starting..." in caplog.text + assert "Sanic multi-process server stopping..." in caplog.text + + +def _create_listener_execute(listener_name): + async def execute(app=None): + for call in getattr(Sanic, listener_name).call_args_list: + (coro,) = call.args + await coro(app or Mock(), loop=Mock()) + + return execute diff --git a/website/docs/coding-guides/channels.md b/website/docs/coding-guides/channels.md index 5f2abbc..8f48348 100644 --- a/website/docs/coding-guides/channels.md +++ b/website/docs/coding-guides/channels.md @@ -193,3 +193,9 @@ builder.use_inline_resources(""" """) bot = builder.build() ``` + +## Webhooks + +Channels can receive data in two modes: webhook and polling. Polling means that the bot polls the server at some intervals to see if there are any changes. Please note: only the built-in `telegram` channel can run in polling mode. Webhook mode means that the bot has an external web address that will be called when there is new data from the messenger. + +The processing time of an incoming webhook request is limited. We recommend completing the handling within 5 seconds. Keep in mind that if you take too long to process a message from a user, the messenger may resend the same message. In order not to process the same message again, it will be necessary to check the identifiers of the processed messages. This code can be represented as `middleware` and, for example, use a database as storage. diff --git a/website/docs/design-guides/channel-setting.md b/website/docs/design-guides/channel-setting.md new file mode 100644 index 0000000..dcd9c61 --- /dev/null +++ b/website/docs/design-guides/channel-setting.md @@ -0,0 +1,68 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Channel setting + +## Overview + +The latest version of `Maxbot` has a built-in support of `Telegram`, `Viber` and `VK` messengers, but you can always [add the messenger you need](/coding-guides/channels). +Each messenger has its own differences in controls, features, etc. If you want to expand the possibilities of dialogue, for example, add sending location to `Viber` or sending money to `VK`, read about it in [Coding Guide](/category/coding-guides). + +## Telegram + +If a bot in `Telegram` messenger has already created, the steps 1 - 5 described below can be skipped. +To set up bot integration with `Telegram` it is necessary: +1. Open the `Telegram` messenger. +2. Go [@BotFacther](https://t.me/botfather) and send `/start` command to the dialog. +3. Send the `/newbot` command. +4. Set the user and system name for the new bot. +5. Save the resulting bot token. +6. Specify parameter in the bot resources. You can find all of them in [Telegram schema](/design-reference/channels/#telegram): + +```yaml +channels: + telegram: + api_token: 110201543:AAHdqTcvCH1vGWJxfSeofSAs0K5PALDsaw +``` + +## Viber + +Connection to `Viber` is similar to connection to `Telegram`. You need to create a bot and get an API token, see [instruction](https://developers.viber.com/docs/api/rest-bot-api/#get-started). + +If the company has already created a bot in Viber, the steps 1 - 4 described below can be skipped. +To set up bot integration with `Viber` you need to: +1. Go to [link](https://partners.viber.com/login) and enter the control panel. +2. Click on `Create Bot Account`. +3. Fill in the data for the bot and click on `Create`. +4. The account will be created. The next window will display the token that you need to use in `Maxbot` to integrate the bot with `Viber` messenger. +5. Specify parameter in the bot resources. You can find all of them in [Viber schema](/design-reference/channels/#viber): + +```yaml +channels: + viber: + api_token: 511c56j76j44dcb2-d80de780c65cd798-cbdf5833d0a9aa3c +``` + +## VK + +If the company has already had a `VK` community to integrate with the bot, the steps 1 - 5 described below can be skipped. Steps 6 to 11 are mandatory. +To configure the integration of the bot with the `VK` community it is necessary: +1. Log in to the `VK` social network. +2. Go to the section `Communities`. +3. In the section `Communities` click on the button `Create community`. +4. In the opened window fill in the name and theme of the community. +5. Click on `Create community`. +6. On the community page click on the link `Manage`. +7. In the list of sections of settings click on `API usage`-> `Callback API` -> `Secret key`. +8. Fill in the `Secret key` field by yourself. This is a `secret_key` parameter. +9. Select a tab `Access tokens` -> `Show`. This is a `access_token` parameter. +10. Look at the URL of your community. The numbers in the string is the `group_id` parameter. +11. Specify parameters in the bot resources. You can find all of them in [VK schema](/design-reference/channels/#vk): + +```yaml +channels: + vk: + secret_key: 9rLP4x4xLX4fg + access_token: 99fa2e47fe664da97f3d166c7c32c226370ef9d6e5026a97f3d166c7c32c226370026160db99 + group_id: 123456789 +``` \ No newline at end of file diff --git a/website/docs/design-guides/digressions.md b/website/docs/design-guides/digressions.md index 7358c24..4dfe70d 100644 --- a/website/docs/design-guides/digressions.md +++ b/website/docs/design-guides/digressions.md @@ -281,6 +281,16 @@ For example, this node will never be triggered during digressions. response: | {# ... #} ``` +Also, the `digressing` condition can be checked in response. Next example shows how to prevent from jumping to a specific node in digression: + +```yaml +- condition: intents.some_intent + response: | + ... + {% if not digressing %} + + {% endif %} +``` ### Break digression chain {#end-command} diff --git a/website/docs/design-guides/slot-filling.md b/website/docs/design-guides/slot-filling.md index 0f0d0d0..0d08b7d 100644 --- a/website/docs/design-guides/slot-filling.md +++ b/website/docs/design-guides/slot-filling.md @@ -747,3 +747,171 @@ Let's take an overall look to the slot filling flow. When a user input is receiv * A [digression](digressions.md) into the root level dialog nodes. * `not_found` response of previously prompted slot. * `prompt` response of the first empty slot is processed. + +## Advices on response scenario + +When you are working with digressions from `slot_filling`, it is worth paying attention to the specifics of the bot's work. + +For instance, consider the example below: + + + + +```yaml +dialog: + - condition: intents.restaurant_opening_hours + label: restaurant_opening_hours + response: | + The restaurant is open from 8am to 8pm. + - condition: intents.reservation + label: reservation + slot_filling: + - name: date + check_for: entities.date + prompt: | + What day would you like to come in? + - name: time + check_for: entities.time + prompt: | + What time do you want the reservation to be made for? + - name: guests + check_for: entities.number + prompt: | + How many people will be dining? + response: | + OK. I am making you a reservation for {{ slots.guests }} + on {{ slots.date }} at {{ slots.time }}.
+ Do you want to make one more reservation? + {% delete slots.date %} + {% delete slots.time %} + {% delete slots.guests %} + followup: + - condition: entities.yes + response: | + + - condition: entities.no + response: Ok, have a good day! +``` + +
+ + + ```yaml + channels: + telegram: + api_token: !ENV ${TELEGRAM_API_KEY} + intents: + - name: reservation + examples: + - i'd like to make a reservation + - I want to reserve a table for dinner + - Can 3 of us get a table for lunch on May 29, 2022 at 5pm? + - do you have openings for next Wednesday at 7? + - Is there availability for 4 on Tuesday night? + - i'd like to come in for brunch tomorrow + - can i reserve a table? + - name: restaurant_opening_hours + examples: + - When does the restaurant close? + - When is the restaurant open? + - What are the restaurant opening hours + - Restaurant openin hours + - What time do you close? + - When do you close? + - When do you open? + - At what time do you open? + entities: + - name: yes + values: + - name: all + phrases: + - yes + - name: no + values: + - name: all + phrases: + - no + dialog: + - condition: intents.restaurant_opening_hours + label: restaurant_opening_hours + response: | + The restaurant is open from 8am to 8pm. + - condition: intents.reservation + label: reservation + slot_filling: + - name: date + check_for: entities.date + prompt: | + What day would you like to come in? + - name: time + check_for: entities.time + prompt: | + What time do you want the reservation to be made for? + - name: guests + check_for: entities.number + prompt: | + How many people will be dining? + response: | + OK. I am making you a reservation for {{ slots.guests }} + on {{ slots.date }} at {{ slots.time }}.
+ Do you want to make one more reservation? + {% delete slots.date %} + {% delete slots.time %} + {% delete slots.guests %} + followup: + - condition: entities.yes + response: | + + - condition: entities.no + response: Ok, have a good day! + ``` + +
+
+ +The conversation could go like this: + +``` +🧑 I'd like to make a reservation for 6 people tomorrow at 5 pm +🤖 OK. I am making you a reservation for 6 on 2023-06-22 at 17:00:00. Do you want to make one more reservation? +🧑 When does the restaurant close? +🤖 The restaurant is open from 8am to 8pm. +🤖 What day would you like to come in? +``` +Slots with reservation data were processed and then reset. The client made a digression to the `restaurant_opening_hours` node, and after he was returned to node `reservation`. Since the slots were reset, the recording is started over. + +To avoid this behavior, you need to ask the question "Do you want to make one more reservation?" in a separate node: + +```yaml +... + response: | + OK. I am making you a reservation for {{ slots.guests }} + on {{ slots.date }} at {{ slots.time }}.
+ {% delete slots.date %} + {% delete slots.time %} + {% delete slots.guests %} + + - condition: false + label: ask_make_new_reservation + response: | + Do you want to make one more reservation? + followup: + - condition: entities.yes + response: | + + - condition: entities.no + response: Ok, have a good day! +``` +``` +🧑 I'd like to make a reservation for 6 people tomorrow at 5 pm +🤖 OK. I am making you a reservation for 6 on 2023-06-22 at 17:00:00. +🤖 Do you want to make one more reservation? +🧑 When does the restaurant close? +🤖 The restaurant is open from 8am to 8pm. +🤖 Do you want to make one more reservation? +``` \ No newline at end of file diff --git a/website/docs/design-guides/stories.md b/website/docs/design-guides/stories.md index 2206067..ec562b3 100644 --- a/website/docs/design-guides/stories.md +++ b/website/docs/design-guides/stories.md @@ -3,6 +3,7 @@ You can use the stories mechanism to test the bot. This is a mechanism that verifies that the bot will react in the expected way to events known in advance from the user. Stories are written in a YAML file, the full format of which can be found in [Design Reference](/design-reference/stories.md). +Multiple stories files can be grouped in a directory. ## Example of using the stories @@ -31,24 +32,61 @@ and expects to receive it back: response: lorem ipsum ``` -Now we can start the stories engine with the `stories` command of the `maxbot` utility. -We can only specify the path to `bot.yaml` in the `-B` argument (like the `run` command): +You will need [pytest](https://pytest.org/) to run stories. +It's a great framework for testing. +MaxBot is registered in it as plugin `maxbot_stories` to run stories. +`pytest` is not installed as a dependency to `maxbot` by default. +`pytest` must be installed additionally: ```bash -$ maxbot stories -B examples/echo/bot.yaml -echo OK +$ pip install pytest ``` -`OK` after the name of the story (`echo`) indicates that all the steps in the story were completed in accordance with the expected behavior of the bot. +**Note:** if in process of running stories you get warning message: +``` +PytestConfigWarning: Unknown config option: asyncio_mode +``` +You can either ignore it or install the `pytest-asyncio` package to remove the warning: +```bash +$ pip install pytest-asyncio +``` + +Now we can start the stories engine. +If `--bot` option is specified when starting `pytest`, +then `pytest` will interpret the file passed to it as a YAML stories file. + +```bash +$ pytest --bot bot.yaml stories.yaml +============================ test session starts ============================ +platform openbsd7 -- Python 3.9.17, pytest-7.3.1, pluggy-1.0.0 +plugins: anyio-3.7.0, asyncio-0.20.3, cov-4.1.0, respx-0.20.1, maxbot-0.2.0 +collected 1 item + +stories.yaml . [100%] -The stories file can be located anywhere in the file system hierarchy. -To run an arbitrary stories file for execution, you need to specify the path to it in the `-S` argument: +============================ 1 passed in 6.75s ============================== +``` +`1 passed` indicates that one story was successfully completed. +If we need more runtime details we can add the `-v` option: ```bash -$ maxbot stories -B examples/echo/bot.yaml -S examples/echo/stories.yaml -echo OK +$ pytest --bot bot.yaml stories.yaml +============================ test session starts ============================ +platform openbsd7 -- Python 3.9.17, pytest-7.3.1, pluggy-1.0.0 +plugins: anyio-3.7.0, asyncio-0.20.3, cov-4.1.0, respx-0.20.1, maxbot-0.2.0 +collected 1 item + +stories.yaml::echo PASSED [100%] + +============================ 1 passed in 6.75s ============================== ``` +The value of option `--bot` can be a path for bot file or directory or the Maxbot instance to load. +The instance can be in the form 'module:name'. +Module can be a dotted import. +Name is not required if it is 'bot'. +This value is the same as the value of option `-B` / `--bot` when starting [maxbot run](/design-reference/cli#run). + ### Error detection Modify the bot code so that it returns only the first 6 characters of the user's message: @@ -78,31 +116,50 @@ splitting the story with three steps into three separate stories: response: lorem ipsum ``` -And run the `maxbot` utility again with the `stories` command: +Let's run stories again, but +now specify `-x` option - terminate testing at first mismatch: ```bash -$ maxbot stories -B examples/echo/bot.yaml -hello OK -how are you? FAILED at step [0] +$ pytest -x --bot bot.yaml stories.yaml +============================ test session starts ============================ +platform openbsd7 -- Python 3.9.17, pytest-7.3.1, pluggy-1.0.0 +plugins: anyio-3.7.0, asyncio-0.20.3, cov-4.1.0, respx-0.20.1, maxbot-0.2.0 +collected 3 items + +stories.yaml .F + +============================ FAILURES ======================================= +____________________________ how are you? ___________________________________ +Mismatch at step [0] Expected: how are you? Actual: how ar -Aborted! +============================ short test summary info ======================== +FAILED stories.yaml::how are you? +!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!! +=========================== 1 failed, 1 passed in 2.19s ===================== + ``` +The result: +* 3 stories were collected for testing +* first story executed successfully +* second story executed with mismatch, +* testing being stopped (third story was not run) + In the console output of the utility we see: In the steps of the `hello` story all bot responses are as expected. In the steps of the `how are you?` story at first step (steps are numbered from 0) bot response is not as expected. We expect a text message from the bot with the content "how are you?", but received a "how ar" message. -If we don't want stories to fail, we can mark the problematic story with the `xfail` field: +If we don't want stories to fail, we can mark the problematic story with the `xfail` mark: ```yaml - name: hello turns: - message: hello response: hello - name: how are you? - xfail: true + markers: ["xfail"] turns: - message: how are you? response: how are you? @@ -112,24 +169,104 @@ If we don't want stories to fail, we can mark the problematic story with the `xf response: lorem ipsum ``` -And run the `maxbot` utility again with the `stories` command: +And run stories again with the `stories` command with same command-line options: ```bash -$ maxbot stories -B examples/echo/bot.yaml -how are you? XFAIL at step [0] -Expected: - how are you? -Actual: - how ar -lorem ipsum FAILED at step [0] +$ pytest -x --bot bot.yaml stories.yaml +============================ test session starts ============================ +platform openbsd7 -- Python 3.9.17, pytest-7.3.1, pluggy-1.0.0 +plugins: anyio-3.7.0, asyncio-0.20.3, cov-4.1.0, respx-0.20.1, maxbot-0.2.0 +collected 3 items + +stories.yaml .xF + +============================ FAILURES ======================================= +____________________________ lorem ipsum ____________________________________ +Mismatch at step [0] Expected: lorem ipsum Actual: lorem -Aborted! +============================ short test summary info ======================== +FAILED stories.yaml::lorem ipsum +!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!! +============================ 1 failed, 1 passed, 1 xfailed in 2.20s ========= ``` -We see that the result of the "how are you?" story has changed from `FAILED` to `XFAIL`. +We see that the result of the "how are you?" story has changed from `failed` to `xfailed`. And now the next story "lorem ipsum" has started after it. -If running the `stories` command results in "FAILED", the return code of the `maxbot` utility will be different from 0. -Otherwise the `maxbot` utility will return 0. + +### Selective stories run + +For example, we have such stories (`stories.yaml` file context): +```yaml +- name: branch_1_story_1 + turns: + - message: test1 + response: test1 +- name: branch_2_story_2 + turns: + - message: test2 + response: test2 +- name: branch_2_story_3 + turns: + - message: test3 + response: test3 +- name: branch_2_story_4 + turns: + - message: test4 + response: test4 +``` + +To run only history `branch_1_story_1` you can use the command line parameter `-k` +```bash +$ pytest --bot bot.yaml -k "branch_1_story_1" stories.yaml +============================ test session starts ============================ +platform openbsd7 -- Python 3.9.17, pytest-7.3.1, pluggy-1.0.0 +plugins: anyio-3.7.0, asyncio-0.20.3, cov-4.1.0, respx-0.20.1, maxbot-0.2.0 +collected 4 items / 3 deselected / 1 selected + +stories.yaml . [100%] + +============================ 1 passed, 3 deselected in 7.07s ================ +``` + +But you can also use more complex expressions. +For example: +you can run all stories with an `branch_2_` substring in the name, excluding `branch_2_story_3`. +```bash +$ pytest -v --bot bot.yaml -k "branch_2_ and not branch_2_story_3" stories.yaml + +============================ test session starts ============================ +platform openbsd7 -- Python 3.9.17, pytest-7.3.1, pluggy-1.0.0 +plugins: anyio-3.7.0, asyncio-0.20.3, cov-4.1.0, respx-0.20.1, maxbot-0.2.0 +collected 4 items / 2 deselected / 2 selected + +stories.yaml::branch_2_story_2 PASSED [ 50%] +stories.yaml::branch_2_story_4 PASSED [100%] + +============================ 2 passed, 2 deselected in 6.07s ================ +``` + +### Directory usage + +As the number of stories increases, storing them in a single file becomes inconvenient. +Therefore, you can create a separate directory in which to place YAML files with stories +Let's split our previous example into three files: +`stories/1.yaml`, `stories/2.yaml` and `stories/3.yaml`. +Now let's run `pytest` to execute all stories from `stories/` directory: +```bash +$ pytest -v --bot bot.yaml stories/ +============================ test session starts ============================ +platform openbsd7 -- Python 3.9.17, pytest-7.3.1, pluggy-1.0.0 +cachedir: .pytest_cache +plugins: anyio-3.7.0, asyncio-0.20.3, cov-4.1.0, respx-0.20.1, maxbot-0.2.0 +collected 3 items + +stories/1.yaml::hello PASSED [ 33%] +stories/2.yaml::how are you? PASSED [ 66%] +stories/3.yaml::lorem ipsum PASSED [100%] + +============================ 3 passed in 3.34s ============================== + +``` diff --git a/website/docs/design-reference/channels.md b/website/docs/design-reference/channels.md new file mode 100644 index 0000000..1f64f48 --- /dev/null +++ b/website/docs/design-reference/channels.md @@ -0,0 +1,58 @@ +# Builtin channels + +This document contains the description and settings configuration of MaxBot builtin channels. +You can expand your bot by [implementing a new channel](/coding-guides/channels.md)! + +## Telegram + +[Telegram Messenger](https://telegram.org/) is instant messaging service. +To make your bot available in this messenger you need to add channel `telegram` to the `channels` section. +The channel has the following configuration settings: + +| Name | Type | Description | See also | +| ------------- | ---- | ----------- | -------- | +| `api_token`\* | [String](/design-reference/strings.md) | Authentication token to access telegram bot api | https://core.telegram.org/bots#how-do-i-create-a-bot | +| `timeout` | [Timeout](/design-reference/timeout.md) | Default HTTP request timeouts | https://www.python-httpx.org/advanced/#timeout-configuration | +| `limits` | [Pool limits](/design-reference/pool-limits.md) | Pool limit configuration | https://www.python-httpx.org/advanced/#pool-limit-configuration | + +## Facebook + +[Facebook Messenger](https://www.messenger.com/) instant messaging application and platform developed by [Meta Platforms](https://meta.com/). +To make your bot available in this messenger you need to add channel `facebook` to the `channels` section. +The channel has the following configuration settings: + +| Name | Type | Description | See also | +| ---------------- | ---- | ----------- | -------- | +| `app_secret`\* | [String](/design-reference/strings.md) | Facebook `App Secret` | https://developers.facebook.com/docs/facebook-login/security/#appsecret | +| `access_token`\* | [String](/design-reference/strings.md) | Facebook `App Access Tokens` | https://developers.facebook.com/docs/facebook-login/security/#appsecret | +| `timeout` | [Timeout](/design-reference/timeout.md) | Default HTTP request timeouts | https://www.python-httpx.org/advanced/#timeout-configuration | +| `limits` | [Pool limits](/design-reference/pool-limits.md) | Pool limit configuration | https://www.python-httpx.org/advanced/#pool-limit-configuration | + +## Viber + +[Rakuten Viber](http://viber.com/) is instant messaging software application. +To make your bot available in this messenger you need to add channel `viber` to the `channels` section. +The channel has the following configuration settings: + +| Name | Type | Description | See also | +| ------------- | ---- | ----------- | -------- | +| `api_token`\* | [String](/design-reference/strings.md) | Authentication token to access viber bot api | https://developers.viber.com/docs/api/rest-bot-api/#authentication-token | +| `name` | [String](/design-reference/strings.md) | Bot name | https://developers.viber.com/docs/api/python-bot-api/#firstly-lets-import-and-configure-our-bot https://developers.viber.com/docs/api/python-bot-api/#userprofile-object | +| `avatar` | [String](/design-reference/strings.md) | Bot avatar | https://developers.viber.com/docs/api/python-bot-api/#firstly-lets-import-and-configure-our-bot https://developers.viber.com/docs/api/python-bot-api/#userprofile-object | +| `timeout` | [Timeout](/design-reference/timeout.md) | Default HTTP request timeouts | https://www.python-httpx.org/advanced/#timeout-configuration | +| `limits` | [Pool limits](/design-reference/pool-limits.md) | Pool limit configuration | https://www.python-httpx.org/advanced/#pool-limit-configuration | + +## VK + +[VK](http://vk.com/) is online social media and social networking service. +To make your bot available in this messenger you need to add channel `vk` to the `channels` section. +The channel has the following configuration settings: + +| Name | Type | Description | See also | +| ---------------- | ---- | ----------- | -------- | +| `access_token`\* | [String](/design-reference/strings.md) | Authentication token to access VK bot api | https://dev.vk.com/api/access-token/authcode-flow-user | +| `group_id` | [Integer](/design-reference/numbers.md) | `group_id` for VK page, if present, the incoming messages will be checked against it and use for set webhook | | +| `secret_key` | [String](/design-reference/strings.md) | Secret key, use for set webhook | https://dev.vk.com/method/groups.addCallbackServer | +| `server_title` | [String](/design-reference/strings.md) | Server title, use for set webhook | https://dev.vk.com/method/groups.addCallbackServer | +| `timeout` | [Timeout](/design-reference/timeout.md) | Default HTTP request timeouts | https://www.python-httpx.org/advanced/#timeout-configuration | +| `limits` | [Pool limits](/design-reference/pool-limits.md) | Pool limit configuration | https://www.python-httpx.org/advanced/#pool-limit-configuration | diff --git a/website/docs/design-reference/cli.md b/website/docs/design-reference/cli.md index 819a762..65a34c1 100644 --- a/website/docs/design-reference/cli.md +++ b/website/docs/design-reference/cli.md @@ -41,16 +41,7 @@ Run the bot. | -q, --quiet | Do not log to console. | | --journal-file FILENAME | Write the journal to the file | | --journal-output [json\|yaml] | Journal file format [default: json] | +| --workers | Number of web application worker processes to spawn. | +| --fast | Set the number of web application workers to max allowed. | +| --single-process | Run web application in a single process. | | --help | Show this message and exit. | - -## stories - -Run bot stories. -You can use the stories mechanism to test the bot. -This is a mechanism that verifies that the bot will react in the expected way to events known in advance from the user. - -| Name | Description | -| ------------------ | ------------------------------ | -| -B, --bot TEXT | Path for bot file or directory or the Maxbot instance to load. The instance can be in the form 'module:name'. Module can be a dotted import. Name is not required if it is 'bot'. [required] | -| -S, --stories FILE | Path to YAML file with stories | -| --help | Show this message and exit. | diff --git a/website/docs/design-reference/context.md b/website/docs/design-reference/context.md index a9d6d8f..8b80f7c 100644 --- a/website/docs/design-reference/context.md +++ b/website/docs/design-reference/context.md @@ -21,7 +21,7 @@ Special variables related to the [digression](/design-guides/digressions.md) flo | Name | Type | Description | | ----------- | ----------- | ----------- | -| `digressing` | [Boolean](/design-reference/booleans.md) | The variable is set to true during the digression. You can check it to [prevent digressions into](/design-guides/digressions.md#check-digressing) a particular root node. | +| `digressing` | [Boolean](/design-reference/booleans.md) | The variable is set to true during the digression. You can check it to [prevent digressions](/design-guides/digressions.md#check-digressing) into a particular root node or execute some commands. | | `returning` | [Boolean](/design-reference/booleans.md) | The variable is set to true when returning after digression. Use it to add [custom return message](/design-guides/digressions.md#custom-return-message) to node response. | The following special variables can help you check and set values in slots. diff --git a/website/docs/design-reference/mp.md b/website/docs/design-reference/mp.md new file mode 100644 index 0000000..f7a93e3 --- /dev/null +++ b/website/docs/design-reference/mp.md @@ -0,0 +1,45 @@ +# Multiprocess WEB application + +To receive incoming messages, the bot can work in two modes: `webhooks` and `polling` (see [run --updater argument](/design-reference/cli.md#run)). +The current document only describes operation in `webhooks` mode. +Mode `webhooks` implies that the bot acts as an HTTP server. +The server receives and processes incoming HTTP requests. +Such server is called WEB application. + +The bot handles incoming requests with concurrent code (asynchronous I/O). +If your bot can't handle the incoming load you can scale incoming requests processing to multiple processes. +You can control the number of processes handling incoming requests with command line argument [--workers](/design-reference/cli.md#run). +Or you can specify [--fast](/design-reference/cli.md#run) option to automatically determine the number of processes for best performance. + +Working in multi-process mode has its own characteristics and imposes a number of requirements. +If your bot code is not designed to work in several processes, you can force it to work in single process mode with [--single-process option](/design-reference/cli.md#run). + +## Workers + +The processes responsible for handling incoming HTTP requests are called workers. +The bot implements a user locking mechanism, +so all incoming user messages are processed sequentially for each user (including sending reply commands). +Messages from different users can be processed simultaneously both in one worker (concurrent code, asynchronous I/O) +and in different workers (processes). +You can implement your own custom user lock that will meet the same requirements. + +## Persistence storage + +By default, we use SQLite as our storage engine. +It allows multiple processes to work on a single database file. +But this engine does not allow parallel writes. +This can cause serious performance issues. +Therefore, we recommend using a different database engine (such as PostgreSQL, MySQL) for highly loaded solutions. +For details, see [SQLite FAQ](https://www.sqlite.org/faq.html#q5). + +A unique persistent storage file in a temporary directory is generated for each bot run. +It is not deleted when the bot is stopped. +You can delete the storage files yourself after stopping the bot. +For example (unix-like systems): files can be deleted with `rm /tmp/maxbot-*.db` command. + +## Known issues + +MaxBot uses [sanic](https://sanic.dev/en/) as a web application framework. +Unfortunately, there are unresolved issues in `sanic` that lead to various random exceptions when trying to shut down the bot by pressing Ctrl-C. +Examples of such exceptions can be found [here](https://community.sanicframework.org/t/random-exceptions-with-workermanager/1154). + diff --git a/website/docs/design-reference/pool-limits.md b/website/docs/design-reference/pool-limits.md new file mode 100644 index 0000000..d6f32c9 --- /dev/null +++ b/website/docs/design-reference/pool-limits.md @@ -0,0 +1,13 @@ +# HTTP pool limit + +With the pool limit object, you can control the size of the connection pool. +We use the `Limits` from the [httpx](https://www.python-httpx.org/) library as the value. + +## Fields + +| Name | Type | Description | Default value | +| --------------------------- | --------------------------------------- | ---------------------------------------------------- | ------------- | +| `max_keepalive_connections` | [Integer](/design-reference/numbers.md) | Number of allowable keep-alive connections | 20 | +| `max_connections` | [Integer](/design-reference/numbers.md) | Maximum number of allowable connections | 100 | +| `keepalive_expiry` | [Float](/design-reference/numbers.md) | Time limit on idle keep-alive connections in seconds | 5.0 | + diff --git a/website/docs/design-reference/stories.md b/website/docs/design-reference/stories.md index 369326d..59b979b 100644 --- a/website/docs/design-reference/stories.md +++ b/website/docs/design-reference/stories.md @@ -2,7 +2,8 @@ You can use the stories mechanism to test the bot. This is a mechanism that verifies that the bot will react in the expected way to events known in advance from the user. -Events are grouped into separate stories, which are all described together in one file as a list. +Events are grouped into separate stories, which are all described together in file as a list. +Multiple stories files can be grouped in a directory. ## `StorySchema` @@ -10,9 +11,9 @@ Each story is an object and has the following set of fields: | Name | Type | Description | | --------- | ---------------------------------------- | ----------------------------------------------------------------- | -| `xfail` | [Boolean](/design-reference/booleans.md) | The flag means that you expect the story to fail for some reason. | | `name`\* | [String](/design-reference/strings.md) | Printable name. | | `turns`\* | List of [TurnSchema](#turnschema) | List of story turns. | +| `markers` | List of [Strings](/design-reference/strings.md) | List of [pytest marks](https://docs.pytest.org/en/stable/how-to/mark.html) | ## `TurnSchema` @@ -33,6 +34,17 @@ If the `TurnSchema.utc_time` field contains a value with any time zone other tha After a step with the `TurnSchema.utc_time` field explicitly specified, all subsequent steps without this field will be shifted forward by 10 seconds. If the bot can respond with one of the predefined answers -(e.g. chosen by the bot's script using `random`) +(e.g. chosen by the bot's script using [`random`](/design-reference/lists/#filter-random)) then the `response` field should contain a list of strings. One list item for each possible bot response. + +```yaml +- name: seasons-random + turns: + - message: season + response: + - "Spring" + - "Summer" + - "Autumn" + - "Winter" +``` diff --git a/website/docs/design-reference/timedelta.md b/website/docs/design-reference/timedelta.md new file mode 100644 index 0000000..5eea153 --- /dev/null +++ b/website/docs/design-reference/timedelta.md @@ -0,0 +1,59 @@ +# `timedelta` object representation + +A [timedelta](https://docs.python.org/3/library/datetime.html#timedelta-objects) object represents a duration, the difference between two dates or times. + +## Feilds + +`timedelta` object is defined as a set of fields: +* weeks +* days +* hours +* minutes +* seconds +* microseconds +* milliseconds + +All fields are of type [integer](/design-reference/numbers.md) and default to `0`. + +## How to set a value + +Let's look at the ways of setting a value using the example of [rest](/extensions/rest.md) extension configuration. + +You can use a short syntax to specify a value in seconds: +```yaml +extensions: + rest: + services: + name: my_server + cache: 10 +``` +For the above example, values of successful HTTP requests will be cached for 10 seconds. +It's equivalent to the following: +```yaml +extensions: + rest: + services: + name: my_server + cache: + seconds: 10 +``` + +If we want to cache results for 90 minutes (5400 seconds), we can set it like this: +```yaml +extensions: + rest: + services: + name: my_server + cache: + hours: 1 + minutes: 30 +``` +Or like this: +```yaml +extensions: + rest: + services: + name: my_server + cache: + minutes: 90 +``` diff --git a/website/docs/design-reference/timeout.md b/website/docs/design-reference/timeout.md new file mode 100644 index 0000000..a083644 --- /dev/null +++ b/website/docs/design-reference/timeout.md @@ -0,0 +1,68 @@ +# HTTP request timeout + +The timeout object is used to control the maximum time to wait for an HTTP request to complete. +We use the `Timeout` from the `httpx` library as the value. +More details on the use of timeout can be found in the [documentation for the httpx library](https://www.python-httpx.org/advanced/#timeout-configuration). + + +## Fields + +There are four different types of timeouts that may occur. These are `connect`, `read`, `write`, and `pool` timeouts. + +* The `connect` timeout specifies the maximum amount of time to wait until a socket connection to the requested host is established. +* The `read` timeout specifies the maximum duration to wait for a chunk of data to be received (for example, a chunk of the response body). +* The `write` timeout specifies the maximum duration to wait for a chunk of data to be sent (for example, a chunk of the request body). +* The `pool` timeout specifies the maximum duration to wait for acquiring a connection from the connection pool. + +All fields are of type [float](/design-reference/numbers.md) and contain values in seconds. + +## How to set a value + +Let's look at the ways of setting a value using the example of [rest](/extensions/rest.md) extension configuration. + +If no value is set: +```yaml +extensions: + rest: {} +``` +5 seconds timeout will be applied by default: +```python +httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=5.0) +``` + +You can use a short syntax to fill the values of all fields with the same value: +```yaml +extensions: + rest: + timeout: 6.5 +``` +So the value of all fields will be filled with the value of 6.5 seconds +```python +httpx.Timeout(connect=6.5, read=6.5, write=6.5, pool=6.5) +``` + +You can fill the value of selected fields: +```yaml +extensions: + rest: + timeout: + read: 1.0 + write: 1.0 +``` +The values of the other fields will be filled in by default: +```python +httpx.Timeout(connect=5.0, read=1.0, write=1.0, pool=5.0) +``` + +The default value can also be changed: +```yaml +extensions: + rest: + timeout: + connect: 3.5 + default: 2.0 +``` +For the above example, we obtain the following value: +```python +httpx.Timeout(connect=3.5, read=2.0, write=2.0, pool=2.0) +``` diff --git a/website/docs/extensions/rest.md b/website/docs/extensions/rest.md new file mode 100644 index 0000000..148de0a --- /dev/null +++ b/website/docs/extensions/rest.md @@ -0,0 +1,191 @@ +# rest extension + +The `rest` extension is designed to perform HTTP requests, the input and output of which are described by JSON objects. +This extension provides global function `rest_call` and Jinja tags for every possible HTTP method: +`GET`, `POST`, `PUT`, `DELETE` and `PATCH`. + +## Extension configuration + +| Name | Type | Description | +| --------------------------- | ----------------------------------------------- | --------------------------------- | +| `services` | Sequence (list) of [Service](#service) | Pre-configured remote server data | +| `timeout` | [Timeout](/design-reference/timeout.md) | Default HTTP request timeout | +| `limits` | [Pool limits](/design-reference/pool-limits.md) | HTTP pool limits configuration | +| `garbage_collector_timeout` | [Time delta](/design-reference/timedelta.md) | Garbage collector timeout | + +Value `garbage_collector_timeout` sets the time interval not more often than garbage collector will be triggered. +Garbage collector removes cached results older than `garbage_collector_timeout`. +Default value: 1 hour. + +### `Service` + +Each `Service` object has the following format: + +| Field name | Type | Description | +| ------------ | ----------------------------------------------- | -------------------------------------------------------------- | +| `name`\* | [String](/design-reference/strings.md) | Unique name (case-insensitive) | +| `method` | [String](/design-reference/strings.md) | HTTP request method: `get`, `post`, `put`, `delete` or `patch` | +| `base_url` | [String](/design-reference/strings.md) | Basic part of the URL | +| `auth` | [Auth](#auth) | Authentication data | +| `headers` | [Dictionary](/design-reference/dictionaries.md) | HTTP headers | +| `parameters` | [Dictionary](/design-reference/dictionaries.md) | URL parameters | +| `timeout` | [Timeout](/design-reference/timeout.md) | Default HTTP request timeout | +| `limits` | [Pool limits](/design-reference/pool-limits.md) | HTTP pool limits configuration | +| `cache` | [Time delta](/design-reference/timedelta.md) | Time for which a successful requests is cached | + +If `limits` field is not specified, the value of the `limits` will be taken from [extension configuration](#extension-configuration). +If both values are missing, the default values will be used (see [Pool limits](/design-reference/pool-limits.md)). + +If `cache` field is specified, the all results of a successful HTTP requests for this serivce will be cached. +Be careful with this feature. We **recommend to cache only** `GET` requests. + +### `Auth` + +`Auth` object has the following format: + +| Field name | Type | Description | +| ------------ | -------------------------------------- | ------------- | +| `user`\* | [String](/design-reference/strings.md) | User name | +| `password`\* | [String](/design-reference/strings.md) | User password | + +## `rest_call` function + +Arguments of function `rest_call`: + +| Name | Type | Description | +| ------------ | -------------------------------------- | ----------- | +| `service` | [String](/design-reference/strings.md) | Unique name of [service](#service) from configuration (case-insensitive) | +| `url` | [String](/design-reference/strings.md) | URL | +| `method` | [String](/design-reference/strings.md) | HTTP request method: `get`, `post`, `put`, `delete` or `patch` | +| `auth` | [Auth](#auth) | Authentication data | +| `body` | [Dictionary](/design-reference/dictionaries.md) or [String](/design-reference/strings.md) | HTTP request body | +| `headers` | [Dictionary](/design-reference/dictionaries.md) | HTTP headers | +| `parameters` | [Dictionary](/design-reference/dictionaries.md) | URL parameters | +| `timeout` | [Integer](/design-reference/numbers.md) | Request timeout in seconds | +| `on_error` | [String](/design-reference/strings.md) | Function behavior when an error occurs: `continue` or `break_flow` | +| `cache` | [Integer](/design-reference/numbers.md) | Time (in seconds) for which a successful request is cached | + +The `rest_call` function returns an dictionary: + +| Field name | Type | Description | +| --------------- | ---------------------------------------- | ------------------------- | +| `ok`\* | [Boolean](/design-reference/booleans.md) | Request success flag | +| `status_code`\* | [Integer](/design-reference/numbers.md) | HTTP response status code | +| `json`\* | Any | Response data | + + +### Arguments `service` and `url` + +When calling `rest_call` function, there are two ways to refer to a service: +* specify argument `service` +* specify argument `url` in the format `service://[url]"` + +For example, the following bot responds to the user with a link to the latest release from GitHub. +```yaml +extensions: + rest: + services: + - name: api_github + base_url: https://api.github.com/ +dialog: +- condition: true + response: | + {% set rest = rest_call(method="get", url="api_github://repos/maxbot-ai/maxbot/releases/latest") %} + {{ rest.json.html_url }} +``` + +We can make the same call by explicitly referring to the service by name: +``` + {% set rest = rest_call(method="get", service="api_github", url="repos/maxbot-ai/maxbot/releases/latest") %} +``` + +Or not use extension configuration at all: +```yaml +extensions: + rest: {} +dialog: +- condition: true + response: | + {% set rest = rest_call(method="get", url="https://api.github.com/repos/maxbot-ai/maxbot/releases/latest") %} + {{ rest.json.html_url }} +``` + +### Argument `method` + +HTTP request method: `get`, `post`, `put`, `delete` or `patch`. +Default value: `post` if argument `body` is given, otherwise `get` + +### Argument `auth` + +Authentication data. +If no argument is passed, then the value is taken from the service. +If the value is not set anywhere, then the request is made without authentication. + +### Argument `body` + +If value of `body` is a [string](/design-reference/strings.md), it is passed to body of HTTP request as is. + +If value of `body` is a [dictionary](/design-reference/dictionaries.md): +* if `Content-Type` is explicitly specified in `headers` by `application/x-www-form-urlencoded`, the `body` will be URL-encoded; +* in the opposite case, the `body` will be JSON-encoded. + +### Argument `headers` + +HTTP request headers. +The dictionary passed in function argument is merged with dictionary from the service. +Values from function argument have higher precedence. + +### Argument `parameters` + +URL parameters. +The dictionary passed in function argument is merged with dictionary from the service. +Values from function argument have higher precedence. + + +### Argument `timeout` + +Request timeout in seconds. Value 0 is ignored. +If no argument is passed, then the value is taken from the service or from extension configuration. +If the value is not set anywhere, then the request is executed with a timeout of 5 seconds. + +### Argument `on_error` + +This argument controls the behavior of the function in case of an error in the execution of an HTTP request: +* `continue`: function will return control. + The error signal is the `ok` field, which has the value `False` on the returned dictionary. +* `break_flow` (default value): function will throw an exception `BotError`. + +### Argument `cache` + +If `cache` argument (or the corresponding setting in [Service](#service)) is given, the result of a successful HTTP request will be cached. +Be careful with this feature. We **recommend to cache only** `GET` requests. + +## Using Jinja tags + +You can use Jinja tags instead of `rest_call` function: `GET`, `POST`, `PUT`, `DELETE` or `PATCH`. +These tags correspond to the available HTTP request methods. + +As a result of the tag operation, a local variable `rest` will be created. +It will be filled with the return value of function `rest_call`. + +Let's rewrite the example where the bot replies to the user with a link to the latest release from GitHub. +```yaml +extensions: + rest: + services: + - name: api_github + base_url: https://api.github.com/ +dialog: +- condition: true + response: | + {% GET "api_github://repos/maxbot-ai/maxbot/releases/latest" %} + {{ rest.json.html_url }} +``` + +The tag is followed by a string with the [value of the URL](#arguments-service-and-url). +The remaining arguments of `rest_call` function can be specified further in pairs: `argument_name` (unquoted), `value`. + +For example, increasing the timeout to 15 seconds looks like this: +``` + {% GET "api_github://repos/maxbot-ai/maxbot/releases/latest" timeout 15 %} +``` diff --git a/website/docs/getting-started/creating-bots.md b/website/docs/getting-started/creating-bots.md index 0f94f03..f988aa8 100644 --- a/website/docs/getting-started/creating-bots.md +++ b/website/docs/getting-started/creating-bots.md @@ -144,7 +144,7 @@ These are called "flow collections" and can sometimes be useful. ## Channels -Channels are a way to integrate your bot with various messaging platforms. You must configure at least one channel to create a bot. MaxBot provides pre-built channels for Telegram, Viber Chatbots, VKontakte and there will be more soon. +Channels are a way to integrate your bot with various messaging platforms. You must configure at least one channel to create a bot. MaxBot provides pre-built channels for `Telegram`, `Viber Chatbots`, `VK` and there will be more soon. Just add the channel configuration to the bot resources to integrate the bot with that channel. The process of integrating your bot with a specific messaging platform is covered in the schema description of the corresponding channel. For example, [TelegramSchema](/design-reference/resources.md#telegramschema) description says that you should go to [@BotFacther](https://t.me/botfather), get an API token and specify it in the bot resources. @@ -153,7 +153,7 @@ channels: telegram: api_token: 110201543:AAHdqTcvCH1vGWJxfSeofSAs0K5PALDsaw ``` - +See [Channel settings](/design-guides/channel-setting) for information about other channels. ## Intents Intents are purposes or goals that are expressed in a user input, such as answering a question or making a reservation. By recognizing the intent expressed in a user's input, the bot can choose the correct dialog flow for responding to it. For example, you might want to define `intents.buy_something` intent to recognize when the user wants to make a purchase. @@ -200,6 +200,15 @@ Rule-based entities cover commonly used categories, such as numbers or dates. Yo Phrase and regex entities are defined in bot resources. An entity definition includes a set of entity values that represent vocabulary that is often used in the context of a given intent. For each value, you define a set of recognition rules, which can be either phrases or regular expressions. +You can also save entities values to variables. There are two kinds of state variables: + +* Slot variables are used by the bot as short-term memory to keep the conversation on the current topic only. +* User variables are long-term memory that is used to personalize the entire communication process. + +State variables are used to retaining information across dialog turns. Use state variables to collect information from the user and then refer back to it later in the conversation. + +For more information about state variables, see [State Variables](/design-guides/state.md) guide. + ### Phrase Entities {#phrase-entities} You define an entity (`entities.menu`), and then one or more values for that entity (`standard`, `vegetarian`, `cake`). For each value, you specify a bunch of phrases with which this value can be mentioned in the user input, e.g. "cake shop", "desserts" and "bakery offerings" for the `cake` value etc. @@ -230,6 +239,7 @@ dialog: ``` MaxBot recognizes pieces of information in the user input that closely match the phrases that you defined for the entity as mentions of that entity. +MaxBot performs a comparison of phrases on tokens, which are extracted on the basis of linguistic features of a particular language, [details](https://spacy.io/usage/linguistic-features#tokenization). ### Regex Entities {#regex-entities} @@ -249,7 +259,7 @@ entities: Note, that we use single-quoted strings to specify regular expressions. This avoids escaping mess when your string contains a lot of backslashes. -MaxBot looks for patterns matching your regular expression in the user input, and identifies any matches as mentions of that entity. +MaxBot performs a character-by-character comparison with a template, and identifies any matches as mentions of that entity. ### Using entities @@ -365,7 +375,7 @@ Use the syntax `entities.city.boston` to trigger node when *the entity value* is The node is used if the state variable expression that you specify is true. Use the syntax with [comparison operators](/design-reference/booleans.md#comparisons) like, `dialog.channel_name == "telegram"` or `slots.guests > 5`. -Make sure state variables are initialized before usage, otherwise use them with the [default filter](/design-reference/jinja.md#filter-default): `slots.counter|default(0) > 1`. +Make sure state variables are initialized before usage (see [State Variables](/design-guides/state.md)), otherwise use them with the [default filter](/design-reference/jinja.md#filter-default): `slots.counter|default(0) > 1`. For node conditions, state variables is typically used with an `and` operator and another condition value because something in the user input must trigger the node. @@ -400,7 +410,7 @@ See [Protocol](/design-reference/protocol.md) reference to learn more about diff The `true` condition is always evaluated to true. You can use it at the end of a list of nodes to catch any requests that did not match any of the previous conditions. -The `message` expression is used in to trigger node if the MaxBot receives the *any type of message* from the user, i.e. text message, image file, button click, etc. This is the opposite of `rpc` methods, which are called by integrations and not by users. +The `message` expression is used in to trigger node if the MaxBot receives the *any type of message* from the user, i.e. text message, image file, button click, etc. This is the opposite of [`RPC`](/design-guides/rpc) (remote procedure call) methods, which are called by integrations and not by users. ```yaml - condition: message diff --git a/website/docs/getting-started/quick-start.md b/website/docs/getting-started/quick-start.md index b959698..d7282f2 100644 --- a/website/docs/getting-started/quick-start.md +++ b/website/docs/getting-started/quick-start.md @@ -108,7 +108,7 @@ The bot will be available through the [Telegram Messenger](https://core.telegram To integrate with the messenger, contact [@BotFather](https://t.me/botfather) and ask it to create a bot for you and generate an API token. Then specify API token in the bot resources. Refer [official docs](https://core.telegram.org/bots#6-botfather) for more information about telegram bots. ::: -Save the bot resources as `bot.yaml` or something similar. Run the MaxBot CLI app passing the path to the `bot.yaml` as a parameter. +Save the bot resources as `bot.yaml` or something similar. Run the MaxBot Command Line Interface (CLI) app passing the path to the `bot.yaml` as a parameter. ```bash $ maxbot run --bot bot.yaml @@ -140,3 +140,12 @@ The output in your console will look like this Press `Ctrl-C` to exit MaxBot CLI app. Congratulations! You have successfully created and launched a simple bot and chatted with it. + + +## Examples + +You can find a lot of basic bot examples in this reference. If you want to get more complex ones, check out the list of examples below. They show the advanced features of Maxbot, such as custom messanger controls, integration with different REST services, databases and so on. + +- [Bank Bot example](https://github.com/maxbot-ai/bank_bot). +- [Taxi Bot example](https://github.com/maxbot-ai/taxi_bot). +- [Transport Bot example](https://github.com/maxbot-ai/transport_bot). \ No newline at end of file diff --git a/website/sidebars.js b/website/sidebars.js index d67658c..e7f48e8 100644 --- a/website/sidebars.js +++ b/website/sidebars.js @@ -57,6 +57,7 @@ const sidebars = { }, collapsed: true, items: [ + 'design-guides/channel-setting', 'design-guides/dialog-tree', 'design-guides/slot-filling', 'design-guides/digressions', @@ -93,7 +94,12 @@ const sidebars = { 'design-reference/numbers', ] }, - 'design-reference/stories' + 'design-reference/stories', + 'design-reference/timeout', + 'design-reference/pool-limits', + 'design-reference/timedelta', + 'design-reference/channels', + 'design-reference/mp', ], }, { @@ -108,6 +114,7 @@ const sidebars = { 'extensions/babel', 'extensions/rasa', 'extensions/jinja_loader', + 'extensions/rest', ], }, {