"""Web driver using :pypi:`aiohttp`."""
from pathlib import Path
from typing import Any, Callable, Mapping, MutableMapping, Optional, Union, cast
import aiohttp_cors
from aiohttp import __version__ as aiohttp_version
from aiohttp.payload import Payload
from aiohttp.web import (
Application,
AppRunner,
BaseSite,
Request,
Response,
TCPSite,
UnixSite,
)
from aiohttp_cors import CorsConfig, ResourceOptions
from mode import Service
from mode.threads import ServiceThread
from faust.types import AppT
from faust.types.web import ResourceOptions as _ResourceOptions
from faust.utils import json as _json
from faust.web import base
__all__ = ["Web"]
_bytes = bytes
NON_OPTIONS_METHODS = frozenset({"GET", "PUT", "POST", "DELETE"})
def _prepare_cors_options(opts: Mapping[str, Any]) -> Mapping[str, ResourceOptions]:
return {k: _faust_to_aiohttp_options(v) for k, v in opts.items()}
def _faust_to_aiohttp_options(opts: ResourceOptions) -> ResourceOptions:
if isinstance(opts, _ResourceOptions):
return ResourceOptions(**opts._asdict())
return opts
class ServerThread(ServiceThread):
"""A web server running in a dedicated thread."""
def __init__(self, web: "Web", **kwargs: Any) -> None:
self.web = web
super().__init__(**kwargs)
async def on_start(self) -> None:
"""Call in parent thread when the service thread is starting."""
await self.web.start_server()
async def on_thread_stop(self) -> None:
"""Call in thread when the service stops."""
# on_stop() executes in parent thread, on_thread_stop in the thread.
await self.web.stop_server()
class Server(Service):
"""Web server service."""
def __init__(self, web: "Web", **kwargs: Any) -> None:
self.web = web
super().__init__(**kwargs)
async def on_start(self) -> None:
"""Call when the web server starts."""
await self.web.start_server()
async def on_stop(self) -> None:
"""Call when the web server stops."""
await self.web.stop_server()
[docs]class Web(base.Web):
"""Web server and framework implementation using :pypi:`aiohttp`."""
driver_version = f"aiohttp={aiohttp_version}"
handler_shutdown_timeout: float = 60.0
cors_options: Mapping[str, ResourceOptions]
#: We serve the web server in a separate thread (and separate event loop).
_thread: Optional[Service] = None
_transport_handlers: Mapping[str, Callable[[], BaseSite]]
_cors: Optional[CorsConfig] = None
def __init__(self, app: AppT, **kwargs: Any) -> None:
super().__init__(app, **kwargs)
self.web_app: Application = Application()
self.cors_options = _prepare_cors_options(app.conf.web_cors_options or {})
self._runner: AppRunner = AppRunner(self.web_app, access_log=None)
self._transport_handlers = {
"tcp": self._new_transport_tcp,
"unix": self._new_transport_unix,
}
@property
def cors(self) -> CorsConfig:
"""Return CORS config object."""
if self._cors is None:
self._cors = aiohttp_cors.setup(self.web_app, defaults=self.cors_options)
return self._cors
[docs] async def on_start(self) -> None:
"""Call when the embedded web server starts.
Only used for `faust worker`, not when using :meth:`wsgi`.
"""
cors = self.cors
assert cors
self.init_server()
server_cls = ServerThread if self.app.conf.web_in_thread else Server
self._thread = server_cls(self, loop=self.loop, beacon=self.beacon)
self.add_dependency(self._thread)
[docs] async def wsgi(self) -> Any:
"""Call WSGI handler.
Used by :pypi:`gunicorn` and other WSGI compatible hosts
to access the Faust web entry point.
"""
self.init_server()
return self.web_app
[docs] def text(
self,
value: str,
*,
content_type: Optional[str] = None,
status: int = 200,
reason: Optional[str] = None,
headers: MutableMapping = None,
) -> base.Response:
"""Create text response, using "text/plain" content-type."""
response = Response(
text=value,
content_type=content_type,
status=status,
reason=reason,
headers=headers,
)
return cast(base.Response, response)
[docs] def html(
self,
value: str,
*,
content_type: Optional[str] = None,
status: int = 200,
reason: Optional[str] = None,
headers: MutableMapping = None,
) -> base.Response:
"""Create HTML response from string, ``text/html`` content-type."""
return self.text(
value,
status=status,
content_type=content_type or "text/html",
reason=reason,
headers=headers,
)
[docs] def json(
self,
value: Any,
*,
content_type: Optional[str] = None,
status: int = 200,
reason: Optional[str] = None,
headers: MutableMapping = None,
) -> Any:
"""Create new JSON response.
Accepts any JSON-serializable value and will automatically
serialize it for you.
The content-type is set to "application/json".
"""
ctype = content_type or "application/json"
payload: Any = _json.dumps(value)
# normal json returns str, orjson returns bytes
if isinstance(payload, bytes):
return self.bytes(
payload,
content_type=ctype,
status=status,
reason=reason,
headers=headers,
)
else:
return self.text(
payload,
content_type=ctype,
status=status,
reason=reason,
headers=headers,
)
[docs] def bytes(
self,
value: _bytes,
*,
content_type: Optional[str] = None,
status: int = 200,
reason: Optional[str] = None,
headers: MutableMapping = None,
) -> base.Response:
"""Create new ``bytes`` response - for binary data."""
response = Response(
body=value,
content_type=content_type,
status=status,
reason=reason,
headers=headers,
)
return cast(base.Response, response)
[docs] async def read_request_content(self, request: base.Request) -> _bytes:
"""Return the request body as bytes."""
return await cast(Request, request).content.read()
[docs] def route(
self,
pattern: str,
handler: Callable,
cors_options: Mapping[str, ResourceOptions] = None,
) -> None:
"""Add route for web view or handler."""
async_handler = self._wrap_into_asyncdef(handler)
if cors_options or self.cors_options:
for method in NON_OPTIONS_METHODS & handler.get_methods():
r = self.web_app.router.add_route(method, pattern, async_handler)
self.cors.add(r, _prepare_cors_options(cors_options or {}))
else:
for method in handler.get_methods():
self.web_app.router.add_route(method, pattern, async_handler)
def _wrap_into_asyncdef(self, handler: Callable) -> Callable:
# get rid of pesky "DeprecationWarning: Bare functions are
# deprecated, use async ones" warnings.
# The handler is actually a class that defines `async def __call__`
# but aiohttp doesn't recognize it as such and emits the warning.
# To avoid that we just wrap it in an `async def` function
async def _dispatch(request: base.Request) -> base.Response:
return await handler(request)
_dispatch.__doc__ = handler.__doc__
return _dispatch
[docs] def add_static(self, prefix: str, path: Union[Path, str], **kwargs: Any) -> None:
"""Add route for static assets."""
self.web_app.router.add_static(prefix, str(path), **kwargs)
[docs] def bytes_to_response(self, s: _bytes) -> base.Response:
"""Deserialize byte string back into a response object."""
status, headers, body = self._bytes_to_response(s)
response = Response(
body=body,
status=status,
headers=headers,
)
return cast(base.Response, response)
[docs] def response_to_bytes(self, response: base.Response) -> _bytes:
"""Convert response to serializable byte string.
The result is a byte string that can be deserialized
using :meth:`bytes_to_response`.
"""
resp = cast(Response, response)
if resp.body is None:
body = b""
elif isinstance(resp.body, Payload):
raise NotImplementedError("Does not support Payload")
else:
body = resp.body
return self._response_to_bytes(
resp.status,
resp.headers,
body,
)
def _create_site(self) -> BaseSite:
return self._new_transport(self.app.conf.web_transport.scheme)
def _new_transport(self, type_: str) -> BaseSite:
return self._transport_handlers[type_]()
def _new_transport_tcp(self) -> BaseSite:
return TCPSite(
self._runner,
self.app.conf.web_bind,
self.app.conf.web_port,
ssl_context=self.app.conf.web_ssl_context,
)
def _new_transport_unix(self) -> BaseSite:
return UnixSite(
self._runner,
self.app.conf.web_transport.path,
)
[docs] async def start_server(self) -> None:
"""Start the web server."""
await self._runner.setup()
site = self._create_site()
await site.start()
[docs] async def stop_server(self) -> None:
"""Stop the web server."""
if self._runner:
await self._runner.cleanup()
await self._cleanup_app()
async def _cleanup_app(self) -> None:
if self.web_app is not None:
self.log.info("Cleanup")
await self.web_app.cleanup()
@property
def _app(self) -> Application:
# XXX compat alias
return self.web_app