Третий коммит, добавление share, share_kb, а также ADMIN_ID

This commit is contained in:
2025-07-22 13:50:14 +03:00
parent 849feb7beb
commit b98123f4dc
1479 changed files with 323549 additions and 11 deletions

View File

@@ -0,0 +1,41 @@
from typing import Any, Dict, Optional, overload
from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey
class FSMContext:
def __init__(self, storage: BaseStorage, key: StorageKey) -> None:
self.storage = storage
self.key = key
async def set_state(self, state: StateType = None) -> None:
await self.storage.set_state(key=self.key, state=state)
async def get_state(self) -> Optional[str]:
return await self.storage.get_state(key=self.key)
async def set_data(self, data: Dict[str, Any]) -> None:
await self.storage.set_data(key=self.key, data=data)
async def get_data(self) -> Dict[str, Any]:
return await self.storage.get_data(key=self.key)
@overload
async def get_value(self, key: str) -> Optional[Any]: ...
@overload
async def get_value(self, key: str, default: Any) -> Any: ...
async def get_value(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
return await self.storage.get_value(storage_key=self.key, dict_key=key, default=default)
async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]:
if data:
kwargs.update(data)
return await self.storage.update_data(key=self.key, data=kwargs)
async def clear(self) -> None:
await self.set_state(state=None)
await self.set_data({})

View File

@@ -0,0 +1,113 @@
from typing import Any, Awaitable, Callable, Dict, Optional, cast
from aiogram import Bot
from aiogram.dispatcher.middlewares.base import BaseMiddleware
from aiogram.dispatcher.middlewares.user_context import EVENT_CONTEXT_KEY, EventContext
from aiogram.fsm.context import FSMContext
from aiogram.fsm.storage.base import (
DEFAULT_DESTINY,
BaseEventIsolation,
BaseStorage,
StorageKey,
)
from aiogram.fsm.strategy import FSMStrategy, apply_strategy
from aiogram.types import TelegramObject
class FSMContextMiddleware(BaseMiddleware):
def __init__(
self,
storage: BaseStorage,
events_isolation: BaseEventIsolation,
strategy: FSMStrategy = FSMStrategy.USER_IN_CHAT,
) -> None:
self.storage = storage
self.strategy = strategy
self.events_isolation = events_isolation
async def __call__(
self,
handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
bot: Bot = cast(Bot, data["bot"])
context = self.resolve_event_context(bot, data)
data["fsm_storage"] = self.storage
if context:
# Bugfix: https://github.com/aiogram/aiogram/issues/1317
# State should be loaded after lock is acquired
async with self.events_isolation.lock(key=context.key):
data.update({"state": context, "raw_state": await context.get_state()})
return await handler(event, data)
return await handler(event, data)
def resolve_event_context(
self,
bot: Bot,
data: Dict[str, Any],
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
event_context: EventContext = cast(EventContext, data.get(EVENT_CONTEXT_KEY))
return self.resolve_context(
bot=bot,
chat_id=event_context.chat_id,
user_id=event_context.user_id,
thread_id=event_context.thread_id,
business_connection_id=event_context.business_connection_id,
destiny=destiny,
)
def resolve_context(
self,
bot: Bot,
chat_id: Optional[int],
user_id: Optional[int],
thread_id: Optional[int] = None,
business_connection_id: Optional[str] = None,
destiny: str = DEFAULT_DESTINY,
) -> Optional[FSMContext]:
if chat_id is None:
chat_id = user_id
if chat_id is not None and user_id is not None:
chat_id, user_id, thread_id = apply_strategy(
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
strategy=self.strategy,
)
return self.get_context(
bot=bot,
chat_id=chat_id,
user_id=user_id,
thread_id=thread_id,
business_connection_id=business_connection_id,
destiny=destiny,
)
return None
def get_context(
self,
bot: Bot,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
business_connection_id: Optional[str] = None,
destiny: str = DEFAULT_DESTINY,
) -> FSMContext:
return FSMContext(
storage=self.storage,
key=StorageKey(
user_id=user_id,
chat_id=chat_id,
bot_id=bot.id,
thread_id=thread_id,
business_connection_id=business_connection_id,
destiny=destiny,
),
)
async def close(self) -> None:
await self.storage.close()
await self.events_isolation.close()

View File

@@ -0,0 +1,952 @@
from __future__ import annotations
import inspect
from collections import defaultdict
from dataclasses import dataclass, replace
from enum import Enum, auto
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, Union, overload
from typing_extensions import Self
from aiogram import loggers
from aiogram.dispatcher.dispatcher import Dispatcher
from aiogram.dispatcher.event.bases import NextMiddlewareType
from aiogram.dispatcher.event.handler import CallableObject, CallbackType
from aiogram.dispatcher.flags import extract_flags_from_object
from aiogram.dispatcher.router import Router
from aiogram.exceptions import SceneException
from aiogram.filters import StateFilter
from aiogram.fsm.context import FSMContext
from aiogram.fsm.state import State
from aiogram.fsm.storage.memory import MemoryStorageRecord
from aiogram.types import TelegramObject, Update
from aiogram.utils.class_attrs_resolver import (
ClassAttrsResolver,
get_sorted_mro_attrs_resolver,
)
class HistoryManager:
def __init__(self, state: FSMContext, destiny: str = "scenes_history", size: int = 10):
self._size = size
self._state = state
self._history_state = FSMContext(
storage=state.storage, key=replace(state.key, destiny=destiny)
)
async def push(self, state: Optional[str], data: Dict[str, Any]) -> None:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
history.append({"state": state, "data": data})
if len(history) > self._size:
history = history[-self._size :]
loggers.scene.debug("Push state=%s data=%s to history", state, data)
await self._history_state.update_data(history=history)
async def pop(self) -> Optional[MemoryStorageRecord]:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
if not history:
return None
record = history.pop()
state = record["state"]
data = record["data"]
if not history:
await self._history_state.set_data({})
else:
await self._history_state.update_data(history=history)
loggers.scene.debug("Pop state=%s data=%s from history", state, data)
return MemoryStorageRecord(state=state, data=data)
async def get(self) -> Optional[MemoryStorageRecord]:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
if not history:
return None
return MemoryStorageRecord(**history[-1])
async def all(self) -> List[MemoryStorageRecord]:
history_data = await self._history_state.get_data()
history = history_data.setdefault("history", [])
return [MemoryStorageRecord(**item) for item in history]
async def clear(self) -> None:
loggers.scene.debug("Clear history")
await self._history_state.set_data({})
async def snapshot(self) -> None:
state = await self._state.get_state()
data = await self._state.get_data()
await self.push(state, data)
async def _set_state(self, state: Optional[str], data: Dict[str, Any]) -> None:
await self._state.set_state(state)
await self._state.set_data(data)
async def rollback(self) -> Optional[str]:
previous_state = await self.pop()
if not previous_state:
await self._set_state(None, {})
return None
loggers.scene.debug(
"Rollback to state=%s data=%s",
previous_state.state,
previous_state.data,
)
await self._set_state(previous_state.state, previous_state.data)
return previous_state.state
class ObserverDecorator:
def __init__(
self,
name: str,
filters: tuple[CallbackType, ...],
action: SceneAction | None = None,
after: Optional[After] = None,
) -> None:
self.name = name
self.filters = filters
self.action = action
self.after = after
def _wrap_filter(self, target: Type[Scene] | CallbackType) -> None:
handlers = getattr(target, "__aiogram_handler__", None)
if not handlers:
handlers = []
setattr(target, "__aiogram_handler__", handlers)
handlers.append(
HandlerContainer(
name=self.name,
handler=target,
filters=self.filters,
after=self.after,
)
)
def _wrap_action(self, target: CallbackType) -> None:
assert self.action is not None, "Scene action is not specified"
action = getattr(target, "__aiogram_action__", None)
if action is None:
action = defaultdict(dict)
setattr(target, "__aiogram_action__", action)
action[self.action][self.name] = CallableObject(target)
def __call__(self, target: CallbackType) -> CallbackType:
if inspect.isfunction(target):
if self.action is None:
self._wrap_filter(target)
else:
self._wrap_action(target)
else:
raise TypeError("Only function or method is allowed")
return target
def leave(self) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.leave)
def enter(self, target: Type[Scene]) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.enter, target)
def exit(self) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.exit)
def back(self) -> ActionContainer:
return ActionContainer(self.name, self.filters, SceneAction.back)
class SceneAction(Enum):
enter = auto()
leave = auto()
exit = auto()
back = auto()
class ActionContainer:
def __init__(
self,
name: str,
filters: Tuple[CallbackType, ...],
action: SceneAction,
target: Optional[Union[Type[Scene], str]] = None,
) -> None:
self.name = name
self.filters = filters
self.action = action
self.target = target
async def execute(self, wizard: SceneWizard) -> None:
if self.action == SceneAction.enter and self.target is not None:
await wizard.goto(self.target)
elif self.action == SceneAction.leave:
await wizard.leave()
elif self.action == SceneAction.exit:
await wizard.exit()
elif self.action == SceneAction.back:
await wizard.back()
class HandlerContainer:
def __init__(
self,
name: str,
handler: CallbackType,
filters: Tuple[CallbackType, ...],
after: Optional[After] = None,
) -> None:
self.name = name
self.handler = handler
self.filters = filters
self.after = after
@dataclass()
class SceneConfig:
state: Optional[str]
"""Scene state"""
handlers: List[HandlerContainer]
"""Scene handlers"""
actions: Dict[SceneAction, Dict[str, CallableObject]]
"""Scene actions"""
reset_data_on_enter: Optional[bool] = None
"""Reset scene data on enter"""
reset_history_on_enter: Optional[bool] = None
"""Reset scene history on enter"""
callback_query_without_state: Optional[bool] = None
"""Allow callback query without state"""
attrs_resolver: ClassAttrsResolver = get_sorted_mro_attrs_resolver
"""
Attributes resolver.
.. danger::
This attribute should only be changed when you know what you are doing.
.. versionadded:: 3.19.0
"""
async def _empty_handler(*args: Any, **kwargs: Any) -> None:
pass
class SceneHandlerWrapper:
def __init__(
self,
scene: Type[Scene],
handler: CallbackType,
after: Optional[After] = None,
) -> None:
self.scene = scene
self.handler = CallableObject(handler)
self.after = after
async def __call__(
self,
event: TelegramObject,
**kwargs: Any,
) -> Any:
state: FSMContext = kwargs["state"]
scenes: ScenesManager = kwargs["scenes"]
event_update: Update = kwargs["event_update"]
scene = self.scene(
wizard=SceneWizard(
scene_config=self.scene.__scene_config__,
manager=scenes,
state=state,
update_type=event_update.event_type,
event=event,
data=kwargs,
)
)
result = await self.handler.call(scene, event, **kwargs)
if self.after:
action_container = ActionContainer(
"after",
(),
self.after.action,
self.after.scene,
)
await action_container.execute(scene.wizard)
return result
def __await__(self) -> Self:
return self
def __str__(self) -> str:
result = f"SceneHandlerWrapper({self.scene}, {self.handler.callback}"
if self.after:
result += f", after={self.after}"
result += ")"
return result
class Scene:
"""
Represents a scene in a conversation flow.
A scene is a specific state in a conversation where certain actions can take place.
Each scene has a set of filters that determine when it should be triggered,
and a set of handlers that define the actions to be executed when the scene is active.
.. note::
This class is not meant to be used directly. Instead, it should be subclassed
to define custom scenes.
"""
__scene_config__: ClassVar[SceneConfig]
"""Scene configuration."""
def __init__(
self,
wizard: SceneWizard,
) -> None:
self.wizard = wizard
self.wizard.scene = self
def __init_subclass__(cls, **kwargs: Any) -> None:
state_name = kwargs.pop("state", None)
reset_data_on_enter = kwargs.pop("reset_data_on_enter", None)
reset_history_on_enter = kwargs.pop("reset_history_on_enter", None)
callback_query_without_state = kwargs.pop("callback_query_without_state", None)
attrs_resolver = kwargs.pop("attrs_resolver", None)
super().__init_subclass__(**kwargs)
handlers: list[HandlerContainer] = []
actions: defaultdict[SceneAction, Dict[str, CallableObject]] = defaultdict(dict)
for base in cls.__bases__:
if not issubclass(base, Scene):
continue
parent_scene_config = getattr(base, "__scene_config__", None)
if not parent_scene_config:
continue
if reset_data_on_enter is None:
reset_data_on_enter = parent_scene_config.reset_data_on_enter
if reset_history_on_enter is None:
reset_history_on_enter = parent_scene_config.reset_history_on_enter
if callback_query_without_state is None:
callback_query_without_state = parent_scene_config.callback_query_without_state
if attrs_resolver is None:
attrs_resolver = parent_scene_config.attrs_resolver
if attrs_resolver is None:
attrs_resolver = get_sorted_mro_attrs_resolver
for name, value in attrs_resolver(cls):
if scene_handlers := getattr(value, "__aiogram_handler__", None):
handlers.extend(scene_handlers)
if isinstance(value, ObserverDecorator):
handlers.append(
HandlerContainer(
value.name,
_empty_handler,
value.filters,
after=value.after,
)
)
if hasattr(value, "__aiogram_action__"):
for action, action_handlers in value.__aiogram_action__.items():
actions[action].update(action_handlers)
cls.__scene_config__ = SceneConfig(
state=state_name,
handlers=handlers,
actions=dict(actions),
reset_data_on_enter=reset_data_on_enter,
reset_history_on_enter=reset_history_on_enter,
callback_query_without_state=callback_query_without_state,
attrs_resolver=attrs_resolver,
)
@classmethod
def add_to_router(cls, router: Router) -> None:
"""
Adds the scene to the given router.
:param router:
:return:
"""
scene_config = cls.__scene_config__
used_observers = set()
for handler in scene_config.handlers:
router.observers[handler.name].register(
SceneHandlerWrapper(
cls,
handler.handler,
after=handler.after,
),
*handler.filters,
flags=extract_flags_from_object(handler.handler),
)
used_observers.add(handler.name)
for observer_name in used_observers:
if scene_config.callback_query_without_state and observer_name == "callback_query":
continue
router.observers[observer_name].filter(StateFilter(scene_config.state))
@classmethod
def as_router(cls, name: Optional[str] = None) -> Router:
"""
Returns the scene as a router.
:return: new router
"""
if name is None:
name = (
f"Scene '{cls.__module__}.{cls.__qualname__}' "
f"for state {cls.__scene_config__.state!r}"
)
router = Router(name=name)
cls.add_to_router(router)
return router
@classmethod
def as_handler(cls, **kwargs: Any) -> CallbackType:
"""
Create an entry point handler for the scene, can be used to simplify the handler
that starts the scene.
>>> router.message.register(MyScene.as_handler(), Command("start"))
"""
async def enter_to_scene_handler(event: TelegramObject, scenes: ScenesManager) -> None:
await scenes.enter(cls, **kwargs)
return enter_to_scene_handler
class SceneWizard:
"""
A class that represents a wizard for managing scenes in a Telegram bot.
Instance of this class is passed to each scene as a parameter.
So, you can use it to transition between scenes, get and set data, etc.
.. note::
This class is not meant to be used directly. Instead, it should be used
as a parameter in the scene constructor.
"""
def __init__(
self,
scene_config: SceneConfig,
manager: ScenesManager,
state: FSMContext,
update_type: str,
event: TelegramObject,
data: Dict[str, Any],
):
"""
A class that represents a wizard for managing scenes in a Telegram bot.
:param scene_config: The configuration of the scene.
:param manager: The scene manager.
:param state: The FSMContext object for storing the state of the scene.
:param update_type: The type of the update event.
:param event: The TelegramObject represents the event.
:param data: Additional data for the scene.
"""
self.scene_config = scene_config
self.manager = manager
self.state = state
self.update_type = update_type
self.event = event
self.data = data
self.scene: Optional[Scene] = None
async def enter(self, **kwargs: Any) -> None:
"""
Enter method is used to transition into a scene in the SceneWizard class.
It sets the state, clears data and history if specified,
and triggers entering event of the scene.
:param kwargs: Additional keyword arguments.
:return: None
"""
loggers.scene.debug("Entering scene %r", self.scene_config.state)
if self.scene_config.reset_data_on_enter:
await self.state.set_data({})
if self.scene_config.reset_history_on_enter:
await self.manager.history.clear()
await self.state.set_state(self.scene_config.state)
await self._on_action(SceneAction.enter, **kwargs)
async def leave(self, _with_history: bool = True, **kwargs: Any) -> None:
"""
Leaves the current scene.
This method is used to exit a scene and transition to the next scene.
:param _with_history: Whether to include history in the snapshot. Defaults to True.
:param kwargs: Additional keyword arguments.
:return: None
"""
loggers.scene.debug("Leaving scene %r", self.scene_config.state)
if _with_history:
await self.manager.history.snapshot()
await self._on_action(SceneAction.leave, **kwargs)
async def exit(self, **kwargs: Any) -> None:
"""
Exit the current scene and enter the default scene/state.
:param kwargs: Additional keyword arguments.
:return: None
"""
loggers.scene.debug("Exiting scene %r", self.scene_config.state)
await self.manager.history.clear()
await self._on_action(SceneAction.exit, **kwargs)
await self.manager.enter(None, _check_active=False, **kwargs)
async def back(self, **kwargs: Any) -> None:
"""
This method is used to go back to the previous scene.
:param kwargs: Keyword arguments that can be passed to the method.
:return: None
"""
loggers.scene.debug("Back to previous scene from scene %s", self.scene_config.state)
await self.leave(_with_history=False, **kwargs)
new_scene = await self.manager.history.rollback()
await self.manager.enter(new_scene, _check_active=False, **kwargs)
async def retake(self, **kwargs: Any) -> None:
"""
This method allows to re-enter the current scene.
:param kwargs: Additional keyword arguments to pass to the scene.
:return: None
"""
assert self.scene_config.state is not None, "Scene state is not specified"
await self.goto(self.scene_config.state, **kwargs)
async def goto(self, scene: Union[Type[Scene], str], **kwargs: Any) -> None:
"""
The `goto` method transitions to a new scene.
It first calls the `leave` method to perform any necessary cleanup
in the current scene, then calls the `enter` event to enter the specified scene.
:param scene: The scene to transition to. Can be either a `Scene` instance
or a string representing the scene.
:param kwargs: Additional keyword arguments to pass to the `enter`
method of the scene manager.
:return: None
"""
await self.leave(**kwargs)
await self.manager.enter(scene, _check_active=False, **kwargs)
async def _on_action(self, action: SceneAction, **kwargs: Any) -> bool:
if not self.scene:
raise SceneException("Scene is not initialized")
loggers.scene.debug("Call action %r in scene %r", action.name, self.scene_config.state)
action_config = self.scene_config.actions.get(action, {})
if not action_config:
loggers.scene.debug(
"Action %r not found in scene %r", action.name, self.scene_config.state
)
return False
event_type = self.update_type
if event_type not in action_config:
loggers.scene.debug(
"Action %r for event %r not found in scene %r",
action.name,
event_type,
self.scene_config.state,
)
return False
await action_config[event_type].call(self.scene, self.event, **{**self.data, **kwargs})
return True
async def set_data(self, data: Dict[str, Any]) -> None:
"""
Sets custom data in the current state.
:param data: A dictionary containing the custom data to be set in the current state.
:return: None
"""
await self.state.set_data(data=data)
async def get_data(self) -> Dict[str, Any]:
"""
This method returns the data stored in the current state.
:return: A dictionary containing the data stored in the scene state.
"""
return await self.state.get_data()
@overload
async def get_value(self, key: str) -> Optional[Any]:
"""
This method returns the value from key in the data of the current state.
:param key: The keyname of the item you want to return the value from.
:return: A dictionary containing the data stored in the scene state.
"""
pass
@overload
async def get_value(self, key: str, default: Any) -> Any:
"""
This method returns the value from key in the data of the current state.
:param key: The keyname of the item you want to return the value from.
:param default: Default value to return, if ``key`` was not found.
:return: A dictionary containing the data stored in the scene state.
"""
pass
async def get_value(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
return await self.state.get_value(key, default)
async def update_data(
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""
This method updates the data stored in the current state
:param data: Optional dictionary of data to update.
:param kwargs: Additional key-value pairs of data to update.
:return: Dictionary of updated data
"""
if data:
kwargs.update(data)
return await self.state.update_data(data=kwargs)
async def clear_data(self) -> None:
"""
Clears the data.
:return: None
"""
await self.set_data({})
class ScenesManager:
"""
The ScenesManager class is responsible for managing scenes in an application.
It provides methods for entering and exiting scenes, as well as retrieving the active scene.
"""
def __init__(
self,
registry: SceneRegistry,
update_type: str,
event: TelegramObject,
state: FSMContext,
data: Dict[str, Any],
) -> None:
self.registry = registry
self.update_type = update_type
self.event = event
self.state = state
self.data = data
self.history = HistoryManager(self.state)
async def _get_scene(self, scene_type: Optional[Union[Type[Scene], str]]) -> Scene:
scene_type = self.registry.get(scene_type)
return scene_type(
wizard=SceneWizard(
scene_config=scene_type.__scene_config__,
manager=self,
state=self.state,
update_type=self.update_type,
event=self.event,
data=self.data,
),
)
async def _get_active_scene(self) -> Optional[Scene]:
state = await self.state.get_state()
try:
return await self._get_scene(state)
except SceneException:
return None
async def enter(
self,
scene_type: Optional[Union[Type[Scene], str]],
_check_active: bool = True,
**kwargs: Any,
) -> None:
"""
Enters the specified scene.
:param scene_type: Optional Type[Scene] or str representing the scene type to enter.
:param _check_active: Optional bool indicating whether to check if
there is an active scene to exit before entering the new scene. Defaults to True.
:param kwargs: Additional keyword arguments to pass to the scene's wizard.enter() method.
:return: None
"""
if _check_active:
active_scene = await self._get_active_scene()
if active_scene is not None:
await active_scene.wizard.exit(**kwargs)
try:
scene = await self._get_scene(scene_type)
except SceneException:
if scene_type is not None:
raise
await self.state.set_state(None)
else:
await scene.wizard.enter(**kwargs)
async def close(self, **kwargs: Any) -> None:
"""
Close method is used to exit the currently active scene in the ScenesManager.
:param kwargs: Additional keyword arguments passed to the scene's exit method.
:return: None
"""
scene = await self._get_active_scene()
if not scene:
return
await scene.wizard.exit(**kwargs)
class SceneRegistry:
"""
A class that represents a registry for scenes in a Telegram bot.
"""
def __init__(self, router: Router, register_on_add: bool = True) -> None:
"""
Initialize a new instance of the SceneRegistry class.
:param router: The router instance used for scene registration.
:param register_on_add: Whether to register the scenes to the router when they are added.
"""
self.router = router
self.register_on_add = register_on_add
self._scenes: Dict[Optional[str], Type[Scene]] = {}
self._setup_middleware(router)
def _setup_middleware(self, router: Router) -> None:
if isinstance(router, Dispatcher):
# Small optimization for Dispatcher
# - we don't need to set up middleware for all observers
router.update.outer_middleware(self._update_middleware)
return
for observer in router.observers.values():
if observer.event_name in {"update", "error"}:
continue
observer.outer_middleware(self._middleware)
async def _update_middleware(
self,
handler: NextMiddlewareType[TelegramObject],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
assert isinstance(event, Update), "Event must be an Update instance"
data["scenes"] = ScenesManager(
registry=self,
update_type=event.event_type,
event=event.event,
state=data["state"],
data=data,
)
return await handler(event, data)
async def _middleware(
self,
handler: NextMiddlewareType[TelegramObject],
event: TelegramObject,
data: Dict[str, Any],
) -> Any:
update: Update = data["event_update"]
data["scenes"] = ScenesManager(
registry=self,
update_type=update.event_type,
event=event,
state=data["state"],
data=data,
)
return await handler(event, data)
def add(self, *scenes: Type[Scene], router: Optional[Router] = None) -> None:
"""
This method adds the specified scenes to the registry
and optionally registers it to the router.
If a scene with the same state already exists in the registry, a SceneException is raised.
.. warning::
If the router is not specified, the scenes will not be registered to the router.
You will need to include the scenes manually to the router or use the register method.
:param scenes: A variable length parameter that accepts one or more types of scenes.
These scenes are instances of the Scene class.
:param router: An optional parameter that specifies the router
to which the scenes should be added.
:return: None
"""
if not scenes:
raise ValueError("At least one scene must be specified")
for scene in scenes:
if scene.__scene_config__.state in self._scenes:
raise SceneException(
f"Scene with state {scene.__scene_config__.state!r} already exists"
)
self._scenes[scene.__scene_config__.state] = scene
if router:
router.include_router(scene.as_router())
elif self.register_on_add:
self.router.include_router(scene.as_router())
def register(self, *scenes: Type[Scene]) -> None:
"""
Registers one or more scenes to the SceneRegistry.
:param scenes: One or more scene classes to register.
:return: None
"""
self.add(*scenes, router=self.router)
def get(self, scene: Optional[Union[Type[Scene], str]]) -> Type[Scene]:
"""
This method returns the registered Scene object for the specified scene.
The scene parameter can be either a Scene object or a string representing
the name of the scene. If a Scene object is provided, the state attribute
of the SceneConfig object associated with the Scene object will be used as the scene name.
If None or an invalid type is provided, a SceneException will be raised.
If the specified scene is not registered in the SceneRegistry object,
a SceneException will be raised.
:param scene: A Scene object or a string representing the name of the scene.
:return: The registered Scene object corresponding to the given scene parameter.
"""
if inspect.isclass(scene) and issubclass(scene, Scene):
scene = scene.__scene_config__.state
if isinstance(scene, State):
scene = scene.state
if scene is not None and not isinstance(scene, str):
raise SceneException("Scene must be a subclass of Scene or a string")
try:
return self._scenes[scene]
except KeyError:
raise SceneException(f"Scene {scene!r} is not registered")
@dataclass
class After:
action: SceneAction
scene: Optional[Union[Type[Scene], str]] = None
@classmethod
def exit(cls) -> After:
return cls(action=SceneAction.exit)
@classmethod
def back(cls) -> After:
return cls(action=SceneAction.back)
@classmethod
def goto(cls, scene: Optional[Union[Type[Scene], str]]) -> After:
return cls(action=SceneAction.enter, scene=scene)
class ObserverMarker:
def __init__(self, name: str) -> None:
self.name = name
def __call__(
self,
*filters: CallbackType,
after: Optional[After] = None,
) -> ObserverDecorator:
return ObserverDecorator(
self.name,
filters,
after=after,
)
def enter(self, *filters: CallbackType) -> ObserverDecorator:
return ObserverDecorator(self.name, filters, action=SceneAction.enter)
def leave(self) -> ObserverDecorator:
return ObserverDecorator(self.name, (), action=SceneAction.leave)
def exit(self) -> ObserverDecorator:
return ObserverDecorator(self.name, (), action=SceneAction.exit)
def back(self) -> ObserverDecorator:
return ObserverDecorator(self.name, (), action=SceneAction.back)
class OnMarker:
"""
The `OnMarker` class is used as a marker class to define different
types of events in the Scenes.
Attributes:
- :code:`message`: Event marker for handling `Message` events.
- :code:`edited_message`: Event marker for handling edited `Message` events.
- :code:`channel_post`: Event marker for handling channel `Post` events.
- :code:`edited_channel_post`: Event marker for handling edited channel `Post` events.
- :code:`inline_query`: Event marker for handling `InlineQuery` events.
- :code:`chosen_inline_result`: Event marker for handling chosen `InlineResult` events.
- :code:`callback_query`: Event marker for handling `CallbackQuery` events.
- :code:`shipping_query`: Event marker for handling `ShippingQuery` events.
- :code:`pre_checkout_query`: Event marker for handling `PreCheckoutQuery` events.
- :code:`poll`: Event marker for handling `Poll` events.
- :code:`poll_answer`: Event marker for handling `PollAnswer` events.
- :code:`my_chat_member`: Event marker for handling my chat `Member` events.
- :code:`chat_member`: Event marker for handling chat `Member` events.
- :code:`chat_join_request`: Event marker for handling chat `JoinRequest` events.
- :code:`error`: Event marker for handling `Error` events.
.. note::
This is a marker class and does not contain any methods or implementation logic.
"""
message = ObserverMarker("message")
edited_message = ObserverMarker("edited_message")
channel_post = ObserverMarker("channel_post")
edited_channel_post = ObserverMarker("edited_channel_post")
inline_query = ObserverMarker("inline_query")
chosen_inline_result = ObserverMarker("chosen_inline_result")
callback_query = ObserverMarker("callback_query")
shipping_query = ObserverMarker("shipping_query")
pre_checkout_query = ObserverMarker("pre_checkout_query")
poll = ObserverMarker("poll")
poll_answer = ObserverMarker("poll_answer")
my_chat_member = ObserverMarker("my_chat_member")
chat_member = ObserverMarker("chat_member")
chat_join_request = ObserverMarker("chat_join_request")
on = OnMarker()

View File

@@ -0,0 +1,172 @@
import inspect
from typing import Any, Iterator, Optional, Tuple, Type, no_type_check
from aiogram.types import TelegramObject
class State:
"""
State object
"""
def __init__(self, state: Optional[str] = None, group_name: Optional[str] = None) -> None:
self._state = state
self._group_name = group_name
self._group: Optional[Type[StatesGroup]] = None
@property
def group(self) -> "Type[StatesGroup]":
if not self._group:
raise RuntimeError("This state is not in any group.")
return self._group
@property
def state(self) -> Optional[str]:
if self._state is None or self._state == "*":
return self._state
if self._group_name is None and self._group:
group = self._group.__full_group_name__
elif self._group_name:
group = self._group_name
else:
group = "@"
return f"{group}:{self._state}"
def set_parent(self, group: "Type[StatesGroup]") -> None:
if not issubclass(group, StatesGroup):
raise ValueError("Group must be subclass of StatesGroup")
self._group = group
def __set_name__(self, owner: "Type[StatesGroup]", name: str) -> None:
if self._state is None:
self._state = name
self.set_parent(owner)
def __str__(self) -> str:
return f"<State '{self.state or ''}'>"
__repr__ = __str__
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
if self.state == "*":
return True
return raw_state == self.state
def __eq__(self, other: Any) -> bool:
if isinstance(other, self.__class__):
return self.state == other.state
if isinstance(other, str):
return self.state == other
return NotImplemented
def __hash__(self) -> int:
return hash(self.state)
class StatesGroupMeta(type):
__parent__: "Optional[Type[StatesGroup]]"
__childs__: "Tuple[Type[StatesGroup], ...]"
__states__: Tuple[State, ...]
__state_names__: Tuple[str, ...]
__all_childs__: Tuple[Type["StatesGroup"], ...]
__all_states__: Tuple[State, ...]
__all_states_names__: Tuple[str, ...]
@no_type_check
def __new__(mcs, name, bases, namespace, **kwargs):
cls = super().__new__(mcs, name, bases, namespace)
states = []
childs = []
for name, arg in namespace.items():
if isinstance(arg, State):
states.append(arg)
elif inspect.isclass(arg) and issubclass(arg, StatesGroup):
child = cls._prepare_child(arg)
childs.append(child)
cls.__parent__ = None
cls.__childs__ = tuple(childs)
cls.__states__ = tuple(states)
cls.__state_names__ = tuple(state.state for state in states)
cls.__all_childs__ = cls._get_all_childs()
cls.__all_states__ = cls._get_all_states()
# In order to ensure performance, we calculate this parameter
# in advance already during the production of the class.
# Depending on the relationship, it should be recalculated
cls.__all_states_names__ = cls._get_all_states_names()
return cls
@property
def __full_group_name__(cls) -> str:
if cls.__parent__:
return ".".join((cls.__parent__.__full_group_name__, cls.__name__))
return cls.__name__
def _prepare_child(cls, child: Type["StatesGroup"]) -> Type["StatesGroup"]:
"""Prepare child.
While adding `cls` for its children, we also need to recalculate
the parameter `__all_states_names__` for each child
`StatesGroup`. Since the child class appears before the
parent, at the time of adding the parent, the child's
`__all_states_names__` is already recorded without taking into
account the name of current parent.
"""
child.__parent__ = cls # type: ignore[assignment]
child.__all_states_names__ = child._get_all_states_names()
return child
def _get_all_childs(cls) -> Tuple[Type["StatesGroup"], ...]:
result = cls.__childs__
for child in cls.__childs__:
result += child.__childs__
return result
def _get_all_states(cls) -> Tuple[State, ...]:
result = cls.__states__
for group in cls.__childs__:
result += group.__all_states__
return result
def _get_all_states_names(cls) -> Tuple[str, ...]:
return tuple(state.state for state in cls.__all_states__ if state.state)
def __contains__(cls, item: Any) -> bool:
if isinstance(item, str):
return item in cls.__all_states_names__
if isinstance(item, State):
return item in cls.__all_states__
if isinstance(item, StatesGroupMeta):
return item in cls.__all_childs__
return False
def __str__(self) -> str:
return f"<StatesGroup '{self.__full_group_name__}'>"
def __iter__(self) -> Iterator[State]:
return iter(self.__all_states__)
class StatesGroup(metaclass=StatesGroupMeta):
@classmethod
def get_root(cls) -> Type["StatesGroup"]:
if cls.__parent__ is None:
return cls
return cls.__parent__.get_root()
def __call__(self, event: TelegramObject, raw_state: Optional[str] = None) -> bool:
return raw_state in type(self).__all_states_names__
def __str__(self) -> str:
return f"StatesGroup {type(self).__full_group_name__}"
default_state = State()
any_state = State(state="*")

View File

@@ -0,0 +1,212 @@
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Literal, Optional, Union, overload
from aiogram.fsm.state import State
StateType = Optional[Union[str, State]]
DEFAULT_DESTINY = "default"
@dataclass(frozen=True)
class StorageKey:
bot_id: int
chat_id: int
user_id: int
thread_id: Optional[int] = None
business_connection_id: Optional[str] = None
destiny: str = DEFAULT_DESTINY
class KeyBuilder(ABC):
"""Base class for key builder."""
@abstractmethod
def build(
self,
key: StorageKey,
part: Optional[Literal["data", "state", "lock"]] = None,
) -> str:
"""
Build key to be used in storage's db queries
:param key: contextual key
:param part: part of the record
:return: key to be used in storage's db queries
"""
pass
class DefaultKeyBuilder(KeyBuilder):
"""
Simple key builder with default prefix.
Generates a colon-joined string with prefix, chat_id, user_id,
optional bot_id, business_connection_id, destiny and field.
Format:
:code:`<prefix>:<bot_id?>:<business_connection_id?>:<chat_id>:<user_id>:<destiny?>:<field?>`
"""
def __init__(
self,
*,
prefix: str = "fsm",
separator: str = ":",
with_bot_id: bool = False,
with_business_connection_id: bool = False,
with_destiny: bool = False,
) -> None:
"""
:param prefix: prefix for all records
:param separator: separator
:param with_bot_id: include Bot id in the key
:param with_business_connection_id: include business connection id
:param with_destiny: include destiny key
"""
self.prefix = prefix
self.separator = separator
self.with_bot_id = with_bot_id
self.with_business_connection_id = with_business_connection_id
self.with_destiny = with_destiny
def build(
self,
key: StorageKey,
part: Optional[Literal["data", "state", "lock"]] = None,
) -> str:
parts = [self.prefix]
if self.with_bot_id:
parts.append(str(key.bot_id))
if self.with_business_connection_id and key.business_connection_id:
parts.append(str(key.business_connection_id))
parts.append(str(key.chat_id))
if key.thread_id:
parts.append(str(key.thread_id))
parts.append(str(key.user_id))
if self.with_destiny:
parts.append(key.destiny)
elif key.destiny != DEFAULT_DESTINY:
error_message = (
"Default key builder is not configured to use key destiny other than the default."
"\n\nProbably, you should set `with_destiny=True` in for DefaultKeyBuilder."
)
raise ValueError(error_message)
if part:
parts.append(part)
return self.separator.join(parts)
class BaseStorage(ABC):
"""
Base class for all FSM storages
"""
@abstractmethod
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
"""
Set state for specified key
:param key: storage key
:param state: new state
"""
pass
@abstractmethod
async def get_state(self, key: StorageKey) -> Optional[str]:
"""
Get key state
:param key: storage key
:return: current state
"""
pass
@abstractmethod
async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
"""
Write data (replace)
:param key: storage key
:param data: new data
"""
pass
@abstractmethod
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
"""
Get current data for key
:param key: storage key
:return: current data
"""
pass
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Optional[Any]:
"""
Get single value from data by key
:param storage_key: storage key
:param dict_key: value key
:return: value stored in key of dict or ``None``
"""
pass
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any:
"""
Get single value from data by key
:param storage_key: storage key
:param dict_key: value key
:param default: default value to return
:return: value stored in key of dict or default
"""
pass
async def get_value(
self, storage_key: StorageKey, dict_key: str, default: Optional[Any] = None
) -> Optional[Any]:
data = await self.get_data(storage_key)
return data.get(dict_key, default)
async def update_data(self, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Update date in the storage for key (like dict.update)
:param key: storage key
:param data: partial data
:return: new data
"""
current_data = await self.get_data(key=key)
current_data.update(data)
await self.set_data(key=key, data=current_data)
return current_data.copy()
@abstractmethod
async def close(self) -> None: # pragma: no cover
"""
Close storage (database connection, file or etc.)
"""
pass
class BaseEventIsolation(ABC):
@abstractmethod
@asynccontextmanager
async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:
"""
Isolate events with lock.
Will be used as context manager
:param key: storage key
:return: An async generator
"""
yield None
@abstractmethod
async def close(self) -> None:
pass

View File

@@ -0,0 +1,87 @@
from asyncio import Lock
from collections import defaultdict
from contextlib import asynccontextmanager
from copy import copy
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, DefaultDict, Dict, Hashable, Optional, overload
from aiogram.fsm.state import State
from aiogram.fsm.storage.base import (
BaseEventIsolation,
BaseStorage,
StateType,
StorageKey,
)
@dataclass
class MemoryStorageRecord:
data: Dict[str, Any] = field(default_factory=dict)
state: Optional[str] = None
class MemoryStorage(BaseStorage):
"""
Default FSM storage, stores all data in :class:`dict` and loss everything on shutdown
.. warning::
Is not recommended using in production in due to you will lose all data
when your bot restarts
"""
def __init__(self) -> None:
self.storage: DefaultDict[StorageKey, MemoryStorageRecord] = defaultdict(
MemoryStorageRecord
)
async def close(self) -> None:
pass
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
self.storage[key].state = state.state if isinstance(state, State) else state
async def get_state(self, key: StorageKey) -> Optional[str]:
return self.storage[key].state
async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
self.storage[key].data = data.copy()
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
return self.storage[key].data.copy()
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str) -> Optional[Any]: ...
@overload
async def get_value(self, storage_key: StorageKey, dict_key: str, default: Any) -> Any: ...
async def get_value(
self, storage_key: StorageKey, dict_key: str, default: Optional[Any] = None
) -> Optional[Any]:
data = self.storage[storage_key].data
return copy(data.get(dict_key, default))
class DisabledEventIsolation(BaseEventIsolation):
@asynccontextmanager
async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:
yield
async def close(self) -> None:
pass
class SimpleEventIsolation(BaseEventIsolation):
def __init__(self) -> None:
# TODO: Unused locks cleaner is needed
self._locks: DefaultDict[Hashable, Lock] = defaultdict(Lock)
@asynccontextmanager
async def lock(self, key: StorageKey) -> AsyncGenerator[None, None]:
lock = self._locks[key]
async with lock:
yield
async def close(self) -> None:
self._locks.clear()

View File

@@ -0,0 +1,130 @@
from typing import Any, Dict, Optional, cast
from motor.motor_asyncio import AsyncIOMotorClient
from aiogram.fsm.state import State
from aiogram.fsm.storage.base import (
BaseStorage,
DefaultKeyBuilder,
KeyBuilder,
StateType,
StorageKey,
)
class MongoStorage(BaseStorage):
"""
MongoDB storage required :code:`motor` package installed (:code:`pip install motor`)
"""
def __init__(
self,
client: AsyncIOMotorClient,
key_builder: Optional[KeyBuilder] = None,
db_name: str = "aiogram_fsm",
collection_name: str = "states_and_data",
) -> None:
"""
:param client: Instance of AsyncIOMotorClient
:param key_builder: builder that helps to convert contextual key to string
:param db_name: name of the MongoDB database for FSM
:param collection_name: name of the collection for storing FSM states and data
"""
if key_builder is None:
key_builder = DefaultKeyBuilder()
self._client = client
self._database = self._client[db_name]
self._collection = self._database[collection_name]
self._key_builder = key_builder
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> "MongoStorage":
"""
Create an instance of :class:`MongoStorage` with specifying the connection string
:param url: for example :code:`mongodb://user:password@host:port`
:param connection_kwargs: see :code:`motor` docs
:param kwargs: arguments to be passed to :class:`MongoStorage`
:return: an instance of :class:`MongoStorage`
"""
if connection_kwargs is None:
connection_kwargs = {}
client = AsyncIOMotorClient(url, **connection_kwargs)
return cls(client=client, **kwargs)
async def close(self) -> None:
"""Cleanup client resources and disconnect from MongoDB."""
self._client.close()
def resolve_state(self, value: StateType) -> Optional[str]:
if value is None:
return None
if isinstance(value, State):
return value.state
return str(value)
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
document_id = self._key_builder.build(key)
if state is None:
updated = await self._collection.find_one_and_update(
filter={"_id": document_id},
update={"$unset": {"state": 1}},
projection={"_id": 0},
return_document=True,
)
if updated == {}:
await self._collection.delete_one({"_id": document_id})
else:
await self._collection.update_one(
filter={"_id": document_id},
update={"$set": {"state": self.resolve_state(state)}},
upsert=True,
)
async def get_state(self, key: StorageKey) -> Optional[str]:
document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id})
if document is None:
return None
return document.get("state")
async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
document_id = self._key_builder.build(key)
if not data:
updated = await self._collection.find_one_and_update(
filter={"_id": document_id},
update={"$unset": {"data": 1}},
projection={"_id": 0},
return_document=True,
)
if updated == {}:
await self._collection.delete_one({"_id": document_id})
else:
await self._collection.update_one(
filter={"_id": document_id},
update={"$set": {"data": data}},
upsert=True,
)
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
document_id = self._key_builder.build(key)
document = await self._collection.find_one({"_id": document_id})
if document is None or not document.get("data"):
return {}
return cast(Dict[str, Any], document["data"])
async def update_data(self, key: StorageKey, data: Dict[str, Any]) -> Dict[str, Any]:
document_id = self._key_builder.build(key)
update_with = {f"data.{key}": value for key, value in data.items()}
update_result = await self._collection.find_one_and_update(
filter={"_id": document_id},
update={"$set": update_with},
upsert=True,
return_document=True,
projection={"_id": 0},
)
if not update_result:
await self._collection.delete_one({"_id": document_id})
return update_result.get("data", {})

View File

@@ -0,0 +1,169 @@
import json
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Callable, Dict, Optional, cast
from redis.asyncio.client import Redis
from redis.asyncio.connection import ConnectionPool
from redis.asyncio.lock import Lock
from redis.typing import ExpiryT
from aiogram.fsm.state import State
from aiogram.fsm.storage.base import (
BaseEventIsolation,
BaseStorage,
DefaultKeyBuilder,
KeyBuilder,
StateType,
StorageKey,
)
DEFAULT_REDIS_LOCK_KWARGS = {"timeout": 60}
_JsonLoads = Callable[..., Any]
_JsonDumps = Callable[..., str]
class RedisStorage(BaseStorage):
"""
Redis storage required :code:`redis` package installed (:code:`pip install redis`)
"""
def __init__(
self,
redis: Redis,
key_builder: Optional[KeyBuilder] = None,
state_ttl: Optional[ExpiryT] = None,
data_ttl: Optional[ExpiryT] = None,
json_loads: _JsonLoads = json.loads,
json_dumps: _JsonDumps = json.dumps,
) -> None:
"""
:param redis: Instance of Redis connection
:param key_builder: builder that helps to convert contextual key to string
:param state_ttl: TTL for state records
:param data_ttl: TTL for data records
"""
if key_builder is None:
key_builder = DefaultKeyBuilder()
self.redis = redis
self.key_builder = key_builder
self.state_ttl = state_ttl
self.data_ttl = data_ttl
self.json_loads = json_loads
self.json_dumps = json_dumps
@classmethod
def from_url(
cls, url: str, connection_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any
) -> "RedisStorage":
"""
Create an instance of :class:`RedisStorage` with specifying the connection string
:param url: for example :code:`redis://user:password@host:port/db`
:param connection_kwargs: see :code:`redis` docs
:param kwargs: arguments to be passed to :class:`RedisStorage`
:return: an instance of :class:`RedisStorage`
"""
if connection_kwargs is None:
connection_kwargs = {}
pool = ConnectionPool.from_url(url, **connection_kwargs)
redis = Redis(connection_pool=pool)
return cls(redis=redis, **kwargs)
def create_isolation(self, **kwargs: Any) -> "RedisEventIsolation":
return RedisEventIsolation(redis=self.redis, key_builder=self.key_builder, **kwargs)
async def close(self) -> None:
await self.redis.aclose(close_connection_pool=True)
async def set_state(
self,
key: StorageKey,
state: StateType = None,
) -> None:
redis_key = self.key_builder.build(key, "state")
if state is None:
await self.redis.delete(redis_key)
else:
await self.redis.set(
redis_key,
cast(str, state.state if isinstance(state, State) else state),
ex=self.state_ttl,
)
async def get_state(
self,
key: StorageKey,
) -> Optional[str]:
redis_key = self.key_builder.build(key, "state")
value = await self.redis.get(redis_key)
if isinstance(value, bytes):
return value.decode("utf-8")
return cast(Optional[str], value)
async def set_data(
self,
key: StorageKey,
data: Dict[str, Any],
) -> None:
redis_key = self.key_builder.build(key, "data")
if not data:
await self.redis.delete(redis_key)
return
await self.redis.set(
redis_key,
self.json_dumps(data),
ex=self.data_ttl,
)
async def get_data(
self,
key: StorageKey,
) -> Dict[str, Any]:
redis_key = self.key_builder.build(key, "data")
value = await self.redis.get(redis_key)
if value is None:
return {}
if isinstance(value, bytes):
value = value.decode("utf-8")
return cast(Dict[str, Any], self.json_loads(value))
class RedisEventIsolation(BaseEventIsolation):
def __init__(
self,
redis: Redis,
key_builder: Optional[KeyBuilder] = None,
lock_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if key_builder is None:
key_builder = DefaultKeyBuilder()
if lock_kwargs is None:
lock_kwargs = DEFAULT_REDIS_LOCK_KWARGS
self.redis = redis
self.key_builder = key_builder
self.lock_kwargs = lock_kwargs
@classmethod
def from_url(
cls,
url: str,
connection_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> "RedisEventIsolation":
if connection_kwargs is None:
connection_kwargs = {}
pool = ConnectionPool.from_url(url, **connection_kwargs)
redis = Redis(connection_pool=pool)
return cls(redis=redis, **kwargs)
@asynccontextmanager
async def lock(
self,
key: StorageKey,
) -> AsyncGenerator[None, None]:
redis_key = self.key_builder.build(key, "lock")
async with self.redis.lock(name=redis_key, **self.lock_kwargs, lock_class=Lock):
yield None
async def close(self) -> None:
pass

View File

@@ -0,0 +1,37 @@
from enum import Enum, auto
from typing import Optional, Tuple
class FSMStrategy(Enum):
"""
FSM strategy for storage key generation.
"""
USER_IN_CHAT = auto()
"""State will be stored for each user in chat."""
CHAT = auto()
"""State will be stored for each chat globally without separating by users."""
GLOBAL_USER = auto()
"""State will be stored globally for each user globally."""
USER_IN_TOPIC = auto()
"""State will be stored for each user in chat and topic."""
CHAT_TOPIC = auto()
"""State will be stored for each chat and topic, but not separated by users."""
def apply_strategy(
strategy: FSMStrategy,
chat_id: int,
user_id: int,
thread_id: Optional[int] = None,
) -> Tuple[int, int, Optional[int]]:
if strategy == FSMStrategy.CHAT:
return chat_id, chat_id, None
if strategy == FSMStrategy.GLOBAL_USER:
return user_id, user_id, None
if strategy == FSMStrategy.USER_IN_TOPIC:
return chat_id, user_id, thread_id
if strategy == FSMStrategy.CHAT_TOPIC:
return chat_id, chat_id, thread_id
return chat_id, user_id, None