from typing import Any, Awaitable, Callable, Dict, Optional, cast from aiogram import Bot from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.dispatcher.middlewares.user_context import EVENT_CONTEXT_KEY, EventContext from aiogram.fsm.context import FSMContext from aiogram.fsm.storage.base import ( DEFAULT_DESTINY, BaseEventIsolation, BaseStorage, StorageKey, ) from aiogram.fsm.strategy import FSMStrategy, apply_strategy from aiogram.types import TelegramObject class FSMContextMiddleware(BaseMiddleware): def __init__( self, storage: BaseStorage, events_isolation: BaseEventIsolation, strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT, ) -> None: self.storage = storage self.strategy = strategy self.events_isolation = events_isolation async def __call__( self, handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], event: TelegramObject, data: Dict[str, Any], ) -> Any: bot: Bot = cast(Bot, data["bot"]) context = self.resolve_event_context(bot, data) data["fsm_storage"] = self.storage if context: # Bugfix: https://github.com/aiogram/aiogram/issues/1317 # State should be loaded after lock is acquired async with self.events_isolation.lock(key=context.key): data.update({"state": context, "raw_state": await context.get_state()}) return await handler(event, data) return await handler(event, data) def resolve_event_context( self, bot: Bot, data: Dict[str, Any], destiny: str = DEFAULT_DESTINY, ) -> Optional[FSMContext]: event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY)) return self.resolve_context( bot=bot, chat_id=event_context.chat_id, user_id=event_context.user_id, thread_id=event_context.thread_id, business_connection_id=event_context.business_connection_id, destiny=destiny, ) def resolve_context( self, bot: Bot, chat_id: Optional[int], user_id: Optional[int], thread_id: Optional[int] = None, business_connection_id: Optional[str] = None, destiny: str = DEFAULT_DESTINY, ) -> Optional[FSMContext]: if chat_id is None: chat_id = user_id if chat_id is not None and user_id is not None: chat_id, user_id, thread_id = apply_strategy( chat_id=chat_id, user_id=user_id, thread_id=thread_id, strategy=self.strategy, ) return self.get_context( bot=bot, chat_id=chat_id, user_id=user_id, thread_id=thread_id, business_connection_id=business_connection_id, destiny=destiny, ) return None def get_context( self, bot: Bot, chat_id: int, user_id: int, thread_id: Optional[int] = None, business_connection_id: Optional[str] = None, destiny: str = DEFAULT_DESTINY, ) -> FSMContext: return FSMContext( storage=self.storage, key=StorageKey( user_id=user_id, chat_id=chat_id, bot_id=bot.id, thread_id=thread_id, business_connection_id=business_connection_id, destiny=destiny, ), ) async def close(self) -> None: await self.storage.close() await self.events_isolation.close()