Source code for faust.agents.replies
"""Agent replies: waiting for replies, sending them, etc."""
import asyncio
from collections import defaultdict
from typing import Any, AsyncIterator, MutableMapping, MutableSet, NamedTuple, Optional
from weakref import WeakSet
from mode import Service
from faust.types import AppT, ChannelT, TopicT
from .models import ReqRepResponse
__all__ = ["ReplyPromise", "BarrierState", "ReplyConsumer"]
from ..models import maybe_model
class ReplyTuple(NamedTuple):
correlation_id: str
value: Any
[docs]class ReplyPromise(asyncio.Future):
"""Reply promise can be :keyword:`await`-ed to wait until result ready."""
reply_to: str
correlation_id: str
def __init__(self, reply_to: str, correlation_id: str = "", **kwargs: Any) -> None:
self.reply_to = reply_to
self._verify_correlation_id(correlation_id)
self.correlation_id = correlation_id
self.__post_init__()
super().__init__(**kwargs)
def __post_init__(self) -> None: ...
def _verify_correlation_id(self, correlation_id: str) -> None:
if not correlation_id:
raise ValueError("ReplyPromise missing correlation_id argument.")
[docs] def fulfill(self, correlation_id: str, value: Any) -> None:
"""Fulfill promise: a reply was received."""
# If it wasn't for BarrierState we would just use .set_result()
# directly, but BarrierState.fulfill requires the correlation_id
# to be sent with it. That way it can mark that part of the map
# operation as completed.
assert correlation_id == self.correlation_id
self.set_result(value)
[docs]class BarrierState(ReplyPromise):
"""State of pending/complete barrier.
A barrier is a synchronization primitive that will wait until
a group of coroutines have completed.
"""
#: This is the size while the messages are being sent.
#: (it's a tentative total, added to until the total is finalized).
size: int = 0
#: This is the actual total when all messages have been sent.
#: It's set by :meth:`finalize`.
total: int = 0
#: The number of results we have received.
fulfilled: int = 0
#: Internal queue where results are added to.
_results: asyncio.Queue
#: Set of pending replies that this barrier is composed of.
pending: MutableSet[ReplyPromise]
def __post_init__(self) -> None:
self.pending = set()
self._results = asyncio.Queue(maxsize=1000)
def _verify_correlation_id(self, correlation_id: str) -> None:
pass # barrier does not require a correlation id.
[docs] def add(self, p: ReplyPromise) -> None:
"""Add promise to barrier.
Note:
You can only add promises before the barrier is finalized
using :meth:`finalize`.
"""
self.pending.add(p)
self.size += 1
[docs] def finalize(self) -> None:
"""Finalize this barrier.
After finalization you can not grow or shrink the size
of the barrier.
"""
self.total = self.size
# The barrier may have been filled up already at this point,
if self.fulfilled >= self.total:
self.set_result(True)
self._results.put_nowait(None) # always wake-up .iterate()
[docs] def fulfill(self, correlation_id: str, value: Any) -> None:
"""Fulfill one of the promises in this barrier.
Once all promises in this barrier is fulfilled, the barrier
will be ready.
"""
# ReplyConsumer calls this whenever a new reply is received.
self._results.put_nowait(ReplyTuple(correlation_id, value))
self.fulfilled += 1
if self.total:
if self.fulfilled >= self.total:
self.set_result(True)
self._results.put_nowait(None) # always wake-up .iterate()
[docs] def get_nowait(self) -> ReplyTuple:
"""Return next reply, or raise :exc:`asyncio.QueueEmpty`."""
for _ in range(10): # remove sentinels
value = self._results.get_nowait()
if value is not None:
return value
raise asyncio.QueueEmpty()
[docs] async def iterate(self) -> AsyncIterator[ReplyTuple]:
"""Iterate over results as they arrive."""
get = self._results.get
get_nowait = self._results.get_nowait
is_done = self.done
while not is_done():
value = await get()
if value is not None:
yield value
while 1:
try:
value = get_nowait()
except asyncio.QueueEmpty:
break
else:
if value is not None:
yield value
[docs]class ReplyConsumer(Service):
"""Consumer responsible for redelegation of replies received."""
_waiting: MutableMapping[str, MutableSet[ReplyPromise]]
_fetchers: MutableMapping[str, Optional[asyncio.Future]]
def __init__(self, app: AppT, **kwargs: Any) -> None:
self.app = app
self._waiting = defaultdict(WeakSet)
self._fetchers = {}
super().__init__(**kwargs)
[docs] async def on_start(self) -> None:
"""Call when reply consumer starts."""
if self.app.conf.reply_create_topic:
await self._start_fetcher(self.app.conf.reply_to)
[docs] async def add(self, correlation_id: str, promise: ReplyPromise) -> None:
"""Register promise to start tracking when it arrives."""
reply_topic = promise.reply_to
if reply_topic not in self._fetchers:
await self._start_fetcher(reply_topic)
self._waiting[correlation_id].add(promise)
async def _start_fetcher(self, topic_name: str) -> None:
if topic_name not in self._fetchers:
# set the key as a lock, so it doesn't happen twice
self._fetchers[topic_name] = None
# declare the topic
topic = self._reply_topic(topic_name)
await topic.maybe_declare()
self.app.topics.add(topic)
await self.sleep(3.0)
# then create the future
self._fetchers[topic_name] = self.add_future(self._drain_replies(topic))
async def _drain_replies(self, channel: ChannelT) -> None:
async for reply in channel.stream():
for promise in self._waiting[reply.correlation_id]:
promise.fulfill(reply.correlation_id, maybe_model(reply.value))
def _reply_topic(self, topic: str) -> TopicT:
return self.app.topic(
topic,
partitions=1,
replicas=0,
deleting=True,
retention=self.app.conf.reply_expires,
value_type=ReqRepResponse,
)