Source code for faust.transport.conductor

"""The conductor delegates messages from the consumer to the streams."""

import asyncio
import os
import typing
from collections import defaultdict
from typing import (
    Any,
    Callable,
    Iterable,
    Iterator,
    MutableMapping,
    MutableSet,
    Optional,
    Set,
    Tuple,
    cast,
)

from mode import Service, get_logger
from mode.utils.futures import notify

from faust.exceptions import KeyDecodeError, ValueDecodeError
from faust.types import TP, AppT, EventT, K, Message, V
from faust.types.topics import TopicT
from faust.types.transports import ConductorT, ConsumerCallback, TPorTopicSet
from faust.types.tuples import tp_set_to_map
from faust.utils.tracing import traced_from_parent_span

if typing.TYPE_CHECKING:  # pragma: no cover
    from faust.topics import Topic as _Topic
else:

    class _Topic: ...  # noqa


NO_CYTHON = bool(os.environ.get("NO_CYTHON", False))

if not NO_CYTHON:  # pragma: no cover
    try:
        from ._cython.conductor import ConductorHandler
    except ImportError:
        ConductorHandler = None
else:  # pragma: no cover
    ConductorHandler = None

__all__ = ["Conductor", "ConductorCompiler"]

logger = get_logger(__name__)


[docs]class ConductorCompiler: # pragma: no cover """Compile a function to handle the messages for a topic+partition."""
[docs] def build( self, conductor: "Conductor", tp: TP, channels: MutableSet[_Topic] ) -> ConsumerCallback: """Generate closure used to deliver messages.""" # This method localizes variables and attribute access # for better performance. This is part of the inner loop # of a Faust worker, so tiny improvements here has big impact. topic, partition = tp app = conductor.app len_: Callable[[Any], int] = len # We divide `stream_buffer_maxsize` with Queue.pressure_ratio # find a limit to the number of messages we will buffer # before considering the buffer to be under high pressure. # When the buffer is under high pressure, we call # Consumer.on_buffer_full(tp) to remove this topic partition # from the fetcher. # We still accept anything that's currently in the fetcher (it's # already in memory so we are just moving the data) without blocking, # but signal the fetcher to stop retrieving any more data for this # partition. consumer_on_buffer_full = app.consumer.on_buffer_full # when the buffer drops down to half we re-enable fetching # from the partition. consumer_on_buffer_drop = app.consumer.on_buffer_drop # flow control will completely block the streams from processing # more data, and is used during rebalancing. acquire_flow_control: Callable = app.flow_control.acquire # when streams send new messages as a side effect, the producer # buffer can sometimes fill up, in that case we block # the streams to wait until the buffer is free. wait_until_producer_ebb = app.producer.buffer.wait_until_ebb # This sensor method is called every time the buffer is full. on_topic_buffer_full = app.sensors.on_topic_buffer_full # callback called when the queue is under high pressure/ # about to become full. def on_pressure_high() -> None: on_topic_buffer_full(tp) consumer_on_buffer_full(tp) # callback used when pressure drops. # added to Queue._pending_pressure_drop_callbacks # when the buffer is under high pressure/full. def on_pressure_drop() -> None: consumer_on_buffer_drop(tp) async def on_message(message: Message) -> None: # when a message is received we find all channels # that subscribe to this message await acquire_flow_control() await wait_until_producer_ebb() channels_n = len_(channels) if channels_n: # we increment the reference count for this message in bulk # immediately, so that nothing will get a chance to decref to # zero before we've had the chance to pass it to all channels message.incref(channels_n) event: Optional[EventT] = None event_keyid: Optional[Tuple[K, V]] = None # forward message to all channels subscribing to this topic # keep track of the number of channels we delivered to, # so that if a DecodeError is raised we can propagate # that error to the remaining channels. delivered: Set[_Topic] = set() full: typing.List[Tuple[EventT, _Topic]] = [] try: for chan in channels: keyid = chan.key_type, chan.value_type if event is None: # first channel deserializes the payload: event = await chan.decode(message, propagate=True) event_keyid = keyid queue = chan.queue if queue.full(): full.append((event, chan)) continue queue.put_nowait_enhanced( event, on_pressure_high=on_pressure_high, on_pressure_drop=on_pressure_drop, ) else: # subsequent channels may have a different # key/value type pair, meaning they all can # deserialize the message in different ways dest_event: EventT if keyid == event_keyid: # Reuse the event if it uses the same keypair: dest_event = event else: dest_event = await chan.decode(message, propagate=True) queue = chan.queue if queue.full(): full.append((dest_event, chan)) continue queue.put_nowait_enhanced( dest_event, on_pressure_high=on_pressure_high, on_pressure_drop=on_pressure_drop, ) delivered.add(chan) if full: for _, dest_chan in full: on_topic_buffer_full(dest_chan) await asyncio.wait( [ asyncio.ensure_future(dest_chan.put(dest_event)) for dest_event, dest_chan in full ], return_when=asyncio.ALL_COMPLETED, ) except KeyDecodeError as exc: remaining = channels - delivered message.ack(app.consumer, n=len(remaining)) for channel in remaining: await channel.on_key_decode_error(exc, message) delivered.add(channel) except ValueDecodeError as exc: remaining = channels - delivered message.ack(app.consumer, n=len(remaining)) for channel in remaining: await channel.on_value_decode_error(exc, message) delivered.add(channel) return on_message
[docs]class Conductor(ConductorT, Service): """Manages the channels that subscribe to topics. - Consumes messages from topic using a single consumer. - Forwards messages to all channels subscribing to a topic. """ logger = logger #: Fast index to see if Topic is registered. _topics: MutableSet[TopicT] #: Map of (topic,partition) to set of channels that subscribe to that TP. _tp_index: MutableMapping[TP, MutableSet[TopicT]] #: Map str topic name to set of channels that subscribe #: to that topic. _topic_name_index: MutableMapping[str, MutableSet[TopicT]] #: For every TP assigned we compile a callback closure, #: and when we receive a message to a TP we look up which callback #: to call here. _tp_to_callback: MutableMapping[TP, ConsumerCallback] #: Lock used to synchronize access to _tp_to_callback. #: Resubscriptions and updates to the indices may modify the mapping, and #: while that is happening, the mapping should not be accessed by message #: handlers. _tp_to_callback_lock: asyncio.Lock #: Whenever a change is made, i.e. a Topic is added/removed, we notify #: the background task responsible for resubscribing. _subscription_changed: Optional[asyncio.Event] _subscription_done: Optional[asyncio.Future] _acking_topics: Set[str] _compiler: ConductorCompiler # `_resubscribe_sleep_lock_seconds` trades off between the latency of # receiving messages for newly added topics and the cost of resubscribing # to topics. Note that this resubscription flow only occurs when the topic # list has changed (see the `_subscription_changed` event). This mechanism # attempts to coalesce topic list changes that happen in quick succession # and prevents the framework from constantly resubscribing to topics after # every change. # # If the value is set too low and an agent is adding topics very # frequently, then resubscription will happen very often and will issue # unnecessary work on the async loop. # If the value is set too high, it will take a long time for a newly added # agent to start receiving messages; this time is bounded by the value of # `_resubscribe_sleep_lock_seconds`, barring something hogging the async # loop. _resubscribe_sleep_lock_seconds: float = 45.0 def __init__(self, app: AppT, **kwargs: Any) -> None: Service.__init__(self, **kwargs) self.app = app self._topics = set() self._topic_name_index = defaultdict(set) self._tp_index = defaultdict(set) self._tp_to_callback = {} self._tp_to_callback_lock = asyncio.Lock() self._acking_topics = set() self._subscription_changed = None self._subscription_done = None self._compiler = ConductorCompiler() # This callback is called whenever the Consumer has # fetched a new record from Kafka. # We compile this down to a closure having variables already # localized as an optimization. self.on_message: ConsumerCallback self.on_message = self._compile_message_handler()
[docs] async def commit(self, topics: TPorTopicSet) -> bool: """Commit offsets in topics.""" return await self.app.consumer.commit(topics)
[docs] def acks_enabled_for(self, topic: str) -> bool: """Return :const:`True` if acks are enabled for topic by name.""" return topic in self._acking_topics
def _compile_message_handler(self) -> ConsumerCallback: # This method localizes variables and attribute access # for better performance. This is part of the inner loop # of a Faust worker, so tiny improvements here has big impact. get_callback_for_tp = self._tp_to_callback.__getitem__ if self.app.client_only: async def on_message(message: Message) -> None: tp = TP(topic=message.topic, partition=0) async with self._tp_to_callback_lock: callback = get_callback_for_tp(tp) return await callback(message) else: async def on_message(message: Message) -> None: async with self._tp_to_callback_lock: callback = get_callback_for_tp(message.tp) return await callback(message) return on_message @Service.task async def _subscriber(self) -> None: # pragma: no cover # the first time we start, we will wait two seconds # to give agents a chance to start up and register their # streams. This way we won't have N subscription requests at the # start. if self.app.client_only or self.app.producer_only: self.log.info("Not waiting for agent/table startups...") else: self.log.info("Waiting for agents to start...") await self.app.agents.wait_until_agents_started() self.log.info("Waiting for tables to be registered...") await self.app.tables.wait_until_tables_registered() if not self.should_stop: # tell the consumer to subscribe to the topics. await self.app.consumer.subscribe(await self._update_indices()) notify(self._subscription_done) # Now we wait for changes ev = self._subscription_changed = asyncio.Event() while not self.should_stop: # Wait for something to add/remove topics from subscription. await ev.wait() if self.app.rebalancing: # we do not want to perform a resubscribe if the application # is rebalancing. ev.clear() else: # The change could be in reaction to something like "all agents # restarting", in that case it would be bad if we resubscribe # over and over, so we wait for 45 seconds to make sure any # further subscription requests will happen during the same # rebalance. await self.sleep(self._resubscribe_sleep_lock_seconds) # Clear the event before updating indices. This way, new events # that get triggered during the update will be handled the next # time around. ev.clear() subscribed_topics = await self._update_indices() await self.app.consumer.subscribe(subscribed_topics) # wake-up anything waiting for the subscription to be done. notify(self._subscription_done)
[docs] async def wait_for_subscriptions(self) -> None: """Wait for consumer to be subscribed.""" if self._subscription_done is None: self._subscription_done = asyncio.Future(loop=self.loop) await self._subscription_done
[docs] async def maybe_wait_for_subscriptions(self) -> None: if self._subscription_done is not None: await self._subscription_done
async def _update_indices(self) -> Iterable[str]: async with self._tp_to_callback_lock: self._topic_name_index.clear() self._tp_to_callback.clear() # Make a (shallow) copy of the topics, so new additions to the set # won't poison the iterator. Additions can come in while this # function yields during an await. topics = list(self._topics) for channel in topics: if channel.internal: await channel.maybe_declare() for topic in channel.topics: if channel.acks: self._acking_topics.add(topic) self._topic_name_index[topic].add(channel) self._update_callback_map() return self._topic_name_index
[docs] async def on_partitions_assigned(self, assigned: Set[TP]) -> None: """Call when cluster is rebalancing and partitions are assigned.""" T = traced_from_parent_span() self._tp_index.clear() T(self._update_tp_index)(assigned) T(self._update_callback_map)()
[docs] async def on_client_only_start(self) -> None: tp_index = self._tp_index for topic in self._topics: for subtopic in topic.topics: tp = TP(topic=subtopic, partition=0) tp_index[tp].add(topic) self._update_callback_map()
def _update_tp_index(self, assigned: Set[TP]) -> None: assignmap = tp_set_to_map(assigned) tp_index = self._tp_index for topic in self._topics: if topic.active_partitions is not None: # Isolated Partitions: One agent per partition. if topic.active_partitions: if assigned: assert topic.active_partitions.issubset(assigned) for tp in topic.active_partitions: tp_index[tp].add(topic) else: # Default: One agent receives messages for all partitions. for subtopic in topic.topics: for tp in assignmap[subtopic]: tp_index[tp].add(topic) def _update_callback_map(self) -> None: self._tp_to_callback.update( (tp, self._build_handler(tp, cast(MutableSet[_Topic], channels))) for tp, channels in self._tp_index.items() ) def _build_handler(self, tp: TP, channels: MutableSet[_Topic]) -> ConsumerCallback: if ConductorHandler is not None: # pragma: no cover return ConductorHandler(self, tp, channels) else: return self._compiler.build(self, tp, channels)
[docs] def clear(self) -> None: """Clear all subscriptions.""" self._topics.clear() self._topic_name_index.clear() self._tp_index.clear() self._tp_to_callback.clear() self._acking_topics.clear()
def __contains__(self, value: Any) -> bool: return value in self._topics def __iter__(self) -> Iterator[TopicT]: return iter(self._topics) def __len__(self) -> int: return len(self._topics) def __hash__(self) -> int: return object.__hash__(self)
[docs] def add(self, topic: TopicT) -> None: """Register topic to be subscribed.""" if topic not in self._topics: self._topics.add(topic) if self._topic_contain_unsubscribed_topics(topic): self._flag_changes()
def _topic_contain_unsubscribed_topics(self, topic: TopicT) -> bool: index = self._topic_name_index return bool(any(t not in index for t in topic.topics))
[docs] def discard(self, topic: Any) -> None: """Unregister topic from conductor.""" self._topics.discard(topic) self._flag_changes()
def _flag_changes(self) -> None: if self._subscription_changed is not None: self._subscription_changed.set() if self._subscription_done is None: self._subscription_done = asyncio.Future(loop=self.loop) @property def label(self) -> str: """Return label for use in logs.""" return f"{type(self).__name__}({len(self._topics)})" @property def shortlabel(self) -> str: """Return short label for use in logs.""" return type(self).__name__ @property def acking_topics(self) -> Set[str]: return self._acking_topics