From e2ce68bae8c1362fd01f898db14f747eba4c22af Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 10 Jun 2023 21:16:49 +0400 Subject: [PATCH] Added context methods to reject and requeue. Signed-off-by: Pavel Kirilin --- taskiq/__init__.py | 2 - taskiq/acks.py | 9 ++-- taskiq/context.py | 27 +++++++++++ taskiq/exceptions.py | 4 +- taskiq/receiver/receiver.py | 60 +++++++++++++++---------- tests/receiver/test_receiver.py | 79 +-------------------------------- tests/test_requeue.py | 22 +++++++++ 7 files changed, 93 insertions(+), 110 deletions(-) create mode 100644 tests/test_requeue.py diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 7557e0b1..8aa986ca 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -14,7 +14,6 @@ from taskiq.events import TaskiqEvents from taskiq.exceptions import ( NoResultError, - RejectError, ResultGetError, ResultIsReadyError, SecurityError, @@ -45,7 +44,6 @@ "Context", "AsyncBroker", "TaskiqError", - "RejectError", "TaskiqState", "TaskiqResult", "ZeroMQBroker", diff --git a/taskiq/acks.py b/taskiq/acks.py index 2c89b3c7..fddb29bf 100644 --- a/taskiq/acks.py +++ b/taskiq/acks.py @@ -1,9 +1,9 @@ -import dataclasses -from typing import Awaitable, Callable, Optional, Union +from typing import Awaitable, Callable, Union +from pydantic import BaseModel -@dataclasses.dataclass -class AckableMessage: + +class AckableMessage(BaseModel): """ Message that can be acknowledged. @@ -18,4 +18,3 @@ class AckableMessage: data: bytes ack: Callable[[], Union[None, Awaitable[None]]] - reject: Optional[Callable[[], Union[None, Awaitable[None]]]] = None diff --git a/taskiq/context.py b/taskiq/context.py index bfc59a0b..cb1a2d85 100644 --- a/taskiq/context.py +++ b/taskiq/context.py @@ -1,6 +1,8 @@ +from copy import copy from typing import TYPE_CHECKING from taskiq.abc.broker import AsyncBroker +from taskiq.exceptions import NoResultError, TaskRejectedError from taskiq.message import TaskiqMessage if TYPE_CHECKING: # pragma: no cover @@ -15,3 +17,28 @@ def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None: self.broker = broker self.state: "TaskiqState" = None # type: ignore self.state = broker.state + + async def requeue(self) -> None: + """ + Requeue task. + + This fuction creates a task with + the same message and sends it using + current broker. + + :raises NoResultError: to not store result for current task. + """ + message = copy(self.message) + requeue_count = int(message.labels.get("X-Taskiq-requeue", 0)) + requeue_count += 1 + message.labels["X-Taskiq-requeue"] = str(requeue_count) + await self.broker.kick(self.broker.formatter.dumps(self.message)) + raise NoResultError() + + def reject(self) -> None: + """ + Raise reject error. + + :raises TaskRejectedError: to reject current message. + """ + raise TaskRejectedError() diff --git a/taskiq/exceptions.py b/taskiq/exceptions.py index ea45647d..a4eec6ab 100644 --- a/taskiq/exceptions.py +++ b/taskiq/exceptions.py @@ -38,5 +38,5 @@ class NoResultError(TaskiqError): """Error if user does not want to set result.""" -class RejectError(TaskiqError): - """Error is thrown if message should be rejected.""" +class TaskRejectedError(TaskiqError): + """Task was rejected.""" diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 59b4a63f..4a94af05 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -3,7 +3,7 @@ from concurrent.futures import Executor from logging import getLogger from time import time -from typing import Any, Callable, Dict, Optional, Set, Union, get_type_hints +from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints import anyio from taskiq_dependencies import DependencyGraph @@ -11,7 +11,7 @@ from taskiq.abc.broker import AckableMessage, AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware from taskiq.context import Context -from taskiq.exceptions import NoResultError, RejectError +from taskiq.exceptions import NoResultError from taskiq.message import TaskiqMessage from taskiq.receiver.params_parser import parse_params from taskiq.result import TaskiqResult @@ -22,7 +22,11 @@ QUEUE_DONE = b"-1" -def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any: +def _run_sync( + target: Callable[..., Any], + args: List[Any], + kwargs: Dict[str, Any], +) -> Any: """ Runs function synchronously. @@ -30,10 +34,11 @@ def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any: we cannot pass kwargs in loop.run_with_executor(). :param target: function to execute. - :param message: received message from broker. + :param args: list of function's args. + :param kwargs: dict of function's kwargs. :return: result of function's execution. """ - return target(*message.args, **message.kwargs) + return target(*args, **kwargs) class Receiver: @@ -124,20 +129,16 @@ async def callback( # noqa: C901, WPS213, WPS217 taskiq_msg.task_name, taskiq_msg.task_id, ) + + # If broker has an ability to ack messages. + if isinstance(message, AckableMessage): + await maybe_awaitable(message.ack()) + result = await self.run_task( target=self.broker.available_tasks[taskiq_msg.task_name].original_func, message=taskiq_msg, ) - # If broker has an ability to ack or reject messages. - if isinstance(message, AckableMessage): - # If we received an error for negative acknowledgement. - if message.reject is not None and isinstance(result.error, RejectError): - await maybe_awaitable(message.reject()) - # Otherwise we positively acknowledge the message. - else: - await maybe_awaitable(message.ack()) - for middleware in self.broker.middlewares: if middleware.__class__.post_execute != TaskiqMiddleware.post_execute: await maybe_awaitable(middleware.post_execute(taskiq_msg, result)) @@ -182,7 +183,7 @@ async def run_task( # noqa: C901, WPS210 """ loop = asyncio.get_running_loop() returned = None - found_exception = None + found_exception: "Optional[BaseException]" = None signature = None if self.validate_params: signature = self.task_signatures.get(message.task_name) @@ -190,6 +191,10 @@ async def run_task( # noqa: C901, WPS210 parse_params(signature, self.task_hints.get(message.task_name) or {}, message) dep_ctx = None + # Kwargs are defined in another variable, + # because we want to update them with + # kwargs resolved by dependency injector. + kwargs = {} if dependency_graph: # Create a context for dependency resolving. broker_ctx = self.broker.custom_dependency_context @@ -201,25 +206,34 @@ async def run_task( # noqa: C901, WPS210 ) dep_ctx = dependency_graph.async_ctx(broker_ctx) # Resolve all function's dependencies. - dep_kwargs = await dep_ctx.resolve_kwargs() - for key, val in dep_kwargs.items(): - if key not in message.kwargs: - message.kwargs[key] = val + kwargs = await dep_ctx.resolve_kwargs() + + # We udpate kwargs with kwargs from network. + kwargs.update(message.kwargs) + # Start a timer. start_time = time() try: - # If the function is a coroutine we await it. + # If the function is a coroutine, we await it. if asyncio.iscoroutinefunction(target): - returned = await target(*message.args, **message.kwargs) + returned = await target(*message.args, **kwargs) else: - # If this is a synchronous function we + # If this is a synchronous function, we # run it in executor. returned = await loop.run_in_executor( self.executor, _run_sync, target, - message, + message.args, + kwargs, ) + except NoResultError as no_res_exc: + found_exception = no_res_exc + logger.warning( + "Task %s with id %s skipped setting result.", + message.task_name, + message.task_id, + ) except BaseException as exc: # noqa: WPS424 found_exception = exc logger.error( diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 07c5e163..ebba1ca9 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -8,7 +8,7 @@ from taskiq.abc.broker import AckableMessage, AsyncBroker from taskiq.abc.middleware import TaskiqMiddleware from taskiq.brokers.inmemory_broker import InMemoryBroker -from taskiq.exceptions import NoResultError, RejectError, TaskiqResultTimeoutError +from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError from taskiq.message import TaskiqMessage from taskiq.receiver import Receiver from taskiq.result import TaskiqResult @@ -260,83 +260,6 @@ async def ack_callback() -> None: assert acked -@pytest.mark.anyio -async def test_callback_success_reject() -> None: - """ - Test that if reject error is thrown, - broker would reject a message. - """ - broker = InMemoryBroker() - rejected = False - - @broker.task - async def my_task() -> None: - raise RejectError() - - def reject_callback() -> None: - nonlocal rejected - rejected = True - - receiver = get_receiver(broker) - - broker_message = broker.formatter.dumps( - TaskiqMessage( - task_id="task_id", - task_name=my_task.task_name, - labels={}, - args=[], - kwargs={}, - ), - ) - - await receiver.callback( - AckableMessage( - data=broker_message.message, - ack=lambda: None, - reject=reject_callback, - ), - ) - assert rejected - - -@pytest.mark.anyio -async def test_callback_no_reject_func() -> None: - """ - Test that if broker doesn't support rejects, - it acks message instead. - """ - broker = InMemoryBroker() - acked = False - - @broker.task - async def my_task() -> None: - raise RejectError() - - def ack_callback() -> None: - nonlocal acked - acked = True - - receiver = get_receiver(broker) - - broker_message = broker.formatter.dumps( - TaskiqMessage( - task_id="task_id", - task_name=my_task.task_name, - labels={}, - args=[], - kwargs={}, - ), - ) - - await receiver.callback( - AckableMessage( - data=broker_message.message, - ack=ack_callback, - ), - ) - assert acked - - @pytest.mark.anyio async def test_callback_wrong_format() -> None: """Test that wrong format of a message won't thow an error.""" diff --git a/tests/test_requeue.py b/tests/test_requeue.py new file mode 100644 index 00000000..f5089a8f --- /dev/null +++ b/tests/test_requeue.py @@ -0,0 +1,22 @@ +import pytest + +from taskiq import Context, InMemoryBroker, TaskiqDepends + + +@pytest.mark.anyio +async def test_requeue() -> None: + broker = InMemoryBroker() + + runs_count = 0 + + @broker.task + async def task(context: Context = TaskiqDepends()) -> None: + nonlocal runs_count + runs_count += 1 + if runs_count < 2: + await context.requeue() + + kicked = await task.kiq() + await kicked.wait_result() + + assert runs_count == 2