Source code for faust.web.cache.cache

"""Cache interface."""

import hashlib
from contextlib import suppress
from functools import wraps
from typing import Any, Callable, ClassVar, Mapping, Optional, Type, Union, cast
from urllib.parse import quote

from mode.utils.compat import want_bytes
from mode.utils.logging import get_logger
from mode.utils.times import Seconds, want_seconds

from faust.types.web import CacheBackendT, CacheT, Request, Response, View

logger = get_logger(__name__)

IDENT: str = "faustweb.cache.view"


[docs]class Cache(CacheT): """Cache interface.""" ident: ClassVar[str] = IDENT def __init__( self, timeout: Optional[Seconds] = None, include_headers: bool = False, key_prefix: Optional[str] = None, backend: Union[Type[CacheBackendT], str] = None, **kwargs: Any, ) -> None: self.timeout = timeout self.include_headers = include_headers self.key_prefix = key_prefix or "" self.backend = backend
[docs] def view( self, timeout: Optional[Seconds] = None, include_headers: bool = False, key_prefix: Optional[str] = None, **kwargs: Any, ) -> Callable[[Callable], Callable]: """Decorate view to be cached.""" def _inner(fun: Callable) -> Callable: @wraps(fun) async def cached( view: View, request: Request, *args: Any, **kwargs: Any ) -> Response: key: Optional[str] = None is_head = request.method.upper() == "HEAD" if self.can_cache_request(request): key = self.key_for_request( request, key_prefix, "GET", include_headers ) response = await self.get_view(key, view) if response is not None: logger.info("Found cached response for %r", key) return response if is_head: response = await self.get_view( self.key_for_request( request, key_prefix, "HEAD", include_headers ), view, ) if response is not None: logger.info("Found cached HEAD response for %r", key) return response logger.info("No cache found for %r", key) res = await fun(view, request, *args, **kwargs) if key is not None and self.can_cache_response(request, res): logger.info("Saving cache for key %r", key) if is_head: key = self.key_for_request( request, key_prefix, "HEAD", include_headers ) await self.set_view(key, view, res, timeout) return res return cached return _inner
[docs] async def get_view(self, key: str, view: View) -> Optional[Response]: """Get cached value for HTTP view request.""" backend = self._view_backend(view) with suppress(backend.Unavailable): payload = await backend.get(key) if payload is not None: return view.bytes_to_response(payload) return None
def _view_backend(self, view: View) -> CacheBackendT: return cast(CacheBackendT, self.backend or view.app.cache)
[docs] async def set_view( self, key: str, view: View, response: Response, timeout: Optional[Seconds] = None, ) -> None: """Set cached value for HTTP view request.""" backend = self._view_backend(view) _timeout = timeout if timeout is not None else self.timeout with suppress(backend.Unavailable): return await backend.set( key, view.response_to_bytes(response), want_seconds(_timeout) if _timeout is not None else None, )
[docs] def can_cache_request(self, request: Request) -> bool: """Return :const:`True` if we can cache this type of HTTP request.""" return True
[docs] def can_cache_response(self, request: Request, response: Response) -> bool: """Return :const:`True` for HTTP status codes we CAN cache.""" return response.status == 200
[docs] def key_for_request( self, request: Request, prefix: Optional[str] = None, method: Optional[str] = None, include_headers: bool = False, ) -> str: """Return a cache key created from web request.""" actual_method: str = request.method if method is None else method headers = request.headers if include_headers else {} if prefix is None: prefix = self.key_prefix return self.build_key(request, actual_method, prefix, headers)
[docs] def build_key( self, request: Request, method: str, prefix: str, headers: Mapping[str, str] ) -> str: """Build cache key from web request and environment.""" context = hashlib.md5( # nosec b"".join(want_bytes(k) + want_bytes(v) for k, v in headers.items()), ).hexdigest() url = hashlib.md5( # nosec iri_to_uri(str(request.url)).encode("ascii"), ).hexdigest() return f"{self.ident}.{prefix}.{method}.{url}.{context}"
[docs]def iri_to_uri(iri: str) -> str: """Convert IRI to URI.""" # The list of safe characters here is constructed from the "reserved" and # "unreserved" characters specified in sections 2.2 and 2.3 of RFC 3986: # reserved = gen-delims / sub-delims # gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" # sub-delims = "!" / "$" / "&" / "'" / "(" / ")" # / "*" / "+" / "," / ";" / "=" # unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" # Of the unreserved characters, urllib.parse.quote() already considers all # but the ~ safe. # The % character is also added to the list of safe characters here, as the # end of section 3.1 of RFC 3987 specifically mentions that % must not be # converted. return quote(iri, safe="""/#%[]=:;$&()+,!?*@'~""")