Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions taskiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from taskiq.events import TaskiqEvents
from taskiq.exceptions import (
NoResultError,
RejectError,
ResultGetError,
ResultIsReadyError,
SecurityError,
Expand Down Expand Up @@ -45,7 +44,6 @@
"Context",
"AsyncBroker",
"TaskiqError",
"RejectError",
"TaskiqState",
"TaskiqResult",
"ZeroMQBroker",
Expand Down
9 changes: 4 additions & 5 deletions taskiq/acks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
from typing import Awaitable, Callable, Optional, Union
from typing import Awaitable, Callable, Union

from pydantic import BaseModel

@dataclasses.dataclass
class AckableMessage:

class AckableMessage(BaseModel):
"""
Message that can be acknowledged.

Expand All @@ -18,4 +18,3 @@ class AckableMessage:

data: bytes
ack: Callable[[], Union[None, Awaitable[None]]]
reject: Optional[Callable[[], Union[None, Awaitable[None]]]] = None
27 changes: 27 additions & 0 deletions taskiq/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from copy import copy
from typing import TYPE_CHECKING

from taskiq.abc.broker import AsyncBroker
from taskiq.exceptions import NoResultError, TaskRejectedError
from taskiq.message import TaskiqMessage

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -15,3 +17,28 @@ def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None:
self.broker = broker
self.state: "TaskiqState" = None # type: ignore
self.state = broker.state

async def requeue(self) -> None:
"""
Requeue task.

This fuction creates a task with
the same message and sends it using
current broker.

:raises NoResultError: to not store result for current task.
"""
message = copy(self.message)
requeue_count = int(message.labels.get("X-Taskiq-requeue", 0))
requeue_count += 1
message.labels["X-Taskiq-requeue"] = str(requeue_count)
await self.broker.kick(self.broker.formatter.dumps(self.message))
raise NoResultError()

def reject(self) -> None:
"""
Raise reject error.

:raises TaskRejectedError: to reject current message.
"""
raise TaskRejectedError()
4 changes: 2 additions & 2 deletions taskiq/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ class NoResultError(TaskiqError):
"""Error if user does not want to set result."""


class RejectError(TaskiqError):
"""Error is thrown if message should be rejected."""
class TaskRejectedError(TaskiqError):
"""Task was rejected."""
60 changes: 37 additions & 23 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from concurrent.futures import Executor
from logging import getLogger
from time import time
from typing import Any, Callable, Dict, Optional, Set, Union, get_type_hints
from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints

import anyio
from taskiq_dependencies import DependencyGraph

from taskiq.abc.broker import AckableMessage, AsyncBroker
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.context import Context
from taskiq.exceptions import NoResultError, RejectError
from taskiq.exceptions import NoResultError
from taskiq.message import TaskiqMessage
from taskiq.receiver.params_parser import parse_params
from taskiq.result import TaskiqResult
Expand All @@ -22,18 +22,23 @@
QUEUE_DONE = b"-1"


def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
def _run_sync(
target: Callable[..., Any],
args: List[Any],
kwargs: Dict[str, Any],
) -> Any:
"""
Runs function synchronously.

We use this function, because
we cannot pass kwargs in loop.run_with_executor().

:param target: function to execute.
:param message: received message from broker.
:param args: list of function's args.
:param kwargs: dict of function's kwargs.
:return: result of function's execution.
"""
return target(*message.args, **message.kwargs)
return target(*args, **kwargs)


class Receiver:
Expand Down Expand Up @@ -124,20 +129,16 @@ async def callback( # noqa: C901, WPS213, WPS217
taskiq_msg.task_name,
taskiq_msg.task_id,
)

# If broker has an ability to ack messages.
if isinstance(message, AckableMessage):
await maybe_awaitable(message.ack())

result = await self.run_task(
target=self.broker.available_tasks[taskiq_msg.task_name].original_func,
message=taskiq_msg,
)

# If broker has an ability to ack or reject messages.
if isinstance(message, AckableMessage):
# If we received an error for negative acknowledgement.
if message.reject is not None and isinstance(result.error, RejectError):
await maybe_awaitable(message.reject())
# Otherwise we positively acknowledge the message.
else:
await maybe_awaitable(message.ack())

for middleware in self.broker.middlewares:
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
Expand Down Expand Up @@ -182,14 +183,18 @@ async def run_task( # noqa: C901, WPS210
"""
loop = asyncio.get_running_loop()
returned = None
found_exception = None
found_exception: "Optional[BaseException]" = None
signature = None
if self.validate_params:
signature = self.task_signatures.get(message.task_name)
dependency_graph = self.dependency_graphs.get(message.task_name)
parse_params(signature, self.task_hints.get(message.task_name) or {}, message)

dep_ctx = None
# Kwargs are defined in another variable,
# because we want to update them with
# kwargs resolved by dependency injector.
kwargs = {}
if dependency_graph:
# Create a context for dependency resolving.
broker_ctx = self.broker.custom_dependency_context
Expand All @@ -201,25 +206,34 @@ async def run_task( # noqa: C901, WPS210
)
dep_ctx = dependency_graph.async_ctx(broker_ctx)
# Resolve all function's dependencies.
dep_kwargs = await dep_ctx.resolve_kwargs()
for key, val in dep_kwargs.items():
if key not in message.kwargs:
message.kwargs[key] = val
kwargs = await dep_ctx.resolve_kwargs()

# We udpate kwargs with kwargs from network.
kwargs.update(message.kwargs)

# Start a timer.
start_time = time()
try:
# If the function is a coroutine we await it.
# If the function is a coroutine, we await it.
if asyncio.iscoroutinefunction(target):
returned = await target(*message.args, **message.kwargs)
returned = await target(*message.args, **kwargs)
else:
# If this is a synchronous function we
# If this is a synchronous function, we
# run it in executor.
returned = await loop.run_in_executor(
self.executor,
_run_sync,
target,
message,
message.args,
kwargs,
)
except NoResultError as no_res_exc:
found_exception = no_res_exc
logger.warning(
"Task %s with id %s skipped setting result.",
message.task_name,
message.task_id,
)
except BaseException as exc: # noqa: WPS424
found_exception = exc
logger.error(
Expand Down
79 changes: 1 addition & 78 deletions tests/receiver/test_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from taskiq.abc.broker import AckableMessage, AsyncBroker
from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.exceptions import NoResultError, RejectError, TaskiqResultTimeoutError
from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError
from taskiq.message import TaskiqMessage
from taskiq.receiver import Receiver
from taskiq.result import TaskiqResult
Expand Down Expand Up @@ -260,83 +260,6 @@ async def ack_callback() -> None:
assert acked


@pytest.mark.anyio
async def test_callback_success_reject() -> None:
"""
Test that if reject error is thrown,
broker would reject a message.
"""
broker = InMemoryBroker()
rejected = False

@broker.task
async def my_task() -> None:
raise RejectError()

def reject_callback() -> None:
nonlocal rejected
rejected = True

receiver = get_receiver(broker)

broker_message = broker.formatter.dumps(
TaskiqMessage(
task_id="task_id",
task_name=my_task.task_name,
labels={},
args=[],
kwargs={},
),
)

await receiver.callback(
AckableMessage(
data=broker_message.message,
ack=lambda: None,
reject=reject_callback,
),
)
assert rejected


@pytest.mark.anyio
async def test_callback_no_reject_func() -> None:
"""
Test that if broker doesn't support rejects,
it acks message instead.
"""
broker = InMemoryBroker()
acked = False

@broker.task
async def my_task() -> None:
raise RejectError()

def ack_callback() -> None:
nonlocal acked
acked = True

receiver = get_receiver(broker)

broker_message = broker.formatter.dumps(
TaskiqMessage(
task_id="task_id",
task_name=my_task.task_name,
labels={},
args=[],
kwargs={},
),
)

await receiver.callback(
AckableMessage(
data=broker_message.message,
ack=ack_callback,
),
)
assert acked


@pytest.mark.anyio
async def test_callback_wrong_format() -> None:
"""Test that wrong format of a message won't thow an error."""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_requeue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from taskiq import Context, InMemoryBroker, TaskiqDepends


@pytest.mark.anyio
async def test_requeue() -> None:
broker = InMemoryBroker()

runs_count = 0

@broker.task
async def task(context: Context = TaskiqDepends()) -> None:
nonlocal runs_count
runs_count += 1
if runs_count < 2:
await context.requeue()

kicked = await task.kiq()
await kicked.wait_result()

assert runs_count == 2