Source code for faust.agents.replies

"""Agent replies: waiting for replies, sending them, etc."""

import asyncio
from collections import defaultdict
from typing import Any, AsyncIterator, MutableMapping, MutableSet, NamedTuple, Optional
from weakref import WeakSet

from mode import Service

from faust.types import AppT, ChannelT, TopicT

from .models import ReqRepResponse

__all__ = ["ReplyPromise", "BarrierState", "ReplyConsumer"]

from ..models import maybe_model


class ReplyTuple(NamedTuple):
    correlation_id: str
    value: Any


[docs]class ReplyPromise(asyncio.Future): """Reply promise can be :keyword:`await`-ed to wait until result ready.""" reply_to: str correlation_id: str def __init__(self, reply_to: str, correlation_id: str = "", **kwargs: Any) -> None: self.reply_to = reply_to self._verify_correlation_id(correlation_id) self.correlation_id = correlation_id self.__post_init__() super().__init__(**kwargs) def __post_init__(self) -> None: ... def _verify_correlation_id(self, correlation_id: str) -> None: if not correlation_id: raise ValueError("ReplyPromise missing correlation_id argument.")
[docs] def fulfill(self, correlation_id: str, value: Any) -> None: """Fulfill promise: a reply was received.""" # If it wasn't for BarrierState we would just use .set_result() # directly, but BarrierState.fulfill requires the correlation_id # to be sent with it. That way it can mark that part of the map # operation as completed. assert correlation_id == self.correlation_id self.set_result(value)
[docs]class BarrierState(ReplyPromise): """State of pending/complete barrier. A barrier is a synchronization primitive that will wait until a group of coroutines have completed. """ #: This is the size while the messages are being sent. #: (it's a tentative total, added to until the total is finalized). size: int = 0 #: This is the actual total when all messages have been sent. #: It's set by :meth:`finalize`. total: int = 0 #: The number of results we have received. fulfilled: int = 0 #: Internal queue where results are added to. _results: asyncio.Queue #: Set of pending replies that this barrier is composed of. pending: MutableSet[ReplyPromise] def __post_init__(self) -> None: self.pending = set() self._results = asyncio.Queue(maxsize=1000) def _verify_correlation_id(self, correlation_id: str) -> None: pass # barrier does not require a correlation id.
[docs] def add(self, p: ReplyPromise) -> None: """Add promise to barrier. Note: You can only add promises before the barrier is finalized using :meth:`finalize`. """ self.pending.add(p) self.size += 1
[docs] def finalize(self) -> None: """Finalize this barrier. After finalization you can not grow or shrink the size of the barrier. """ self.total = self.size # The barrier may have been filled up already at this point, if self.fulfilled >= self.total: self.set_result(True) self._results.put_nowait(None) # always wake-up .iterate()
[docs] def fulfill(self, correlation_id: str, value: Any) -> None: """Fulfill one of the promises in this barrier. Once all promises in this barrier is fulfilled, the barrier will be ready. """ # ReplyConsumer calls this whenever a new reply is received. self._results.put_nowait(ReplyTuple(correlation_id, value)) self.fulfilled += 1 if self.total: if self.fulfilled >= self.total: self.set_result(True) self._results.put_nowait(None) # always wake-up .iterate()
[docs] def get_nowait(self) -> ReplyTuple: """Return next reply, or raise :exc:`asyncio.QueueEmpty`.""" for _ in range(10): # remove sentinels value = self._results.get_nowait() if value is not None: return value raise asyncio.QueueEmpty()
[docs] async def iterate(self) -> AsyncIterator[ReplyTuple]: """Iterate over results as they arrive.""" get = self._results.get get_nowait = self._results.get_nowait is_done = self.done while not is_done(): value = await get() if value is not None: yield value while 1: try: value = get_nowait() except asyncio.QueueEmpty: break else: if value is not None: yield value
[docs]class ReplyConsumer(Service): """Consumer responsible for redelegation of replies received.""" _waiting: MutableMapping[str, MutableSet[ReplyPromise]] _fetchers: MutableMapping[str, Optional[asyncio.Future]] def __init__(self, app: AppT, **kwargs: Any) -> None: self.app = app self._waiting = defaultdict(WeakSet) self._fetchers = {} super().__init__(**kwargs)
[docs] async def on_start(self) -> None: """Call when reply consumer starts.""" if self.app.conf.reply_create_topic: await self._start_fetcher(self.app.conf.reply_to)
[docs] async def add(self, correlation_id: str, promise: ReplyPromise) -> None: """Register promise to start tracking when it arrives.""" reply_topic = promise.reply_to if reply_topic not in self._fetchers: await self._start_fetcher(reply_topic) self._waiting[correlation_id].add(promise)
async def _start_fetcher(self, topic_name: str) -> None: if topic_name not in self._fetchers: # set the key as a lock, so it doesn't happen twice self._fetchers[topic_name] = None # declare the topic topic = self._reply_topic(topic_name) await topic.maybe_declare() self.app.topics.add(topic) await self.sleep(3.0) # then create the future self._fetchers[topic_name] = self.add_future(self._drain_replies(topic)) async def _drain_replies(self, channel: ChannelT) -> None: async for reply in channel.stream(): for promise in self._waiting[reply.correlation_id]: promise.fulfill(reply.correlation_id, maybe_model(reply.value)) def _reply_topic(self, topic: str) -> TopicT: return self.app.topic( topic, partitions=1, replicas=0, deleting=True, retention=self.app.conf.reply_expires, value_type=ReqRepResponse, )