Source code for faust.transport.producer

"""Producer.

The Producer is responsible for:

   - Holds reference to the transport that created it
   - ... and the app via ``self.transport.app``.
   - Sending messages.
"""

import asyncio
import time
from typing import Any, Awaitable, Mapping, Optional, cast

from mode import Seconds, Service, get_logger
from mode.threads import ServiceThread

from faust.types import AppT, HeadersArg
from faust.types.transports import ProducerBufferT, ProducerT, TransportT
from faust.types.tuples import TP, FutureMessage, RecordMetadata

__all__ = ["Producer"]
logger = get_logger(__name__)


class ProducerBuffer(Service, ProducerBufferT):
    app: Optional[AppT] = None
    max_messages = 100
    queue: Optional[asyncio.Queue] = None

    def __post_init__(self) -> None:
        self.pending = asyncio.Queue()
        self.message_sent = asyncio.Event()

    def put(self, fut: FutureMessage) -> None:
        """Add message to buffer.

        The message will be eventually produced, you can await
        the future to wait for that to happen.
        """
        if self.app.conf.producer_threaded:
            if not self.queue:
                self.queue = self.threaded_producer.event_queue
            asyncio.run_coroutine_threadsafe(
                self.queue.put(fut), self.threaded_producer.thread_loop
            )
        else:
            self.pending.put_nowait(fut)

    async def on_stop(self) -> None:
        await self.flush()

    async def flush(self) -> None:
        """Flush all messages (draining the buffer)."""
        await self.flush_atmost(None)

    async def flush_atmost(self, max_messages: Optional[int]) -> int:
        """Flush at most ``n`` messages."""
        flushed_messages = 0
        while True:
            if self.state != "running" and self.size:
                raise RuntimeError("Cannot flush: Producer not Running")
            if self.size != 0 and (
                (max_messages is None or flushed_messages < max_messages)
            ):
                self.message_sent.clear()
                try:
                    await asyncio.wait_for(self.message_sent.wait(), timeout=0.1)
                    flushed_messages += 1
                except asyncio.TimeoutError:
                    return flushed_messages
            else:
                return flushed_messages

    async def _send_pending(self, fut: FutureMessage) -> None:
        await fut.message.channel.publish_message(fut, wait=False)

    async def wait_until_ebb(self) -> None:
        """Wait until buffer is of an acceptable size.

        Modifying a table key is using the Python dictionary API,
        and as ``__getitem__`` is synchronous we have to add
        pending messages to a buffer.

        The ``__getitem__`` method cannot drain the buffer as doing
        so requires trampolining into the event loop.

        To solve this, we have the conductor wait until the buffer
        is of an acceptable size before resuming stream processing flow.
        """
        if self.size > self.max_messages:
            logger.warning(f"producer buffer full size {self.size}")
            start_time = time.time()
            await self.flush_atmost(self.max_messages)
            end_time = time.time()
            logger.info(f"producer flush took {end_time - start_time}")

    @Service.task
    async def _handle_pending(self) -> None:
        get_pending = self.pending.get
        send_pending = self._send_pending
        while not self.should_stop:
            msg = await get_pending()
            await send_pending(msg)
            self.message_sent.set()

    @property
    def size(self) -> int:
        """Current buffer size (messages waiting to be produced)."""
        if self.app.conf.producer_threaded:
            if not self.queue:
                return 0
            queue_items = self.queue._queue  # type: ignore
        else:
            queue_items = self.pending._queue
        queue_items = cast(list, queue_items)
        return len(queue_items)


[docs]class Producer(Service, ProducerT): """Base Producer.""" app: AppT _api_version: str threaded_producer: Optional[ServiceThread] = None def __init__( self, transport: TransportT, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any, ) -> None: self.transport = transport self.app = self.transport.app conf = self.transport.app.conf self.client_id = conf.broker_client_id self.linger_ms = int(conf.producer_linger * 1000) self.max_batch_size = conf.producer_max_batch_size self.acks = conf.producer_acks self.max_request_size = conf.producer_max_request_size self.compression_type = conf.producer_compression_type self.request_timeout = conf.producer_request_timeout self.ssl_context = conf.ssl_context self.credentials = conf.broker_credentials self.partitioner = conf.producer_partitioner api_version = self._api_version = conf.producer_api_version assert api_version is not None super().__init__(loop=loop, **kwargs) self.buffer = ProducerBuffer(loop=self.loop, beacon=self.beacon) if conf.producer_threaded: self.threaded_producer = self.create_threaded_producer() self.buffer.threaded_producer = self.threaded_producer self.buffer.app = self.app
[docs] async def on_start(self) -> None: await self.add_runtime_dependency(self.buffer)
[docs] async def send( self, topic: str, key: Optional[bytes], value: Optional[bytes], partition: Optional[int], timestamp: Optional[float], headers: Optional[HeadersArg], *, transactional_id: Optional[str] = None, ) -> Awaitable[RecordMetadata]: """Schedule message to be sent by producer.""" raise NotImplementedError()
[docs] def send_soon(self, fut: FutureMessage) -> None: self.buffer.put(fut)
[docs] async def send_and_wait( self, topic: str, key: Optional[bytes], value: Optional[bytes], partition: Optional[int], timestamp: Optional[float], headers: Optional[HeadersArg], *, transactional_id: Optional[str] = None, ) -> RecordMetadata: """Send message and wait for it to be transmitted.""" raise NotImplementedError()
[docs] async def flush(self) -> None: """Flush all in-flight messages.""" # XXX subclasses must call self.buffer.flush() here. ...
[docs] async def create_topic( self, topic: str, partitions: int, replication: int, *, config: Optional[Mapping[str, Any]] = None, timeout: Seconds = 1000.0, retention: Optional[Seconds] = None, compacting: Optional[bool] = None, deleting: Optional[bool] = None, ensure_created: bool = False, ) -> None: """Create/declare topic on server.""" raise NotImplementedError()
[docs] def key_partition(self, topic: str, key: bytes) -> TP: """Hash key to determine partition.""" raise NotImplementedError()
[docs] async def begin_transaction(self, transactional_id: str) -> None: """Begin transaction by id.""" raise NotImplementedError()
[docs] async def commit_transaction(self, transactional_id: str) -> None: """Commit transaction by id.""" raise NotImplementedError()
[docs] async def abort_transaction(self, transactional_id: str) -> None: """Abort and rollback transaction by id.""" raise NotImplementedError()
[docs] async def stop_transaction(self, transactional_id: str) -> None: """Stop transaction by id.""" raise NotImplementedError()
[docs] async def maybe_begin_transaction(self, transactional_id: str) -> None: """Begin transaction by id, if not already started.""" raise NotImplementedError()
[docs] async def commit_transactions( self, tid_to_offset_map: Mapping[str, Mapping[TP, int]], group_id: str, start_new_transaction: bool = True, ) -> None: """Commit transactions.""" raise NotImplementedError()
[docs] def supports_headers(self) -> bool: """Return :const:`True` if headers are supported by this transport.""" return False