Source code for faust.livecheck.patches.aiohttp
"""LiveCheck :pypi:`aiohttp` integration."""
from contextlib import ExitStack
from types import SimpleNamespace
from typing import Any, List, Optional, no_type_check
import aiohttp
from aiohttp import web
from faust.livecheck.locals import current_test_stack
from faust.livecheck.models import TestExecution
__all__ = ["patch_all", "patch_aiohttp_session", "LiveCheckMiddleware"]
[docs]def patch_all() -> None:
"""Patch all :pypi:`aiohttp` functions to integrate with LiveCheck."""
patch_aiohttp_session()
[docs]def patch_aiohttp_session() -> None:
"""Patch :class:`aiohttp.ClientSession` to integrate with LiveCheck.
If there is any currently active test, we will
use that to forward LiveCheck HTTP headers to the new HTTP request.
"""
from aiohttp import TraceConfig, client
# monkeypatch to remove ridiculous "do not subclass" warning.
def __init_subclass__() -> None: ...
client.ClientSession.__init_subclass__ = __init_subclass__ # type: ignore
async def _on_request_start(
session: aiohttp.ClientSession,
trace_config_ctx: SimpleNamespace,
params: aiohttp.TraceRequestStartParams,
) -> None:
test = current_test_stack.top
if test is not None:
params.headers.update(test.as_headers())
class ClientSession(client.ClientSession):
def __init__(
self, trace_configs: Optional[List[TraceConfig]] = None, **kwargs: Any
) -> None:
super().__init__(
trace_configs=self._faust_trace_configs(trace_configs), **kwargs
)
@no_type_check
def _faust_trace_configs(
self, configs: List[TraceConfig] = None
) -> List[TraceConfig]:
if configs is None:
configs = []
trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(_on_request_start)
configs.append(trace_config)
return configs
client.ClientSession = ClientSession # type: ignore
[docs]@web.middleware
class LiveCheckMiddleware:
"""LiveCheck support for :pypi:`aiohttp` web servers.
This middleware is applied to all incoming web requests,
and is used to extract LiveCheck HTTP headers.
If the web request is configured with the correct set of LiveCheck
headers, we will use that to set the "current test" context.
"""
async def __call__(self, request: web.Request, handler: Any) -> Any:
"""Call to handle new web request."""
related_test = TestExecution.from_headers(request.headers)
with ExitStack() as stack:
if related_test:
stack.enter_context(current_test_stack.push(related_test))
return await handler(request)