Третий коммит, добавление share, share_kb, а также ADMIN_ID

This commit is contained in:
2025-07-22 13:50:14 +03:00
parent 849feb7beb
commit b98123f4dc
1479 changed files with 323549 additions and 11 deletions

View File

@@ -0,0 +1 @@
5baf9cbb44b21e13843de6f8ef936fc5a810d49cfda56ab2649bf012a4cb36cd

View File

@@ -0,0 +1 @@
0455129b185e981b5b96ac738f31f7c74dc57f1696953cae0083b3f18679fe73

View File

@@ -0,0 +1 @@
fc5f7b39a827f8a4327775ab5c50ac6e98f73c126fdec6094d8ea44e6bdd263e

View File

@@ -0,0 +1 @@
c3ad073fa4d540a9abb3f9c79bc16548d622457d04068ec7caaf62994983363a

View File

@@ -0,0 +1 @@
ee1b6686067213d1ea59b3e9c47534afb90021d4f692939741ad4069d0e1d96f

View File

@@ -0,0 +1,264 @@
__version__ = "3.11.18"
from typing import TYPE_CHECKING, Tuple
from . import hdrs as hdrs
from .client import (
BaseConnector,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorDNSError,
ClientConnectorError,
ClientConnectorSSLError,
ClientError,
ClientHttpProxyError,
ClientOSError,
ClientPayloadError,
ClientProxyConnectionError,
ClientRequest,
ClientResponse,
ClientResponseError,
ClientSession,
ClientSSLError,
ClientTimeout,
ClientWebSocketResponse,
ClientWSTimeout,
ConnectionTimeoutError,
ContentTypeError,
Fingerprint,
InvalidURL,
InvalidUrlClientError,
InvalidUrlRedirectClientError,
NamedPipeConnector,
NonHttpUrlClientError,
NonHttpUrlRedirectClientError,
RedirectClientError,
RequestInfo,
ServerConnectionError,
ServerDisconnectedError,
ServerFingerprintMismatch,
ServerTimeoutError,
SocketTimeoutError,
TCPConnector,
TooManyRedirects,
UnixConnector,
WSMessageTypeError,
WSServerHandshakeError,
request,
)
from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar
from .formdata import FormData as FormData
from .helpers import BasicAuth, ChainMapProxy, ETag
from .http import (
HttpVersion as HttpVersion,
HttpVersion10 as HttpVersion10,
HttpVersion11 as HttpVersion11,
WebSocketError as WebSocketError,
WSCloseCode as WSCloseCode,
WSMessage as WSMessage,
WSMsgType as WSMsgType,
)
from .multipart import (
BadContentDispositionHeader as BadContentDispositionHeader,
BadContentDispositionParam as BadContentDispositionParam,
BodyPartReader as BodyPartReader,
MultipartReader as MultipartReader,
MultipartWriter as MultipartWriter,
content_disposition_filename as content_disposition_filename,
parse_content_disposition as parse_content_disposition,
)
from .payload import (
PAYLOAD_REGISTRY as PAYLOAD_REGISTRY,
AsyncIterablePayload as AsyncIterablePayload,
BufferedReaderPayload as BufferedReaderPayload,
BytesIOPayload as BytesIOPayload,
BytesPayload as BytesPayload,
IOBasePayload as IOBasePayload,
JsonPayload as JsonPayload,
Payload as Payload,
StringIOPayload as StringIOPayload,
StringPayload as StringPayload,
TextIOPayload as TextIOPayload,
get_payload as get_payload,
payload_type as payload_type,
)
from .payload_streamer import streamer as streamer
from .resolver import (
AsyncResolver as AsyncResolver,
DefaultResolver as DefaultResolver,
ThreadedResolver as ThreadedResolver,
)
from .streams import (
EMPTY_PAYLOAD as EMPTY_PAYLOAD,
DataQueue as DataQueue,
EofStream as EofStream,
FlowControlDataQueue as FlowControlDataQueue,
StreamReader as StreamReader,
)
from .tracing import (
TraceConfig as TraceConfig,
TraceConnectionCreateEndParams as TraceConnectionCreateEndParams,
TraceConnectionCreateStartParams as TraceConnectionCreateStartParams,
TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams,
TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams,
TraceConnectionReuseconnParams as TraceConnectionReuseconnParams,
TraceDnsCacheHitParams as TraceDnsCacheHitParams,
TraceDnsCacheMissParams as TraceDnsCacheMissParams,
TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams,
TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams,
TraceRequestChunkSentParams as TraceRequestChunkSentParams,
TraceRequestEndParams as TraceRequestEndParams,
TraceRequestExceptionParams as TraceRequestExceptionParams,
TraceRequestHeadersSentParams as TraceRequestHeadersSentParams,
TraceRequestRedirectParams as TraceRequestRedirectParams,
TraceRequestStartParams as TraceRequestStartParams,
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
)
if TYPE_CHECKING:
# At runtime these are lazy-loaded at the bottom of the file.
from .worker import (
GunicornUVLoopWebWorker as GunicornUVLoopWebWorker,
GunicornWebWorker as GunicornWebWorker,
)
__all__: Tuple[str, ...] = (
"hdrs",
# client
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorDNSError",
"ClientConnectorError",
"ClientConnectorSSLError",
"ClientError",
"ClientHttpProxyError",
"ClientOSError",
"ClientPayloadError",
"ClientProxyConnectionError",
"ClientResponse",
"ClientRequest",
"ClientResponseError",
"ClientSSLError",
"ClientSession",
"ClientTimeout",
"ClientWebSocketResponse",
"ClientWSTimeout",
"ConnectionTimeoutError",
"ContentTypeError",
"Fingerprint",
"FlowControlDataQueue",
"InvalidURL",
"InvalidUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlClientError",
"NonHttpUrlRedirectClientError",
"RedirectClientError",
"RequestInfo",
"ServerConnectionError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
"SocketTimeoutError",
"TCPConnector",
"TooManyRedirects",
"UnixConnector",
"NamedPipeConnector",
"WSServerHandshakeError",
"request",
# cookiejar
"CookieJar",
"DummyCookieJar",
# formdata
"FormData",
# helpers
"BasicAuth",
"ChainMapProxy",
"ETag",
# http
"HttpVersion",
"HttpVersion10",
"HttpVersion11",
"WSMsgType",
"WSCloseCode",
"WSMessage",
"WebSocketError",
# multipart
"BadContentDispositionHeader",
"BadContentDispositionParam",
"BodyPartReader",
"MultipartReader",
"MultipartWriter",
"content_disposition_filename",
"parse_content_disposition",
# payload
"AsyncIterablePayload",
"BufferedReaderPayload",
"BytesIOPayload",
"BytesPayload",
"IOBasePayload",
"JsonPayload",
"PAYLOAD_REGISTRY",
"Payload",
"StringIOPayload",
"StringPayload",
"TextIOPayload",
"get_payload",
"payload_type",
# payload_streamer
"streamer",
# resolver
"AsyncResolver",
"DefaultResolver",
"ThreadedResolver",
# streams
"DataQueue",
"EMPTY_PAYLOAD",
"EofStream",
"StreamReader",
# tracing
"TraceConfig",
"TraceConnectionCreateEndParams",
"TraceConnectionCreateStartParams",
"TraceConnectionQueuedEndParams",
"TraceConnectionQueuedStartParams",
"TraceConnectionReuseconnParams",
"TraceDnsCacheHitParams",
"TraceDnsCacheMissParams",
"TraceDnsResolveHostEndParams",
"TraceDnsResolveHostStartParams",
"TraceRequestChunkSentParams",
"TraceRequestEndParams",
"TraceRequestExceptionParams",
"TraceRequestHeadersSentParams",
"TraceRequestRedirectParams",
"TraceRequestStartParams",
"TraceResponseChunkReceivedParams",
# workers (imported lazily with __getattr__)
"GunicornUVLoopWebWorker",
"GunicornWebWorker",
"WSMessageTypeError",
)
def __dir__() -> Tuple[str, ...]:
return __all__ + ("__doc__",)
def __getattr__(name: str) -> object:
global GunicornUVLoopWebWorker, GunicornWebWorker
# Importing gunicorn takes a long time (>100ms), so only import if actually needed.
if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"):
try:
from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw
except ImportError:
return None
GunicornUVLoopWebWorker = guv # type: ignore[misc]
GunicornWebWorker = gw # type: ignore[misc]
return guv if name == "GunicornUVLoopWebWorker" else gw
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -0,0 +1,158 @@
from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t
cdef extern from "../vendor/llhttp/build/llhttp.h":
struct llhttp__internal_s:
int32_t _index
void* _span_pos0
void* _span_cb0
int32_t error
const char* reason
const char* error_pos
void* data
void* _current
uint64_t content_length
uint8_t type
uint8_t method
uint8_t http_major
uint8_t http_minor
uint8_t header_state
uint8_t lenient_flags
uint8_t upgrade
uint8_t finish
uint16_t flags
uint16_t status_code
void* settings
ctypedef llhttp__internal_s llhttp__internal_t
ctypedef llhttp__internal_t llhttp_t
ctypedef int (*llhttp_data_cb)(llhttp_t*, const char *at, size_t length) except -1
ctypedef int (*llhttp_cb)(llhttp_t*) except -1
struct llhttp_settings_s:
llhttp_cb on_message_begin
llhttp_data_cb on_url
llhttp_data_cb on_status
llhttp_data_cb on_header_field
llhttp_data_cb on_header_value
llhttp_cb on_headers_complete
llhttp_data_cb on_body
llhttp_cb on_message_complete
llhttp_cb on_chunk_header
llhttp_cb on_chunk_complete
llhttp_cb on_url_complete
llhttp_cb on_status_complete
llhttp_cb on_header_field_complete
llhttp_cb on_header_value_complete
ctypedef llhttp_settings_s llhttp_settings_t
enum llhttp_errno:
HPE_OK,
HPE_INTERNAL,
HPE_STRICT,
HPE_LF_EXPECTED,
HPE_UNEXPECTED_CONTENT_LENGTH,
HPE_CLOSED_CONNECTION,
HPE_INVALID_METHOD,
HPE_INVALID_URL,
HPE_INVALID_CONSTANT,
HPE_INVALID_VERSION,
HPE_INVALID_HEADER_TOKEN,
HPE_INVALID_CONTENT_LENGTH,
HPE_INVALID_CHUNK_SIZE,
HPE_INVALID_STATUS,
HPE_INVALID_EOF_STATE,
HPE_INVALID_TRANSFER_ENCODING,
HPE_CB_MESSAGE_BEGIN,
HPE_CB_HEADERS_COMPLETE,
HPE_CB_MESSAGE_COMPLETE,
HPE_CB_CHUNK_HEADER,
HPE_CB_CHUNK_COMPLETE,
HPE_PAUSED,
HPE_PAUSED_UPGRADE,
HPE_USER
ctypedef llhttp_errno llhttp_errno_t
enum llhttp_flags:
F_CHUNKED,
F_CONTENT_LENGTH
enum llhttp_type:
HTTP_REQUEST,
HTTP_RESPONSE,
HTTP_BOTH
enum llhttp_method:
HTTP_DELETE,
HTTP_GET,
HTTP_HEAD,
HTTP_POST,
HTTP_PUT,
HTTP_CONNECT,
HTTP_OPTIONS,
HTTP_TRACE,
HTTP_COPY,
HTTP_LOCK,
HTTP_MKCOL,
HTTP_MOVE,
HTTP_PROPFIND,
HTTP_PROPPATCH,
HTTP_SEARCH,
HTTP_UNLOCK,
HTTP_BIND,
HTTP_REBIND,
HTTP_UNBIND,
HTTP_ACL,
HTTP_REPORT,
HTTP_MKACTIVITY,
HTTP_CHECKOUT,
HTTP_MERGE,
HTTP_MSEARCH,
HTTP_NOTIFY,
HTTP_SUBSCRIBE,
HTTP_UNSUBSCRIBE,
HTTP_PATCH,
HTTP_PURGE,
HTTP_MKCALENDAR,
HTTP_LINK,
HTTP_UNLINK,
HTTP_SOURCE,
HTTP_PRI,
HTTP_DESCRIBE,
HTTP_ANNOUNCE,
HTTP_SETUP,
HTTP_PLAY,
HTTP_PAUSE,
HTTP_TEARDOWN,
HTTP_GET_PARAMETER,
HTTP_SET_PARAMETER,
HTTP_REDIRECT,
HTTP_RECORD,
HTTP_FLUSH
ctypedef llhttp_method llhttp_method_t;
void llhttp_settings_init(llhttp_settings_t* settings)
void llhttp_init(llhttp_t* parser, llhttp_type type,
const llhttp_settings_t* settings)
llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len)
int llhttp_should_keep_alive(const llhttp_t* parser)
void llhttp_resume_after_upgrade(llhttp_t* parser)
llhttp_errno_t llhttp_get_errno(const llhttp_t* parser)
const char* llhttp_get_error_reason(const llhttp_t* parser)
const char* llhttp_get_error_pos(const llhttp_t* parser)
const char* llhttp_method_name(llhttp_method_t method)
void llhttp_set_lenient_headers(llhttp_t* parser, int enabled)
void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled)
void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled)

View File

@@ -0,0 +1,2 @@
cdef extern from "_find_header.h":
int find_header(char *, int)

View File

@@ -0,0 +1,83 @@
# The file is autogenerated from aiohttp/hdrs.py
# Run ./tools/gen.py to update it after the origin changing.
from . import hdrs
cdef tuple headers = (
hdrs.ACCEPT,
hdrs.ACCEPT_CHARSET,
hdrs.ACCEPT_ENCODING,
hdrs.ACCEPT_LANGUAGE,
hdrs.ACCEPT_RANGES,
hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
hdrs.ACCESS_CONTROL_ALLOW_METHODS,
hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
hdrs.ACCESS_CONTROL_MAX_AGE,
hdrs.ACCESS_CONTROL_REQUEST_HEADERS,
hdrs.ACCESS_CONTROL_REQUEST_METHOD,
hdrs.AGE,
hdrs.ALLOW,
hdrs.AUTHORIZATION,
hdrs.CACHE_CONTROL,
hdrs.CONNECTION,
hdrs.CONTENT_DISPOSITION,
hdrs.CONTENT_ENCODING,
hdrs.CONTENT_LANGUAGE,
hdrs.CONTENT_LENGTH,
hdrs.CONTENT_LOCATION,
hdrs.CONTENT_MD5,
hdrs.CONTENT_RANGE,
hdrs.CONTENT_TRANSFER_ENCODING,
hdrs.CONTENT_TYPE,
hdrs.COOKIE,
hdrs.DATE,
hdrs.DESTINATION,
hdrs.DIGEST,
hdrs.ETAG,
hdrs.EXPECT,
hdrs.EXPIRES,
hdrs.FORWARDED,
hdrs.FROM,
hdrs.HOST,
hdrs.IF_MATCH,
hdrs.IF_MODIFIED_SINCE,
hdrs.IF_NONE_MATCH,
hdrs.IF_RANGE,
hdrs.IF_UNMODIFIED_SINCE,
hdrs.KEEP_ALIVE,
hdrs.LAST_EVENT_ID,
hdrs.LAST_MODIFIED,
hdrs.LINK,
hdrs.LOCATION,
hdrs.MAX_FORWARDS,
hdrs.ORIGIN,
hdrs.PRAGMA,
hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.RANGE,
hdrs.REFERER,
hdrs.RETRY_AFTER,
hdrs.SEC_WEBSOCKET_ACCEPT,
hdrs.SEC_WEBSOCKET_EXTENSIONS,
hdrs.SEC_WEBSOCKET_KEY,
hdrs.SEC_WEBSOCKET_KEY1,
hdrs.SEC_WEBSOCKET_PROTOCOL,
hdrs.SEC_WEBSOCKET_VERSION,
hdrs.SERVER,
hdrs.SET_COOKIE,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.URI,
hdrs.UPGRADE,
hdrs.USER_AGENT,
hdrs.VARY,
hdrs.VIA,
hdrs.WWW_AUTHENTICATE,
hdrs.WANT_DIGEST,
hdrs.WARNING,
hdrs.X_FORWARDED_FOR,
hdrs.X_FORWARDED_HOST,
hdrs.X_FORWARDED_PROTO,
)

View File

@@ -0,0 +1,837 @@
#cython: language_level=3
#
# Based on https://github.com/MagicStack/httptools
#
from cpython cimport (
Py_buffer,
PyBUF_SIMPLE,
PyBuffer_Release,
PyBytes_AsString,
PyBytes_AsStringAndSize,
PyObject_GetBuffer,
)
from cpython.mem cimport PyMem_Free, PyMem_Malloc
from libc.limits cimport ULLONG_MAX
from libc.string cimport memcpy
from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy
from yarl import URL as _URL
from aiohttp import hdrs
from aiohttp.helpers import DEBUG, set_exception
from .http_exceptions import (
BadHttpMessage,
BadHttpMethod,
BadStatusLine,
ContentLengthError,
InvalidHeader,
InvalidURLError,
LineTooLong,
PayloadEncodingError,
TransferEncodingError,
)
from .http_parser import DeflateBuffer as _DeflateBuffer
from .http_writer import (
HttpVersion as _HttpVersion,
HttpVersion10 as _HttpVersion10,
HttpVersion11 as _HttpVersion11,
)
from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader
cimport cython
from aiohttp cimport _cparser as cparser
include "_headers.pxi"
from aiohttp cimport _find_header
ALLOWED_UPGRADES = frozenset({"websocket"})
DEF DEFAULT_FREELIST_SIZE = 250
cdef extern from "Python.h":
int PyByteArray_Resize(object, Py_ssize_t) except -1
Py_ssize_t PyByteArray_Size(object) except -1
char* PyByteArray_AsString(object)
__all__ = ('HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage')
cdef object URL = _URL
cdef object URL_build = URL.build
cdef object CIMultiDict = _CIMultiDict
cdef object CIMultiDictProxy = _CIMultiDictProxy
cdef object HttpVersion = _HttpVersion
cdef object HttpVersion10 = _HttpVersion10
cdef object HttpVersion11 = _HttpVersion11
cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1
cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING
cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD
cdef object StreamReader = _StreamReader
cdef object DeflateBuffer = _DeflateBuffer
cdef bytes EMPTY_BYTES = b""
cdef inline object extend(object buf, const char* at, size_t length):
cdef Py_ssize_t s
cdef char* ptr
s = PyByteArray_Size(buf)
PyByteArray_Resize(buf, s + length)
ptr = PyByteArray_AsString(buf)
memcpy(ptr + s, at, length)
DEF METHODS_COUNT = 46;
cdef list _http_method = []
for i in range(METHODS_COUNT):
_http_method.append(
cparser.llhttp_method_name(<cparser.llhttp_method_t> i).decode('ascii'))
cdef inline str http_method_str(int i):
if i < METHODS_COUNT:
return <str>_http_method[i]
else:
return "<unknown>"
cdef inline object find_header(bytes raw_header):
cdef Py_ssize_t size
cdef char *buf
cdef int idx
PyBytes_AsStringAndSize(raw_header, &buf, &size)
idx = _find_header.find_header(buf, size)
if idx == -1:
return raw_header.decode('utf-8', 'surrogateescape')
return headers[idx]
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawRequestMessage:
cdef readonly str method
cdef readonly str path
cdef readonly object version # HttpVersion
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
cdef readonly object url # yarl.URL
def __init__(self, method, path, version, headers, raw_headers,
should_close, compression, upgrade, chunked, url):
self.method = method
self.path = path
self.version = version
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
self.url = url
def __repr__(self):
info = []
info.append(("method", self.method))
info.append(("path", self.path))
info.append(("version", self.version))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
info.append(("url", self.url))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawRequestMessage(' + sinfo + ')>'
def _replace(self, **dct):
cdef RawRequestMessage ret
ret = _new_request_message(self.method,
self.path,
self.version,
self.headers,
self.raw_headers,
self.should_close,
self.compression,
self.upgrade,
self.chunked,
self.url)
if "method" in dct:
ret.method = dct["method"]
if "path" in dct:
ret.path = dct["path"]
if "version" in dct:
ret.version = dct["version"]
if "headers" in dct:
ret.headers = dct["headers"]
if "raw_headers" in dct:
ret.raw_headers = dct["raw_headers"]
if "should_close" in dct:
ret.should_close = dct["should_close"]
if "compression" in dct:
ret.compression = dct["compression"]
if "upgrade" in dct:
ret.upgrade = dct["upgrade"]
if "chunked" in dct:
ret.chunked = dct["chunked"]
if "url" in dct:
ret.url = dct["url"]
return ret
cdef _new_request_message(str method,
str path,
object version,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked,
object url):
cdef RawRequestMessage ret
ret = RawRequestMessage.__new__(RawRequestMessage)
ret.method = method
ret.path = path
ret.version = version
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
ret.url = url
return ret
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawResponseMessage:
cdef readonly object version # HttpVersion
cdef readonly int code
cdef readonly str reason
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
def __init__(self, version, code, reason, headers, raw_headers,
should_close, compression, upgrade, chunked):
self.version = version
self.code = code
self.reason = reason
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
def __repr__(self):
info = []
info.append(("version", self.version))
info.append(("code", self.code))
info.append(("reason", self.reason))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawResponseMessage(' + sinfo + ')>'
cdef _new_response_message(object version,
int code,
str reason,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked):
cdef RawResponseMessage ret
ret = RawResponseMessage.__new__(RawResponseMessage)
ret.version = version
ret.code = code
ret.reason = reason
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
return ret
@cython.internal
cdef class HttpParser:
cdef:
cparser.llhttp_t* _cparser
cparser.llhttp_settings_t* _csettings
bytes _raw_name
object _name
bytes _raw_value
bint _has_value
object _protocol
object _loop
object _timer
size_t _max_line_size
size_t _max_field_size
size_t _max_headers
bint _response_with_body
bint _read_until_eof
bint _started
object _url
bytearray _buf
str _path
str _reason
list _headers
list _raw_headers
bint _upgraded
list _messages
object _payload
bint _payload_error
object _payload_exception
object _last_error
bint _auto_decompress
int _limit
str _content_encoding
Py_buffer py_buf
def __cinit__(self):
self._cparser = <cparser.llhttp_t*> \
PyMem_Malloc(sizeof(cparser.llhttp_t))
if self._cparser is NULL:
raise MemoryError()
self._csettings = <cparser.llhttp_settings_t*> \
PyMem_Malloc(sizeof(cparser.llhttp_settings_t))
if self._csettings is NULL:
raise MemoryError()
def __dealloc__(self):
PyMem_Free(self._cparser)
PyMem_Free(self._csettings)
cdef _init(
self, cparser.llhttp_type mode,
object protocol, object loop, int limit,
object timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True,
):
cparser.llhttp_settings_init(self._csettings)
cparser.llhttp_init(self._cparser, mode, self._csettings)
self._cparser.data = <void*>self
self._cparser.content_length = 0
self._protocol = protocol
self._loop = loop
self._timer = timer
self._buf = bytearray()
self._payload = None
self._payload_error = 0
self._payload_exception = payload_exception
self._messages = []
self._raw_name = EMPTY_BYTES
self._raw_value = EMPTY_BYTES
self._has_value = False
self._max_line_size = max_line_size
self._max_headers = max_headers
self._max_field_size = max_field_size
self._response_with_body = response_with_body
self._read_until_eof = read_until_eof
self._upgraded = False
self._auto_decompress = auto_decompress
self._content_encoding = None
self._csettings.on_url = cb_on_url
self._csettings.on_status = cb_on_status
self._csettings.on_header_field = cb_on_header_field
self._csettings.on_header_value = cb_on_header_value
self._csettings.on_headers_complete = cb_on_headers_complete
self._csettings.on_body = cb_on_body
self._csettings.on_message_begin = cb_on_message_begin
self._csettings.on_message_complete = cb_on_message_complete
self._csettings.on_chunk_header = cb_on_chunk_header
self._csettings.on_chunk_complete = cb_on_chunk_complete
self._last_error = None
self._limit = limit
cdef _process_header(self):
cdef str value
if self._raw_name is not EMPTY_BYTES:
name = find_header(self._raw_name)
value = self._raw_value.decode('utf-8', 'surrogateescape')
self._headers.append((name, value))
if name is CONTENT_ENCODING:
self._content_encoding = value
self._has_value = False
self._raw_headers.append((self._raw_name, self._raw_value))
self._raw_name = EMPTY_BYTES
self._raw_value = EMPTY_BYTES
cdef _on_header_field(self, char* at, size_t length):
if self._has_value:
self._process_header()
if self._raw_name is EMPTY_BYTES:
self._raw_name = at[:length]
else:
self._raw_name += at[:length]
cdef _on_header_value(self, char* at, size_t length):
if self._raw_value is EMPTY_BYTES:
self._raw_value = at[:length]
else:
self._raw_value += at[:length]
self._has_value = True
cdef _on_headers_complete(self):
self._process_header()
should_close = not cparser.llhttp_should_keep_alive(self._cparser)
upgrade = self._cparser.upgrade
chunked = self._cparser.flags & cparser.F_CHUNKED
raw_headers = tuple(self._raw_headers)
headers = CIMultiDictProxy(CIMultiDict(self._headers))
if self._cparser.type == cparser.HTTP_REQUEST:
allowed = upgrade and headers.get("upgrade", "").lower() in ALLOWED_UPGRADES
if allowed or self._cparser.method == cparser.HTTP_CONNECT:
self._upgraded = True
else:
if upgrade and self._cparser.status_code == 101:
self._upgraded = True
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
encoding = None
enc = self._content_encoding
if enc is not None:
self._content_encoding = None
enc = enc.lower()
if enc in ('gzip', 'deflate', 'br'):
encoding = enc
if self._cparser.type == cparser.HTTP_REQUEST:
method = http_method_str(self._cparser.method)
msg = _new_request_message(
method, self._path,
self.http_version(), headers, raw_headers,
should_close, encoding, upgrade, chunked, self._url)
else:
msg = _new_response_message(
self.http_version(), self._cparser.status_code, self._reason,
headers, raw_headers, should_close, encoding,
upgrade, chunked)
if (
ULLONG_MAX > self._cparser.content_length > 0 or chunked or
self._cparser.method == cparser.HTTP_CONNECT or
(self._cparser.status_code >= 199 and
self._cparser.content_length == 0 and
self._read_until_eof)
):
payload = StreamReader(
self._protocol, timer=self._timer, loop=self._loop,
limit=self._limit)
else:
payload = EMPTY_PAYLOAD
self._payload = payload
if encoding is not None and self._auto_decompress:
self._payload = DeflateBuffer(payload, encoding)
if not self._response_with_body:
payload = EMPTY_PAYLOAD
self._messages.append((msg, payload))
cdef _on_message_complete(self):
self._payload.feed_eof()
self._payload = None
cdef _on_chunk_header(self):
self._payload.begin_http_chunk_receiving()
cdef _on_chunk_complete(self):
self._payload.end_http_chunk_receiving()
cdef object _on_status_complete(self):
pass
cdef inline http_version(self):
cdef cparser.llhttp_t* parser = self._cparser
if parser.http_major == 1:
if parser.http_minor == 0:
return HttpVersion10
elif parser.http_minor == 1:
return HttpVersion11
return HttpVersion(parser.http_major, parser.http_minor)
### Public API ###
def feed_eof(self):
cdef bytes desc
if self._payload is not None:
if self._cparser.flags & cparser.F_CHUNKED:
raise TransferEncodingError(
"Not enough data for satisfy transfer length header.")
elif self._cparser.flags & cparser.F_CONTENT_LENGTH:
raise ContentLengthError(
"Not enough data for satisfy content length header.")
elif cparser.llhttp_get_errno(self._cparser) != cparser.HPE_OK:
desc = cparser.llhttp_get_error_reason(self._cparser)
raise PayloadEncodingError(desc.decode('latin-1'))
else:
self._payload.feed_eof()
elif self._started:
self._on_headers_complete()
if self._messages:
return self._messages[-1][0]
def feed_data(self, data):
cdef:
size_t data_len
size_t nb
cdef cparser.llhttp_errno_t errno
PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE)
data_len = <size_t>self.py_buf.len
errno = cparser.llhttp_execute(
self._cparser,
<char*>self.py_buf.buf,
data_len)
if errno is cparser.HPE_PAUSED_UPGRADE:
cparser.llhttp_resume_after_upgrade(self._cparser)
nb = cparser.llhttp_get_error_pos(self._cparser) - <char*>self.py_buf.buf
PyBuffer_Release(&self.py_buf)
if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE):
if self._payload_error == 0:
if self._last_error is not None:
ex = self._last_error
self._last_error = None
else:
after = cparser.llhttp_get_error_pos(self._cparser)
before = data[:after - <char*>self.py_buf.buf]
after_b = after.split(b"\r\n", 1)[0]
before = before.rsplit(b"\r\n", 1)[-1]
data = before + after_b
pointer = " " * (len(repr(before))-1) + "^"
ex = parser_error_from_errno(self._cparser, data, pointer)
self._payload = None
raise ex
if self._messages:
messages = self._messages
self._messages = []
else:
messages = ()
if self._upgraded:
return messages, True, data[nb:]
else:
return messages, False, b""
def set_upgraded(self, val):
self._upgraded = val
cdef class HttpRequestParser(HttpParser):
def __init__(
self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True,
):
self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
cdef object _on_status_complete(self):
cdef int idx1, idx2
if not self._buf:
return
self._path = self._buf.decode('utf-8', 'surrogateescape')
try:
idx3 = len(self._path)
if self._cparser.method == cparser.HTTP_CONNECT:
# authority-form,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3
self._url = URL.build(authority=self._path, encoded=True)
elif idx3 > 1 and self._path[0] == '/':
# origin-form,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1
idx1 = self._path.find("?")
if idx1 == -1:
query = ""
idx2 = self._path.find("#")
if idx2 == -1:
path = self._path
fragment = ""
else:
path = self._path[0: idx2]
fragment = self._path[idx2+1:]
else:
path = self._path[0:idx1]
idx1 += 1
idx2 = self._path.find("#", idx1+1)
if idx2 == -1:
query = self._path[idx1:]
fragment = ""
else:
query = self._path[idx1: idx2]
fragment = self._path[idx2+1:]
self._url = URL.build(
path=path,
query_string=query,
fragment=fragment,
encoded=True,
)
else:
# absolute-form for proxy maybe,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2
self._url = URL(self._path, encoded=True)
finally:
PyByteArray_Resize(self._buf, 0)
cdef class HttpResponseParser(HttpParser):
def __init__(
self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True
):
self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
# Use strict parsing on dev mode, so users are warned about broken servers.
if not DEBUG:
cparser.llhttp_set_lenient_headers(self._cparser, 1)
cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1)
cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1)
cdef object _on_status_complete(self):
if self._buf:
self._reason = self._buf.decode('utf-8', 'surrogateescape')
PyByteArray_Resize(self._buf, 0)
else:
self._reason = self._reason or ''
cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
pyparser._started = True
pyparser._headers = []
pyparser._raw_headers = []
PyByteArray_Resize(pyparser._buf, 0)
pyparser._path = None
pyparser._reason = None
return 0
cdef int cb_on_url(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
if length > pyparser._max_line_size:
raise LineTooLong(
'Status line is too long', pyparser._max_line_size, length)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_status(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef str reason
try:
if length > pyparser._max_line_size:
raise LineTooLong(
'Status line is too long', pyparser._max_line_size, length)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_field(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
pyparser._on_status_complete()
size = len(pyparser._raw_name) + length
if size > pyparser._max_field_size:
raise LineTooLong(
'Header name is too long', pyparser._max_field_size, size)
pyparser._on_header_field(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_value(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
size = len(pyparser._raw_value) + length
if size > pyparser._max_field_size:
raise LineTooLong(
'Header value is too long', pyparser._max_field_size, size)
pyparser._on_header_value(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_status_complete()
pyparser._on_headers_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT:
return 2
else:
return 0
cdef int cb_on_body(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef bytes body = at[:length]
try:
pyparser._payload.feed_data(body, length)
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if pyparser._payload_exception is not None:
reraised_exc = pyparser._payload_exception(str(underlying_exc))
set_exception(pyparser._payload, reraised_exc, underlying_exc)
pyparser._payload_error = 1
return -1
else:
return 0
cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._started = False
pyparser._on_message_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_header(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_header()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer):
cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser)
cdef bytes desc = cparser.llhttp_get_error_reason(parser)
err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer)
if errno in {cparser.HPE_CB_MESSAGE_BEGIN,
cparser.HPE_CB_HEADERS_COMPLETE,
cparser.HPE_CB_MESSAGE_COMPLETE,
cparser.HPE_CB_CHUNK_HEADER,
cparser.HPE_CB_CHUNK_COMPLETE,
cparser.HPE_INVALID_CONSTANT,
cparser.HPE_INVALID_HEADER_TOKEN,
cparser.HPE_INVALID_CONTENT_LENGTH,
cparser.HPE_INVALID_CHUNK_SIZE,
cparser.HPE_INVALID_EOF_STATE,
cparser.HPE_INVALID_TRANSFER_ENCODING}:
return BadHttpMessage(err_msg)
elif errno == cparser.HPE_INVALID_METHOD:
return BadHttpMethod(error=err_msg)
elif errno in {cparser.HPE_INVALID_STATUS,
cparser.HPE_INVALID_VERSION}:
return BadStatusLine(error=err_msg)
elif errno == cparser.HPE_INVALID_URL:
return InvalidURLError(err_msg)
return BadHttpMessage(err_msg)

View File

@@ -0,0 +1,160 @@
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.exc cimport PyErr_NoMemory
from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc
from cpython.object cimport PyObject_Str
from libc.stdint cimport uint8_t, uint64_t
from libc.string cimport memcpy
from multidict import istr
DEF BUF_SIZE = 16 * 1024 # 16KiB
cdef char BUFFER[BUF_SIZE]
cdef object _istr = istr
# ----------------- writer ---------------------------
cdef struct Writer:
char *buf
Py_ssize_t size
Py_ssize_t pos
cdef inline void _init_writer(Writer* writer):
writer.buf = &BUFFER[0]
writer.size = BUF_SIZE
writer.pos = 0
cdef inline void _release_writer(Writer* writer):
if writer.buf != BUFFER:
PyMem_Free(writer.buf)
cdef inline int _write_byte(Writer* writer, uint8_t ch):
cdef char * buf
cdef Py_ssize_t size
if writer.pos == writer.size:
# reallocate
size = writer.size + BUF_SIZE
if writer.buf == BUFFER:
buf = <char*>PyMem_Malloc(size)
if buf == NULL:
PyErr_NoMemory()
return -1
memcpy(buf, writer.buf, writer.size)
else:
buf = <char*>PyMem_Realloc(writer.buf, size)
if buf == NULL:
PyErr_NoMemory()
return -1
writer.buf = buf
writer.size = size
writer.buf[writer.pos] = <char>ch
writer.pos += 1
return 0
cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol):
cdef uint64_t utf = <uint64_t> symbol
if utf < 0x80:
return _write_byte(writer, <uint8_t>utf)
elif utf < 0x800:
if _write_byte(writer, <uint8_t>(0xc0 | (utf >> 6))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif 0xD800 <= utf <= 0xDFFF:
# surogate pair, ignored
return 0
elif utf < 0x10000:
if _write_byte(writer, <uint8_t>(0xe0 | (utf >> 12))) < 0:
return -1
if _write_byte(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif utf > 0x10FFFF:
# symbol is too large
return 0
else:
if _write_byte(writer, <uint8_t>(0xf0 | (utf >> 18))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 12) & 0x3f))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
cdef inline int _write_str(Writer* writer, str s):
cdef Py_UCS4 ch
for ch in s:
if _write_utf8(writer, ch) < 0:
return -1
cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s):
cdef Py_UCS4 ch
cdef str out_str
if type(s) is str:
out_str = <str>s
elif type(s) is _istr:
out_str = PyObject_Str(s)
elif not isinstance(s, str):
raise TypeError("Cannot serialize non-str key {!r}".format(s))
else:
out_str = str(s)
for ch in out_str:
if ch == 0x0D or ch == 0x0A:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
if _write_utf8(writer, ch) < 0:
return -1
# --------------- _serialize_headers ----------------------
def _serialize_headers(str status_line, headers):
cdef Writer writer
cdef object key
cdef object val
_init_writer(&writer)
try:
if _write_str(&writer, status_line) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
for key, val in headers.items():
if _write_str_raise_on_nlcr(&writer, key) < 0:
raise
if _write_byte(&writer, b':') < 0:
raise
if _write_byte(&writer, b' ') < 0:
raise
if _write_str_raise_on_nlcr(&writer, val) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
return PyBytes_FromStringAndSize(writer.buf, writer.pos)
finally:
_release_writer(&writer)

View File

@@ -0,0 +1 @@
e354dd499be171b6125bf56bc3b6c5e2bff2a28af69e3b5d699ddb9af2bafa3c

View File

@@ -0,0 +1 @@
468edd38ebf8dc7000a8d333df1c82035d69a5c9febc0448be3c9c4ad4c4630c

View File

@@ -0,0 +1 @@
1cd3a5e20456b4d04d11835b2bd3c639f14443052a2467b105b0ca07fdb4b25d

View File

@@ -0,0 +1 @@
"""WebSocket protocol versions 13 and 8."""

View File

@@ -0,0 +1,147 @@
"""Helpers for WebSocket protocol versions 13 and 8."""
import functools
import re
from struct import Struct
from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple
from ..helpers import NO_EXTENSIONS
from .models import WSHandshakeError
UNPACK_LEN3 = Struct("!Q").unpack_from
UNPACK_CLOSE_CODE = Struct("!H").unpack
PACK_LEN1 = Struct("!BB").pack
PACK_LEN2 = Struct("!BBH").pack
PACK_LEN3 = Struct("!BBQ").pack
PACK_CLOSE_CODE = Struct("!H").pack
PACK_RANDBITS = Struct("!L").pack
MSG_SIZE: Final[int] = 2**14
MASK_LEN: Final[int] = 4
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
# Used by _websocket_mask_python
@functools.lru_cache
def _xor_table() -> List[bytes]:
return [bytes(a ^ b for a in range(256)) for b in range(256)]
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
object of any length. The contents of `data` are masked with `mask`,
as specified in section 5.3 of RFC 6455.
Note that this function mutates the `data` argument.
This pure-python implementation may be replaced by an optimized
version when available.
"""
assert isinstance(data, bytearray), data
assert len(mask) == 4, mask
if data:
_XOR_TABLE = _xor_table()
a, b, c, d = (_XOR_TABLE[n] for n in mask)
data[::4] = data[::4].translate(a)
data[1::4] = data[1::4].translate(b)
data[2::4] = data[2::4].translate(c)
data[3::4] = data[3::4].translate(d)
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
websocket_mask = _websocket_mask_python
else:
try:
from .mask import _websocket_mask_cython # type: ignore[import-not-found]
websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
websocket_mask = _websocket_mask_python
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
r"^(?:;\s*(?:"
r"(server_no_context_takeover)|"
r"(client_no_context_takeover)|"
r"(server_max_window_bits(?:=(\d+))?)|"
r"(client_max_window_bits(?:=(\d+))?)))*$"
)
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
if not extstr:
return 0, False
compress = 0
notakeover = False
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
defext = ext.group(1)
# Return compress = 15 when get `permessage-deflate`
if not defext:
compress = 15
break
match = _WS_EXT_RE.match(defext)
if match:
compress = 15
if isserver:
# Server never fail to detect compress handshake.
# Server does not need to send max wbit to client
if match.group(4):
compress = int(match.group(4))
# Group3 must match if group4 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# CONTINUE to next extension
if compress > 15 or compress < 9:
compress = 0
continue
if match.group(1):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
else:
if match.group(6):
compress = int(match.group(6))
# Group5 must match if group6 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# FAIL the parse progress
if compress > 15 or compress < 9:
raise WSHandshakeError("Invalid window size")
if match.group(2):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
# Return Fail if client side and not match
elif not isserver:
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
return compress, notakeover
def ws_ext_gen(
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
) -> str:
# client_notakeover=False not used for server
# compress wbit 8 does not support in zlib
if compress < 9 or compress > 15:
raise ValueError(
"Compress wbits must between 9 and 15, zlib does not support wbits=8"
)
enabledext = ["permessage-deflate"]
if not isserver:
enabledext.append("client_max_window_bits")
if compress < 15:
enabledext.append("server_max_window_bits=" + str(compress))
if server_notakeover:
enabledext.append("server_no_context_takeover")
# if client_notakeover:
# enabledext.append('client_no_context_takeover')
return "; ".join(enabledext)

View File

@@ -0,0 +1,3 @@
"""Cython declarations for websocket masking."""
cpdef void _websocket_mask_cython(bytes mask, bytearray data)

View File

@@ -0,0 +1,48 @@
from cpython cimport PyBytes_AsString
#from cpython cimport PyByteArray_AsString # cython still not exports that
cdef extern from "Python.h":
char* PyByteArray_AsString(bytearray ba) except NULL
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
cpdef void _websocket_mask_cython(bytes mask, bytearray data):
"""Note, this function mutates its `data` argument
"""
cdef:
Py_ssize_t data_len, i
# bit operations on signed integers are implementation-specific
unsigned char * in_buf
const unsigned char * mask_buf
uint32_t uint32_msk
uint64_t uint64_msk
assert len(mask) == 4
data_len = len(data)
in_buf = <unsigned char*>PyByteArray_AsString(data)
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
uint32_msk = (<uint32_t*>mask_buf)[0]
# TODO: align in_data ptr to achieve even faster speeds
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
if sizeof(size_t) >= 8:
uint64_msk = uint32_msk
uint64_msk = (uint64_msk << 32) | uint32_msk
while data_len >= 8:
(<uint64_t*>in_buf)[0] ^= uint64_msk
in_buf += 8
data_len -= 8
while data_len >= 4:
(<uint32_t*>in_buf)[0] ^= uint32_msk
in_buf += 4
data_len -= 4
for i in range(0, data_len):
in_buf[i] ^= mask_buf[i]

View File

@@ -0,0 +1,84 @@
"""Models for WebSocket protocol versions 13 and 8."""
import json
from enum import IntEnum
from typing import Any, Callable, Final, NamedTuple, Optional, cast
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
class WSCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
ABNORMAL_CLOSURE = 1006
INVALID_TEXT = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
class WSMsgType(IntEnum):
# websocket spec types
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
PING = 0x9
PONG = 0xA
CLOSE = 0x8
# aiohttp specific types
CLOSING = 0x100
CLOSED = 0x101
ERROR = 0x102
text = TEXT
binary = BINARY
ping = PING
pong = PONG
close = CLOSE
closing = CLOSING
closed = CLOSED
error = ERROR
class WSMessage(NamedTuple):
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
data: Any
extra: Optional[str]
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data.
.. versionadded:: 0.22
"""
return loads(self.data)
# Constructing the tuple directly to avoid the overhead of
# the lambda and arg processing since NamedTuples are constructed
# with a run time built lambda
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
class WebSocketError(Exception):
"""WebSocket protocol parser error."""
def __init__(self, code: int, message: str) -> None:
self.code = code
super().__init__(code, message)
def __str__(self) -> str:
return cast(str, self.args[1])
class WSHandshakeError(Exception):
"""WebSocket protocol handshake error."""

View File

@@ -0,0 +1,31 @@
"""Reader for WebSocket protocol versions 13 and 8."""
from typing import TYPE_CHECKING
from ..helpers import NO_EXTENSIONS
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)
WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython
else:
try:
from .reader_c import ( # type: ignore[import-not-found]
WebSocketDataQueue as WebSocketDataQueueCython,
WebSocketReader as WebSocketReaderCython,
)
WebSocketReader = WebSocketReaderCython
WebSocketDataQueue = WebSocketDataQueueCython
except ImportError: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)
WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython

View File

@@ -0,0 +1,110 @@
import cython
from .mask cimport _websocket_mask_cython as websocket_mask
cdef unsigned int READ_HEADER
cdef unsigned int READ_PAYLOAD_LENGTH
cdef unsigned int READ_PAYLOAD_MASK
cdef unsigned int READ_PAYLOAD
cdef int OP_CODE_NOT_SET
cdef int OP_CODE_CONTINUATION
cdef int OP_CODE_TEXT
cdef int OP_CODE_BINARY
cdef int OP_CODE_CLOSE
cdef int OP_CODE_PING
cdef int OP_CODE_PONG
cdef int COMPRESSED_NOT_SET
cdef int COMPRESSED_FALSE
cdef int COMPRESSED_TRUE
cdef object UNPACK_LEN3
cdef object UNPACK_CLOSE_CODE
cdef object TUPLE_NEW
cdef object WSMsgType
cdef object WSMessage
cdef object WS_MSG_TYPE_TEXT
cdef object WS_MSG_TYPE_BINARY
cdef set ALLOWED_CLOSE_CODES
cdef set MESSAGE_TYPES_WITH_CONTENT
cdef tuple EMPTY_FRAME
cdef tuple EMPTY_FRAME_ERROR
cdef class WebSocketDataQueue:
cdef unsigned int _size
cdef public object _protocol
cdef unsigned int _limit
cdef object _loop
cdef bint _eof
cdef object _waiter
cdef object _exception
cdef public object _buffer
cdef object _get_buffer
cdef object _put_buffer
cdef void _release_waiter(self)
cpdef void feed_data(self, object data, unsigned int size)
@cython.locals(size="unsigned int")
cdef _read_from_buffer(self)
cdef class WebSocketReader:
cdef WebSocketDataQueue queue
cdef unsigned int _max_msg_size
cdef Exception _exc
cdef bytearray _partial
cdef unsigned int _state
cdef int _opcode
cdef bint _frame_fin
cdef int _frame_opcode
cdef list _payload_fragments
cdef Py_ssize_t _frame_payload_len
cdef bytes _tail
cdef bint _has_mask
cdef bytes _frame_mask
cdef Py_ssize_t _payload_bytes_to_read
cdef unsigned int _payload_len_flag
cdef int _compressed
cdef object _decompressobj
cdef bint _compress
cpdef tuple feed_data(self, object data)
@cython.locals(
is_continuation=bint,
fin=bint,
has_partial=bint,
payload_merged=bytes,
)
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
@cython.locals(
start_pos=Py_ssize_t,
data_len=Py_ssize_t,
length=Py_ssize_t,
chunk_size=Py_ssize_t,
chunk_len=Py_ssize_t,
data_len=Py_ssize_t,
data_cstr="const unsigned char *",
first_byte="unsigned char",
second_byte="unsigned char",
f_start_pos=Py_ssize_t,
f_end_pos=Py_ssize_t,
has_mask=bint,
fin=bint,
had_fragments=Py_ssize_t,
payload_bytearray=bytearray,
)
cpdef void _feed_data(self, bytes data) except *

View File

@@ -0,0 +1,468 @@
"""Reader for WebSocket protocol versions 13 and 8."""
import asyncio
import builtins
from collections import deque
from typing import Deque, Final, Optional, Set, Tuple, Union
from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value
EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")
COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1
TUPLE_NEW = tuple.__new__
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Union[BaseException, None] = None
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append
def is_eof(self) -> bool:
return self._eof
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(
self,
exc: "BaseException",
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)
def feed_eof(self) -> None:
self._eof = True
self._release_waiter()
self._exception = None # Break cyclic references
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> WSMessage:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()
def _read_from_buffer(self) -> WSMessage:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._exc: Optional[Exception] = None
self._partial = bytearray()
self._state = READ_HEADER
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
# data can be bytearray on Windows because proactor event loop uses bytearray
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
# coerce data to bytes if it is not
def feed_data(
self, data: Union[bytes, bytearray, memoryview]
) -> Tuple[bool, bytes]:
if type(data) is not bytes:
data = bytes(data)
if self._exc is not None:
return True, data
try:
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return EMPTY_FRAME_ERROR
return EMPTY_FRAME
def _handle_frame(
self,
fin: bool,
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
payload: Union[bytes, bytearray],
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
) -> None:
msg: WSMessage
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
# load text/binary
if not fin:
# got partial frame payload
if opcode != OP_CODE_CONTINUATION:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(self._partial)} "
f"exceeds limit {self._max_msg_size}",
)
return
has_partial = bool(self._partial)
if opcode == OP_CODE_CONTINUATION:
if self._opcode == OP_CODE_NOT_SET:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
opcode = self._opcode
self._opcode = OP_CODE_NOT_SET
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
f"to be zero, got {opcode!r}",
)
assembled_payload: Union[bytes, bytearray]
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(assembled_payload)} "
f"exceeds limit {self._max_msg_size}",
)
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message size {self._max_msg_size + left}"
f" exceeds limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {close_code}",
)
try:
close_message = payload[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
def _feed_data(self, data: bytes) -> None:
"""Return the next frame from the socket."""
if self._tail:
data, self._tail = self._tail + data, b""
start_pos: int = 0
data_len = len(data)
data_cstr = data
while True:
# read header
if self._state == READ_HEADER:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH
# read payload length
if self._state == READ_PAYLOAD_LENGTH:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_bytes_to_read = len_flag
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD
if self._state == READ_PAYLOAD:
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0
had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos
if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break
payload: Union[bytes, bytearray]
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray(b"".join(self._payload_fragments))
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]
self._handle_frame(
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload_len = 0
self._state = READ_HEADER
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

View File

@@ -0,0 +1,468 @@
"""Reader for WebSocket protocol versions 13 and 8."""
import asyncio
import builtins
from collections import deque
from typing import Deque, Final, Optional, Set, Tuple, Union
from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value
EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")
COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1
TUPLE_NEW = tuple.__new__
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Union[BaseException, None] = None
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append
def is_eof(self) -> bool:
return self._eof
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(
self,
exc: "BaseException",
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)
def feed_eof(self) -> None:
self._eof = True
self._release_waiter()
self._exception = None # Break cyclic references
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> WSMessage:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()
def _read_from_buffer(self) -> WSMessage:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._exc: Optional[Exception] = None
self._partial = bytearray()
self._state = READ_HEADER
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
# data can be bytearray on Windows because proactor event loop uses bytearray
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
# coerce data to bytes if it is not
def feed_data(
self, data: Union[bytes, bytearray, memoryview]
) -> Tuple[bool, bytes]:
if type(data) is not bytes:
data = bytes(data)
if self._exc is not None:
return True, data
try:
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return EMPTY_FRAME_ERROR
return EMPTY_FRAME
def _handle_frame(
self,
fin: bool,
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
payload: Union[bytes, bytearray],
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
) -> None:
msg: WSMessage
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
# load text/binary
if not fin:
# got partial frame payload
if opcode != OP_CODE_CONTINUATION:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(self._partial)} "
f"exceeds limit {self._max_msg_size}",
)
return
has_partial = bool(self._partial)
if opcode == OP_CODE_CONTINUATION:
if self._opcode == OP_CODE_NOT_SET:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
opcode = self._opcode
self._opcode = OP_CODE_NOT_SET
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
f"to be zero, got {opcode!r}",
)
assembled_payload: Union[bytes, bytearray]
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(assembled_payload)} "
f"exceeds limit {self._max_msg_size}",
)
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message size {self._max_msg_size + left}"
f" exceeds limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {close_code}",
)
try:
close_message = payload[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
def _feed_data(self, data: bytes) -> None:
"""Return the next frame from the socket."""
if self._tail:
data, self._tail = self._tail + data, b""
start_pos: int = 0
data_len = len(data)
data_cstr = data
while True:
# read header
if self._state == READ_HEADER:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH
# read payload length
if self._state == READ_PAYLOAD_LENGTH:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_bytes_to_read = len_flag
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD
if self._state == READ_PAYLOAD:
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0
had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos
if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break
payload: Union[bytes, bytearray]
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray(b"".join(self._payload_fragments))
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]
self._handle_frame(
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload_len = 0
self._state = READ_HEADER
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

View File

@@ -0,0 +1,177 @@
"""WebSocket protocol versions 13 and 8."""
import asyncio
import random
import zlib
from functools import partial
from typing import Any, Final, Optional, Union
from ..base_protocol import BaseProtocol
from ..client_exceptions import ClientConnectionResetError
from ..compression_utils import ZLibCompressor
from .helpers import (
MASK_LEN,
MSG_SIZE,
PACK_CLOSE_CODE,
PACK_LEN1,
PACK_LEN2,
PACK_LEN3,
PACK_RANDBITS,
websocket_mask,
)
from .models import WS_DEFLATE_TRAILING, WSMsgType
DEFAULT_LIMIT: Final[int] = 2**16
# For websockets, keeping latency low is extremely important as implementations
# generally expect to be able to send and receive messages quickly. We use a
# larger chunk size than the default to reduce the number of executor calls
# since the executor is a significant source of latency and overhead when
# the chunks are small. A size of 5KiB was chosen because it is also the
# same value python-zlib-ng choose to use as the threshold to release the GIL.
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024
class WebSocketWriter:
"""WebSocket writer.
The writer is responsible for sending messages to the client. It is
created by the protocol when a connection is established. The writer
should avoid implementing any application logic and should only be
concerned with the low-level details of the WebSocket protocol.
"""
def __init__(
self,
protocol: BaseProtocol,
transport: asyncio.Transport,
*,
use_mask: bool = False,
limit: int = DEFAULT_LIMIT,
random: random.Random = random.Random(),
compress: int = 0,
notakeover: bool = False,
) -> None:
"""Initialize a WebSocket writer."""
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.get_random_bits = partial(random.getrandbits, 32)
self.compress = compress
self.notakeover = notakeover
self._closing = False
self._limit = limit
self._output_size = 0
self._compressobj: Any = None # actually compressobj
async def send_frame(
self, message: bytes, opcode: int, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ClientConnectionResetError("Cannot write to closing transport")
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
rsv = 0
# Only compress larger packets (disabled)
# Does small packet needs to be compressed?
# if self.compress and opcode < 8 and len(message) > 124:
if (compress or self.compress) and opcode < 8:
# RSV1 (rsv = 0x40) is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
rsv = 0x40
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = self._make_compress_obj(compress)
else: # self.compress
if not self._compressobj:
self._compressobj = self._make_compress_obj(self.compress)
compressobj = self._compressobj
message = (
await compressobj.compress(message)
+ compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING)
# Its critical that we do not return control to the event
# loop until we have finished sending all the compressed
# data. Otherwise we could end up mixing compressed frames
# if there are multiple coroutines compressing data.
msg_length = len(message)
use_mask = self.use_mask
mask_bit = 0x80 if use_mask else 0
# Depending on the message length, the header is assembled differently.
# The first byte is reserved for the opcode and the RSV bits.
first_byte = 0x80 | rsv | opcode
if msg_length < 126:
header = PACK_LEN1(first_byte, msg_length | mask_bit)
header_len = 2
elif msg_length < 65536:
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
header_len = 4
else:
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
header_len = 10
if self.transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
# If we are using a mask, we need to generate it randomly
# and apply it to the message before sending it. A mask is
# a 32-bit value that is applied to the message using a
# bitwise XOR operation. It is used to prevent certain types
# of attacks on the websocket protocol. The mask is only used
# when aiohttp is acting as a client. Servers do not use a mask.
if use_mask:
mask = PACK_RANDBITS(self.get_random_bits())
message = bytearray(message)
websocket_mask(mask, message)
self.transport.write(header + mask + message)
self._output_size += MASK_LEN
elif msg_length > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
else:
self.transport.write(header + message)
self._output_size += header_len + msg_length
# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.
# Once we have written output_size up to the limit, we call the
# drain helper which waits for the transport to be ready to accept
# more data. This is a flow control mechanism to prevent the buffer
# from growing too large. The drain helper will return right away
# if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
if self.protocol._paused:
await self.protocol._drain_helper()
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
return ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode("utf-8")
try:
await self.send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
)
finally:
self._closing = True

View File

@@ -0,0 +1,253 @@
import asyncio
import logging
import socket
import zlib
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
TypedDict,
Union,
)
from multidict import CIMultiDict
from yarl import URL
from .typedefs import LooseCookies
if TYPE_CHECKING:
from .web_app import Application
from .web_exceptions import HTTPException
from .web_request import BaseRequest, Request
from .web_response import StreamResponse
else:
BaseRequest = Request = Application = StreamResponse = None
HTTPException = None
class AbstractRouter(ABC):
def __init__(self) -> None:
self._frozen = False
def post_init(self, app: Application) -> None:
"""Post init stage.
Not an abstract method for sake of backward compatibility,
but if the router wants to be aware of the application
it can override this.
"""
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
"""Freeze router."""
self._frozen = True
@abstractmethod
async def resolve(self, request: Request) -> "AbstractMatchInfo":
"""Return MATCH_INFO for given request"""
class AbstractMatchInfo(ABC):
__slots__ = ()
@property # pragma: no branch
@abstractmethod
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
"""Execute matched request handler"""
@property
@abstractmethod
def expect_handler(
self,
) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]:
"""Expect handler for 100-continue processing"""
@property # pragma: no branch
@abstractmethod
def http_exception(self) -> Optional[HTTPException]:
"""HTTPException instance raised on router's resolving, or None"""
@abstractmethod # pragma: no branch
def get_info(self) -> Dict[str, Any]:
"""Return a dict with additional info useful for introspection"""
@property # pragma: no branch
@abstractmethod
def apps(self) -> Tuple[Application, ...]:
"""Stack of nested applications.
Top level application is left-most element.
"""
@abstractmethod
def add_app(self, app: Application) -> None:
"""Add application to the nested apps stack."""
@abstractmethod
def freeze(self) -> None:
"""Freeze the match info.
The method is called after route resolution.
After the call .add_app() is forbidden.
"""
class AbstractView(ABC):
"""Abstract class based view."""
def __init__(self, request: Request) -> None:
self._request = request
@property
def request(self) -> Request:
"""Request instance."""
return self._request
@abstractmethod
def __await__(self) -> Generator[Any, None, StreamResponse]:
"""Execute the view handler."""
class ResolveResult(TypedDict):
"""Resolve result.
This is the result returned from an AbstractResolver's
resolve method.
:param hostname: The hostname that was provided.
:param host: The IP address that was resolved.
:param port: The port that was resolved.
:param family: The address family that was resolved.
:param proto: The protocol that was resolved.
:param flags: The flags that were resolved.
"""
hostname: str
host: str
port: int
family: int
proto: int
flags: int
class AbstractResolver(ABC):
"""Abstract DNS resolver."""
@abstractmethod
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
"""Return IP address for given hostname"""
@abstractmethod
async def close(self) -> None:
"""Release resolver"""
if TYPE_CHECKING:
IterableBase = Iterable[Morsel[str]]
else:
IterableBase = Iterable
ClearCookiePredicate = Callable[["Morsel[str]"], bool]
class AbstractCookieJar(Sized, IterableBase):
"""Abstract Cookie Jar."""
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_running_loop()
@property
@abstractmethod
def quote_cookie(self) -> bool:
"""Return True if cookies should be quoted."""
@abstractmethod
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
"""Clear all cookies if no predicate is passed."""
@abstractmethod
def clear_domain(self, domain: str) -> None:
"""Clear all cookies for domain and all subdomains."""
@abstractmethod
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
@abstractmethod
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
"""Return the jar's cookies filtered by their attributes."""
class AbstractStreamWriter(ABC):
"""Abstract stream writer."""
buffer_size: int = 0
output_size: int = 0
length: Optional[int] = 0
@abstractmethod
async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
"""Write chunk into stream."""
@abstractmethod
async def write_eof(self, chunk: bytes = b"") -> None:
"""Write last chunk."""
@abstractmethod
async def drain(self) -> None:
"""Flush the write buffer."""
@abstractmethod
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
"""Enable HTTP body compression"""
@abstractmethod
def enable_chunking(self) -> None:
"""Enable HTTP chunked mode"""
@abstractmethod
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write HTTP headers"""
class AbstractAccessLogger(ABC):
"""Abstract writer to access log."""
__slots__ = ("logger", "log_format")
def __init__(self, logger: logging.Logger, log_format: str) -> None:
self.logger = logger
self.log_format = log_format
@abstractmethod
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
"""Emit log to logger."""
@property
def enabled(self) -> bool:
"""Check if logger is enabled."""
return True

View File

@@ -0,0 +1,100 @@
import asyncio
from typing import Optional, cast
from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay
class BaseProtocol(asyncio.Protocol):
__slots__ = (
"_loop",
"_paused",
"_drain_waiter",
"_connection_lost",
"_reading_paused",
"transport",
)
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
self._reading_paused = False
self.transport: Optional[asyncio.Transport] = None
@property
def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None
@property
def writing_paused(self) -> bool:
return self._paused
def pause_writing(self) -> None:
assert not self._paused
self._paused = True
def resume_writing(self) -> None:
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
def pause_reading(self) -> None:
if not self._reading_paused and self.transport is not None:
try:
self.transport.pause_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = True
def resume_reading(self) -> None:
if self._reading_paused and self.transport is not None:
try:
self.transport.resume_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = False
def connection_made(self, transport: asyncio.BaseTransport) -> None:
tr = cast(asyncio.Transport, transport)
tcp_nodelay(tr, True)
self.transport = tr
def connection_lost(self, exc: Optional[BaseException]) -> None:
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
set_exception(
waiter,
ConnectionError("Connection lost"),
exc,
)
async def _drain_helper(self) -> None:
if self.transport is None:
raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
waiter = self._loop.create_future()
self._drain_waiter = waiter
await asyncio.shield(waiter)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,421 @@
"""HTTP related errors."""
import asyncio
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from multidict import MultiMapping
from .typedefs import StrOrURL
if TYPE_CHECKING:
import ssl
SSLContext = ssl.SSLContext
else:
try:
import ssl
SSLContext = ssl.SSLContext
except ImportError: # pragma: no cover
ssl = SSLContext = None # type: ignore[assignment]
if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
from .http_parser import RawResponseMessage
else:
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None
__all__ = (
"ClientError",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
"ClientSSLError",
"ClientConnectorDNSError",
"ClientConnectorSSLError",
"ClientConnectorCertificateError",
"ConnectionTimeoutError",
"SocketTimeoutError",
"ServerConnectionError",
"ServerTimeoutError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ClientResponseError",
"ClientHttpProxyError",
"WSServerHandshakeError",
"ContentTypeError",
"ClientPayloadError",
"InvalidURL",
"InvalidUrlClientError",
"RedirectClientError",
"NonHttpUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlRedirectClientError",
"WSMessageTypeError",
)
class ClientError(Exception):
"""Base class for client connection errors."""
class ClientResponseError(ClientError):
"""Base class for exceptions that occur after getting a response.
request_info: An instance of RequestInfo.
history: A sequence of responses, if redirects occurred.
status: HTTP status code.
message: Error message.
headers: Response headers.
"""
def __init__(
self,
request_info: RequestInfo,
history: Tuple[ClientResponse, ...],
*,
code: Optional[int] = None,
status: Optional[int] = None,
message: str = "",
headers: Optional[MultiMapping[str]] = None,
) -> None:
self.request_info = request_info
if code is not None:
if status is not None:
raise ValueError(
"Both code and status arguments are provided; "
"code is deprecated, use status instead"
)
warnings.warn(
"code argument is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
if status is not None:
self.status = status
elif code is not None:
self.status = code
else:
self.status = 0
self.message = message
self.headers = headers
self.history = history
self.args = (request_info, history)
def __str__(self) -> str:
return "{}, message={!r}, url={!r}".format(
self.status,
self.message,
str(self.request_info.real_url),
)
def __repr__(self) -> str:
args = f"{self.request_info!r}, {self.history!r}"
if self.status != 0:
args += f", status={self.status!r}"
if self.message != "":
args += f", message={self.message!r}"
if self.headers is not None:
args += f", headers={self.headers!r}"
return f"{type(self).__name__}({args})"
@property
def code(self) -> int:
warnings.warn(
"code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
return self.status
@code.setter
def code(self, value: int) -> None:
warnings.warn(
"code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
self.status = value
class ContentTypeError(ClientResponseError):
"""ContentType found is not valid."""
class WSServerHandshakeError(ClientResponseError):
"""websocket server handshake error."""
class ClientHttpProxyError(ClientResponseError):
"""HTTP proxy error.
Raised in :class:`aiohttp.connector.TCPConnector` if
proxy responds with status other than ``200 OK``
on ``CONNECT`` request.
"""
class TooManyRedirects(ClientResponseError):
"""Client was redirected too many times."""
class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
"""ConnectionResetError"""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
class ClientConnectorError(ClientOSError):
"""Client connector error.
Raised in :class:`aiohttp.connector.TCPConnector` if
a connection can not be established.
"""
def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None:
self._conn_key = connection_key
self._os_error = os_error
super().__init__(os_error.errno, os_error.strerror)
self.args = (connection_key, os_error)
@property
def os_error(self) -> OSError:
return self._os_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
return self._conn_key.ssl
def __str__(self) -> str:
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
self, "default" if self.ssl is True else self.ssl, self.strerror
)
# OSError.__reduce__ does too much black magick
__reduce__ = BaseException.__reduce__
class ClientConnectorDNSError(ClientConnectorError):
"""DNS resolution failed during client connection.
Raised in :class:`aiohttp.connector.TCPConnector` if
DNS resolution fails.
"""
class ClientProxyConnectionError(ClientConnectorError):
"""Proxy connection error.
Raised in :class:`aiohttp.connector.TCPConnector` if
connection to proxy can not be established.
"""
class UnixClientConnectorError(ClientConnectorError):
"""Unix connector error.
Raised in :py:class:`aiohttp.connector.UnixConnector`
if connection to unix socket can not be established.
"""
def __init__(
self, path: str, connection_key: ConnectionKey, os_error: OSError
) -> None:
self._path = path
super().__init__(connection_key, os_error)
@property
def path(self) -> str:
return self._path
def __str__(self) -> str:
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
self, "default" if self.ssl is True else self.ssl, self.strerror
)
class ServerConnectionError(ClientConnectionError):
"""Server connection errors."""
class ServerDisconnectedError(ServerConnectionError):
"""Server disconnected."""
def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None:
if message is None:
message = "Server disconnected"
self.args = (message,)
self.message = message
class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
"""Server timeout error."""
class ConnectionTimeoutError(ServerTimeoutError):
"""Connection timeout error."""
class SocketTimeoutError(ServerTimeoutError):
"""Socket timeout error."""
class ServerFingerprintMismatch(ServerConnectionError):
"""SSL certificate does not match expected fingerprint."""
def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None:
self.expected = expected
self.got = got
self.host = host
self.port = port
self.args = (expected, got, host, port)
def __repr__(self) -> str:
return "<{} expected={!r} got={!r} host={!r} port={!r}>".format(
self.__class__.__name__, self.expected, self.got, self.host, self.port
)
class ClientPayloadError(ClientError):
"""Response payload error."""
class InvalidURL(ClientError, ValueError):
"""Invalid URL.
URL used for fetching is malformed, e.g. it doesn't contains host
part.
"""
# Derive from ValueError for backward compatibility
def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None:
# The type of url is not yarl.URL because the exception can be raised
# on URL(url) call
self._url = url
self._description = description
if description:
super().__init__(url, description)
else:
super().__init__(url)
@property
def url(self) -> StrOrURL:
return self._url
@property
def description(self) -> "str | None":
return self._description
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self}>"
def __str__(self) -> str:
if self._description:
return f"{self._url} - {self._description}"
return str(self._url)
class InvalidUrlClientError(InvalidURL):
"""Invalid URL client error."""
class RedirectClientError(ClientError):
"""Client redirect error."""
class NonHttpUrlClientError(ClientError):
"""Non http URL client error."""
class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError):
"""Invalid URL redirect client error."""
class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError):
"""Non http URL redirect client error."""
class ClientSSLError(ClientConnectorError):
"""Base error for ssl.*Errors."""
if ssl is not None:
cert_errors = (ssl.CertificateError,)
cert_errors_bases = (
ClientSSLError,
ssl.CertificateError,
)
ssl_errors = (ssl.SSLError,)
ssl_error_bases = (ClientSSLError, ssl.SSLError)
else: # pragma: no cover
cert_errors = tuple()
cert_errors_bases = (
ClientSSLError,
ValueError,
)
ssl_errors = tuple()
ssl_error_bases = (ClientSSLError,)
class ClientConnectorSSLError(*ssl_error_bases): # type: ignore[misc]
"""Response ssl error."""
class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore[misc]
"""Response certificate error."""
def __init__(
self, connection_key: ConnectionKey, certificate_error: Exception
) -> None:
self._conn_key = connection_key
self._certificate_error = certificate_error
self.args = (connection_key, certificate_error)
@property
def certificate_error(self) -> Exception:
return self._certificate_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> bool:
return self._conn_key.is_ssl
def __str__(self) -> str:
return (
"Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} "
"[{0.certificate_error.__class__.__name__}: "
"{0.certificate_error.args}]".format(self)
)
class WSMessageTypeError(TypeError):
"""WebSocket message type is not valid."""

View File

@@ -0,0 +1,308 @@
import asyncio
from contextlib import suppress
from typing import Any, Optional, Tuple
from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
SocketTimeoutError,
)
from .helpers import (
_EXC_SENTINEL,
EMPTY_BODY_STATUS_CODES,
BaseTimerContext,
set_exception,
)
from .http import HttpResponseParser, RawResponseMessage
from .http_exceptions import HttpProcessingError
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]):
"""Helper class to adapt between Protocol and StreamReader."""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
BaseProtocol.__init__(self, loop=loop)
DataQueue.__init__(self, loop)
self._should_close = False
self._payload: Optional[StreamReader] = None
self._skip_payload = False
self._payload_parser = None
self._timer = None
self._tail = b""
self._upgraded = False
self._parser: Optional[HttpResponseParser] = None
self._read_timeout: Optional[float] = None
self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
self._timeout_ceil_threshold: Optional[float] = 5
@property
def upgraded(self) -> bool:
return self._upgraded
@property
def should_close(self) -> bool:
return bool(
self._should_close
or (self._payload is not None and not self._payload.is_eof())
or self._upgraded
or self._exception is not None
or self._payload_parser is not None
or self._buffer
or self._tail
)
def force_close(self) -> None:
self._should_close = True
def close(self) -> None:
self._exception = None # Break cyclic references
transport = self.transport
if transport is not None:
transport.close()
self.transport = None
self._payload = None
self._drop_timeout()
def is_connected(self) -> bool:
return self.transport is not None and not self.transport.is_closing()
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()
original_connection_error = exc
reraised_exc = original_connection_error
connection_closed_cleanly = original_connection_error is None
if self._payload_parser is not None:
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()
uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception as underlying_exc:
if self._payload is not None:
client_payload_exc_msg = (
f"Response payload is not completed: {underlying_exc !r}"
)
if not connection_closed_cleanly:
client_payload_exc_msg = (
f"{client_payload_exc_msg !s}. "
f"{original_connection_error !r}"
)
set_exception(
self._payload,
ClientPayloadError(client_payload_exc_msg),
underlying_exc,
)
if not self.is_eof():
if isinstance(original_connection_error, OSError):
reraised_exc = ClientOSError(*original_connection_error.args)
if connection_closed_cleanly:
reraised_exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
underlying_non_eof_exc = (
_EXC_SENTINEL
if connection_closed_cleanly
else original_connection_error
)
assert underlying_non_eof_exc is not None
assert reraised_exc is not None
self.set_exception(reraised_exc, underlying_non_eof_exc)
self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False
super().connection_lost(reraised_exc)
def eof_received(self) -> None:
# should call parser.feed_eof() most likely
self._drop_timeout()
def pause_reading(self) -> None:
super().pause_reading()
self._drop_timeout()
def resume_reading(self) -> None:
super().resume_reading()
self._reschedule_timeout()
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc, exc_cause)
def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: WebSocketDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
self._payload_parser = parser
self._drop_timeout()
if self._tail:
data, self._tail = self._tail, b""
self.data_received(data)
def set_response_params(
self,
*,
timer: Optional[BaseTimerContext] = None,
skip_payload: bool = False,
read_until_eof: bool = False,
auto_decompress: bool = True,
read_timeout: Optional[float] = None,
read_bufsize: int = 2**16,
timeout_ceil_threshold: float = 5,
max_line_size: int = 8190,
max_field_size: int = 8190,
) -> None:
self._skip_payload = skip_payload
self._read_timeout = read_timeout
self._timeout_ceil_threshold = timeout_ceil_threshold
self._parser = HttpResponseParser(
self,
self._loop,
read_bufsize,
timer=timer,
payload_exception=ClientPayloadError,
response_with_body=not skip_payload,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
max_line_size=max_line_size,
max_field_size=max_field_size,
)
if self._tail:
data, self._tail = self._tail, b""
self.data_received(data)
def _drop_timeout(self) -> None:
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
self._read_timeout_handle = None
def _reschedule_timeout(self) -> None:
timeout = self._read_timeout
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
if timeout:
self._read_timeout_handle = self._loop.call_later(
timeout, self._on_read_timeout
)
else:
self._read_timeout_handle = None
def start_timeout(self) -> None:
self._reschedule_timeout()
@property
def read_timeout(self) -> Optional[float]:
return self._read_timeout
@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout
def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
set_exception(self._payload, exc)
def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
if not data:
return
# custom payload parser - currently always WebSocketReader
if self._payload_parser is not None:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self._payload = None
self._payload_parser = None
if tail:
self.data_received(tail)
return
if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data
return
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return
self._upgraded = upgraded
payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True
self._payload = payload
if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()
if upgraded and tail:
self.data_received(tail)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,428 @@
"""WebSocket client for asyncio."""
import asyncio
import sys
from types import TracebackType
from typing import Any, Optional, Type, cast
import attr
from ._websocket.reader import WebSocketDataQueue
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
from .streams import EofStream
from .typedefs import (
DEFAULT_JSON_DECODER,
DEFAULT_JSON_ENCODER,
JSONDecoder,
JSONEncoder,
)
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
@attr.s(frozen=True, slots=True)
class ClientWSTimeout:
ws_receive = attr.ib(type=Optional[float], default=None)
ws_close = attr.ib(type=Optional[float], default=None)
DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0)
class ClientWebSocketResponse:
def __init__(
self,
reader: WebSocketDataQueue,
writer: WebSocketWriter,
protocol: Optional[str],
response: ClientResponse,
timeout: ClientWSTimeout,
autoclose: bool,
autoping: bool,
loop: asyncio.AbstractEventLoop,
*,
heartbeat: Optional[float] = None,
compress: int = 0,
client_notakeover: bool = False,
) -> None:
self._response = response
self._conn = response.connection
self._writer = writer
self._reader = reader
self._protocol = protocol
self._closed = False
self._closing = False
self._close_code: Optional[int] = None
self._timeout = timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
self._heartbeat_when: float = 0.0
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._loop = loop
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: Optional[asyncio.Task[None]] = None
self._reset_heartbeat()
def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
loop = self._loop
assert loop is not None
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
def _send_heartbeat(self) -> None:
self._heartbeat_cb = None
loop = self._loop
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
coro = self._writer.send_frame(b"", WSMsgType.PING)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
ping_task = loop.create_task(coro)
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None
def _pong_not_received(self) -> None:
self._handle_ping_pong_exception(
ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
)
def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()
def _set_closing(self) -> None:
"""Set the connection to closing.
Cancel any heartbeat timers and set the closing flag.
"""
self._closing = True
self._cancel_heartbeat()
@property
def closed(self) -> bool:
return self._closed
@property
def close_code(self) -> Optional[int]:
return self._close_code
@property
def protocol(self) -> Optional[str]:
return self._protocol
@property
def compress(self) -> int:
return self._compress
@property
def client_notakeover(self) -> bool:
return self._client_notakeover
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""extra info from connection transport"""
conn = self._response.connection
if conn is None:
return default
transport = conn.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def exception(self) -> Optional[BaseException]:
return self._exception
async def ping(self, message: bytes = b"") -> None:
await self._writer.send_frame(message, WSMsgType.PING)
async def pong(self, message: bytes = b"") -> None:
await self._writer.send_frame(message, WSMsgType.PONG)
async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket."""
await self._writer.send_frame(message, opcode, compress)
async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)
async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
async def send_json(
self,
data: Any,
compress: Optional[int] = None,
*,
dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
) -> None:
await self.send_str(dumps(data), compress=compress)
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting and not self._closing:
assert self._loop is not None
self._close_wait = self._loop.create_future()
self._set_closing()
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._close_wait
if self._closed:
return False
self._set_closed()
try:
await self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if self._close_code:
self._response.close()
return True
while True:
try:
async with async_timeout.timeout(self._timeout.ws_close):
msg = await self._reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if msg.type is WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
receive_timeout = timeout or self._timeout.ws_receive
while True:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
return WS_CLOSED_MESSAGE
elif self._closing:
await self.close()
return WS_CLOSED_MESSAGE
try:
self._waiting = True
try:
if receive_timeout:
# Entering the context manager and creating
# Timeout() object can take almost 50% of the
# run time in this loop so we avoid it if
# there is no read timeout.
async with async_timeout.timeout(receive_timeout):
msg = await self._reader.read()
else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except EofStream:
self._close_code = WSCloseCode.OK
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
# Likely ServerDisconnectedError when connection is lost
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._set_closing()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type not in _INTERNAL_RECEIVE_TYPES:
# If its not a close/closing/ping/pong message
# we can return it immediately
return msg
if msg.type is WSMsgType.CLOSE:
self._set_closing()
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type is WSMsgType.CLOSING:
self._set_closing()
elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return cast(str, msg.data)
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
)
return cast(bytes, msg.data)
async def receive_json(
self,
*,
loads: JSONDecoder = DEFAULT_JSON_DECODER,
timeout: Optional[float] = None,
) -> Any:
data = await self.receive_str(timeout=timeout)
return loads(data)
def __aiter__(self) -> "ClientWebSocketResponse":
return self
async def __anext__(self) -> WSMessage:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
return msg
async def __aenter__(self) -> "ClientWebSocketResponse":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()

View File

@@ -0,0 +1,173 @@
import asyncio
import zlib
from concurrent.futures import Executor
from typing import Optional, cast
try:
try:
import brotlicffi as brotli
except ImportError:
import brotli
HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False
MAX_SYNC_CHUNK_SIZE = 1024
def encoding_to_mode(
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
) -> int:
if encoding == "gzip":
return 16 + zlib.MAX_WBITS
return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS
class ZlibBaseHandler:
def __init__(
self,
mode: int,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
self._mode = mode
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size
class ZLibCompressor(ZlibBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
level: Optional[int] = None,
wbits: Optional[int] = None,
strategy: int = zlib.Z_DEFAULT_STRATEGY,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=(
encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
else wbits
),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
if level is None:
self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy)
else:
self._compressor = zlib.compressobj(
wbits=self._mode, strategy=strategy, level=level
)
self._compress_lock = asyncio.Lock()
def compress_sync(self, data: bytes) -> bytes:
return self._compressor.compress(data)
async def compress(self, data: bytes) -> bytes:
"""Compress the data and returned the compressed bytes.
Note that flush() must be called after the last call to compress()
If the data size is large than the max_sync_chunk_size, the compression
will be done in the executor. Otherwise, the compression will be done
in the event loop.
"""
async with self._compress_lock:
# To ensure the stream is consistent in the event
# there are multiple writers, we need to lock
# the compressor so that only one writer can
# compress at a time.
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)
def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
return self._compressor.flush(mode)
class ZLibDecompressor(ZlibBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=encoding_to_mode(encoding, suppress_deflate_header),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
self._decompressor = zlib.decompressobj(wbits=self._mode)
def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
return self._decompressor.decompress(data, max_length)
async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
"""Decompress the data and return the decompressed bytes.
If the data size is large than the max_sync_chunk_size, the decompression
will be done in the executor. Otherwise, the decompression will be done
in the event loop.
"""
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._decompressor.decompress, data, max_length
)
return self.decompress_sync(data, max_length)
def flush(self, length: int = 0) -> bytes:
return (
self._decompressor.flush(length)
if length > 0
else self._decompressor.flush()
)
@property
def eof(self) -> bool:
return self._decompressor.eof
@property
def unconsumed_tail(self) -> bytes:
return self._decompressor.unconsumed_tail
@property
def unused_data(self) -> bytes:
return self._decompressor.unused_data
class BrotliDecompressor:
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self) -> None:
if not HAS_BROTLI:
raise RuntimeError(
"The brotli decompression is not available. "
"Please install `Brotli` module"
)
self._obj = brotli.Decompressor()
def decompress_sync(self, data: bytes) -> bytes:
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))
def flush(self) -> bytes:
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,495 @@
import asyncio
import calendar
import contextlib
import datetime
import heapq
import itertools
import os # noqa
import pathlib
import pickle
import re
import time
import warnings
from collections import defaultdict
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import (
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
Union,
cast,
)
from yarl import URL
from .abc import AbstractCookieJar, ClearCookiePredicate
from .helpers import is_ip_address
from .typedefs import LooseCookies, PathLike, StrOrURL
__all__ = ("CookieJar", "DummyCookieJar")
CookieItem = Union[str, "Morsel[str]"]
# We cache these string methods here as their use is in performance critical code.
_FORMAT_PATH = "{}/{}".format
_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
# The minimum number of scheduled cookie expirations before we start cleaning up
# the expiration heap. This is a performance optimization to avoid cleaning up the
# heap too often when there are only a few scheduled expirations.
_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
DATE_TOKENS_RE = re.compile(
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
)
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
DATE_MONTH_RE = re.compile(
"(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
re.I,
)
DATE_YEAR_RE = re.compile(r"(\d{2,4})")
# calendar.timegm() fails for timestamps after datetime.datetime.max
# Minus one as a loss of precision occurs when timestamp() is called.
MAX_TIME = (
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
)
try:
calendar.timegm(time.gmtime(MAX_TIME))
except (OSError, ValueError):
# Hit the maximum representable time on Windows
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
# Throws ValueError on PyPy 3.9, OSError elsewhere
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
except OverflowError:
# #4515: datetime.max may not be representable on 32-bit platforms
MAX_TIME = 2**31 - 1
# Avoid minuses in the future, 3x faster
SUB_MAX_TIME = MAX_TIME - 1
def __init__(
self,
*,
unsafe: bool = False,
quote_cookie: bool = True,
treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
super().__init__(loop=loop)
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
SimpleCookie
)
self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
defaultdict(dict)
)
self._host_only_cookies: Set[Tuple[str, str]] = set()
self._unsafe = unsafe
self._quote_cookie = quote_cookie
if treat_as_secure_origin is None:
treat_as_secure_origin = []
elif isinstance(treat_as_secure_origin, URL):
treat_as_secure_origin = [treat_as_secure_origin.origin()]
elif isinstance(treat_as_secure_origin, str):
treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
else:
treat_as_secure_origin = [
URL(url).origin() if isinstance(url, str) else url.origin()
for url in treat_as_secure_origin
]
self._treat_as_secure_origin = treat_as_secure_origin
self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
self._expirations: Dict[Tuple[str, str, str], float] = {}
@property
def quote_cookie(self) -> bool:
return self._quote_cookie
def save(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode="wb") as f:
pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
def load(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode="rb") as f:
self._cookies = pickle.load(f)
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
if predicate is None:
self._expire_heap.clear()
self._cookies.clear()
self._morsel_cache.clear()
self._host_only_cookies.clear()
self._expirations.clear()
return
now = time.time()
to_del = [
key
for (domain, path), cookie in self._cookies.items()
for name, morsel in cookie.items()
if (
(key := (domain, path, name)) in self._expirations
and self._expirations[key] <= now
)
or predicate(morsel)
]
if to_del:
self._delete_cookies(to_del)
def clear_domain(self, domain: str) -> None:
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
def __iter__(self) -> "Iterator[Morsel[str]]":
self._do_expiration()
for val in self._cookies.values():
yield from val.values()
def __len__(self) -> int:
"""Return number of cookies.
This function does not iterate self to avoid unnecessary expiration
checks.
"""
return sum(len(cookie.values()) for cookie in self._cookies.values())
def _do_expiration(self) -> None:
"""Remove expired cookies."""
if not (expire_heap_len := len(self._expire_heap)):
return
# If the expiration heap grows larger than the number expirations
# times two, we clean it up to avoid keeping expired entries in
# the heap and consuming memory. We guard this with a minimum
# threshold to avoid cleaning up the heap too often when there are
# only a few scheduled expirations.
if (
expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
and expire_heap_len > len(self._expirations) * 2
):
# Remove any expired entries from the expiration heap
# that do not match the expiration time in the expirations
# as it means the cookie has been re-added to the heap
# with a different expiration time.
self._expire_heap = [
entry
for entry in self._expire_heap
if self._expirations.get(entry[1]) == entry[0]
]
heapq.heapify(self._expire_heap)
now = time.time()
to_del: List[Tuple[str, str, str]] = []
# Find any expired cookies and add them to the to-delete list
while self._expire_heap:
when, cookie_key = self._expire_heap[0]
if when > now:
break
heapq.heappop(self._expire_heap)
# Check if the cookie hasn't been re-added to the heap
# with a different expiration time as it will be removed
# later when it reaches the top of the heap and its
# expiration time is met.
if self._expirations.get(cookie_key) == when:
to_del.append(cookie_key)
if to_del:
self._delete_cookies(to_del)
def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
for domain, path, name in to_del:
self._host_only_cookies.discard((domain, name))
self._cookies[(domain, path)].pop(name, None)
self._morsel_cache[(domain, path)].pop(name, None)
self._expirations.pop((domain, path, name), None)
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
cookie_key = (domain, path, name)
if self._expirations.get(cookie_key) == when:
# Avoid adding duplicates to the heap
return
heapq.heappush(self._expire_heap, (when, cookie_key))
self._expirations[cookie_key] = when
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
hostname = response_url.raw_host
if not self._unsafe and is_ip_address(hostname):
# Don't accept cookies from IPs
return
if isinstance(cookies, Mapping):
cookies = cookies.items()
for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp = SimpleCookie()
tmp[name] = cookie # type: ignore[assignment]
cookie = tmp[name]
domain = cookie["domain"]
# ignore domains with trailing dots
if domain and domain[-1] == ".":
domain = ""
del cookie["domain"]
if not domain and hostname is not None:
# Set the cookie's domain to the response hostname
# and set its host-only-flag
self._host_only_cookies.add((hostname, name))
domain = cookie["domain"] = hostname
if domain and domain[0] == ".":
# Remove leading dot
domain = domain[1:]
cookie["domain"] = domain
if hostname and not self._is_domain_match(domain, hostname):
# Setting cookies for different domains is not allowed
continue
path = cookie["path"]
if not path or path[0] != "/":
# Set the cookie's path to the response path
path = response_url.path
if not path.startswith("/"):
path = "/"
else:
# Cut everything from the last slash to the end
path = "/" + path[1 : path.rfind("/")]
cookie["path"] = path
path = path.rstrip("/")
if max_age := cookie["max-age"]:
try:
delta_seconds = int(max_age)
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
self._expire_cookie(max_age_expiration, domain, path, name)
except ValueError:
cookie["max-age"] = ""
elif expires := cookie["expires"]:
if expire_time := self._parse_date(expires):
self._expire_cookie(expire_time, domain, path, name)
else:
cookie["expires"] = ""
key = (domain, path)
if self._cookies[key].get(name) != cookie:
# Don't blow away the cache if the same
# cookie gets set again
self._cookies[key][name] = cookie
self._morsel_cache[key].pop(name, None)
self._do_expiration()
def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
"""Returns this jar's cookies filtered by their attributes."""
filtered: Union[SimpleCookie, "BaseCookie[str]"] = (
SimpleCookie() if self._quote_cookie else BaseCookie()
)
if not self._cookies:
# Skip do_expiration() if there are no cookies.
return filtered
self._do_expiration()
if not self._cookies:
# Skip rest of function if no non-expired cookies.
return filtered
if type(request_url) is not URL:
warnings.warn(
"filter_cookies expects yarl.URL instances only,"
f"and will stop working in 4.x, got {type(request_url)}",
DeprecationWarning,
stacklevel=2,
)
request_url = URL(request_url)
hostname = request_url.raw_host or ""
is_not_secure = request_url.scheme not in ("https", "wss")
if is_not_secure and self._treat_as_secure_origin:
request_origin = URL()
with contextlib.suppress(ValueError):
request_origin = request_url.origin()
is_not_secure = request_origin not in self._treat_as_secure_origin
# Send shared cookie
for c in self._cookies[("", "")].values():
filtered[c.key] = c.value
if is_ip_address(hostname):
if not self._unsafe:
return filtered
domains: Iterable[str] = (hostname,)
else:
# Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
domains = itertools.accumulate(
reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
)
# Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
# Create every combination of (domain, path) pairs.
pairs = itertools.product(domains, paths)
path_len = len(request_url.path)
# Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
for p in pairs:
for name, cookie in self._cookies[p].items():
domain = cookie["domain"]
if (domain, name) in self._host_only_cookies and domain != hostname:
continue
# Skip edge case when the cookie has a trailing slash but request doesn't.
if len(cookie["path"]) > path_len:
continue
if is_not_secure and cookie["secure"]:
continue
# We already built the Morsel so reuse it here
if name in self._morsel_cache[p]:
filtered[name] = self._morsel_cache[p][name]
continue
# It's critical we use the Morsel so the coded_value
# (based on cookie version) is preserved
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
self._morsel_cache[p][name] = mrsl_val
filtered[name] = mrsl_val
return filtered
@staticmethod
def _is_domain_match(domain: str, hostname: str) -> bool:
"""Implements domain matching adhering to RFC 6265."""
if hostname == domain:
return True
if not hostname.endswith(domain):
return False
non_matching = hostname[: -len(domain)]
if not non_matching.endswith("."):
return False
return not is_ip_address(hostname)
@classmethod
def _parse_date(cls, date_str: str) -> Optional[int]:
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
found_time = False
found_day = False
found_month = False
found_year = False
hour = minute = second = 0
day = 0
month = 0
year = 0
for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
token = token_match.group("token")
if not found_time:
time_match = cls.DATE_HMS_TIME_RE.match(token)
if time_match:
found_time = True
hour, minute, second = (int(s) for s in time_match.groups())
continue
if not found_day:
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
if day_match:
found_day = True
day = int(day_match.group())
continue
if not found_month:
month_match = cls.DATE_MONTH_RE.match(token)
if month_match:
found_month = True
assert month_match.lastindex is not None
month = month_match.lastindex
continue
if not found_year:
year_match = cls.DATE_YEAR_RE.match(token)
if year_match:
found_year = True
year = int(year_match.group())
if 70 <= year <= 99:
year += 1900
elif 0 <= year <= 69:
year += 2000
if False in (found_day, found_month, found_year, found_time):
return None
if not 1 <= day <= 31:
return None
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None
return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
class DummyCookieJar(AbstractCookieJar):
"""Implements a dummy cookie storage.
It can be used with the ClientSession when no cookie processing is needed.
"""
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(loop=loop)
def __iter__(self) -> "Iterator[Morsel[str]]":
while False:
yield None
def __len__(self) -> int:
return 0
@property
def quote_cookie(self) -> bool:
return True
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
pass
def clear_domain(self, domain: str) -> None:
pass
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
pass
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
return SimpleCookie()

View File

@@ -0,0 +1,182 @@
import io
import warnings
from typing import Any, Iterable, List, Optional
from urllib.parse import urlencode
from multidict import MultiDict, MultiDictProxy
from . import hdrs, multipart, payload
from .helpers import guess_filename
from .payload import Payload
__all__ = ("FormData",)
class FormData:
"""Helper class for form body generation.
Supports multipart/form-data and application/x-www-form-urlencoded.
"""
def __init__(
self,
fields: Iterable[Any] = (),
quote_fields: bool = True,
charset: Optional[str] = None,
*,
default_to_multipart: bool = False,
) -> None:
self._writer = multipart.MultipartWriter("form-data")
self._fields: List[Any] = []
self._is_multipart = default_to_multipart
self._is_processed = False
self._quote_fields = quote_fields
self._charset = charset
if isinstance(fields, dict):
fields = list(fields.items())
elif not isinstance(fields, (list, tuple)):
fields = (fields,)
self.add_fields(*fields)
@property
def is_multipart(self) -> bool:
return self._is_multipart
def add_field(
self,
name: str,
value: Any,
*,
content_type: Optional[str] = None,
filename: Optional[str] = None,
content_transfer_encoding: Optional[str] = None,
) -> None:
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
msg = (
"In v4, passing bytes will no longer create a file field. "
"Please explicitly use the filename parameter or pass a BytesIO object."
)
if filename is None and content_transfer_encoding is None:
warnings.warn(msg, DeprecationWarning)
filename = name
type_options: MultiDict[str] = MultiDict({"name": name})
if filename is not None and not isinstance(filename, str):
raise TypeError("filename must be an instance of str. Got: %s" % filename)
if filename is None and isinstance(value, io.IOBase):
filename = guess_filename(value, name)
if filename is not None:
type_options["filename"] = filename
self._is_multipart = True
headers = {}
if content_type is not None:
if not isinstance(content_type, str):
raise TypeError(
"content_type must be an instance of str. Got: %s" % content_type
)
headers[hdrs.CONTENT_TYPE] = content_type
self._is_multipart = True
if content_transfer_encoding is not None:
if not isinstance(content_transfer_encoding, str):
raise TypeError(
"content_transfer_encoding must be an instance"
" of str. Got: %s" % content_transfer_encoding
)
msg = (
"content_transfer_encoding is deprecated. "
"To maintain compatibility with v4 please pass a BytesPayload."
)
warnings.warn(msg, DeprecationWarning)
self._is_multipart = True
self._fields.append((type_options, headers, value))
def add_fields(self, *fields: Any) -> None:
to_add = list(fields)
while to_add:
rec = to_add.pop(0)
if isinstance(rec, io.IOBase):
k = guess_filename(rec, "unknown")
self.add_field(k, rec) # type: ignore[arg-type]
elif isinstance(rec, (MultiDictProxy, MultiDict)):
to_add.extend(rec.items())
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
k, fp = rec
self.add_field(k, fp) # type: ignore[arg-type]
else:
raise TypeError(
"Only io.IOBase, multidict and (name, file) "
"pairs allowed, use .add_field() for passing "
"more complex parameters, got {!r}".format(rec)
)
def _gen_form_urlencoded(self) -> payload.BytesPayload:
# form data (x-www-form-urlencoded)
data = []
for type_options, _, value in self._fields:
data.append((type_options["name"], value))
charset = self._charset if self._charset is not None else "utf-8"
if charset == "utf-8":
content_type = "application/x-www-form-urlencoded"
else:
content_type = "application/x-www-form-urlencoded; charset=%s" % charset
return payload.BytesPayload(
urlencode(data, doseq=True, encoding=charset).encode(),
content_type=content_type,
)
def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
if self._is_processed:
raise RuntimeError("Form data has been processed already")
for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
part = payload.get_payload(
value,
content_type=headers[hdrs.CONTENT_TYPE],
headers=headers,
encoding=self._charset,
)
else:
part = payload.get_payload(
value, headers=headers, encoding=self._charset
)
except Exception as exc:
raise TypeError(
"Can not serialize value type: %r\n "
"headers: %r\n value: %r" % (type(value), headers, value)
) from exc
if dispparams:
part.set_content_disposition(
"form-data", quote_fields=self._quote_fields, **dispparams
)
# FIXME cgi.FieldStorage doesn't likes body parts with
# Content-Length which were sent via chunked transfer encoding
assert part.headers is not None
part.headers.popall(hdrs.CONTENT_LENGTH, None)
self._writer.append_payload(part)
self._is_processed = True
return self._writer
def __call__(self) -> Payload:
if self._is_multipart:
return self._gen_form_data()
else:
return self._gen_form_urlencoded()

View File

@@ -0,0 +1,121 @@
"""HTTP Headers constants."""
# After changing the file content call ./tools/gen.py
# to regenerate the headers parser
import itertools
from typing import Final, Set
from multidict import istr
METH_ANY: Final[str] = "*"
METH_CONNECT: Final[str] = "CONNECT"
METH_HEAD: Final[str] = "HEAD"
METH_GET: Final[str] = "GET"
METH_DELETE: Final[str] = "DELETE"
METH_OPTIONS: Final[str] = "OPTIONS"
METH_PATCH: Final[str] = "PATCH"
METH_POST: Final[str] = "POST"
METH_PUT: Final[str] = "PUT"
METH_TRACE: Final[str] = "TRACE"
METH_ALL: Final[Set[str]] = {
METH_CONNECT,
METH_HEAD,
METH_GET,
METH_DELETE,
METH_OPTIONS,
METH_PATCH,
METH_POST,
METH_PUT,
METH_TRACE,
}
ACCEPT: Final[istr] = istr("Accept")
ACCEPT_CHARSET: Final[istr] = istr("Accept-Charset")
ACCEPT_ENCODING: Final[istr] = istr("Accept-Encoding")
ACCEPT_LANGUAGE: Final[istr] = istr("Accept-Language")
ACCEPT_RANGES: Final[istr] = istr("Accept-Ranges")
ACCESS_CONTROL_MAX_AGE: Final[istr] = istr("Access-Control-Max-Age")
ACCESS_CONTROL_ALLOW_CREDENTIALS: Final[istr] = istr("Access-Control-Allow-Credentials")
ACCESS_CONTROL_ALLOW_HEADERS: Final[istr] = istr("Access-Control-Allow-Headers")
ACCESS_CONTROL_ALLOW_METHODS: Final[istr] = istr("Access-Control-Allow-Methods")
ACCESS_CONTROL_ALLOW_ORIGIN: Final[istr] = istr("Access-Control-Allow-Origin")
ACCESS_CONTROL_EXPOSE_HEADERS: Final[istr] = istr("Access-Control-Expose-Headers")
ACCESS_CONTROL_REQUEST_HEADERS: Final[istr] = istr("Access-Control-Request-Headers")
ACCESS_CONTROL_REQUEST_METHOD: Final[istr] = istr("Access-Control-Request-Method")
AGE: Final[istr] = istr("Age")
ALLOW: Final[istr] = istr("Allow")
AUTHORIZATION: Final[istr] = istr("Authorization")
CACHE_CONTROL: Final[istr] = istr("Cache-Control")
CONNECTION: Final[istr] = istr("Connection")
CONTENT_DISPOSITION: Final[istr] = istr("Content-Disposition")
CONTENT_ENCODING: Final[istr] = istr("Content-Encoding")
CONTENT_LANGUAGE: Final[istr] = istr("Content-Language")
CONTENT_LENGTH: Final[istr] = istr("Content-Length")
CONTENT_LOCATION: Final[istr] = istr("Content-Location")
CONTENT_MD5: Final[istr] = istr("Content-MD5")
CONTENT_RANGE: Final[istr] = istr("Content-Range")
CONTENT_TRANSFER_ENCODING: Final[istr] = istr("Content-Transfer-Encoding")
CONTENT_TYPE: Final[istr] = istr("Content-Type")
COOKIE: Final[istr] = istr("Cookie")
DATE: Final[istr] = istr("Date")
DESTINATION: Final[istr] = istr("Destination")
DIGEST: Final[istr] = istr("Digest")
ETAG: Final[istr] = istr("Etag")
EXPECT: Final[istr] = istr("Expect")
EXPIRES: Final[istr] = istr("Expires")
FORWARDED: Final[istr] = istr("Forwarded")
FROM: Final[istr] = istr("From")
HOST: Final[istr] = istr("Host")
IF_MATCH: Final[istr] = istr("If-Match")
IF_MODIFIED_SINCE: Final[istr] = istr("If-Modified-Since")
IF_NONE_MATCH: Final[istr] = istr("If-None-Match")
IF_RANGE: Final[istr] = istr("If-Range")
IF_UNMODIFIED_SINCE: Final[istr] = istr("If-Unmodified-Since")
KEEP_ALIVE: Final[istr] = istr("Keep-Alive")
LAST_EVENT_ID: Final[istr] = istr("Last-Event-ID")
LAST_MODIFIED: Final[istr] = istr("Last-Modified")
LINK: Final[istr] = istr("Link")
LOCATION: Final[istr] = istr("Location")
MAX_FORWARDS: Final[istr] = istr("Max-Forwards")
ORIGIN: Final[istr] = istr("Origin")
PRAGMA: Final[istr] = istr("Pragma")
PROXY_AUTHENTICATE: Final[istr] = istr("Proxy-Authenticate")
PROXY_AUTHORIZATION: Final[istr] = istr("Proxy-Authorization")
RANGE: Final[istr] = istr("Range")
REFERER: Final[istr] = istr("Referer")
RETRY_AFTER: Final[istr] = istr("Retry-After")
SEC_WEBSOCKET_ACCEPT: Final[istr] = istr("Sec-WebSocket-Accept")
SEC_WEBSOCKET_VERSION: Final[istr] = istr("Sec-WebSocket-Version")
SEC_WEBSOCKET_PROTOCOL: Final[istr] = istr("Sec-WebSocket-Protocol")
SEC_WEBSOCKET_EXTENSIONS: Final[istr] = istr("Sec-WebSocket-Extensions")
SEC_WEBSOCKET_KEY: Final[istr] = istr("Sec-WebSocket-Key")
SEC_WEBSOCKET_KEY1: Final[istr] = istr("Sec-WebSocket-Key1")
SERVER: Final[istr] = istr("Server")
SET_COOKIE: Final[istr] = istr("Set-Cookie")
TE: Final[istr] = istr("TE")
TRAILER: Final[istr] = istr("Trailer")
TRANSFER_ENCODING: Final[istr] = istr("Transfer-Encoding")
UPGRADE: Final[istr] = istr("Upgrade")
URI: Final[istr] = istr("URI")
USER_AGENT: Final[istr] = istr("User-Agent")
VARY: Final[istr] = istr("Vary")
VIA: Final[istr] = istr("Via")
WANT_DIGEST: Final[istr] = istr("Want-Digest")
WARNING: Final[istr] = istr("Warning")
WWW_AUTHENTICATE: Final[istr] = istr("WWW-Authenticate")
X_FORWARDED_FOR: Final[istr] = istr("X-Forwarded-For")
X_FORWARDED_HOST: Final[istr] = istr("X-Forwarded-Host")
X_FORWARDED_PROTO: Final[istr] = istr("X-Forwarded-Proto")
# These are the upper/lower case variants of the headers/methods
# Example: {'hOst', 'host', 'HoST', 'HOSt', 'hOsT', 'HosT', 'hoSt', ...}
METH_HEAD_ALL: Final = frozenset(
map("".join, itertools.product(*zip(METH_HEAD.upper(), METH_HEAD.lower())))
)
METH_CONNECT_ALL: Final = frozenset(
map("".join, itertools.product(*zip(METH_CONNECT.upper(), METH_CONNECT.lower())))
)
HOST_ALL: Final = frozenset(
map("".join, itertools.product(*zip(HOST.upper(), HOST.lower())))
)

View File

@@ -0,0 +1,958 @@
"""Various helper functions"""
import asyncio
import base64
import binascii
import contextlib
import datetime
import enum
import functools
import inspect
import netrc
import os
import platform
import re
import sys
import time
import weakref
from collections import namedtuple
from contextlib import suppress
from email.parser import HeaderParser
from email.utils import parsedate
from math import ceil
from pathlib import Path
from types import MappingProxyType, TracebackType
from typing import (
Any,
Callable,
ContextManager,
Dict,
Generator,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Tuple,
Type,
TypeVar,
Union,
get_args,
overload,
)
from urllib.parse import quote
from urllib.request import getproxies, proxy_bypass
import attr
from multidict import MultiDict, MultiDictProxy, MultiMapping
from propcache.api import under_cached_property as reify
from yarl import URL
from . import hdrs
from .log import client_logger
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify")
IS_MACOS = platform.system() == "Darwin"
IS_WINDOWS = platform.system() == "Windows"
PY_310 = sys.version_info >= (3, 10)
PY_311 = sys.version_info >= (3, 11)
_T = TypeVar("_T")
_S = TypeVar("_S")
_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
sentinel = _SENTINEL.sentinel
NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
DEBUG = sys.flags.dev_mode or (
not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
)
CHAR = {chr(i) for i in range(0, 128)}
CTL = {chr(i) for i in range(0, 32)} | {
chr(127),
}
SEPARATORS = {
"(",
")",
"<",
">",
"@",
",",
";",
":",
"\\",
'"',
"/",
"[",
"]",
"?",
"=",
"{",
"}",
" ",
chr(9),
}
TOKEN = CHAR ^ CTL ^ SEPARATORS
class noop:
def __await__(self) -> Generator[None, None, None]:
yield
class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
"""Http basic authentication helper."""
def __new__(
cls, login: str, password: str = "", encoding: str = "latin1"
) -> "BasicAuth":
if login is None:
raise ValueError("None is not allowed as login value")
if password is None:
raise ValueError("None is not allowed as password value")
if ":" in login:
raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
return super().__new__(cls, login, password, encoding)
@classmethod
def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
"""Create a BasicAuth object from an Authorization HTTP header."""
try:
auth_type, encoded_credentials = auth_header.split(" ", 1)
except ValueError:
raise ValueError("Could not parse authorization header.")
if auth_type.lower() != "basic":
raise ValueError("Unknown authorization method %s" % auth_type)
try:
decoded = base64.b64decode(
encoded_credentials.encode("ascii"), validate=True
).decode(encoding)
except binascii.Error:
raise ValueError("Invalid base64 encoding.")
try:
# RFC 2617 HTTP Authentication
# https://www.ietf.org/rfc/rfc2617.txt
# the colon must be present, but the username and password may be
# otherwise blank.
username, password = decoded.split(":", 1)
except ValueError:
raise ValueError("Invalid credentials.")
return cls(username, password, encoding=encoding)
@classmethod
def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return None
return cls(url.user or "", url.password or "", encoding=encoding)
def encode(self) -> str:
"""Encode credentials."""
creds = (f"{self.login}:{self.password}").encode(self.encoding)
return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
"""Remove user and password from URL if present and return BasicAuth object."""
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return url, None
return url.with_user(None), BasicAuth(url.user or "", url.password or "")
def netrc_from_env() -> Optional[netrc.netrc]:
"""Load netrc from file.
Attempt to load it from the path specified by the env-var
NETRC or in the default location in the user's home directory.
Returns None if it couldn't be found or fails to parse.
"""
netrc_env = os.environ.get("NETRC")
if netrc_env is not None:
netrc_path = Path(netrc_env)
else:
try:
home_dir = Path.home()
except RuntimeError as e: # pragma: no cover
# if pathlib can't resolve home, it may raise a RuntimeError
client_logger.debug(
"Could not resolve home directory when "
"trying to look for .netrc file: %s",
e,
)
return None
netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc")
try:
return netrc.netrc(str(netrc_path))
except netrc.NetrcParseError as e:
client_logger.warning("Could not parse .netrc file: %s", e)
except OSError as e:
netrc_exists = False
with contextlib.suppress(OSError):
netrc_exists = netrc_path.is_file()
# we couldn't read the file (doesn't exist, permissions, etc.)
if netrc_env or netrc_exists:
# only warn if the environment wanted us to load it,
# or it appears like the default file does actually exist
client_logger.warning("Could not read .netrc file: %s", e)
return None
@attr.s(auto_attribs=True, frozen=True, slots=True)
class ProxyInfo:
proxy: URL
proxy_auth: Optional[BasicAuth]
def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
"""
Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
:raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
entry is found for the ``host``.
"""
if netrc_obj is None:
raise LookupError("No .netrc file found")
auth_from_netrc = netrc_obj.authenticators(host)
if auth_from_netrc is None:
raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
login, account, password = auth_from_netrc
# TODO(PY311): username = login or account
# Up to python 3.10, account could be None if not specified,
# and login will be empty string if not specified. From 3.11,
# login and account will be empty string if not specified.
username = login if (login or account is None) else account
# TODO(PY311): Remove this, as password will be empty string
# if not specified
if password is None:
password = ""
return BasicAuth(username, password)
def proxies_from_env() -> Dict[str, ProxyInfo]:
proxy_urls = {
k: URL(v)
for k, v in getproxies().items()
if k in ("http", "https", "ws", "wss")
}
netrc_obj = netrc_from_env()
stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
ret = {}
for proto, val in stripped.items():
proxy, auth = val
if proxy.scheme in ("https", "wss"):
client_logger.warning(
"%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
)
continue
if netrc_obj and auth is None:
if proxy.host is not None:
try:
auth = basicauth_from_netrc(netrc_obj, proxy.host)
except LookupError:
auth = None
ret[proto] = ProxyInfo(proxy, auth)
return ret
def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
"""Get a permitted proxy for the given URL from the env."""
if url.host is not None and proxy_bypass(url.host):
raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
proxies_in_env = proxies_from_env()
try:
proxy_info = proxies_in_env[url.scheme]
except KeyError:
raise LookupError(f"No proxies found for `{url!s}` in the env")
else:
return proxy_info.proxy, proxy_info.proxy_auth
@attr.s(auto_attribs=True, frozen=True, slots=True)
class MimeType:
type: str
subtype: str
suffix: str
parameters: "MultiDictProxy[str]"
@functools.lru_cache(maxsize=56)
def parse_mimetype(mimetype: str) -> MimeType:
"""Parses a MIME type into its components.
mimetype is a MIME type string.
Returns a MimeType object.
Example:
>>> parse_mimetype('text/html; charset=utf-8')
MimeType(type='text', subtype='html', suffix='',
parameters={'charset': 'utf-8'})
"""
if not mimetype:
return MimeType(
type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
)
parts = mimetype.split(";")
params: MultiDict[str] = MultiDict()
for item in parts[1:]:
if not item:
continue
key, _, value = item.partition("=")
params.add(key.lower().strip(), value.strip(' "'))
fulltype = parts[0].strip().lower()
if fulltype == "*":
fulltype = "*/*"
mtype, _, stype = fulltype.partition("/")
stype, _, suffix = stype.partition("+")
return MimeType(
type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
)
@functools.lru_cache(maxsize=56)
def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]:
"""Parse Content-Type header.
Returns a tuple of the parsed content type and a
MappingProxyType of parameters.
"""
msg = HeaderParser().parsestr(f"Content-Type: {raw}")
content_type = msg.get_content_type()
params = msg.get_params(())
content_dict = dict(params[1:]) # First element is content type again
return content_type, MappingProxyType(content_dict)
def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
name = getattr(obj, "name", None)
if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
return Path(name).name
return default
not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
def quoted_string(content: str) -> str:
"""Return 7-bit content as quoted-string.
Format content into a quoted-string as defined in RFC5322 for
Internet Message Format. Notice that this is not the 8-bit HTTP
format, but the 7-bit email format. Content must be in usascii or
a ValueError is raised.
"""
if not (QCONTENT > set(content)):
raise ValueError(f"bad content for quoted-string {content!r}")
return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
def content_disposition_header(
disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
) -> str:
"""Sets ``Content-Disposition`` header for MIME.
This is the MIME payload Content-Disposition header from RFC 2183
and RFC 7579 section 4.2, not the HTTP Content-Disposition from
RFC 6266.
disptype is a disposition type: inline, attachment, form-data.
Should be valid extension token (see RFC 2183)
quote_fields performs value quoting to 7-bit MIME headers
according to RFC 7578. Set to quote_fields to False if recipient
can take 8-bit file names and field values.
_charset specifies the charset to use when quote_fields is True.
params is a dict with disposition params.
"""
if not disptype or not (TOKEN > set(disptype)):
raise ValueError(f"bad content disposition type {disptype!r}")
value = disptype
if params:
lparams = []
for key, val in params.items():
if not key or not (TOKEN > set(key)):
raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
if quote_fields:
if key.lower() == "filename":
qval = quote(val, "", encoding=_charset)
lparams.append((key, '"%s"' % qval))
else:
try:
qval = quoted_string(val)
except ValueError:
qval = "".join(
(_charset, "''", quote(val, "", encoding=_charset))
)
lparams.append((key + "*", qval))
else:
lparams.append((key, '"%s"' % qval))
else:
qval = val.replace("\\", "\\\\").replace('"', '\\"')
lparams.append((key, '"%s"' % qval))
sparams = "; ".join("=".join(pair) for pair in lparams)
value = "; ".join((value, sparams))
return value
def is_ip_address(host: Optional[str]) -> bool:
"""Check if host looks like an IP Address.
This check is only meant as a heuristic to ensure that
a host is not a domain name.
"""
if not host:
return False
# For a host to be an ipv4 address, it must be all numeric.
# The host must contain a colon to be an IPv6 address.
return ":" in host or host.replace(".", "").isdigit()
_cached_current_datetime: Optional[int] = None
_cached_formatted_datetime = ""
def rfc822_formatted_time() -> str:
global _cached_current_datetime
global _cached_formatted_datetime
now = int(time.time())
if now != _cached_current_datetime:
# Weekday and month names for HTTP date/time formatting;
# always English!
# Tuples are constants stored in codeobject!
_weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
_monthname = (
"", # Dummy so we can use 1-based month numbers
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
)
year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
_cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
_weekdayname[wd],
day,
_monthname[month],
year,
hh,
mm,
ss,
)
_cached_current_datetime = now
return _cached_formatted_datetime
def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
ref, name = info
ob = ref()
if ob is not None:
with suppress(Exception):
getattr(ob, name)()
def weakref_handle(
ob: object,
name: str,
timeout: float,
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is not None and timeout > 0:
when = loop.time() + timeout
if timeout >= timeout_ceil_threshold:
when = ceil(when)
return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
return None
def call_later(
cb: Callable[[], Any],
timeout: float,
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is None or timeout <= 0:
return None
now = loop.time()
when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
return loop.call_at(when, cb)
def calculate_timeout_when(
loop_time: float,
timeout: float,
timeout_ceiling_threshold: float,
) -> float:
"""Calculate when to execute a timeout."""
when = loop_time + timeout
if timeout > timeout_ceiling_threshold:
return ceil(when)
return when
class TimeoutHandle:
"""Timeout handle"""
__slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
def __init__(
self,
loop: asyncio.AbstractEventLoop,
timeout: Optional[float],
ceil_threshold: float = 5,
) -> None:
self._timeout = timeout
self._loop = loop
self._ceil_threshold = ceil_threshold
self._callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
] = []
def register(
self, callback: Callable[..., None], *args: Any, **kwargs: Any
) -> None:
self._callbacks.append((callback, args, kwargs))
def close(self) -> None:
self._callbacks.clear()
def start(self) -> Optional[asyncio.TimerHandle]:
timeout = self._timeout
if timeout is not None and timeout > 0:
when = self._loop.time() + timeout
if timeout >= self._ceil_threshold:
when = ceil(when)
return self._loop.call_at(when, self.__call__)
else:
return None
def timer(self) -> "BaseTimerContext":
if self._timeout is not None and self._timeout > 0:
timer = TimerContext(self._loop)
self.register(timer.timeout)
return timer
else:
return TimerNoop()
def __call__(self) -> None:
for cb, args, kwargs in self._callbacks:
with suppress(Exception):
cb(*args, **kwargs)
self._callbacks.clear()
class BaseTimerContext(ContextManager["BaseTimerContext"]):
__slots__ = ()
def assert_timeout(self) -> None:
"""Raise TimeoutError if timeout has been exceeded."""
class TimerNoop(BaseTimerContext):
__slots__ = ()
def __enter__(self) -> BaseTimerContext:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
return
class TimerContext(BaseTimerContext):
"""Low resolution timeout context manager"""
__slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks: List[asyncio.Task[Any]] = []
self._cancelled = False
self._cancelling = 0
def assert_timeout(self) -> None:
"""Raise TimeoutError if timer has already been cancelled."""
if self._cancelled:
raise asyncio.TimeoutError from None
def __enter__(self) -> BaseTimerContext:
task = asyncio.current_task(loop=self._loop)
if task is None:
raise RuntimeError("Timeout context manager should be used inside a task")
if sys.version_info >= (3, 11):
# Remember if the task was already cancelling
# so when we __exit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = task.cancelling()
if self._cancelled:
raise asyncio.TimeoutError from None
self._tasks.append(task)
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
enter_task: Optional[asyncio.Task[Any]] = None
if self._tasks:
enter_task = self._tasks.pop()
if exc_type is asyncio.CancelledError and self._cancelled:
assert enter_task is not None
# The timeout was hit, and the task was cancelled
# so we need to uncancel the last task that entered the context manager
# since the cancellation should not leak out of the context manager
if sys.version_info >= (3, 11):
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
if enter_task.uncancel() > self._cancelling:
return None
raise asyncio.TimeoutError from exc_val
return None
def timeout(self) -> None:
if not self._cancelled:
for task in set(self._tasks):
task.cancel()
self._cancelled = True
def ceil_timeout(
delay: Optional[float], ceil_threshold: float = 5
) -> async_timeout.Timeout:
if delay is None or delay <= 0:
return async_timeout.timeout(None)
loop = asyncio.get_running_loop()
now = loop.time()
when = now + delay
if delay > ceil_threshold:
when = ceil(when)
return async_timeout.timeout_at(when)
class HeadersMixin:
"""Mixin for handling headers."""
ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
_headers: MultiMapping[str]
_content_type: Optional[str] = None
_content_dict: Optional[Dict[str, str]] = None
_stored_content_type: Union[str, None, _SENTINEL] = sentinel
def _parse_content_type(self, raw: Optional[str]) -> None:
self._stored_content_type = raw
if raw is None:
# default value according to RFC 2616
self._content_type = "application/octet-stream"
self._content_dict = {}
else:
content_type, content_mapping_proxy = parse_content_type(raw)
self._content_type = content_type
# _content_dict needs to be mutable so we can update it
self._content_dict = content_mapping_proxy.copy()
@property
def content_type(self) -> str:
"""The value of content part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
assert self._content_type is not None
return self._content_type
@property
def charset(self) -> Optional[str]:
"""The value of charset part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
assert self._content_dict is not None
return self._content_dict.get("charset")
@property
def content_length(self) -> Optional[int]:
"""The value of Content-Length HTTP header."""
content_length = self._headers.get(hdrs.CONTENT_LENGTH)
return None if content_length is None else int(content_length)
def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
if not fut.done():
fut.set_result(result)
_EXC_SENTINEL = BaseException()
class ErrorableProtocol(Protocol):
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = ...,
) -> None: ... # pragma: no cover
def set_exception(
fut: "asyncio.Future[_T] | ErrorableProtocol",
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
"""Set future exception.
If the future is marked as complete, this function is a no-op.
:param exc_cause: An exception that is a direct cause of ``exc``.
Only set if provided.
"""
if asyncio.isfuture(fut) and fut.done():
return
exc_is_sentinel = exc_cause is _EXC_SENTINEL
exc_causes_itself = exc is exc_cause
if not exc_is_sentinel and not exc_causes_itself:
exc.__cause__ = exc_cause
fut.set_exception(exc)
@functools.total_ordering
class AppKey(Generic[_T]):
"""Keys for static typing support in Application."""
__slots__ = ("_name", "_t", "__orig_class__")
# This may be set by Python when instantiating with a generic type. We need to
# support this, in order to support types that are not concrete classes,
# like Iterable, which can't be passed as the second parameter to __init__.
__orig_class__: Type[object]
def __init__(self, name: str, t: Optional[Type[_T]] = None):
# Prefix with module name to help deduplicate key names.
frame = inspect.currentframe()
while frame:
if frame.f_code.co_name == "<module>":
module: str = frame.f_globals["__name__"]
break
frame = frame.f_back
self._name = module + "." + name
self._t = t
def __lt__(self, other: object) -> bool:
if isinstance(other, AppKey):
return self._name < other._name
return True # Order AppKey above other types.
def __repr__(self) -> str:
t = self._t
if t is None:
with suppress(AttributeError):
# Set to type arg.
t = get_args(self.__orig_class__)[0]
if t is None:
t_repr = "<<Unknown>>"
elif isinstance(t, type):
if t.__module__ == "builtins":
t_repr = t.__qualname__
else:
t_repr = f"{t.__module__}.{t.__qualname__}"
else:
t_repr = repr(t)
return f"<AppKey({self._name}, type={t_repr})>"
class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
__slots__ = ("_maps",)
def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
self._maps = tuple(maps)
def __init_subclass__(cls) -> None:
raise TypeError(
"Inheritance class {} from ChainMapProxy "
"is forbidden".format(cls.__name__)
)
@overload # type: ignore[override]
def __getitem__(self, key: AppKey[_T]) -> _T: ...
@overload
def __getitem__(self, key: str) -> Any: ...
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
for mapping in self._maps:
try:
return mapping[key]
except KeyError:
pass
raise KeyError(key)
@overload # type: ignore[override]
def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ...
@overload
def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
@overload
def get(self, key: str, default: Any = ...) -> Any: ...
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
try:
return self[key]
except KeyError:
return default
def __len__(self) -> int:
# reuses stored hash values if possible
return len(set().union(*self._maps))
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
d: Dict[Union[str, AppKey[Any]], Any] = {}
for mapping in reversed(self._maps):
# reuses stored hash values if possible
d.update(mapping)
return iter(d)
def __contains__(self, key: object) -> bool:
return any(key in m for m in self._maps)
def __bool__(self) -> bool:
return any(self._maps)
def __repr__(self) -> str:
content = ", ".join(map(repr, self._maps))
return f"ChainMapProxy({content})"
# https://tools.ietf.org/html/rfc7232#section-2.3
_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
_ETAGC_RE = re.compile(_ETAGC)
_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
ETAG_ANY = "*"
@attr.s(auto_attribs=True, frozen=True, slots=True)
class ETag:
value: str
is_weak: bool = False
def validate_etag_value(value: str) -> None:
if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
raise ValueError(
f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
)
def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
"""Process a date string, return a datetime object"""
if date_str is not None:
timetuple = parsedate(date_str)
if timetuple is not None:
with suppress(ValueError):
return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
return None
@functools.lru_cache
def must_be_empty_body(method: str, code: int) -> bool:
"""Check if a request must return an empty body."""
return (
code in EMPTY_BODY_STATUS_CODES
or method in EMPTY_BODY_METHODS
or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
)
def should_remove_content_length(method: str, code: int) -> bool:
"""Check if a Content-Length header should be removed.
This should always be a subset of must_be_empty_body
"""
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
# https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
return code in EMPTY_BODY_STATUS_CODES or (
200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
)

View File

@@ -0,0 +1,72 @@
import sys
from http import HTTPStatus
from typing import Mapping, Tuple
from . import __version__
from .http_exceptions import HttpProcessingError as HttpProcessingError
from .http_parser import (
HeadersParser as HeadersParser,
HttpParser as HttpParser,
HttpRequestParser as HttpRequestParser,
HttpResponseParser as HttpResponseParser,
RawRequestMessage as RawRequestMessage,
RawResponseMessage as RawResponseMessage,
)
from .http_websocket import (
WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE,
WS_KEY as WS_KEY,
WebSocketError as WebSocketError,
WebSocketReader as WebSocketReader,
WebSocketWriter as WebSocketWriter,
WSCloseCode as WSCloseCode,
WSMessage as WSMessage,
WSMsgType as WSMsgType,
ws_ext_gen as ws_ext_gen,
ws_ext_parse as ws_ext_parse,
)
from .http_writer import (
HttpVersion as HttpVersion,
HttpVersion10 as HttpVersion10,
HttpVersion11 as HttpVersion11,
StreamWriter as StreamWriter,
)
__all__ = (
"HttpProcessingError",
"RESPONSES",
"SERVER_SOFTWARE",
# .http_writer
"StreamWriter",
"HttpVersion",
"HttpVersion10",
"HttpVersion11",
# .http_parser
"HeadersParser",
"HttpParser",
"HttpRequestParser",
"HttpResponseParser",
"RawRequestMessage",
"RawResponseMessage",
# .http_websocket
"WS_CLOSED_MESSAGE",
"WS_CLOSING_MESSAGE",
"WS_KEY",
"WebSocketReader",
"WebSocketWriter",
"ws_ext_gen",
"ws_ext_parse",
"WSMessage",
"WebSocketError",
"WSMsgType",
"WSCloseCode",
)
SERVER_SOFTWARE: str = "Python/{0[0]}.{0[1]} aiohttp/{1}".format(
sys.version_info, __version__
)
RESPONSES: Mapping[int, Tuple[str, str]] = {
v: (v.phrase, v.description) for v in HTTPStatus.__members__.values()
}

View File

@@ -0,0 +1,112 @@
"""Low-level http related exceptions."""
from textwrap import indent
from typing import Optional, Union
from .typedefs import _CIMultiDict
__all__ = ("HttpProcessingError",)
class HttpProcessingError(Exception):
"""HTTP error.
Shortcut for raising HTTP errors with custom code, message and headers.
code: HTTP Error code.
message: (optional) Error message.
headers: (optional) Headers to be sent in response, a list of pairs
"""
code = 0
message = ""
headers = None
def __init__(
self,
*,
code: Optional[int] = None,
message: str = "",
headers: Optional[_CIMultiDict] = None,
) -> None:
if code is not None:
self.code = code
self.headers = headers
self.message = message
def __str__(self) -> str:
msg = indent(self.message, " ")
return f"{self.code}, message:\n{msg}"
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.code}, message={self.message!r}>"
class BadHttpMessage(HttpProcessingError):
code = 400
message = "Bad Request"
def __init__(self, message: str, *, headers: Optional[_CIMultiDict] = None) -> None:
super().__init__(message=message, headers=headers)
self.args = (message,)
class HttpBadRequest(BadHttpMessage):
code = 400
message = "Bad Request"
class PayloadEncodingError(BadHttpMessage):
"""Base class for payload errors"""
class ContentEncodingError(PayloadEncodingError):
"""Content encoding error."""
class TransferEncodingError(PayloadEncodingError):
"""transfer encoding error."""
class ContentLengthError(PayloadEncodingError):
"""Not enough data for satisfy content length header."""
class LineTooLong(BadHttpMessage):
def __init__(
self, line: str, limit: str = "Unknown", actual_size: str = "Unknown"
) -> None:
super().__init__(
f"Got more than {limit} bytes ({actual_size}) when reading {line}."
)
self.args = (line, limit, actual_size)
class InvalidHeader(BadHttpMessage):
def __init__(self, hdr: Union[bytes, str]) -> None:
hdr_s = hdr.decode(errors="backslashreplace") if isinstance(hdr, bytes) else hdr
super().__init__(f"Invalid HTTP header: {hdr!r}")
self.hdr = hdr_s
self.args = (hdr,)
class BadStatusLine(BadHttpMessage):
def __init__(self, line: str = "", error: Optional[str] = None) -> None:
if not isinstance(line, str):
line = repr(line)
super().__init__(error or f"Bad status line {line!r}")
self.args = (line,)
self.line = line
class BadHttpMethod(BadStatusLine):
"""Invalid HTTP method in status line."""
def __init__(self, line: str = "", error: Optional[str] = None) -> None:
super().__init__(line, error or f"Bad HTTP method in status line {line!r}")
class InvalidURLError(BadHttpMessage):
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,36 @@
"""WebSocket protocol versions 13 and 8."""
from ._websocket.helpers import WS_KEY, ws_ext_gen, ws_ext_parse
from ._websocket.models import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSCloseCode,
WSHandshakeError,
WSMessage,
WSMsgType,
)
from ._websocket.reader import WebSocketReader
from ._websocket.writer import WebSocketWriter
# Messages that the WebSocketResponse.receive needs to handle internally
_INTERNAL_RECEIVE_TYPES = frozenset(
(WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.PING, WSMsgType.PONG)
)
__all__ = (
"WS_CLOSED_MESSAGE",
"WS_CLOSING_MESSAGE",
"WS_KEY",
"WebSocketReader",
"WebSocketWriter",
"WSMessage",
"WebSocketError",
"WSMsgType",
"WSCloseCode",
"ws_ext_gen",
"ws_ext_parse",
"WSHandshakeError",
"WSMessage",
)

View File

@@ -0,0 +1,249 @@
"""Http related parsers and protocol."""
import asyncio
import sys
import zlib
from typing import ( # noqa
Any,
Awaitable,
Callable,
Iterable,
List,
NamedTuple,
Optional,
Union,
)
from multidict import CIMultiDict
from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
MIN_PAYLOAD_FOR_WRITELINES = 2048
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
# writelines is not safe for use
# on Python 3.12+ until 3.12.9
# on Python 3.13+ until 3.13.2
# and on older versions it not any faster than write
# CVE-2024-12254: https://github.com/python/cpython/pull/127656
class HttpVersion(NamedTuple):
major: int
minor: int
HttpVersion10 = HttpVersion(1, 0)
HttpVersion11 = HttpVersion(1, 1)
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]
class StreamWriter(AbstractStreamWriter):
length: Optional[int] = None
chunked: bool = False
_eof: bool = False
_compress: Optional[ZLibCompressor] = None
def __init__(
self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
on_chunk_sent: _T_OnChunkSent = None,
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self.loop = loop
self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
self._on_headers_sent: _T_OnHeadersSent = on_headers_sent
@property
def transport(self) -> Optional[asyncio.Transport]:
return self._protocol.transport
@property
def protocol(self) -> BaseProtocol:
return self._protocol
def enable_chunking(self) -> None:
self.chunked = True
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)
def _writelines(self, chunks: Iterable[bytes]) -> None:
size = 0
for chunk in chunks:
size += len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
transport.write(b"".join(chunks))
else:
transport.writelines(chunks)
async def write(
self,
chunk: Union[bytes, bytearray, memoryview],
*,
drain: bool = True,
LIMIT: int = 0x10000,
) -> None:
"""Writes chunk of data to a stream.
write_eof() indicates end of stream.
writer can't be used after write_eof() method being called.
write() return drain future.
"""
if self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if isinstance(chunk, memoryview):
if chunk.nbytes != len(chunk):
# just reshape it
chunk = chunk.cast("c")
if self._compress is not None:
chunk = await self._compress.compress(chunk)
if not chunk:
return
if self.length is not None:
chunk_len = len(chunk)
if self.length >= chunk_len:
self.length = self.length - chunk_len
else:
chunk = chunk[: self.length]
self.length = 0
if not chunk:
return
if chunk:
if self.chunked:
self._writelines(
(f"{len(chunk):x}\r\n".encode("ascii"), chunk, b"\r\n")
)
else:
self._write(chunk)
if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
await self.drain()
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write request/response status and headers."""
if self._on_headers_sent is not None:
await self._on_headers_sent(headers)
# status + headers
buf = _serialize_headers(status_line, headers)
self._write(buf)
def set_eof(self) -> None:
"""Indicate that the message is complete."""
self._eof = True
async def write_eof(self, chunk: bytes = b"") -> None:
if self._eof:
return
if chunk and self._on_chunk_sent is not None:
await self._on_chunk_sent(chunk)
if self._compress:
chunks: List[bytes] = []
chunks_len = 0
if chunk and (compressed_chunk := await self._compress.compress(chunk)):
chunks_len = len(compressed_chunk)
chunks.append(compressed_chunk)
flush_chunk = self._compress.flush()
chunks_len += len(flush_chunk)
chunks.append(flush_chunk)
assert chunks_len
if self.chunked:
chunk_len_pre = f"{chunks_len:x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, *chunks, b"\r\n0\r\n\r\n"))
elif len(chunks) > 1:
self._writelines(chunks)
else:
self._write(chunks[0])
elif self.chunked:
if chunk:
chunk_len_pre = f"{len(chunk):x}\r\n".encode("ascii")
self._writelines((chunk_len_pre, chunk, b"\r\n0\r\n\r\n"))
else:
self._write(b"0\r\n\r\n")
elif chunk:
self._write(chunk)
await self.drain()
self._eof = True
async def drain(self) -> None:
"""Flush the write buffer.
The intended use is to write
await w.write(data)
await w.drain()
"""
protocol = self._protocol
if protocol.transport is not None and protocol._paused:
await protocol._drain_helper()
def _safe_header(string: str) -> str:
if "\r" in string or "\n" in string:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
return string
def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
return line.encode("utf-8")
_serialize_headers = _py_serialize_headers
try:
import aiohttp._http_writer as _http_writer # type: ignore[import-not-found]
_c_serialize_headers = _http_writer._serialize_headers
if not NO_EXTENSIONS:
_serialize_headers = _c_serialize_headers
except ImportError:
pass

View File

@@ -0,0 +1,8 @@
import logging
access_logger = logging.getLogger("aiohttp.access")
client_logger = logging.getLogger("aiohttp.client")
internal_logger = logging.getLogger("aiohttp.internal")
server_logger = logging.getLogger("aiohttp.server")
web_logger = logging.getLogger("aiohttp.web")
ws_logger = logging.getLogger("aiohttp.websocket")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,519 @@
import asyncio
import enum
import io
import json
import mimetypes
import os
import sys
import warnings
from abc import ABC, abstractmethod
from itertools import chain
from typing import (
IO,
TYPE_CHECKING,
Any,
Dict,
Final,
Iterable,
Optional,
TextIO,
Tuple,
Type,
Union,
)
from multidict import CIMultiDict
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import (
_SENTINEL,
content_disposition_header,
guess_filename,
parse_mimetype,
sentinel,
)
from .streams import StreamReader
from .typedefs import JSONEncoder, _CIMultiDict
__all__ = (
"PAYLOAD_REGISTRY",
"get_payload",
"payload_type",
"Payload",
"BytesPayload",
"StringPayload",
"IOBasePayload",
"BytesIOPayload",
"BufferedReaderPayload",
"TextIOPayload",
"StringIOPayload",
"JsonPayload",
"AsyncIterablePayload",
)
TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB
if TYPE_CHECKING:
from typing import List
class LookupError(Exception):
pass
class Order(str, enum.Enum):
normal = "normal"
try_first = "try_first"
try_last = "try_last"
def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
def register_payload(
factory: Type["Payload"], type: Any, *, order: Order = Order.normal
) -> None:
PAYLOAD_REGISTRY.register(factory, type, order=order)
class payload_type:
def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
self.type = type
self.order = order
def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
register_payload(factory, self.type, order=self.order)
return factory
PayloadType = Type["Payload"]
_PayloadRegistryItem = Tuple[PayloadType, Any]
class PayloadRegistry:
"""Payload registry.
note: we need zope.interface for more efficient adapter search
"""
__slots__ = ("_first", "_normal", "_last", "_normal_lookup")
def __init__(self) -> None:
self._first: List[_PayloadRegistryItem] = []
self._normal: List[_PayloadRegistryItem] = []
self._last: List[_PayloadRegistryItem] = []
self._normal_lookup: Dict[Any, PayloadType] = {}
def get(
self,
data: Any,
*args: Any,
_CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
**kwargs: Any,
) -> "Payload":
if self._first:
for factory, type_ in self._first:
if isinstance(data, type_):
return factory(data, *args, **kwargs)
# Try the fast lookup first
if lookup_factory := self._normal_lookup.get(type(data)):
return lookup_factory(data, *args, **kwargs)
# Bail early if its already a Payload
if isinstance(data, Payload):
return data
# Fallback to the slower linear search
for factory, type_ in _CHAIN(self._normal, self._last):
if isinstance(data, type_):
return factory(data, *args, **kwargs)
raise LookupError()
def register(
self, factory: PayloadType, type: Any, *, order: Order = Order.normal
) -> None:
if order is Order.try_first:
self._first.append((factory, type))
elif order is Order.normal:
self._normal.append((factory, type))
if isinstance(type, Iterable):
for t in type:
self._normal_lookup[t] = factory
else:
self._normal_lookup[type] = factory
elif order is Order.try_last:
self._last.append((factory, type))
else:
raise ValueError(f"Unsupported order {order!r}")
class Payload(ABC):
_default_content_type: str = "application/octet-stream"
_size: Optional[int] = None
def __init__(
self,
value: Any,
headers: Optional[
Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
] = None,
content_type: Union[str, None, _SENTINEL] = sentinel,
filename: Optional[str] = None,
encoding: Optional[str] = None,
**kwargs: Any,
) -> None:
self._encoding = encoding
self._filename = filename
self._headers: _CIMultiDict = CIMultiDict()
self._value = value
if content_type is not sentinel and content_type is not None:
self._headers[hdrs.CONTENT_TYPE] = content_type
elif self._filename is not None:
if sys.version_info >= (3, 13):
guesser = mimetypes.guess_file_type
else:
guesser = mimetypes.guess_type
content_type = guesser(self._filename)[0]
if content_type is None:
content_type = self._default_content_type
self._headers[hdrs.CONTENT_TYPE] = content_type
else:
self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
if headers:
self._headers.update(headers)
@property
def size(self) -> Optional[int]:
"""Size of the payload."""
return self._size
@property
def filename(self) -> Optional[str]:
"""Filename of the payload."""
return self._filename
@property
def headers(self) -> _CIMultiDict:
"""Custom item headers"""
return self._headers
@property
def _binary_headers(self) -> bytes:
return (
"".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
"utf-8"
)
+ b"\r\n"
)
@property
def encoding(self) -> Optional[str]:
"""Payload encoding"""
return self._encoding
@property
def content_type(self) -> str:
"""Content type"""
return self._headers[hdrs.CONTENT_TYPE]
def set_content_disposition(
self,
disptype: str,
quote_fields: bool = True,
_charset: str = "utf-8",
**params: Any,
) -> None:
"""Sets ``Content-Disposition`` header."""
self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
disptype, quote_fields=quote_fields, _charset=_charset, **params
)
@abstractmethod
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
"""Return string representation of the value.
This is named decode() to allow compatibility with bytes objects.
"""
@abstractmethod
async def write(self, writer: AbstractStreamWriter) -> None:
"""Write payload.
writer is an AbstractStreamWriter instance:
"""
class BytesPayload(Payload):
_value: bytes
def __init__(
self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any
) -> None:
if "content_type" not in kwargs:
kwargs["content_type"] = "application/octet-stream"
super().__init__(value, *args, **kwargs)
if isinstance(value, memoryview):
self._size = value.nbytes
elif isinstance(value, (bytes, bytearray)):
self._size = len(value)
else:
raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
if self._size > TOO_LARGE_BYTES_BODY:
kwargs = {"source": self}
warnings.warn(
"Sending a large body directly with raw bytes might"
" lock the event loop. You should probably pass an "
"io.BytesIO object instead",
ResourceWarning,
**kwargs,
)
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.decode(encoding, errors)
async def write(self, writer: AbstractStreamWriter) -> None:
await writer.write(self._value)
class StringPayload(BytesPayload):
def __init__(
self,
value: str,
*args: Any,
encoding: Optional[str] = None,
content_type: Optional[str] = None,
**kwargs: Any,
) -> None:
if encoding is None:
if content_type is None:
real_encoding = "utf-8"
content_type = "text/plain; charset=utf-8"
else:
mimetype = parse_mimetype(content_type)
real_encoding = mimetype.parameters.get("charset", "utf-8")
else:
if content_type is None:
content_type = "text/plain; charset=%s" % encoding
real_encoding = encoding
super().__init__(
value.encode(real_encoding),
encoding=real_encoding,
content_type=content_type,
*args,
**kwargs,
)
class StringIOPayload(StringPayload):
def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
super().__init__(value.read(), *args, **kwargs)
class IOBasePayload(Payload):
_value: io.IOBase
def __init__(
self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
) -> None:
if "filename" not in kwargs:
kwargs["filename"] = guess_filename(value)
super().__init__(value, *args, **kwargs)
if self._filename is not None and disposition is not None:
if hdrs.CONTENT_DISPOSITION not in self.headers:
self.set_content_disposition(disposition, filename=self._filename)
async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
await writer.write(chunk)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
finally:
await loop.run_in_executor(None, self._value.close)
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return "".join(r.decode(encoding, errors) for r in self._value.readlines())
class TextIOPayload(IOBasePayload):
_value: io.TextIOBase
def __init__(
self,
value: TextIO,
*args: Any,
encoding: Optional[str] = None,
content_type: Optional[str] = None,
**kwargs: Any,
) -> None:
if encoding is None:
if content_type is None:
encoding = "utf-8"
content_type = "text/plain; charset=utf-8"
else:
mimetype = parse_mimetype(content_type)
encoding = mimetype.parameters.get("charset", "utf-8")
else:
if content_type is None:
content_type = "text/plain; charset=%s" % encoding
super().__init__(
value,
content_type=content_type,
encoding=encoding,
*args,
**kwargs,
)
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except OSError:
return None
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read()
async def write(self, writer: AbstractStreamWriter) -> None:
loop = asyncio.get_event_loop()
try:
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
while chunk:
data = (
chunk.encode(encoding=self._encoding)
if self._encoding
else chunk.encode()
)
await writer.write(data)
chunk = await loop.run_in_executor(None, self._value.read, 2**16)
finally:
await loop.run_in_executor(None, self._value.close)
class BytesIOPayload(IOBasePayload):
_value: io.BytesIO
@property
def size(self) -> int:
position = self._value.tell()
end = self._value.seek(0, os.SEEK_END)
self._value.seek(position)
return end - position
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read().decode(encoding, errors)
class BufferedReaderPayload(IOBasePayload):
_value: io.BufferedIOBase
@property
def size(self) -> Optional[int]:
try:
return os.fstat(self._value.fileno()).st_size - self._value.tell()
except (OSError, AttributeError):
# data.fileno() is not supported, e.g.
# io.BufferedReader(io.BytesIO(b'data'))
# For some file-like objects (e.g. tarfile), the fileno() attribute may
# not exist at all, and will instead raise an AttributeError.
return None
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
return self._value.read().decode(encoding, errors)
class JsonPayload(BytesPayload):
def __init__(
self,
value: Any,
encoding: str = "utf-8",
content_type: str = "application/json",
dumps: JSONEncoder = json.dumps,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(
dumps(value).encode(encoding),
content_type=content_type,
encoding=encoding,
*args,
**kwargs,
)
if TYPE_CHECKING:
from typing import AsyncIterable, AsyncIterator
_AsyncIterator = AsyncIterator[bytes]
_AsyncIterable = AsyncIterable[bytes]
else:
from collections.abc import AsyncIterable, AsyncIterator
_AsyncIterator = AsyncIterator
_AsyncIterable = AsyncIterable
class AsyncIterablePayload(Payload):
_iter: Optional[_AsyncIterator] = None
_value: _AsyncIterable
def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
if not isinstance(value, AsyncIterable):
raise TypeError(
"value argument must support "
"collections.abc.AsyncIterable interface, "
"got {!r}".format(type(value))
)
if "content_type" not in kwargs:
kwargs["content_type"] = "application/octet-stream"
super().__init__(value, *args, **kwargs)
self._iter = value.__aiter__()
async def write(self, writer: AbstractStreamWriter) -> None:
if self._iter:
try:
# iter is not None check prevents rare cases
# when the case iterable is used twice
while True:
chunk = await self._iter.__anext__()
await writer.write(chunk)
except StopAsyncIteration:
self._iter = None
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
raise TypeError("Unable to decode.")
class StreamReaderPayload(AsyncIterablePayload):
def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
super().__init__(value.iter_any(), *args, **kwargs)
PAYLOAD_REGISTRY = PayloadRegistry()
PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
PAYLOAD_REGISTRY.register(StringPayload, str)
PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
# try_last for giving a chance to more specialized async interables like
# multidict.BodyPartReaderPayload override the default
PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)

View File

@@ -0,0 +1,78 @@
"""
Payload implementation for coroutines as data provider.
As a simple case, you can upload data from file::
@aiohttp.streamer
async def file_sender(writer, file_name=None):
with open(file_name, 'rb') as f:
chunk = f.read(2**16)
while chunk:
await writer.write(chunk)
chunk = f.read(2**16)
Then you can use `file_sender` like this:
async with session.post('http://httpbin.org/post',
data=file_sender(file_name='huge_file')) as resp:
print(await resp.text())
..note:: Coroutine must accept `writer` as first argument
"""
import types
import warnings
from typing import Any, Awaitable, Callable, Dict, Tuple
from .abc import AbstractStreamWriter
from .payload import Payload, payload_type
__all__ = ("streamer",)
class _stream_wrapper:
def __init__(
self,
coro: Callable[..., Awaitable[None]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> None:
self.coro = types.coroutine(coro)
self.args = args
self.kwargs = kwargs
async def __call__(self, writer: AbstractStreamWriter) -> None:
await self.coro(writer, *self.args, **self.kwargs)
class streamer:
def __init__(self, coro: Callable[..., Awaitable[None]]) -> None:
warnings.warn(
"@streamer is deprecated, use async generators instead",
DeprecationWarning,
stacklevel=2,
)
self.coro = coro
def __call__(self, *args: Any, **kwargs: Any) -> _stream_wrapper:
return _stream_wrapper(self.coro, args, kwargs)
@payload_type(_stream_wrapper)
class StreamWrapperPayload(Payload):
async def write(self, writer: AbstractStreamWriter) -> None:
await self._value(writer)
def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
raise TypeError("Unable to decode.")
@payload_type(streamer)
class StreamPayload(StreamWrapperPayload):
def __init__(self, value: Any, *args: Any, **kwargs: Any) -> None:
super().__init__(value(), *args, **kwargs)
async def write(self, writer: AbstractStreamWriter) -> None:
await self._value(writer)

View File

@@ -0,0 +1 @@
Marker

View File

@@ -0,0 +1,436 @@
import asyncio
import contextlib
import inspect
import warnings
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterator,
Optional,
Protocol,
Type,
Union,
overload,
)
import pytest
from .test_utils import (
BaseTestServer,
RawTestServer,
TestClient,
TestServer,
loop_context,
setup_test_loop,
teardown_test_loop,
unused_port as _unused_port,
)
from .web import Application, BaseRequest, Request
from .web_protocol import _RequestHandler
try:
import uvloop
except ImportError: # pragma: no cover
uvloop = None # type: ignore[assignment]
class AiohttpClient(Protocol):
@overload
async def __call__(
self,
__param: Application,
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> TestClient[Request, Application]: ...
@overload
async def __call__(
self,
__param: BaseTestServer,
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> TestClient[BaseRequest, None]: ...
class AiohttpServer(Protocol):
def __call__(
self, app: Application, *, port: Optional[int] = None, **kwargs: Any
) -> Awaitable[TestServer]: ...
class AiohttpRawServer(Protocol):
def __call__(
self, handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
) -> Awaitable[RawTestServer]: ...
def pytest_addoption(parser): # type: ignore[no-untyped-def]
parser.addoption(
"--aiohttp-fast",
action="store_true",
default=False,
help="run tests faster by disabling extra checks",
)
parser.addoption(
"--aiohttp-loop",
action="store",
default="pyloop",
help="run tests with specific loop: pyloop, uvloop or all",
)
parser.addoption(
"--aiohttp-enable-loop-debug",
action="store_true",
default=False,
help="enable event loop debug mode",
)
def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def]
"""Set up pytest fixture.
Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
"""
func = fixturedef.func
if inspect.isasyncgenfunction(func):
# async generator fixture
is_async_gen = True
elif inspect.iscoroutinefunction(func):
# regular async fixture
is_async_gen = False
else:
# not an async fixture, nothing to do
return
strip_request = False
if "request" not in fixturedef.argnames:
fixturedef.argnames += ("request",)
strip_request = True
def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
request = kwargs["request"]
if strip_request:
del kwargs["request"]
# if neither the fixture nor the test use the 'loop' fixture,
# 'getfixturevalue' will fail because the test is not parameterized
# (this can be removed someday if 'loop' is no longer parameterized)
if "loop" not in request.fixturenames:
raise Exception(
"Asynchronous fixtures must depend on the 'loop' fixture or "
"be used in tests depending from it."
)
_loop = request.getfixturevalue("loop")
if is_async_gen:
# for async generators, we need to advance the generator once,
# then advance it again in a finalizer
gen = func(*args, **kwargs)
def finalizer(): # type: ignore[no-untyped-def]
try:
return _loop.run_until_complete(gen.__anext__())
except StopAsyncIteration:
pass
request.addfinalizer(finalizer)
return _loop.run_until_complete(gen.__anext__())
else:
return _loop.run_until_complete(func(*args, **kwargs))
fixturedef.func = wrapper
@pytest.fixture
def fast(request): # type: ignore[no-untyped-def]
"""--fast config option"""
return request.config.getoption("--aiohttp-fast")
@pytest.fixture
def loop_debug(request): # type: ignore[no-untyped-def]
"""--enable-loop-debug config option"""
return request.config.getoption("--aiohttp-enable-loop-debug")
@contextlib.contextmanager
def _runtime_warning_context(): # type: ignore[no-untyped-def]
"""Context manager which checks for RuntimeWarnings.
This exists specifically to
avoid "coroutine 'X' was never awaited" warnings being missed.
If RuntimeWarnings occur in the context a RuntimeError is raised.
"""
with warnings.catch_warnings(record=True) as _warnings:
yield
rw = [
"{w.filename}:{w.lineno}:{w.message}".format(w=w)
for w in _warnings
if w.category == RuntimeWarning
]
if rw:
raise RuntimeError(
"{} Runtime Warning{},\n{}".format(
len(rw), "" if len(rw) == 1 else "s", "\n".join(rw)
)
)
@contextlib.contextmanager
def _passthrough_loop_context(loop, fast=False): # type: ignore[no-untyped-def]
"""Passthrough loop context.
Sets up and tears down a loop unless one is passed in via the loop
argument when it's passed straight through.
"""
if loop:
# loop already exists, pass it straight through
yield loop
else:
# this shadows loop_context's standard behavior
loop = setup_test_loop()
yield loop
teardown_test_loop(loop, fast=fast)
def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def]
"""Fix pytest collecting for coroutines."""
if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj):
return list(collector._genfunctions(name, obj))
def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def]
"""Run coroutines in an event loop instead of a normal function call."""
fast = pyfuncitem.config.getoption("--aiohttp-fast")
if inspect.iscoroutinefunction(pyfuncitem.function):
existing_loop = pyfuncitem.funcargs.get(
"proactor_loop"
) or pyfuncitem.funcargs.get("loop", None)
with _runtime_warning_context():
with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
testargs = {
arg: pyfuncitem.funcargs[arg]
for arg in pyfuncitem._fixtureinfo.argnames
}
_loop.run_until_complete(pyfuncitem.obj(**testargs))
return True
def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def]
if "loop_factory" not in metafunc.fixturenames:
return
loops = metafunc.config.option.aiohttp_loop
avail_factories: Dict[str, Type[asyncio.AbstractEventLoopPolicy]]
avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy}
if uvloop is not None: # pragma: no cover
avail_factories["uvloop"] = uvloop.EventLoopPolicy
if loops == "all":
loops = "pyloop,uvloop?"
factories = {} # type: ignore[var-annotated]
for name in loops.split(","):
required = not name.endswith("?")
name = name.strip(" ?")
if name not in avail_factories: # pragma: no cover
if required:
raise ValueError(
"Unknown loop '%s', available loops: %s"
% (name, list(factories.keys()))
)
else:
continue
factories[name] = avail_factories[name]
metafunc.parametrize(
"loop_factory", list(factories.values()), ids=list(factories.keys())
)
@pytest.fixture
def loop(loop_factory, fast, loop_debug): # type: ignore[no-untyped-def]
"""Return an instance of the event loop."""
policy = loop_factory()
asyncio.set_event_loop_policy(policy)
with loop_context(fast=fast) as _loop:
if loop_debug:
_loop.set_debug(True) # pragma: no cover
asyncio.set_event_loop(_loop)
yield _loop
@pytest.fixture
def proactor_loop(): # type: ignore[no-untyped-def]
policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore[attr-defined]
asyncio.set_event_loop_policy(policy)
with loop_context(policy.new_event_loop) as _loop:
asyncio.set_event_loop(_loop)
yield _loop
@pytest.fixture
def unused_port(aiohttp_unused_port: Callable[[], int]) -> Callable[[], int]:
warnings.warn(
"Deprecated, use aiohttp_unused_port fixture instead",
DeprecationWarning,
stacklevel=2,
)
return aiohttp_unused_port
@pytest.fixture
def aiohttp_unused_port() -> Callable[[], int]:
"""Return a port that is unused on the current host."""
return _unused_port
@pytest.fixture
def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
"""Factory to create a TestServer instance, given an app.
aiohttp_server(app, **kwargs)
"""
servers = []
async def go(
app: Application, *, port: Optional[int] = None, **kwargs: Any
) -> TestServer:
server = TestServer(app, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
return server
yield go
async def finalize() -> None:
while servers:
await servers.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def test_server(aiohttp_server): # type: ignore[no-untyped-def] # pragma: no cover
warnings.warn(
"Deprecated, use aiohttp_server fixture instead",
DeprecationWarning,
stacklevel=2,
)
return aiohttp_server
@pytest.fixture
def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]:
"""Factory to create a RawTestServer instance, given a web handler.
aiohttp_raw_server(handler, **kwargs)
"""
servers = []
async def go(
handler: _RequestHandler, *, port: Optional[int] = None, **kwargs: Any
) -> RawTestServer:
server = RawTestServer(handler, port=port)
await server.start_server(loop=loop, **kwargs)
servers.append(server)
return server
yield go
async def finalize() -> None:
while servers:
await servers.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def raw_test_server( # type: ignore[no-untyped-def] # pragma: no cover
aiohttp_raw_server,
):
warnings.warn(
"Deprecated, use aiohttp_raw_server fixture instead",
DeprecationWarning,
stacklevel=2,
)
return aiohttp_raw_server
@pytest.fixture
def aiohttp_client(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpClient]:
"""Factory to create a TestClient instance.
aiohttp_client(app, **kwargs)
aiohttp_client(server, **kwargs)
aiohttp_client(raw_server, **kwargs)
"""
clients = []
@overload
async def go(
__param: Application,
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> TestClient[Request, Application]: ...
@overload
async def go(
__param: BaseTestServer,
*,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> TestClient[BaseRequest, None]: ...
async def go(
__param: Union[Application, BaseTestServer],
*args: Any,
server_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> TestClient[Any, Any]:
if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type]
__param, (Application, BaseTestServer)
):
__param = __param(loop, *args, **kwargs)
kwargs = {}
else:
assert not args, "args should be empty"
if isinstance(__param, Application):
server_kwargs = server_kwargs or {}
server = TestServer(__param, loop=loop, **server_kwargs)
client = TestClient(server, loop=loop, **kwargs)
elif isinstance(__param, BaseTestServer):
client = TestClient(__param, loop=loop, **kwargs)
else:
raise ValueError("Unknown argument type: %r" % type(__param))
await client.start_server()
clients.append(client)
return client
yield go
async def finalize() -> None:
while clients:
await clients.pop().close()
loop.run_until_complete(finalize())
@pytest.fixture
def test_client(aiohttp_client): # type: ignore[no-untyped-def] # pragma: no cover
warnings.warn(
"Deprecated, use aiohttp_client fixture instead",
DeprecationWarning,
stacklevel=2,
)
return aiohttp_client

View File

@@ -0,0 +1,190 @@
import asyncio
import socket
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from .abc import AbstractResolver, ResolveResult
__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
try:
import aiodns
aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
except ImportError: # pragma: no cover
aiodns = None # type: ignore[assignment]
aiodns_default = False
_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
_AI_ADDRCONFIG = socket.AI_ADDRCONFIG
if hasattr(socket, "AI_MASK"):
_AI_ADDRCONFIG &= socket.AI_MASK
class ThreadedResolver(AbstractResolver):
"""Threaded resolver.
Uses an Executor for synchronous getaddrinfo() calls.
concurrent.futures.ThreadPoolExecutor is used by default.
"""
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_running_loop()
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
infos = await self._loop.getaddrinfo(
host,
port,
type=socket.SOCK_STREAM,
family=family,
flags=_AI_ADDRCONFIG,
)
hosts: List[ResolveResult] = []
for family, _, proto, _, address in infos:
if family == socket.AF_INET6:
if len(address) < 3:
# IPv6 is not supported by Python build,
# or IPv6 is not enabled in the host
continue
if address[3]:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
resolved_host, _port = await self._loop.getnameinfo(
address, _NAME_SOCKET_FLAGS
)
port = int(_port)
else:
resolved_host, port = address[:2]
else: # IPv4
assert family == socket.AF_INET
resolved_host, port = address # type: ignore[misc]
hosts.append(
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=proto,
flags=_NUMERIC_SOCKET_FLAGS,
)
)
return hosts
async def close(self) -> None:
pass
class AsyncResolver(AbstractResolver):
"""Use the `aiodns` package to make asynchronous DNS lookups"""
def __init__(
self,
loop: Optional[asyncio.AbstractEventLoop] = None,
*args: Any,
**kwargs: Any,
) -> None:
if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")
self._resolver = aiodns.DNSResolver(*args, **kwargs)
if not hasattr(self._resolver, "gethostbyname"):
# aiodns 1.1 is not available, fallback to DNSResolver.query
self.resolve = self._resolve_with_query # type: ignore
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
try:
resp = await self._resolver.getaddrinfo(
host,
port=port,
type=socket.SOCK_STREAM,
family=family,
flags=_AI_ADDRCONFIG,
)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(None, msg) from exc
hosts: List[ResolveResult] = []
for node in resp.nodes:
address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
family = node.family
if family == socket.AF_INET6:
if len(address) > 3 and address[3]:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
result = await self._resolver.getnameinfo(
(address[0].decode("ascii"), *address[1:]),
_NAME_SOCKET_FLAGS,
)
resolved_host = result.node
else:
resolved_host = address[0].decode("ascii")
port = address[1]
else: # IPv4
assert family == socket.AF_INET
resolved_host = address[0].decode("ascii")
port = address[1]
hosts.append(
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=0,
flags=_NUMERIC_SOCKET_FLAGS,
)
)
if not hosts:
raise OSError(None, "DNS lookup failed")
return hosts
async def _resolve_with_query(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
if family == socket.AF_INET6:
qtype = "AAAA"
else:
qtype = "A"
try:
resp = await self._resolver.query(host, qtype)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(None, msg) from exc
hosts = []
for rr in resp:
hosts.append(
{
"hostname": host,
"host": rr.host,
"port": port,
"family": family,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}
)
if not hosts:
raise OSError(None, "DNS lookup failed")
return hosts
async def close(self) -> None:
self._resolver.cancel()
_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver

View File

@@ -0,0 +1,727 @@
import asyncio
import collections
import warnings
from typing import (
Awaitable,
Callable,
Deque,
Final,
Generic,
List,
Optional,
Tuple,
TypeVar,
)
from .base_protocol import BaseProtocol
from .helpers import (
_EXC_SENTINEL,
BaseTimerContext,
TimerNoop,
set_exception,
set_result,
)
from .log import internal_logger
__all__ = (
"EMPTY_PAYLOAD",
"EofStream",
"StreamReader",
"DataQueue",
)
_T = TypeVar("_T")
class EofStream(Exception):
"""eof stream indication."""
class AsyncStreamIterator(Generic[_T]):
__slots__ = ("read_func",)
def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None:
self.read_func = read_func
def __aiter__(self) -> "AsyncStreamIterator[_T]":
return self
async def __anext__(self) -> _T:
try:
rv = await self.read_func()
except EofStream:
raise StopAsyncIteration
if rv == b"":
raise StopAsyncIteration
return rv
class ChunkTupleAsyncStreamIterator:
__slots__ = ("_stream",)
def __init__(self, stream: "StreamReader") -> None:
self._stream = stream
def __aiter__(self) -> "ChunkTupleAsyncStreamIterator":
return self
async def __anext__(self) -> Tuple[bytes, bool]:
rv = await self._stream.readchunk()
if rv == (b"", False):
raise StopAsyncIteration
return rv
class AsyncStreamReaderMixin:
__slots__ = ()
def __aiter__(self) -> AsyncStreamIterator[bytes]:
return AsyncStreamIterator(self.readline) # type: ignore[attr-defined]
def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]:
"""Returns an asynchronous iterator that yields chunks of size n."""
return AsyncStreamIterator(lambda: self.read(n)) # type: ignore[attr-defined]
def iter_any(self) -> AsyncStreamIterator[bytes]:
"""Yield all available data as soon as it is received."""
return AsyncStreamIterator(self.readany) # type: ignore[attr-defined]
def iter_chunks(self) -> ChunkTupleAsyncStreamIterator:
"""Yield chunks of data as they are received by the server.
The yielded objects are tuples
of (bytes, bool) as returned by the StreamReader.readchunk method.
"""
return ChunkTupleAsyncStreamIterator(self) # type: ignore[arg-type]
class StreamReader(AsyncStreamReaderMixin):
"""An enhancement of asyncio.StreamReader.
Supports asynchronous iteration by line, chunk or as available::
async for line in reader:
...
async for chunk in reader.iter_chunked(1024):
...
async for slice in reader.iter_any():
...
"""
__slots__ = (
"_protocol",
"_low_water",
"_high_water",
"_loop",
"_size",
"_cursor",
"_http_chunk_splits",
"_buffer",
"_buffer_offset",
"_eof",
"_waiter",
"_eof_waiter",
"_exception",
"_timer",
"_eof_callbacks",
"_eof_counter",
"total_bytes",
)
def __init__(
self,
protocol: BaseProtocol,
limit: int,
*,
timer: Optional[BaseTimerContext] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
self._protocol = protocol
self._low_water = limit
self._high_water = limit * 2
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._size = 0
self._cursor = 0
self._http_chunk_splits: Optional[List[int]] = None
self._buffer: Deque[bytes] = collections.deque()
self._buffer_offset = 0
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._eof_waiter: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._timer = TimerNoop() if timer is None else timer
self._eof_callbacks: List[Callable[[], None]] = []
self._eof_counter = 0
self.total_bytes = 0
def __repr__(self) -> str:
info = [self.__class__.__name__]
if self._size:
info.append("%d bytes" % self._size)
if self._eof:
info.append("eof")
if self._low_water != 2**16: # default limit
info.append("low=%d high=%d" % (self._low_water, self._high_water))
if self._waiter:
info.append("w=%r" % self._waiter)
if self._exception:
info.append("e=%r" % self._exception)
return "<%s>" % " ".join(info)
def get_read_buffer_limits(self) -> Tuple[int, int]:
return (self._low_water, self._high_water)
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._exception = exc
self._eof_callbacks.clear()
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
set_exception(waiter, exc, exc_cause)
def on_eof(self, callback: Callable[[], None]) -> None:
if self._eof:
try:
callback()
except Exception:
internal_logger.exception("Exception in eof callback")
else:
self._eof_callbacks.append(callback)
def feed_eof(self) -> None:
self._eof = True
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
waiter = self._eof_waiter
if waiter is not None:
self._eof_waiter = None
set_result(waiter, None)
if self._protocol._reading_paused:
self._protocol.resume_reading()
for cb in self._eof_callbacks:
try:
cb()
except Exception:
internal_logger.exception("Exception in eof callback")
self._eof_callbacks.clear()
def is_eof(self) -> bool:
"""Return True if 'feed_eof' was called."""
return self._eof
def at_eof(self) -> bool:
"""Return True if the buffer is empty and 'feed_eof' was called."""
return self._eof and not self._buffer
async def wait_eof(self) -> None:
if self._eof:
return
assert self._eof_waiter is None
self._eof_waiter = self._loop.create_future()
try:
await self._eof_waiter
finally:
self._eof_waiter = None
def unread_data(self, data: bytes) -> None:
"""rollback reading some data from stream, inserting it to buffer head."""
warnings.warn(
"unread_data() is deprecated "
"and will be removed in future releases (#3260)",
DeprecationWarning,
stacklevel=2,
)
if not data:
return
if self._buffer_offset:
self._buffer[0] = self._buffer[0][self._buffer_offset :]
self._buffer_offset = 0
self._size += len(data)
self._cursor -= len(data)
self._buffer.appendleft(data)
self._eof_counter = 0
# TODO: size is ignored, remove the param later
def feed_data(self, data: bytes, size: int = 0) -> None:
assert not self._eof, "feed_data after feed_eof"
if not data:
return
data_len = len(data)
self._size += data_len
self._buffer.append(data)
self.total_bytes += data_len
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
if self._size > self._high_water and not self._protocol._reading_paused:
self._protocol.pause_reading()
def begin_http_chunk_receiving(self) -> None:
if self._http_chunk_splits is None:
if self.total_bytes:
raise RuntimeError(
"Called begin_http_chunk_receiving when some data was already fed"
)
self._http_chunk_splits = []
def end_http_chunk_receiving(self) -> None:
if self._http_chunk_splits is None:
raise RuntimeError(
"Called end_chunk_receiving without calling "
"begin_chunk_receiving first"
)
# self._http_chunk_splits contains logical byte offsets from start of
# the body transfer. Each offset is the offset of the end of a chunk.
# "Logical" means bytes, accessible for a user.
# If no chunks containing logical data were received, current position
# is difinitely zero.
pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0
if self.total_bytes == pos:
# We should not add empty chunks here. So we check for that.
# Note, when chunked + gzip is used, we can receive a chunk
# of compressed data, but that data may not be enough for gzip FSM
# to yield any uncompressed data. That's why current position may
# not change after receiving a chunk.
return
self._http_chunk_splits.append(self.total_bytes)
# wake up readchunk when end of http chunk received
waiter = self._waiter
if waiter is not None:
self._waiter = None
set_result(waiter, None)
async def _wait(self, func_name: str) -> None:
if not self._protocol.connected:
raise RuntimeError("Connection closed.")
# StreamReader uses a future to link the protocol feed_data() method
# to a read coroutine. Running two read coroutines at the same time
# would have an unexpected behaviour. It would not possible to know
# which coroutine would get the next data.
if self._waiter is not None:
raise RuntimeError(
"%s() called while another coroutine is "
"already waiting for incoming data" % func_name
)
waiter = self._waiter = self._loop.create_future()
try:
with self._timer:
await waiter
finally:
self._waiter = None
async def readline(self) -> bytes:
return await self.readuntil()
async def readuntil(self, separator: bytes = b"\n") -> bytes:
seplen = len(separator)
if seplen == 0:
raise ValueError("Separator should be at least one-byte string")
if self._exception is not None:
raise self._exception
chunk = b""
chunk_size = 0
not_enough = True
while not_enough:
while self._buffer and not_enough:
offset = self._buffer_offset
ichar = self._buffer[0].find(separator, offset) + 1
# Read from current offset to found separator or to the end.
data = self._read_nowait_chunk(
ichar - offset + seplen - 1 if ichar else -1
)
chunk += data
chunk_size += len(data)
if ichar:
not_enough = False
if chunk_size > self._high_water:
raise ValueError("Chunk too big")
if self._eof:
break
if not_enough:
await self._wait("readuntil")
return chunk
async def read(self, n: int = -1) -> bytes:
if self._exception is not None:
raise self._exception
# migration problem; with DataQueue you have to catch
# EofStream exception, so common way is to run payload.read() inside
# infinite loop. what can cause real infinite loop with StreamReader
# lets keep this code one major release.
if __debug__:
if self._eof and not self._buffer:
self._eof_counter = getattr(self, "_eof_counter", 0) + 1
if self._eof_counter > 5:
internal_logger.warning(
"Multiple access to StreamReader in eof state, "
"might be infinite loop.",
stack_info=True,
)
if not n:
return b""
if n < 0:
# This used to just loop creating a new waiter hoping to
# collect everything in self._buffer, but that would
# deadlock if the subprocess sends more than self.limit
# bytes. So just call self.readany() until EOF.
blocks = []
while True:
block = await self.readany()
if not block:
break
blocks.append(block)
return b"".join(blocks)
# TODO: should be `if` instead of `while`
# because waiter maybe triggered on chunk end,
# without feeding any data
while not self._buffer and not self._eof:
await self._wait("read")
return self._read_nowait(n)
async def readany(self) -> bytes:
if self._exception is not None:
raise self._exception
# TODO: should be `if` instead of `while`
# because waiter maybe triggered on chunk end,
# without feeding any data
while not self._buffer and not self._eof:
await self._wait("readany")
return self._read_nowait(-1)
async def readchunk(self) -> Tuple[bytes, bool]:
"""Returns a tuple of (data, end_of_http_chunk).
When chunked transfer
encoding is used, end_of_http_chunk is a boolean indicating if the end
of the data corresponds to the end of a HTTP chunk , otherwise it is
always False.
"""
while True:
if self._exception is not None:
raise self._exception
while self._http_chunk_splits:
pos = self._http_chunk_splits.pop(0)
if pos == self._cursor:
return (b"", True)
if pos > self._cursor:
return (self._read_nowait(pos - self._cursor), True)
internal_logger.warning(
"Skipping HTTP chunk end due to data "
"consumption beyond chunk boundary"
)
if self._buffer:
return (self._read_nowait_chunk(-1), False)
# return (self._read_nowait(-1), False)
if self._eof:
# Special case for signifying EOF.
# (b'', True) is not a final return value actually.
return (b"", False)
await self._wait("readchunk")
async def readexactly(self, n: int) -> bytes:
if self._exception is not None:
raise self._exception
blocks: List[bytes] = []
while n > 0:
block = await self.read(n)
if not block:
partial = b"".join(blocks)
raise asyncio.IncompleteReadError(partial, len(partial) + n)
blocks.append(block)
n -= len(block)
return b"".join(blocks)
def read_nowait(self, n: int = -1) -> bytes:
# default was changed to be consistent with .read(-1)
#
# I believe the most users don't know about the method and
# they are not affected.
if self._exception is not None:
raise self._exception
if self._waiter and not self._waiter.done():
raise RuntimeError(
"Called while some coroutine is waiting for incoming data."
)
return self._read_nowait(n)
def _read_nowait_chunk(self, n: int) -> bytes:
first_buffer = self._buffer[0]
offset = self._buffer_offset
if n != -1 and len(first_buffer) - offset > n:
data = first_buffer[offset : offset + n]
self._buffer_offset += n
elif offset:
self._buffer.popleft()
data = first_buffer[offset:]
self._buffer_offset = 0
else:
data = self._buffer.popleft()
data_len = len(data)
self._size -= data_len
self._cursor += data_len
chunk_splits = self._http_chunk_splits
# Prevent memory leak: drop useless chunk splits
while chunk_splits and chunk_splits[0] < self._cursor:
chunk_splits.pop(0)
if self._size < self._low_water and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
def _read_nowait(self, n: int) -> bytes:
"""Read not more than n bytes, or whole buffer if n == -1"""
self._timer.assert_timeout()
chunks = []
while self._buffer:
chunk = self._read_nowait_chunk(n)
chunks.append(chunk)
if n != -1:
n -= len(chunk)
if n == 0:
break
return b"".join(chunks) if chunks else b""
class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init]
__slots__ = ("_read_eof_chunk",)
def __init__(self) -> None:
self._read_eof_chunk = False
self.total_bytes = 0
def __repr__(self) -> str:
return "<%s>" % self.__class__.__name__
def exception(self) -> Optional[BaseException]:
return None
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
pass
def on_eof(self, callback: Callable[[], None]) -> None:
try:
callback()
except Exception:
internal_logger.exception("Exception in eof callback")
def feed_eof(self) -> None:
pass
def is_eof(self) -> bool:
return True
def at_eof(self) -> bool:
return True
async def wait_eof(self) -> None:
return
def feed_data(self, data: bytes, n: int = 0) -> None:
pass
async def readline(self) -> bytes:
return b""
async def read(self, n: int = -1) -> bytes:
return b""
# TODO add async def readuntil
async def readany(self) -> bytes:
return b""
async def readchunk(self) -> Tuple[bytes, bool]:
if not self._read_eof_chunk:
self._read_eof_chunk = True
return (b"", False)
return (b"", True)
async def readexactly(self, n: int) -> bytes:
raise asyncio.IncompleteReadError(b"", n)
def read_nowait(self, n: int = -1) -> bytes:
return b""
EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader()
class DataQueue(Generic[_T]):
"""DataQueue is a general-purpose blocking queue with one reader."""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._buffer: Deque[Tuple[_T, int]] = collections.deque()
def __len__(self) -> int:
return len(self._buffer)
def is_eof(self) -> bool:
return self._eof
def at_eof(self) -> bool:
return self._eof and not self._buffer
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def feed_data(self, data: _T, size: int = 0) -> None:
self._buffer.append((data, size))
if (waiter := self._waiter) is not None:
self._waiter = None
set_result(waiter, None)
def feed_eof(self) -> None:
self._eof = True
if (waiter := self._waiter) is not None:
self._waiter = None
set_result(waiter, None)
async def read(self) -> _T:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
if self._buffer:
data, _ = self._buffer.popleft()
return data
if self._exception is not None:
raise self._exception
raise EofStream
def __aiter__(self) -> AsyncStreamIterator[_T]:
return AsyncStreamIterator(self.read)
class FlowControlDataQueue(DataQueue[_T]):
"""FlowControlDataQueue resumes and pauses an underlying stream.
It is a destination for parsed data.
This class is deprecated and will be removed in version 4.0.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
super().__init__(loop=loop)
self._size = 0
self._protocol = protocol
self._limit = limit * 2
def feed_data(self, data: _T, size: int = 0) -> None:
super().feed_data(data, size)
self._size += size
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> _T:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
if self._buffer:
data, size = self._buffer.popleft()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream

View File

@@ -0,0 +1,37 @@
"""Helper methods to tune a TCP connection"""
import asyncio
import socket
from contextlib import suppress
from typing import Optional # noqa
__all__ = ("tcp_keepalive", "tcp_nodelay")
if hasattr(socket, "SO_KEEPALIVE"):
def tcp_keepalive(transport: asyncio.Transport) -> None:
sock = transport.get_extra_info("socket")
if sock is not None:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
else:
def tcp_keepalive(transport: asyncio.Transport) -> None: # pragma: no cover
pass
def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None:
sock = transport.get_extra_info("socket")
if sock is None:
return
if sock.family not in (socket.AF_INET, socket.AF_INET6):
return
value = bool(value)
# socket may be closed already, on windows OSError get raised
with suppress(OSError):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)

View File

@@ -0,0 +1,774 @@
"""Utilities shared by tests."""
import asyncio
import contextlib
import gc
import inspect
import ipaddress
import os
import socket
import sys
import warnings
from abc import ABC, abstractmethod
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterator,
List,
Optional,
Type,
TypeVar,
cast,
overload,
)
from unittest import IsolatedAsyncioTestCase, mock
from aiosignal import Signal
from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL
import aiohttp
from aiohttp.client import (
_RequestContextManager,
_RequestOptions,
_WSRequestContextManager,
)
from . import ClientSession, hdrs
from .abc import AbstractCookieJar
from .client_reqrep import ClientResponse
from .client_ws import ClientWebSocketResponse
from .helpers import sentinel
from .http import HttpVersion, RawRequestMessage
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import StrOrURL
from .web import (
Application,
AppRunner,
BaseRequest,
BaseRunner,
Request,
Server,
ServerRunner,
SockSite,
UrlMappingMatchInfo,
)
from .web_protocol import _RequestHandler
if TYPE_CHECKING:
from ssl import SSLContext
else:
SSLContext = None
if sys.version_info >= (3, 11) and TYPE_CHECKING:
from typing import Unpack
if sys.version_info >= (3, 11):
from typing import Self
else:
Self = Any
_ApplicationNone = TypeVar("_ApplicationNone", Application, None)
_Request = TypeVar("_Request", bound=BaseRequest)
REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin"
def get_unused_port_socket(
host: str, family: socket.AddressFamily = socket.AF_INET
) -> socket.socket:
return get_port_socket(host, 0, family)
def get_port_socket(
host: str, port: int, family: socket.AddressFamily
) -> socket.socket:
s = socket.socket(family, socket.SOCK_STREAM)
if REUSE_ADDRESS:
# Windows has different semantics for SO_REUSEADDR,
# so don't set it. Ref:
# https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((host, port))
return s
def unused_port() -> int:
"""Return a port that is unused on the current host."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return cast(int, s.getsockname()[1])
class BaseTestServer(ABC):
__test__ = False
def __init__(
self,
*,
scheme: str = "",
loop: Optional[asyncio.AbstractEventLoop] = None,
host: str = "127.0.0.1",
port: Optional[int] = None,
skip_url_asserts: bool = False,
socket_factory: Callable[
[str, int, socket.AddressFamily], socket.socket
] = get_port_socket,
**kwargs: Any,
) -> None:
self._loop = loop
self.runner: Optional[BaseRunner] = None
self._root: Optional[URL] = None
self.host = host
self.port = port
self._closed = False
self.scheme = scheme
self.skip_url_asserts = skip_url_asserts
self.socket_factory = socket_factory
async def start_server(
self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any
) -> None:
if self.runner:
return
self._loop = loop
self._ssl = kwargs.pop("ssl", None)
self.runner = await self._make_runner(handler_cancellation=True, **kwargs)
await self.runner.setup()
if not self.port:
self.port = 0
absolute_host = self.host
try:
version = ipaddress.ip_address(self.host).version
except ValueError:
version = 4
if version == 6:
absolute_host = f"[{self.host}]"
family = socket.AF_INET6 if version == 6 else socket.AF_INET
_sock = self.socket_factory(self.host, self.port, family)
self.host, self.port = _sock.getsockname()[:2]
site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl)
await site.start()
server = site._server
assert server is not None
sockets = server.sockets # type: ignore[attr-defined]
assert sockets is not None
self.port = sockets[0].getsockname()[1]
if not self.scheme:
self.scheme = "https" if self._ssl else "http"
self._root = URL(f"{self.scheme}://{absolute_host}:{self.port}")
@abstractmethod # pragma: no cover
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
pass
def make_url(self, path: StrOrURL) -> URL:
assert self._root is not None
url = URL(path)
if not self.skip_url_asserts:
assert not url.absolute
return self._root.join(url)
else:
return URL(str(self._root) + str(path))
@property
def started(self) -> bool:
return self.runner is not None
@property
def closed(self) -> bool:
return self._closed
@property
def handler(self) -> Server:
# for backward compatibility
# web.Server instance
runner = self.runner
assert runner is not None
assert runner.server is not None
return runner.server
async def close(self) -> None:
"""Close all fixtures created by the test client.
After that point, the TestClient is no longer usable.
This is an idempotent function: running close multiple times
will not have any additional effects.
close is also run when the object is garbage collected, and on
exit when used as a context manager.
"""
if self.started and not self.closed:
assert self.runner is not None
await self.runner.cleanup()
self._root = None
self.port = None
self._closed = True
def __enter__(self) -> None:
raise TypeError("Use async with instead")
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
async def __aenter__(self) -> "BaseTestServer":
await self.start_server(loop=self._loop)
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
await self.close()
class TestServer(BaseTestServer):
def __init__(
self,
app: Application,
*,
scheme: str = "",
host: str = "127.0.0.1",
port: Optional[int] = None,
**kwargs: Any,
):
self.app = app
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
return AppRunner(self.app, **kwargs)
class RawTestServer(BaseTestServer):
def __init__(
self,
handler: _RequestHandler,
*,
scheme: str = "",
host: str = "127.0.0.1",
port: Optional[int] = None,
**kwargs: Any,
) -> None:
self._handler = handler
super().__init__(scheme=scheme, host=host, port=port, **kwargs)
async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner:
srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs)
return ServerRunner(srv, debug=debug, **kwargs)
class TestClient(Generic[_Request, _ApplicationNone]):
"""
A test client implementation.
To write functional tests for aiohttp based servers.
"""
__test__ = False
@overload
def __init__(
self: "TestClient[Request, Application]",
server: TestServer,
*,
cookie_jar: Optional[AbstractCookieJar] = None,
**kwargs: Any,
) -> None: ...
@overload
def __init__(
self: "TestClient[_Request, None]",
server: BaseTestServer,
*,
cookie_jar: Optional[AbstractCookieJar] = None,
**kwargs: Any,
) -> None: ...
def __init__(
self,
server: BaseTestServer,
*,
cookie_jar: Optional[AbstractCookieJar] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Any,
) -> None:
if not isinstance(server, BaseTestServer):
raise TypeError(
"server must be TestServer instance, found type: %r" % type(server)
)
self._server = server
self._loop = loop
if cookie_jar is None:
cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop)
self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs)
self._session._retry_connection = False
self._closed = False
self._responses: List[ClientResponse] = []
self._websockets: List[ClientWebSocketResponse] = []
async def start_server(self) -> None:
await self._server.start_server(loop=self._loop)
@property
def host(self) -> str:
return self._server.host
@property
def port(self) -> Optional[int]:
return self._server.port
@property
def server(self) -> BaseTestServer:
return self._server
@property
def app(self) -> _ApplicationNone:
return getattr(self._server, "app", None) # type: ignore[return-value]
@property
def session(self) -> ClientSession:
"""An internal aiohttp.ClientSession.
Unlike the methods on the TestClient, client session requests
do not automatically include the host in the url queried, and
will require an absolute path to the resource.
"""
return self._session
def make_url(self, path: StrOrURL) -> URL:
return self._server.make_url(path)
async def _request(
self, method: str, path: StrOrURL, **kwargs: Any
) -> ClientResponse:
resp = await self._session.request(method, self.make_url(path), **kwargs)
# save it to close later
self._responses.append(resp)
return resp
if sys.version_info >= (3, 11) and TYPE_CHECKING:
def request(
self, method: str, path: StrOrURL, **kwargs: Unpack[_RequestOptions]
) -> _RequestContextManager: ...
def get(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def options(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def head(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def post(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def put(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def patch(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
def delete(
self,
path: StrOrURL,
**kwargs: Unpack[_RequestOptions],
) -> _RequestContextManager: ...
else:
def request(
self, method: str, path: StrOrURL, **kwargs: Any
) -> _RequestContextManager:
"""Routes a request to tested http server.
The interface is identical to aiohttp.ClientSession.request,
except the loop kwarg is overridden by the instance used by the
test server.
"""
return _RequestContextManager(self._request(method, path, **kwargs))
def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP GET request."""
return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))
def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP POST request."""
return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))
def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP OPTIONS request."""
return _RequestContextManager(
self._request(hdrs.METH_OPTIONS, path, **kwargs)
)
def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP HEAD request."""
return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))
def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PUT request."""
return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))
def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_PATCH, path, **kwargs)
)
def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(
self._request(hdrs.METH_DELETE, path, **kwargs)
)
def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
"""Initiate websocket connection.
The api corresponds to aiohttp.ClientSession.ws_connect.
"""
return _WSRequestContextManager(self._ws_connect(path, **kwargs))
async def _ws_connect(
self, path: StrOrURL, **kwargs: Any
) -> ClientWebSocketResponse:
ws = await self._session.ws_connect(self.make_url(path), **kwargs)
self._websockets.append(ws)
return ws
async def close(self) -> None:
"""Close all fixtures created by the test client.
After that point, the TestClient is no longer usable.
This is an idempotent function: running close multiple times
will not have any additional effects.
close is also run on exit when used as a(n) (asynchronous)
context manager.
"""
if not self._closed:
for resp in self._responses:
resp.close()
for ws in self._websockets:
await ws.close()
await self._session.close()
await self._server.close()
self._closed = True
def __enter__(self) -> None:
raise TypeError("Use async with instead")
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
# __exit__ should exist in pair with __enter__ but never executed
pass # pragma: no cover
async def __aenter__(self) -> Self:
await self.start_server()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self.close()
class AioHTTPTestCase(IsolatedAsyncioTestCase):
"""A base class to allow for unittest web applications using aiohttp.
Provides the following:
* self.client (aiohttp.test_utils.TestClient): an aiohttp test client.
* self.loop (asyncio.BaseEventLoop): the event loop in which the
application and server are running.
* self.app (aiohttp.web.Application): the application returned by
self.get_application()
Note that the TestClient's methods are asynchronous: you have to
execute function on the test client using asynchronous methods.
"""
async def get_application(self) -> Application:
"""Get application.
This method should be overridden
to return the aiohttp.web.Application
object to test.
"""
return self.get_app()
def get_app(self) -> Application:
"""Obsolete method used to constructing web application.
Use .get_application() coroutine instead.
"""
raise RuntimeError("Did you forget to define get_application()?")
async def asyncSetUp(self) -> None:
self.loop = asyncio.get_running_loop()
return await self.setUpAsync()
async def setUpAsync(self) -> None:
self.app = await self.get_application()
self.server = await self.get_server(self.app)
self.client = await self.get_client(self.server)
await self.client.start_server()
async def asyncTearDown(self) -> None:
return await self.tearDownAsync()
async def tearDownAsync(self) -> None:
await self.client.close()
async def get_server(self, app: Application) -> TestServer:
"""Return a TestServer instance."""
return TestServer(app, loop=self.loop)
async def get_client(self, server: TestServer) -> TestClient[Request, Application]:
"""Return a TestClient instance."""
return TestClient(server, loop=self.loop)
def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any:
"""
A decorator dedicated to use with asynchronous AioHTTPTestCase test methods.
In 3.8+, this does nothing.
"""
warnings.warn(
"Decorator `@unittest_run_loop` is no longer needed in aiohttp 3.8+",
DeprecationWarning,
stacklevel=2,
)
return func
_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop]
@contextlib.contextmanager
def loop_context(
loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False
) -> Iterator[asyncio.AbstractEventLoop]:
"""A contextmanager that creates an event_loop, for test purposes.
Handles the creation and cleanup of a test loop.
"""
loop = setup_test_loop(loop_factory)
yield loop
teardown_test_loop(loop, fast=fast)
def setup_test_loop(
loop_factory: _LOOP_FACTORY = asyncio.new_event_loop,
) -> asyncio.AbstractEventLoop:
"""Create and return an asyncio.BaseEventLoop instance.
The caller should also call teardown_test_loop,
once they are done with the loop.
"""
loop = loop_factory()
asyncio.set_event_loop(loop)
return loop
def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None:
"""Teardown and cleanup an event_loop created by setup_test_loop."""
closed = loop.is_closed()
if not closed:
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
if not fast:
gc.collect()
asyncio.set_event_loop(None)
def _create_app_mock() -> mock.MagicMock:
def get_dict(app: Any, key: str) -> Any:
return app.__app_dict[key]
def set_dict(app: Any, key: str, value: Any) -> None:
app.__app_dict[key] = value
app = mock.MagicMock(spec=Application)
app.__app_dict = {}
app.__getitem__ = get_dict
app.__setitem__ = set_dict
app._debug = False
app.on_response_prepare = Signal(app)
app.on_response_prepare.freeze()
return app
def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock:
transport = mock.Mock()
def get_extra_info(key: str) -> Optional[SSLContext]:
if key == "sslcontext":
return sslcontext
else:
return None
transport.get_extra_info.side_effect = get_extra_info
return transport
def make_mocked_request(
method: str,
path: str,
headers: Any = None,
*,
match_info: Any = sentinel,
version: HttpVersion = HttpVersion(1, 1),
closing: bool = False,
app: Any = None,
writer: Any = sentinel,
protocol: Any = sentinel,
transport: Any = sentinel,
payload: StreamReader = EMPTY_PAYLOAD,
sslcontext: Optional[SSLContext] = None,
client_max_size: int = 1024**2,
loop: Any = ...,
) -> Request:
"""Creates mocked web.Request testing purposes.
Useful in unit tests, when spinning full web server is overkill or
specific conditions and errors are hard to trigger.
"""
task = mock.Mock()
if loop is ...:
# no loop passed, try to get the current one if
# its is running as we need a real loop to create
# executor jobs to be able to do testing
# with a real executor
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = mock.Mock()
loop.create_future.return_value = ()
if version < HttpVersion(1, 1):
closing = True
if headers:
headers = CIMultiDictProxy(CIMultiDict(headers))
raw_hdrs = tuple(
(k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items()
)
else:
headers = CIMultiDictProxy(CIMultiDict())
raw_hdrs = ()
chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower()
message = RawRequestMessage(
method,
path,
version,
headers,
raw_hdrs,
closing,
None,
False,
chunked,
URL(path),
)
if app is None:
app = _create_app_mock()
if transport is sentinel:
transport = _create_transport(sslcontext)
if protocol is sentinel:
protocol = mock.Mock()
protocol.transport = transport
type(protocol).peername = mock.PropertyMock(
return_value=transport.get_extra_info("peername")
)
type(protocol).ssl_context = mock.PropertyMock(return_value=sslcontext)
if writer is sentinel:
writer = mock.Mock()
writer.write_headers = make_mocked_coro(None)
writer.write = make_mocked_coro(None)
writer.write_eof = make_mocked_coro(None)
writer.drain = make_mocked_coro(None)
writer.transport = transport
protocol.transport = transport
protocol.writer = writer
req = Request(
message, payload, protocol, writer, task, loop, client_max_size=client_max_size
)
match_info = UrlMappingMatchInfo(
{} if match_info is sentinel else match_info, mock.Mock()
)
match_info.add_app(app)
req._match_info = match_info
return req
def make_mocked_coro(
return_value: Any = sentinel, raise_exception: Any = sentinel
) -> Any:
"""Creates a coroutine mock."""
async def mock_coro(*args: Any, **kwargs: Any) -> Any:
if raise_exception is not sentinel:
raise raise_exception
if not inspect.isawaitable(return_value):
return return_value
await return_value
return mock.Mock(wraps=mock_coro)

View File

@@ -0,0 +1,470 @@
from types import SimpleNamespace
from typing import TYPE_CHECKING, Awaitable, Mapping, Optional, Protocol, Type, TypeVar
import attr
from aiosignal import Signal
from multidict import CIMultiDict
from yarl import URL
from .client_reqrep import ClientResponse
if TYPE_CHECKING:
from .client import ClientSession
_ParamT_contra = TypeVar("_ParamT_contra", contravariant=True)
class _SignalCallback(Protocol[_ParamT_contra]):
def __call__(
self,
__client_session: ClientSession,
__trace_config_ctx: SimpleNamespace,
__params: _ParamT_contra,
) -> Awaitable[None]: ...
__all__ = (
"TraceConfig",
"TraceRequestStartParams",
"TraceRequestEndParams",
"TraceRequestExceptionParams",
"TraceConnectionQueuedStartParams",
"TraceConnectionQueuedEndParams",
"TraceConnectionCreateStartParams",
"TraceConnectionCreateEndParams",
"TraceConnectionReuseconnParams",
"TraceDnsResolveHostStartParams",
"TraceDnsResolveHostEndParams",
"TraceDnsCacheHitParams",
"TraceDnsCacheMissParams",
"TraceRequestRedirectParams",
"TraceRequestChunkSentParams",
"TraceResponseChunkReceivedParams",
"TraceRequestHeadersSentParams",
)
class TraceConfig:
"""First-class used to trace requests launched via ClientSession objects."""
def __init__(
self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace
) -> None:
self._on_request_start: Signal[_SignalCallback[TraceRequestStartParams]] = (
Signal(self)
)
self._on_request_chunk_sent: Signal[
_SignalCallback[TraceRequestChunkSentParams]
] = Signal(self)
self._on_response_chunk_received: Signal[
_SignalCallback[TraceResponseChunkReceivedParams]
] = Signal(self)
self._on_request_end: Signal[_SignalCallback[TraceRequestEndParams]] = Signal(
self
)
self._on_request_exception: Signal[
_SignalCallback[TraceRequestExceptionParams]
] = Signal(self)
self._on_request_redirect: Signal[
_SignalCallback[TraceRequestRedirectParams]
] = Signal(self)
self._on_connection_queued_start: Signal[
_SignalCallback[TraceConnectionQueuedStartParams]
] = Signal(self)
self._on_connection_queued_end: Signal[
_SignalCallback[TraceConnectionQueuedEndParams]
] = Signal(self)
self._on_connection_create_start: Signal[
_SignalCallback[TraceConnectionCreateStartParams]
] = Signal(self)
self._on_connection_create_end: Signal[
_SignalCallback[TraceConnectionCreateEndParams]
] = Signal(self)
self._on_connection_reuseconn: Signal[
_SignalCallback[TraceConnectionReuseconnParams]
] = Signal(self)
self._on_dns_resolvehost_start: Signal[
_SignalCallback[TraceDnsResolveHostStartParams]
] = Signal(self)
self._on_dns_resolvehost_end: Signal[
_SignalCallback[TraceDnsResolveHostEndParams]
] = Signal(self)
self._on_dns_cache_hit: Signal[_SignalCallback[TraceDnsCacheHitParams]] = (
Signal(self)
)
self._on_dns_cache_miss: Signal[_SignalCallback[TraceDnsCacheMissParams]] = (
Signal(self)
)
self._on_request_headers_sent: Signal[
_SignalCallback[TraceRequestHeadersSentParams]
] = Signal(self)
self._trace_config_ctx_factory = trace_config_ctx_factory
def trace_config_ctx(
self, trace_request_ctx: Optional[Mapping[str, str]] = None
) -> SimpleNamespace:
"""Return a new trace_config_ctx instance"""
return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx)
def freeze(self) -> None:
self._on_request_start.freeze()
self._on_request_chunk_sent.freeze()
self._on_response_chunk_received.freeze()
self._on_request_end.freeze()
self._on_request_exception.freeze()
self._on_request_redirect.freeze()
self._on_connection_queued_start.freeze()
self._on_connection_queued_end.freeze()
self._on_connection_create_start.freeze()
self._on_connection_create_end.freeze()
self._on_connection_reuseconn.freeze()
self._on_dns_resolvehost_start.freeze()
self._on_dns_resolvehost_end.freeze()
self._on_dns_cache_hit.freeze()
self._on_dns_cache_miss.freeze()
self._on_request_headers_sent.freeze()
@property
def on_request_start(self) -> "Signal[_SignalCallback[TraceRequestStartParams]]":
return self._on_request_start
@property
def on_request_chunk_sent(
self,
) -> "Signal[_SignalCallback[TraceRequestChunkSentParams]]":
return self._on_request_chunk_sent
@property
def on_response_chunk_received(
self,
) -> "Signal[_SignalCallback[TraceResponseChunkReceivedParams]]":
return self._on_response_chunk_received
@property
def on_request_end(self) -> "Signal[_SignalCallback[TraceRequestEndParams]]":
return self._on_request_end
@property
def on_request_exception(
self,
) -> "Signal[_SignalCallback[TraceRequestExceptionParams]]":
return self._on_request_exception
@property
def on_request_redirect(
self,
) -> "Signal[_SignalCallback[TraceRequestRedirectParams]]":
return self._on_request_redirect
@property
def on_connection_queued_start(
self,
) -> "Signal[_SignalCallback[TraceConnectionQueuedStartParams]]":
return self._on_connection_queued_start
@property
def on_connection_queued_end(
self,
) -> "Signal[_SignalCallback[TraceConnectionQueuedEndParams]]":
return self._on_connection_queued_end
@property
def on_connection_create_start(
self,
) -> "Signal[_SignalCallback[TraceConnectionCreateStartParams]]":
return self._on_connection_create_start
@property
def on_connection_create_end(
self,
) -> "Signal[_SignalCallback[TraceConnectionCreateEndParams]]":
return self._on_connection_create_end
@property
def on_connection_reuseconn(
self,
) -> "Signal[_SignalCallback[TraceConnectionReuseconnParams]]":
return self._on_connection_reuseconn
@property
def on_dns_resolvehost_start(
self,
) -> "Signal[_SignalCallback[TraceDnsResolveHostStartParams]]":
return self._on_dns_resolvehost_start
@property
def on_dns_resolvehost_end(
self,
) -> "Signal[_SignalCallback[TraceDnsResolveHostEndParams]]":
return self._on_dns_resolvehost_end
@property
def on_dns_cache_hit(self) -> "Signal[_SignalCallback[TraceDnsCacheHitParams]]":
return self._on_dns_cache_hit
@property
def on_dns_cache_miss(self) -> "Signal[_SignalCallback[TraceDnsCacheMissParams]]":
return self._on_dns_cache_miss
@property
def on_request_headers_sent(
self,
) -> "Signal[_SignalCallback[TraceRequestHeadersSentParams]]":
return self._on_request_headers_sent
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceRequestStartParams:
"""Parameters sent by the `on_request_start` signal"""
method: str
url: URL
headers: "CIMultiDict[str]"
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceRequestChunkSentParams:
"""Parameters sent by the `on_request_chunk_sent` signal"""
method: str
url: URL
chunk: bytes
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceResponseChunkReceivedParams:
"""Parameters sent by the `on_response_chunk_received` signal"""
method: str
url: URL
chunk: bytes
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceRequestEndParams:
"""Parameters sent by the `on_request_end` signal"""
method: str
url: URL
headers: "CIMultiDict[str]"
response: ClientResponse
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceRequestExceptionParams:
"""Parameters sent by the `on_request_exception` signal"""
method: str
url: URL
headers: "CIMultiDict[str]"
exception: BaseException
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceRequestRedirectParams:
"""Parameters sent by the `on_request_redirect` signal"""
method: str
url: URL
headers: "CIMultiDict[str]"
response: ClientResponse
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceConnectionQueuedStartParams:
"""Parameters sent by the `on_connection_queued_start` signal"""
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceConnectionQueuedEndParams:
"""Parameters sent by the `on_connection_queued_end` signal"""
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceConnectionCreateStartParams:
"""Parameters sent by the `on_connection_create_start` signal"""
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceConnectionCreateEndParams:
"""Parameters sent by the `on_connection_create_end` signal"""
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceConnectionReuseconnParams:
"""Parameters sent by the `on_connection_reuseconn` signal"""
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceDnsResolveHostStartParams:
"""Parameters sent by the `on_dns_resolvehost_start` signal"""
host: str
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceDnsResolveHostEndParams:
"""Parameters sent by the `on_dns_resolvehost_end` signal"""
host: str
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceDnsCacheHitParams:
"""Parameters sent by the `on_dns_cache_hit` signal"""
host: str
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceDnsCacheMissParams:
"""Parameters sent by the `on_dns_cache_miss` signal"""
host: str
@attr.s(auto_attribs=True, frozen=True, slots=True)
class TraceRequestHeadersSentParams:
"""Parameters sent by the `on_request_headers_sent` signal"""
method: str
url: URL
headers: "CIMultiDict[str]"
class Trace:
"""Internal dependency holder class.
Used to keep together the main dependencies used
at the moment of send a signal.
"""
def __init__(
self,
session: "ClientSession",
trace_config: TraceConfig,
trace_config_ctx: SimpleNamespace,
) -> None:
self._trace_config = trace_config
self._trace_config_ctx = trace_config_ctx
self._session = session
async def send_request_start(
self, method: str, url: URL, headers: "CIMultiDict[str]"
) -> None:
return await self._trace_config.on_request_start.send(
self._session,
self._trace_config_ctx,
TraceRequestStartParams(method, url, headers),
)
async def send_request_chunk_sent(
self, method: str, url: URL, chunk: bytes
) -> None:
return await self._trace_config.on_request_chunk_sent.send(
self._session,
self._trace_config_ctx,
TraceRequestChunkSentParams(method, url, chunk),
)
async def send_response_chunk_received(
self, method: str, url: URL, chunk: bytes
) -> None:
return await self._trace_config.on_response_chunk_received.send(
self._session,
self._trace_config_ctx,
TraceResponseChunkReceivedParams(method, url, chunk),
)
async def send_request_end(
self,
method: str,
url: URL,
headers: "CIMultiDict[str]",
response: ClientResponse,
) -> None:
return await self._trace_config.on_request_end.send(
self._session,
self._trace_config_ctx,
TraceRequestEndParams(method, url, headers, response),
)
async def send_request_exception(
self,
method: str,
url: URL,
headers: "CIMultiDict[str]",
exception: BaseException,
) -> None:
return await self._trace_config.on_request_exception.send(
self._session,
self._trace_config_ctx,
TraceRequestExceptionParams(method, url, headers, exception),
)
async def send_request_redirect(
self,
method: str,
url: URL,
headers: "CIMultiDict[str]",
response: ClientResponse,
) -> None:
return await self._trace_config._on_request_redirect.send(
self._session,
self._trace_config_ctx,
TraceRequestRedirectParams(method, url, headers, response),
)
async def send_connection_queued_start(self) -> None:
return await self._trace_config.on_connection_queued_start.send(
self._session, self._trace_config_ctx, TraceConnectionQueuedStartParams()
)
async def send_connection_queued_end(self) -> None:
return await self._trace_config.on_connection_queued_end.send(
self._session, self._trace_config_ctx, TraceConnectionQueuedEndParams()
)
async def send_connection_create_start(self) -> None:
return await self._trace_config.on_connection_create_start.send(
self._session, self._trace_config_ctx, TraceConnectionCreateStartParams()
)
async def send_connection_create_end(self) -> None:
return await self._trace_config.on_connection_create_end.send(
self._session, self._trace_config_ctx, TraceConnectionCreateEndParams()
)
async def send_connection_reuseconn(self) -> None:
return await self._trace_config.on_connection_reuseconn.send(
self._session, self._trace_config_ctx, TraceConnectionReuseconnParams()
)
async def send_dns_resolvehost_start(self, host: str) -> None:
return await self._trace_config.on_dns_resolvehost_start.send(
self._session, self._trace_config_ctx, TraceDnsResolveHostStartParams(host)
)
async def send_dns_resolvehost_end(self, host: str) -> None:
return await self._trace_config.on_dns_resolvehost_end.send(
self._session, self._trace_config_ctx, TraceDnsResolveHostEndParams(host)
)
async def send_dns_cache_hit(self, host: str) -> None:
return await self._trace_config.on_dns_cache_hit.send(
self._session, self._trace_config_ctx, TraceDnsCacheHitParams(host)
)
async def send_dns_cache_miss(self, host: str) -> None:
return await self._trace_config.on_dns_cache_miss.send(
self._session, self._trace_config_ctx, TraceDnsCacheMissParams(host)
)
async def send_request_headers(
self, method: str, url: URL, headers: "CIMultiDict[str]"
) -> None:
return await self._trace_config._on_request_headers_sent.send(
self._session,
self._trace_config_ctx,
TraceRequestHeadersSentParams(method, url, headers),
)

View File

@@ -0,0 +1,69 @@
import json
import os
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable,
Mapping,
Protocol,
Tuple,
Union,
)
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy, istr
from yarl import URL, Query as _Query
Query = _Query
DEFAULT_JSON_ENCODER = json.dumps
DEFAULT_JSON_DECODER = json.loads
if TYPE_CHECKING:
_CIMultiDict = CIMultiDict[str]
_CIMultiDictProxy = CIMultiDictProxy[str]
_MultiDict = MultiDict[str]
_MultiDictProxy = MultiDictProxy[str]
from http.cookies import BaseCookie, Morsel
from .web import Request, StreamResponse
else:
_CIMultiDict = CIMultiDict
_CIMultiDictProxy = CIMultiDictProxy
_MultiDict = MultiDict
_MultiDictProxy = MultiDictProxy
Byteish = Union[bytes, bytearray, memoryview]
JSONEncoder = Callable[[Any], str]
JSONDecoder = Callable[[str], Any]
LooseHeaders = Union[
Mapping[str, str],
Mapping[istr, str],
_CIMultiDict,
_CIMultiDictProxy,
Iterable[Tuple[Union[str, istr], str]],
]
RawHeaders = Tuple[Tuple[bytes, bytes], ...]
StrOrURL = Union[str, URL]
LooseCookiesMappings = Mapping[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]]
LooseCookiesIterables = Iterable[
Tuple[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]]
]
LooseCookies = Union[
LooseCookiesMappings,
LooseCookiesIterables,
"BaseCookie[str]",
]
Handler = Callable[["Request"], Awaitable["StreamResponse"]]
class Middleware(Protocol):
def __call__(
self, request: "Request", handler: Handler
) -> Awaitable["StreamResponse"]: ...
PathLike = Union[str, "os.PathLike[str]"]

View File

@@ -0,0 +1,605 @@
import asyncio
import logging
import os
import socket
import sys
import warnings
from argparse import ArgumentParser
from collections.abc import Iterable
from contextlib import suppress
from importlib import import_module
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Iterable as TypingIterable,
List,
Optional,
Set,
Type,
Union,
cast,
)
from .abc import AbstractAccessLogger
from .helpers import AppKey as AppKey
from .log import access_logger
from .typedefs import PathLike
from .web_app import Application as Application, CleanupError as CleanupError
from .web_exceptions import (
HTTPAccepted as HTTPAccepted,
HTTPBadGateway as HTTPBadGateway,
HTTPBadRequest as HTTPBadRequest,
HTTPClientError as HTTPClientError,
HTTPConflict as HTTPConflict,
HTTPCreated as HTTPCreated,
HTTPError as HTTPError,
HTTPException as HTTPException,
HTTPExpectationFailed as HTTPExpectationFailed,
HTTPFailedDependency as HTTPFailedDependency,
HTTPForbidden as HTTPForbidden,
HTTPFound as HTTPFound,
HTTPGatewayTimeout as HTTPGatewayTimeout,
HTTPGone as HTTPGone,
HTTPInsufficientStorage as HTTPInsufficientStorage,
HTTPInternalServerError as HTTPInternalServerError,
HTTPLengthRequired as HTTPLengthRequired,
HTTPMethodNotAllowed as HTTPMethodNotAllowed,
HTTPMisdirectedRequest as HTTPMisdirectedRequest,
HTTPMove as HTTPMove,
HTTPMovedPermanently as HTTPMovedPermanently,
HTTPMultipleChoices as HTTPMultipleChoices,
HTTPNetworkAuthenticationRequired as HTTPNetworkAuthenticationRequired,
HTTPNoContent as HTTPNoContent,
HTTPNonAuthoritativeInformation as HTTPNonAuthoritativeInformation,
HTTPNotAcceptable as HTTPNotAcceptable,
HTTPNotExtended as HTTPNotExtended,
HTTPNotFound as HTTPNotFound,
HTTPNotImplemented as HTTPNotImplemented,
HTTPNotModified as HTTPNotModified,
HTTPOk as HTTPOk,
HTTPPartialContent as HTTPPartialContent,
HTTPPaymentRequired as HTTPPaymentRequired,
HTTPPermanentRedirect as HTTPPermanentRedirect,
HTTPPreconditionFailed as HTTPPreconditionFailed,
HTTPPreconditionRequired as HTTPPreconditionRequired,
HTTPProxyAuthenticationRequired as HTTPProxyAuthenticationRequired,
HTTPRedirection as HTTPRedirection,
HTTPRequestEntityTooLarge as HTTPRequestEntityTooLarge,
HTTPRequestHeaderFieldsTooLarge as HTTPRequestHeaderFieldsTooLarge,
HTTPRequestRangeNotSatisfiable as HTTPRequestRangeNotSatisfiable,
HTTPRequestTimeout as HTTPRequestTimeout,
HTTPRequestURITooLong as HTTPRequestURITooLong,
HTTPResetContent as HTTPResetContent,
HTTPSeeOther as HTTPSeeOther,
HTTPServerError as HTTPServerError,
HTTPServiceUnavailable as HTTPServiceUnavailable,
HTTPSuccessful as HTTPSuccessful,
HTTPTemporaryRedirect as HTTPTemporaryRedirect,
HTTPTooManyRequests as HTTPTooManyRequests,
HTTPUnauthorized as HTTPUnauthorized,
HTTPUnavailableForLegalReasons as HTTPUnavailableForLegalReasons,
HTTPUnprocessableEntity as HTTPUnprocessableEntity,
HTTPUnsupportedMediaType as HTTPUnsupportedMediaType,
HTTPUpgradeRequired as HTTPUpgradeRequired,
HTTPUseProxy as HTTPUseProxy,
HTTPVariantAlsoNegotiates as HTTPVariantAlsoNegotiates,
HTTPVersionNotSupported as HTTPVersionNotSupported,
NotAppKeyWarning as NotAppKeyWarning,
)
from .web_fileresponse import FileResponse as FileResponse
from .web_log import AccessLogger
from .web_middlewares import (
middleware as middleware,
normalize_path_middleware as normalize_path_middleware,
)
from .web_protocol import (
PayloadAccessError as PayloadAccessError,
RequestHandler as RequestHandler,
RequestPayloadError as RequestPayloadError,
)
from .web_request import (
BaseRequest as BaseRequest,
FileField as FileField,
Request as Request,
)
from .web_response import (
ContentCoding as ContentCoding,
Response as Response,
StreamResponse as StreamResponse,
json_response as json_response,
)
from .web_routedef import (
AbstractRouteDef as AbstractRouteDef,
RouteDef as RouteDef,
RouteTableDef as RouteTableDef,
StaticDef as StaticDef,
delete as delete,
get as get,
head as head,
options as options,
patch as patch,
post as post,
put as put,
route as route,
static as static,
view as view,
)
from .web_runner import (
AppRunner as AppRunner,
BaseRunner as BaseRunner,
BaseSite as BaseSite,
GracefulExit as GracefulExit,
NamedPipeSite as NamedPipeSite,
ServerRunner as ServerRunner,
SockSite as SockSite,
TCPSite as TCPSite,
UnixSite as UnixSite,
)
from .web_server import Server as Server
from .web_urldispatcher import (
AbstractResource as AbstractResource,
AbstractRoute as AbstractRoute,
DynamicResource as DynamicResource,
PlainResource as PlainResource,
PrefixedSubAppResource as PrefixedSubAppResource,
Resource as Resource,
ResourceRoute as ResourceRoute,
StaticResource as StaticResource,
UrlDispatcher as UrlDispatcher,
UrlMappingMatchInfo as UrlMappingMatchInfo,
View as View,
)
from .web_ws import (
WebSocketReady as WebSocketReady,
WebSocketResponse as WebSocketResponse,
WSMsgType as WSMsgType,
)
__all__ = (
# web_app
"AppKey",
"Application",
"CleanupError",
# web_exceptions
"NotAppKeyWarning",
"HTTPAccepted",
"HTTPBadGateway",
"HTTPBadRequest",
"HTTPClientError",
"HTTPConflict",
"HTTPCreated",
"HTTPError",
"HTTPException",
"HTTPExpectationFailed",
"HTTPFailedDependency",
"HTTPForbidden",
"HTTPFound",
"HTTPGatewayTimeout",
"HTTPGone",
"HTTPInsufficientStorage",
"HTTPInternalServerError",
"HTTPLengthRequired",
"HTTPMethodNotAllowed",
"HTTPMisdirectedRequest",
"HTTPMove",
"HTTPMovedPermanently",
"HTTPMultipleChoices",
"HTTPNetworkAuthenticationRequired",
"HTTPNoContent",
"HTTPNonAuthoritativeInformation",
"HTTPNotAcceptable",
"HTTPNotExtended",
"HTTPNotFound",
"HTTPNotImplemented",
"HTTPNotModified",
"HTTPOk",
"HTTPPartialContent",
"HTTPPaymentRequired",
"HTTPPermanentRedirect",
"HTTPPreconditionFailed",
"HTTPPreconditionRequired",
"HTTPProxyAuthenticationRequired",
"HTTPRedirection",
"HTTPRequestEntityTooLarge",
"HTTPRequestHeaderFieldsTooLarge",
"HTTPRequestRangeNotSatisfiable",
"HTTPRequestTimeout",
"HTTPRequestURITooLong",
"HTTPResetContent",
"HTTPSeeOther",
"HTTPServerError",
"HTTPServiceUnavailable",
"HTTPSuccessful",
"HTTPTemporaryRedirect",
"HTTPTooManyRequests",
"HTTPUnauthorized",
"HTTPUnavailableForLegalReasons",
"HTTPUnprocessableEntity",
"HTTPUnsupportedMediaType",
"HTTPUpgradeRequired",
"HTTPUseProxy",
"HTTPVariantAlsoNegotiates",
"HTTPVersionNotSupported",
# web_fileresponse
"FileResponse",
# web_middlewares
"middleware",
"normalize_path_middleware",
# web_protocol
"PayloadAccessError",
"RequestHandler",
"RequestPayloadError",
# web_request
"BaseRequest",
"FileField",
"Request",
# web_response
"ContentCoding",
"Response",
"StreamResponse",
"json_response",
# web_routedef
"AbstractRouteDef",
"RouteDef",
"RouteTableDef",
"StaticDef",
"delete",
"get",
"head",
"options",
"patch",
"post",
"put",
"route",
"static",
"view",
# web_runner
"AppRunner",
"BaseRunner",
"BaseSite",
"GracefulExit",
"ServerRunner",
"SockSite",
"TCPSite",
"UnixSite",
"NamedPipeSite",
# web_server
"Server",
# web_urldispatcher
"AbstractResource",
"AbstractRoute",
"DynamicResource",
"PlainResource",
"PrefixedSubAppResource",
"Resource",
"ResourceRoute",
"StaticResource",
"UrlDispatcher",
"UrlMappingMatchInfo",
"View",
# web_ws
"WebSocketReady",
"WebSocketResponse",
"WSMsgType",
# web
"run_app",
)
if TYPE_CHECKING:
from ssl import SSLContext
else:
try:
from ssl import SSLContext
except ImportError: # pragma: no cover
SSLContext = object # type: ignore[misc,assignment]
# Only display warning when using -Wdefault, -We, -X dev or similar.
warnings.filterwarnings("ignore", category=NotAppKeyWarning, append=True)
HostSequence = TypingIterable[str]
async def _run_app(
app: Union[Application, Awaitable[Application]],
*,
host: Optional[Union[str, HostSequence]] = None,
port: Optional[int] = None,
path: Union[PathLike, TypingIterable[PathLike], None] = None,
sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None,
shutdown_timeout: float = 60.0,
keepalive_timeout: float = 75.0,
ssl_context: Optional[SSLContext] = None,
print: Optional[Callable[..., None]] = print,
backlog: int = 128,
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
access_log_format: str = AccessLogger.LOG_FORMAT,
access_log: Optional[logging.Logger] = access_logger,
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
) -> None:
# An internal function to actually do all dirty job for application running
if asyncio.iscoroutine(app):
app = await app
app = cast(Application, app)
runner = AppRunner(
app,
handle_signals=handle_signals,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
keepalive_timeout=keepalive_timeout,
shutdown_timeout=shutdown_timeout,
handler_cancellation=handler_cancellation,
)
await runner.setup()
sites: List[BaseSite] = []
try:
if host is not None:
if isinstance(host, (str, bytes, bytearray, memoryview)):
sites.append(
TCPSite(
runner,
host,
port,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
)
else:
for h in host:
sites.append(
TCPSite(
runner,
h,
port,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
)
elif path is None and sock is None or port is not None:
sites.append(
TCPSite(
runner,
port=port,
ssl_context=ssl_context,
backlog=backlog,
reuse_address=reuse_address,
reuse_port=reuse_port,
)
)
if path is not None:
if isinstance(path, (str, os.PathLike)):
sites.append(
UnixSite(
runner,
path,
ssl_context=ssl_context,
backlog=backlog,
)
)
else:
for p in path:
sites.append(
UnixSite(
runner,
p,
ssl_context=ssl_context,
backlog=backlog,
)
)
if sock is not None:
if not isinstance(sock, Iterable):
sites.append(
SockSite(
runner,
sock,
ssl_context=ssl_context,
backlog=backlog,
)
)
else:
for s in sock:
sites.append(
SockSite(
runner,
s,
ssl_context=ssl_context,
backlog=backlog,
)
)
for site in sites:
await site.start()
if print: # pragma: no branch
names = sorted(str(s.name) for s in runner.sites)
print(
"======== Running on {} ========\n"
"(Press CTRL+C to quit)".format(", ".join(names))
)
# sleep forever by 1 hour intervals,
while True:
await asyncio.sleep(3600)
finally:
await runner.cleanup()
def _cancel_tasks(
to_cancel: Set["asyncio.Task[Any]"], loop: asyncio.AbstractEventLoop
) -> None:
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
def run_app(
app: Union[Application, Awaitable[Application]],
*,
host: Optional[Union[str, HostSequence]] = None,
port: Optional[int] = None,
path: Union[PathLike, TypingIterable[PathLike], None] = None,
sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None,
shutdown_timeout: float = 60.0,
keepalive_timeout: float = 75.0,
ssl_context: Optional[SSLContext] = None,
print: Optional[Callable[..., None]] = print,
backlog: int = 128,
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
access_log_format: str = AccessLogger.LOG_FORMAT,
access_log: Optional[logging.Logger] = access_logger,
handle_signals: bool = True,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
handler_cancellation: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
"""Run an app locally"""
if loop is None:
loop = asyncio.new_event_loop()
# Configure if and only if in debugging mode and using the default logger
if loop.get_debug() and access_log and access_log.name == "aiohttp.access":
if access_log.level == logging.NOTSET:
access_log.setLevel(logging.DEBUG)
if not access_log.hasHandlers():
access_log.addHandler(logging.StreamHandler())
main_task = loop.create_task(
_run_app(
app,
host=host,
port=port,
path=path,
sock=sock,
shutdown_timeout=shutdown_timeout,
keepalive_timeout=keepalive_timeout,
ssl_context=ssl_context,
print=print,
backlog=backlog,
access_log_class=access_log_class,
access_log_format=access_log_format,
access_log=access_log,
handle_signals=handle_signals,
reuse_address=reuse_address,
reuse_port=reuse_port,
handler_cancellation=handler_cancellation,
)
)
try:
asyncio.set_event_loop(loop)
loop.run_until_complete(main_task)
except (GracefulExit, KeyboardInterrupt): # pragma: no cover
pass
finally:
try:
main_task.cancel()
with suppress(asyncio.CancelledError):
loop.run_until_complete(main_task)
finally:
_cancel_tasks(asyncio.all_tasks(loop), loop)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
def main(argv: List[str]) -> None:
arg_parser = ArgumentParser(
description="aiohttp.web Application server", prog="aiohttp.web"
)
arg_parser.add_argument(
"entry_func",
help=(
"Callable returning the `aiohttp.web.Application` instance to "
"run. Should be specified in the 'module:function' syntax."
),
metavar="entry-func",
)
arg_parser.add_argument(
"-H",
"--hostname",
help="TCP/IP hostname to serve on (default: localhost)",
default=None,
)
arg_parser.add_argument(
"-P",
"--port",
help="TCP/IP port to serve on (default: %(default)r)",
type=int,
default=8080,
)
arg_parser.add_argument(
"-U",
"--path",
help="Unix file system path to serve on. Can be combined with hostname "
"to serve on both Unix and TCP.",
)
args, extra_argv = arg_parser.parse_known_args(argv)
# Import logic
mod_str, _, func_str = args.entry_func.partition(":")
if not func_str or not mod_str:
arg_parser.error("'entry-func' not in 'module:function' syntax")
if mod_str.startswith("."):
arg_parser.error("relative module names not supported")
try:
module = import_module(mod_str)
except ImportError as ex:
arg_parser.error(f"unable to import {mod_str}: {ex}")
try:
func = getattr(module, func_str)
except AttributeError:
arg_parser.error(f"module {mod_str!r} has no attribute {func_str!r}")
# Compatibility logic
if args.path is not None and not hasattr(socket, "AF_UNIX"):
arg_parser.error(
"file system paths not supported by your operating environment"
)
logging.basicConfig(level=logging.DEBUG)
if args.path and args.hostname is None:
host = port = None
else:
host = args.hostname or "localhost"
port = args.port
app = func(extra_argv)
run_app(app, host=host, port=port, path=args.path)
arg_parser.exit(message="Stopped\n")
if __name__ == "__main__": # pragma: no branch
main(sys.argv[1:]) # pragma: no cover

View File

@@ -0,0 +1,620 @@
import asyncio
import logging
import warnings
from functools import lru_cache, partial, update_wrapper
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
from aiosignal import Signal
from frozenlist import FrozenList
from . import hdrs
from .abc import (
AbstractAccessLogger,
AbstractMatchInfo,
AbstractRouter,
AbstractStreamWriter,
)
from .helpers import DEBUG, AppKey
from .http_parser import RawRequestMessage
from .log import web_logger
from .streams import StreamReader
from .typedefs import Handler, Middleware
from .web_exceptions import NotAppKeyWarning
from .web_log import AccessLogger
from .web_middlewares import _fix_request_current_app
from .web_protocol import RequestHandler
from .web_request import Request
from .web_response import StreamResponse
from .web_routedef import AbstractRouteDef
from .web_server import Server
from .web_urldispatcher import (
AbstractResource,
AbstractRoute,
Domain,
MaskDomain,
MatchedSubAppResource,
PrefixedSubAppResource,
SystemRoute,
UrlDispatcher,
)
__all__ = ("Application", "CleanupError")
if TYPE_CHECKING:
_AppSignal = Signal[Callable[["Application"], Awaitable[None]]]
_RespPrepareSignal = Signal[Callable[[Request, StreamResponse], Awaitable[None]]]
_Middlewares = FrozenList[Middleware]
_MiddlewaresHandlers = Optional[Sequence[Tuple[Middleware, bool]]]
_Subapps = List["Application"]
else:
# No type checker mode, skip types
_AppSignal = Signal
_RespPrepareSignal = Signal
_Middlewares = FrozenList
_MiddlewaresHandlers = Optional[Sequence]
_Subapps = List
_T = TypeVar("_T")
_U = TypeVar("_U")
_Resource = TypeVar("_Resource", bound=AbstractResource)
def _build_middlewares(
handler: Handler, apps: Tuple["Application", ...]
) -> Callable[[Request], Awaitable[StreamResponse]]:
"""Apply middlewares to handler."""
for app in apps[::-1]:
for m, _ in app._middlewares_handlers: # type: ignore[union-attr]
handler = update_wrapper(partial(m, handler=handler), handler) # type: ignore[misc]
return handler
_cached_build_middleware = lru_cache(maxsize=1024)(_build_middlewares)
class Application(MutableMapping[Union[str, AppKey[Any]], Any]):
ATTRS = frozenset(
[
"logger",
"_debug",
"_router",
"_loop",
"_handler_args",
"_middlewares",
"_middlewares_handlers",
"_has_legacy_middlewares",
"_run_middlewares",
"_state",
"_frozen",
"_pre_frozen",
"_subapps",
"_on_response_prepare",
"_on_startup",
"_on_shutdown",
"_on_cleanup",
"_client_max_size",
"_cleanup_ctx",
]
)
def __init__(
self,
*,
logger: logging.Logger = web_logger,
router: Optional[UrlDispatcher] = None,
middlewares: Iterable[Middleware] = (),
handler_args: Optional[Mapping[str, Any]] = None,
client_max_size: int = 1024**2,
loop: Optional[asyncio.AbstractEventLoop] = None,
debug: Any = ..., # mypy doesn't support ellipsis
) -> None:
if router is None:
router = UrlDispatcher()
else:
warnings.warn(
"router argument is deprecated", DeprecationWarning, stacklevel=2
)
assert isinstance(router, AbstractRouter), router
if loop is not None:
warnings.warn(
"loop argument is deprecated", DeprecationWarning, stacklevel=2
)
if debug is not ...:
warnings.warn(
"debug argument is deprecated", DeprecationWarning, stacklevel=2
)
self._debug = debug
self._router: UrlDispatcher = router
self._loop = loop
self._handler_args = handler_args
self.logger = logger
self._middlewares: _Middlewares = FrozenList(middlewares)
# initialized on freezing
self._middlewares_handlers: _MiddlewaresHandlers = None
# initialized on freezing
self._run_middlewares: Optional[bool] = None
self._has_legacy_middlewares: bool = True
self._state: Dict[Union[AppKey[Any], str], object] = {}
self._frozen = False
self._pre_frozen = False
self._subapps: _Subapps = []
self._on_response_prepare: _RespPrepareSignal = Signal(self)
self._on_startup: _AppSignal = Signal(self)
self._on_shutdown: _AppSignal = Signal(self)
self._on_cleanup: _AppSignal = Signal(self)
self._cleanup_ctx = CleanupContext()
self._on_startup.append(self._cleanup_ctx._on_startup)
self._on_cleanup.append(self._cleanup_ctx._on_cleanup)
self._client_max_size = client_max_size
def __init_subclass__(cls: Type["Application"]) -> None:
warnings.warn(
"Inheritance class {} from web.Application "
"is discouraged".format(cls.__name__),
DeprecationWarning,
stacklevel=3,
)
if DEBUG: # pragma: no cover
def __setattr__(self, name: str, val: Any) -> None:
if name not in self.ATTRS:
warnings.warn(
"Setting custom web.Application.{} attribute "
"is discouraged".format(name),
DeprecationWarning,
stacklevel=2,
)
super().__setattr__(name, val)
# MutableMapping API
def __eq__(self, other: object) -> bool:
return self is other
@overload # type: ignore[override]
def __getitem__(self, key: AppKey[_T]) -> _T: ...
@overload
def __getitem__(self, key: str) -> Any: ...
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
return self._state[key]
def _check_frozen(self) -> None:
if self._frozen:
warnings.warn(
"Changing state of started or joined application is deprecated",
DeprecationWarning,
stacklevel=3,
)
@overload # type: ignore[override]
def __setitem__(self, key: AppKey[_T], value: _T) -> None: ...
@overload
def __setitem__(self, key: str, value: Any) -> None: ...
def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None:
self._check_frozen()
if not isinstance(key, AppKey):
warnings.warn(
"It is recommended to use web.AppKey instances for keys.\n"
+ "https://docs.aiohttp.org/en/stable/web_advanced.html"
+ "#application-s-config",
category=NotAppKeyWarning,
stacklevel=2,
)
self._state[key] = value
def __delitem__(self, key: Union[str, AppKey[_T]]) -> None:
self._check_frozen()
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
return iter(self._state)
def __hash__(self) -> int:
return id(self)
@overload # type: ignore[override]
def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
@overload
def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: ...
@overload
def get(self, key: str, default: Any = ...) -> Any: ...
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
return self._state.get(key, default)
########
@property
def loop(self) -> asyncio.AbstractEventLoop:
# Technically the loop can be None
# but we mask it by explicit type cast
# to provide more convenient type annotation
warnings.warn("loop property is deprecated", DeprecationWarning, stacklevel=2)
return cast(asyncio.AbstractEventLoop, self._loop)
def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None:
if loop is None:
loop = asyncio.get_event_loop()
if self._loop is not None and self._loop is not loop:
raise RuntimeError(
"web.Application instance initialized with different loop"
)
self._loop = loop
# set loop debug
if self._debug is ...:
self._debug = loop.get_debug()
# set loop to sub applications
for subapp in self._subapps:
subapp._set_loop(loop)
@property
def pre_frozen(self) -> bool:
return self._pre_frozen
def pre_freeze(self) -> None:
if self._pre_frozen:
return
self._pre_frozen = True
self._middlewares.freeze()
self._router.freeze()
self._on_response_prepare.freeze()
self._cleanup_ctx.freeze()
self._on_startup.freeze()
self._on_shutdown.freeze()
self._on_cleanup.freeze()
self._middlewares_handlers = tuple(self._prepare_middleware())
self._has_legacy_middlewares = any(
not new_style for _, new_style in self._middlewares_handlers
)
# If current app and any subapp do not have middlewares avoid run all
# of the code footprint that it implies, which have a middleware
# hardcoded per app that sets up the current_app attribute. If no
# middlewares are configured the handler will receive the proper
# current_app without needing all of this code.
self._run_middlewares = True if self.middlewares else False
for subapp in self._subapps:
subapp.pre_freeze()
self._run_middlewares = self._run_middlewares or subapp._run_middlewares
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
if self._frozen:
return
self.pre_freeze()
self._frozen = True
for subapp in self._subapps:
subapp.freeze()
@property
def debug(self) -> bool:
warnings.warn("debug property is deprecated", DeprecationWarning, stacklevel=2)
return self._debug # type: ignore[no-any-return]
def _reg_subapp_signals(self, subapp: "Application") -> None:
def reg_handler(signame: str) -> None:
subsig = getattr(subapp, signame)
async def handler(app: "Application") -> None:
await subsig.send(subapp)
appsig = getattr(self, signame)
appsig.append(handler)
reg_handler("on_startup")
reg_handler("on_shutdown")
reg_handler("on_cleanup")
def add_subapp(self, prefix: str, subapp: "Application") -> PrefixedSubAppResource:
if not isinstance(prefix, str):
raise TypeError("Prefix must be str")
prefix = prefix.rstrip("/")
if not prefix:
raise ValueError("Prefix cannot be empty")
factory = partial(PrefixedSubAppResource, prefix, subapp)
return self._add_subapp(factory, subapp)
def _add_subapp(
self, resource_factory: Callable[[], _Resource], subapp: "Application"
) -> _Resource:
if self.frozen:
raise RuntimeError("Cannot add sub application to frozen application")
if subapp.frozen:
raise RuntimeError("Cannot add frozen application")
resource = resource_factory()
self.router.register_resource(resource)
self._reg_subapp_signals(subapp)
self._subapps.append(subapp)
subapp.pre_freeze()
if self._loop is not None:
subapp._set_loop(self._loop)
return resource
def add_domain(self, domain: str, subapp: "Application") -> MatchedSubAppResource:
if not isinstance(domain, str):
raise TypeError("Domain must be str")
elif "*" in domain:
rule: Domain = MaskDomain(domain)
else:
rule = Domain(domain)
factory = partial(MatchedSubAppResource, rule, subapp)
return self._add_subapp(factory, subapp)
def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]:
return self.router.add_routes(routes)
@property
def on_response_prepare(self) -> _RespPrepareSignal:
return self._on_response_prepare
@property
def on_startup(self) -> _AppSignal:
return self._on_startup
@property
def on_shutdown(self) -> _AppSignal:
return self._on_shutdown
@property
def on_cleanup(self) -> _AppSignal:
return self._on_cleanup
@property
def cleanup_ctx(self) -> "CleanupContext":
return self._cleanup_ctx
@property
def router(self) -> UrlDispatcher:
return self._router
@property
def middlewares(self) -> _Middlewares:
return self._middlewares
def _make_handler(
self,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
**kwargs: Any,
) -> Server:
if not issubclass(access_log_class, AbstractAccessLogger):
raise TypeError(
"access_log_class must be subclass of "
"aiohttp.abc.AbstractAccessLogger, got {}".format(access_log_class)
)
self._set_loop(loop)
self.freeze()
kwargs["debug"] = self._debug
kwargs["access_log_class"] = access_log_class
if self._handler_args:
for k, v in self._handler_args.items():
kwargs[k] = v
return Server(
self._handle, # type: ignore[arg-type]
request_factory=self._make_request,
loop=self._loop,
**kwargs,
)
def make_handler(
self,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
**kwargs: Any,
) -> Server:
warnings.warn(
"Application.make_handler(...) is deprecated, use AppRunner API instead",
DeprecationWarning,
stacklevel=2,
)
return self._make_handler(
loop=loop, access_log_class=access_log_class, **kwargs
)
async def startup(self) -> None:
"""Causes on_startup signal
Should be called in the event loop along with the request handler.
"""
await self.on_startup.send(self)
async def shutdown(self) -> None:
"""Causes on_shutdown signal
Should be called before cleanup()
"""
await self.on_shutdown.send(self)
async def cleanup(self) -> None:
"""Causes on_cleanup signal
Should be called after shutdown()
"""
if self.on_cleanup.frozen:
await self.on_cleanup.send(self)
else:
# If an exception occurs in startup, ensure cleanup contexts are completed.
await self._cleanup_ctx._on_cleanup(self)
def _make_request(
self,
message: RawRequestMessage,
payload: StreamReader,
protocol: RequestHandler,
writer: AbstractStreamWriter,
task: "asyncio.Task[None]",
_cls: Type[Request] = Request,
) -> Request:
if TYPE_CHECKING:
assert self._loop is not None
return _cls(
message,
payload,
protocol,
writer,
task,
self._loop,
client_max_size=self._client_max_size,
)
def _prepare_middleware(self) -> Iterator[Tuple[Middleware, bool]]:
for m in reversed(self._middlewares):
if getattr(m, "__middleware_version__", None) == 1:
yield m, True
else:
warnings.warn(
f'old-style middleware "{m!r}" deprecated, see #2252',
DeprecationWarning,
stacklevel=2,
)
yield m, False
yield _fix_request_current_app(self), True
async def _handle(self, request: Request) -> StreamResponse:
loop = asyncio.get_event_loop()
debug = loop.get_debug()
match_info = await self._router.resolve(request)
if debug: # pragma: no cover
if not isinstance(match_info, AbstractMatchInfo):
raise TypeError(
"match_info should be AbstractMatchInfo "
"instance, not {!r}".format(match_info)
)
match_info.add_app(self)
match_info.freeze()
request._match_info = match_info
if request.headers.get(hdrs.EXPECT):
resp = await match_info.expect_handler(request)
await request.writer.drain()
if resp is not None:
return resp
handler = match_info.handler
if self._run_middlewares:
# If its a SystemRoute, don't cache building the middlewares since
# they are constructed for every MatchInfoError as a new handler
# is made each time.
if not self._has_legacy_middlewares and not isinstance(
match_info.route, SystemRoute
):
handler = _cached_build_middleware(handler, match_info.apps)
else:
for app in match_info.apps[::-1]:
for m, new_style in app._middlewares_handlers: # type: ignore[union-attr]
if new_style:
handler = update_wrapper(
partial(m, handler=handler), handler # type: ignore[misc]
)
else:
handler = await m(app, handler) # type: ignore[arg-type,assignment]
return await handler(request)
def __call__(self) -> "Application":
"""gunicorn compatibility"""
return self
def __repr__(self) -> str:
return f"<Application 0x{id(self):x}>"
def __bool__(self) -> bool:
return True
class CleanupError(RuntimeError):
@property
def exceptions(self) -> List[BaseException]:
return cast(List[BaseException], self.args[1])
if TYPE_CHECKING:
_CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]]
else:
_CleanupContextBase = FrozenList
class CleanupContext(_CleanupContextBase):
def __init__(self) -> None:
super().__init__()
self._exits: List[AsyncIterator[None]] = []
async def _on_startup(self, app: Application) -> None:
for cb in self:
it = cb(app).__aiter__()
await it.__anext__()
self._exits.append(it)
async def _on_cleanup(self, app: Application) -> None:
errors = []
for it in reversed(self._exits):
try:
await it.__anext__()
except StopAsyncIteration:
pass
except (Exception, asyncio.CancelledError) as exc:
errors.append(exc)
else:
errors.append(RuntimeError(f"{it!r} has more than one 'yield'"))
if errors:
if len(errors) == 1:
raise errors[0]
else:
raise CleanupError("Multiple errors on cleanup stage", errors)

View File

@@ -0,0 +1,452 @@
import warnings
from typing import Any, Dict, Iterable, List, Optional, Set # noqa
from yarl import URL
from .typedefs import LooseHeaders, StrOrURL
from .web_response import Response
__all__ = (
"HTTPException",
"HTTPError",
"HTTPRedirection",
"HTTPSuccessful",
"HTTPOk",
"HTTPCreated",
"HTTPAccepted",
"HTTPNonAuthoritativeInformation",
"HTTPNoContent",
"HTTPResetContent",
"HTTPPartialContent",
"HTTPMove",
"HTTPMultipleChoices",
"HTTPMovedPermanently",
"HTTPFound",
"HTTPSeeOther",
"HTTPNotModified",
"HTTPUseProxy",
"HTTPTemporaryRedirect",
"HTTPPermanentRedirect",
"HTTPClientError",
"HTTPBadRequest",
"HTTPUnauthorized",
"HTTPPaymentRequired",
"HTTPForbidden",
"HTTPNotFound",
"HTTPMethodNotAllowed",
"HTTPNotAcceptable",
"HTTPProxyAuthenticationRequired",
"HTTPRequestTimeout",
"HTTPConflict",
"HTTPGone",
"HTTPLengthRequired",
"HTTPPreconditionFailed",
"HTTPRequestEntityTooLarge",
"HTTPRequestURITooLong",
"HTTPUnsupportedMediaType",
"HTTPRequestRangeNotSatisfiable",
"HTTPExpectationFailed",
"HTTPMisdirectedRequest",
"HTTPUnprocessableEntity",
"HTTPFailedDependency",
"HTTPUpgradeRequired",
"HTTPPreconditionRequired",
"HTTPTooManyRequests",
"HTTPRequestHeaderFieldsTooLarge",
"HTTPUnavailableForLegalReasons",
"HTTPServerError",
"HTTPInternalServerError",
"HTTPNotImplemented",
"HTTPBadGateway",
"HTTPServiceUnavailable",
"HTTPGatewayTimeout",
"HTTPVersionNotSupported",
"HTTPVariantAlsoNegotiates",
"HTTPInsufficientStorage",
"HTTPNotExtended",
"HTTPNetworkAuthenticationRequired",
)
class NotAppKeyWarning(UserWarning):
"""Warning when not using AppKey in Application."""
############################################################
# HTTP Exceptions
############################################################
class HTTPException(Response, Exception):
# You should set in subclasses:
# status = 200
status_code = -1
empty_body = False
__http_exception__ = True
def __init__(
self,
*,
headers: Optional[LooseHeaders] = None,
reason: Optional[str] = None,
body: Any = None,
text: Optional[str] = None,
content_type: Optional[str] = None,
) -> None:
if body is not None:
warnings.warn(
"body argument is deprecated for http web exceptions",
DeprecationWarning,
)
Response.__init__(
self,
status=self.status_code,
headers=headers,
reason=reason,
body=body,
text=text,
content_type=content_type,
)
Exception.__init__(self, self.reason)
if self.body is None and not self.empty_body:
self.text = f"{self.status}: {self.reason}"
def __bool__(self) -> bool:
return True
class HTTPError(HTTPException):
"""Base class for exceptions with status codes in the 400s and 500s."""
class HTTPRedirection(HTTPException):
"""Base class for exceptions with status codes in the 300s."""
class HTTPSuccessful(HTTPException):
"""Base class for exceptions with status codes in the 200s."""
class HTTPOk(HTTPSuccessful):
status_code = 200
class HTTPCreated(HTTPSuccessful):
status_code = 201
class HTTPAccepted(HTTPSuccessful):
status_code = 202
class HTTPNonAuthoritativeInformation(HTTPSuccessful):
status_code = 203
class HTTPNoContent(HTTPSuccessful):
status_code = 204
empty_body = True
class HTTPResetContent(HTTPSuccessful):
status_code = 205
empty_body = True
class HTTPPartialContent(HTTPSuccessful):
status_code = 206
############################################################
# 3xx redirection
############################################################
class HTTPMove(HTTPRedirection):
def __init__(
self,
location: StrOrURL,
*,
headers: Optional[LooseHeaders] = None,
reason: Optional[str] = None,
body: Any = None,
text: Optional[str] = None,
content_type: Optional[str] = None,
) -> None:
if not location:
raise ValueError("HTTP redirects need a location to redirect to.")
super().__init__(
headers=headers,
reason=reason,
body=body,
text=text,
content_type=content_type,
)
self.headers["Location"] = str(URL(location))
self.location = location
class HTTPMultipleChoices(HTTPMove):
status_code = 300
class HTTPMovedPermanently(HTTPMove):
status_code = 301
class HTTPFound(HTTPMove):
status_code = 302
# This one is safe after a POST (the redirected location will be
# retrieved with GET):
class HTTPSeeOther(HTTPMove):
status_code = 303
class HTTPNotModified(HTTPRedirection):
# FIXME: this should include a date or etag header
status_code = 304
empty_body = True
class HTTPUseProxy(HTTPMove):
# Not a move, but looks a little like one
status_code = 305
class HTTPTemporaryRedirect(HTTPMove):
status_code = 307
class HTTPPermanentRedirect(HTTPMove):
status_code = 308
############################################################
# 4xx client error
############################################################
class HTTPClientError(HTTPError):
pass
class HTTPBadRequest(HTTPClientError):
status_code = 400
class HTTPUnauthorized(HTTPClientError):
status_code = 401
class HTTPPaymentRequired(HTTPClientError):
status_code = 402
class HTTPForbidden(HTTPClientError):
status_code = 403
class HTTPNotFound(HTTPClientError):
status_code = 404
class HTTPMethodNotAllowed(HTTPClientError):
status_code = 405
def __init__(
self,
method: str,
allowed_methods: Iterable[str],
*,
headers: Optional[LooseHeaders] = None,
reason: Optional[str] = None,
body: Any = None,
text: Optional[str] = None,
content_type: Optional[str] = None,
) -> None:
allow = ",".join(sorted(allowed_methods))
super().__init__(
headers=headers,
reason=reason,
body=body,
text=text,
content_type=content_type,
)
self.headers["Allow"] = allow
self.allowed_methods: Set[str] = set(allowed_methods)
self.method = method.upper()
class HTTPNotAcceptable(HTTPClientError):
status_code = 406
class HTTPProxyAuthenticationRequired(HTTPClientError):
status_code = 407
class HTTPRequestTimeout(HTTPClientError):
status_code = 408
class HTTPConflict(HTTPClientError):
status_code = 409
class HTTPGone(HTTPClientError):
status_code = 410
class HTTPLengthRequired(HTTPClientError):
status_code = 411
class HTTPPreconditionFailed(HTTPClientError):
status_code = 412
class HTTPRequestEntityTooLarge(HTTPClientError):
status_code = 413
def __init__(self, max_size: float, actual_size: float, **kwargs: Any) -> None:
kwargs.setdefault(
"text",
"Maximum request body size {} exceeded, "
"actual body size {}".format(max_size, actual_size),
)
super().__init__(**kwargs)
class HTTPRequestURITooLong(HTTPClientError):
status_code = 414
class HTTPUnsupportedMediaType(HTTPClientError):
status_code = 415
class HTTPRequestRangeNotSatisfiable(HTTPClientError):
status_code = 416
class HTTPExpectationFailed(HTTPClientError):
status_code = 417
class HTTPMisdirectedRequest(HTTPClientError):
status_code = 421
class HTTPUnprocessableEntity(HTTPClientError):
status_code = 422
class HTTPFailedDependency(HTTPClientError):
status_code = 424
class HTTPUpgradeRequired(HTTPClientError):
status_code = 426
class HTTPPreconditionRequired(HTTPClientError):
status_code = 428
class HTTPTooManyRequests(HTTPClientError):
status_code = 429
class HTTPRequestHeaderFieldsTooLarge(HTTPClientError):
status_code = 431
class HTTPUnavailableForLegalReasons(HTTPClientError):
status_code = 451
def __init__(
self,
link: Optional[StrOrURL],
*,
headers: Optional[LooseHeaders] = None,
reason: Optional[str] = None,
body: Any = None,
text: Optional[str] = None,
content_type: Optional[str] = None,
) -> None:
super().__init__(
headers=headers,
reason=reason,
body=body,
text=text,
content_type=content_type,
)
self._link = None
if link:
self._link = URL(link)
self.headers["Link"] = f'<{str(self._link)}>; rel="blocked-by"'
@property
def link(self) -> Optional[URL]:
return self._link
############################################################
# 5xx Server Error
############################################################
# Response status codes beginning with the digit "5" indicate cases in
# which the server is aware that it has erred or is incapable of
# performing the request. Except when responding to a HEAD request, the
# server SHOULD include an entity containing an explanation of the error
# situation, and whether it is a temporary or permanent condition. User
# agents SHOULD display any included entity to the user. These response
# codes are applicable to any request method.
class HTTPServerError(HTTPError):
pass
class HTTPInternalServerError(HTTPServerError):
status_code = 500
class HTTPNotImplemented(HTTPServerError):
status_code = 501
class HTTPBadGateway(HTTPServerError):
status_code = 502
class HTTPServiceUnavailable(HTTPServerError):
status_code = 503
class HTTPGatewayTimeout(HTTPServerError):
status_code = 504
class HTTPVersionNotSupported(HTTPServerError):
status_code = 505
class HTTPVariantAlsoNegotiates(HTTPServerError):
status_code = 506
class HTTPInsufficientStorage(HTTPServerError):
status_code = 507
class HTTPNotExtended(HTTPServerError):
status_code = 510
class HTTPNetworkAuthenticationRequired(HTTPServerError):
status_code = 511

View File

@@ -0,0 +1,418 @@
import asyncio
import io
import os
import pathlib
import sys
from contextlib import suppress
from enum import Enum, auto
from mimetypes import MimeTypes
from stat import S_ISREG
from types import MappingProxyType
from typing import ( # noqa
IO,
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Final,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import ETAG_ANY, ETag, must_be_empty_body
from .typedefs import LooseHeaders, PathLike
from .web_exceptions import (
HTTPForbidden,
HTTPNotFound,
HTTPNotModified,
HTTPPartialContent,
HTTPPreconditionFailed,
HTTPRequestRangeNotSatisfiable,
)
from .web_response import StreamResponse
__all__ = ("FileResponse",)
if TYPE_CHECKING:
from .web_request import BaseRequest
_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE"))
CONTENT_TYPES: Final[MimeTypes] = MimeTypes()
# File extension to IANA encodings map that will be checked in the order defined.
ENCODING_EXTENSIONS = MappingProxyType(
{ext: CONTENT_TYPES.encodings_map[ext] for ext in (".br", ".gz")}
)
FALLBACK_CONTENT_TYPE = "application/octet-stream"
# Provide additional MIME type/extension pairs to be recognized.
# https://en.wikipedia.org/wiki/List_of_archive_formats#Compression_only
ADDITIONAL_CONTENT_TYPES = MappingProxyType(
{
"application/gzip": ".gz",
"application/x-brotli": ".br",
"application/x-bzip2": ".bz2",
"application/x-compress": ".Z",
"application/x-xz": ".xz",
}
)
class _FileResponseResult(Enum):
"""The result of the file response."""
SEND_FILE = auto() # Ie a regular file to send
NOT_ACCEPTABLE = auto() # Ie a socket, or non-regular file
PRE_CONDITION_FAILED = auto() # Ie If-Match or If-None-Match failed
NOT_MODIFIED = auto() # 304 Not Modified
# Add custom pairs and clear the encodings map so guess_type ignores them.
CONTENT_TYPES.encodings_map.clear()
for content_type, extension in ADDITIONAL_CONTENT_TYPES.items():
CONTENT_TYPES.add_type(content_type, extension) # type: ignore[attr-defined]
_CLOSE_FUTURES: Set[asyncio.Future[None]] = set()
class FileResponse(StreamResponse):
"""A response object can be used to send files."""
def __init__(
self,
path: PathLike,
chunk_size: int = 256 * 1024,
status: int = 200,
reason: Optional[str] = None,
headers: Optional[LooseHeaders] = None,
) -> None:
super().__init__(status=status, reason=reason, headers=headers)
self._path = pathlib.Path(path)
self._chunk_size = chunk_size
def _seek_and_read(self, fobj: IO[Any], offset: int, chunk_size: int) -> bytes:
fobj.seek(offset)
return fobj.read(chunk_size) # type: ignore[no-any-return]
async def _sendfile_fallback(
self, writer: AbstractStreamWriter, fobj: IO[Any], offset: int, count: int
) -> AbstractStreamWriter:
# To keep memory usage low,fobj is transferred in chunks
# controlled by the constructor's chunk_size argument.
chunk_size = self._chunk_size
loop = asyncio.get_event_loop()
chunk = await loop.run_in_executor(
None, self._seek_and_read, fobj, offset, chunk_size
)
while chunk:
await writer.write(chunk)
count = count - chunk_size
if count <= 0:
break
chunk = await loop.run_in_executor(None, fobj.read, min(chunk_size, count))
await writer.drain()
return writer
async def _sendfile(
self, request: "BaseRequest", fobj: IO[Any], offset: int, count: int
) -> AbstractStreamWriter:
writer = await super().prepare(request)
assert writer is not None
if NOSENDFILE or self.compression:
return await self._sendfile_fallback(writer, fobj, offset, count)
loop = request._loop
transport = request.transport
assert transport is not None
try:
await loop.sendfile(transport, fobj, offset, count)
except NotImplementedError:
return await self._sendfile_fallback(writer, fobj, offset, count)
await super().write_eof()
return writer
@staticmethod
def _etag_match(etag_value: str, etags: Tuple[ETag, ...], *, weak: bool) -> bool:
if len(etags) == 1 and etags[0].value == ETAG_ANY:
return True
return any(
etag.value == etag_value for etag in etags if weak or not etag.is_weak
)
async def _not_modified(
self, request: "BaseRequest", etag_value: str, last_modified: float
) -> Optional[AbstractStreamWriter]:
self.set_status(HTTPNotModified.status_code)
self._length_check = False
self.etag = etag_value # type: ignore[assignment]
self.last_modified = last_modified # type: ignore[assignment]
# Delete any Content-Length headers provided by user. HTTP 304
# should always have empty response body
return await super().prepare(request)
async def _precondition_failed(
self, request: "BaseRequest"
) -> Optional[AbstractStreamWriter]:
self.set_status(HTTPPreconditionFailed.status_code)
self.content_length = 0
return await super().prepare(request)
def _make_response(
self, request: "BaseRequest", accept_encoding: str
) -> Tuple[
_FileResponseResult, Optional[io.BufferedReader], os.stat_result, Optional[str]
]:
"""Return the response result, io object, stat result, and encoding.
If an uncompressed file is returned, the encoding is set to
:py:data:`None`.
This method should be called from a thread executor
since it calls os.stat which may block.
"""
file_path, st, file_encoding = self._get_file_path_stat_encoding(
accept_encoding
)
if not file_path:
return _FileResponseResult.NOT_ACCEPTABLE, None, st, None
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.1-2
if (ifmatch := request.if_match) is not None and not self._etag_match(
etag_value, ifmatch, weak=False
):
return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding
if (
(unmodsince := request.if_unmodified_since) is not None
and ifmatch is None
and st.st_mtime > unmodsince.timestamp()
):
return _FileResponseResult.PRE_CONDITION_FAILED, None, st, file_encoding
# https://www.rfc-editor.org/rfc/rfc9110#section-13.1.2-2
if (ifnonematch := request.if_none_match) is not None and self._etag_match(
etag_value, ifnonematch, weak=True
):
return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding
if (
(modsince := request.if_modified_since) is not None
and ifnonematch is None
and st.st_mtime <= modsince.timestamp()
):
return _FileResponseResult.NOT_MODIFIED, None, st, file_encoding
fobj = file_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return _FileResponseResult.SEND_FILE, fobj, st, file_encoding
def _get_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[Optional[pathlib.Path], os.stat_result, Optional[str]]:
file_path = self._path
for file_extension, file_encoding in ENCODING_EXTENSIONS.items():
if file_encoding not in accept_encoding:
continue
compressed_path = file_path.with_suffix(file_path.suffix + file_extension)
with suppress(OSError):
# Do not follow symlinks and ignore any non-regular files.
st = compressed_path.lstat()
if S_ISREG(st.st_mode):
return compressed_path, st, file_encoding
# Fallback to the uncompressed file
st = file_path.stat()
return file_path if S_ISREG(st.st_mode) else None, st, None
async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
loop = asyncio.get_running_loop()
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
try:
response_result, fobj, st, file_encoding = await loop.run_in_executor(
None, self._make_response, request, accept_encoding
)
except PermissionError:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
except OSError:
# Most likely to be FileNotFoundError or OSError for circular
# symlinks in python >= 3.13, so respond with 404.
self.set_status(HTTPNotFound.status_code)
return await super().prepare(request)
# Forbid special files like sockets, pipes, devices, etc.
if response_result is _FileResponseResult.NOT_ACCEPTABLE:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
if response_result is _FileResponseResult.PRE_CONDITION_FAILED:
return await self._precondition_failed(request)
if response_result is _FileResponseResult.NOT_MODIFIED:
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime
return await self._not_modified(request, etag_value, last_modified)
assert fobj is not None
try:
return await self._prepare_open_file(request, fobj, st, file_encoding)
finally:
# We do not await here because we do not want to wait
# for the executor to finish before returning the response
# so the connection can begin servicing another request
# as soon as possible.
close_future = loop.run_in_executor(None, fobj.close)
# Hold a strong reference to the future to prevent it from being
# garbage collected before it completes.
_CLOSE_FUTURES.add(close_future)
close_future.add_done_callback(_CLOSE_FUTURES.remove)
async def _prepare_open_file(
self,
request: "BaseRequest",
fobj: io.BufferedReader,
st: os.stat_result,
file_encoding: Optional[str],
) -> Optional[AbstractStreamWriter]:
status = self._status
file_size: int = st.st_size
file_mtime: float = st.st_mtime
count: int = file_size
start: Optional[int] = None
if (ifrange := request.if_range) is None or file_mtime <= ifrange.timestamp():
# If-Range header check:
# condition = cached date >= last modification date
# return 206 if True else 200.
# if False:
# Range header would not be processed, return 200
# if True but Range header missing
# return 200
try:
rng = request.http_range
start = rng.start
end: Optional[int] = rng.stop
except ValueError:
# https://tools.ietf.org/html/rfc7233:
# A server generating a 416 (Range Not Satisfiable) response to
# a byte-range request SHOULD send a Content-Range header field
# with an unsatisfied-range value.
# The complete-length in a 416 response indicates the current
# length of the selected representation.
#
# Will do the same below. Many servers ignore this and do not
# send a Content-Range header with HTTP 416
self._headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}"
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return await super().prepare(request)
# If a range request has been made, convert start, end slice
# notation into file pointer offset and count
if start is not None:
if start < 0 and end is None: # return tail of file
start += file_size
if start < 0:
# if Range:bytes=-1000 in request header but file size
# is only 200, there would be trouble without this
start = 0
count = file_size - start
else:
# rfc7233:If the last-byte-pos value is
# absent, or if the value is greater than or equal to
# the current length of the representation data,
# the byte range is interpreted as the remainder
# of the representation (i.e., the server replaces the
# value of last-byte-pos with a value that is one less than
# the current length of the selected representation).
count = (
min(end if end is not None else file_size, file_size) - start
)
if start >= file_size:
# HTTP 416 should be returned in this case.
#
# According to https://tools.ietf.org/html/rfc7233:
# If a valid byte-range-set includes at least one
# byte-range-spec with a first-byte-pos that is less than
# the current length of the representation, or at least one
# suffix-byte-range-spec with a non-zero suffix-length,
# then the byte-range-set is satisfiable. Otherwise, the
# byte-range-set is unsatisfiable.
self._headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}"
self.set_status(HTTPRequestRangeNotSatisfiable.status_code)
return await super().prepare(request)
status = HTTPPartialContent.status_code
# Even though you are sending the whole file, you should still
# return a HTTP 206 for a Range request.
self.set_status(status)
# If the Content-Type header is not already set, guess it based on the
# extension of the request path. The encoding returned by guess_type
# can be ignored since the map was cleared above.
if hdrs.CONTENT_TYPE not in self._headers:
if sys.version_info >= (3, 13):
guesser = CONTENT_TYPES.guess_file_type
else:
guesser = CONTENT_TYPES.guess_type
self.content_type = guesser(self._path)[0] or FALLBACK_CONTENT_TYPE
if file_encoding:
self._headers[hdrs.CONTENT_ENCODING] = file_encoding
self._headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING
# Disable compression if we are already sending
# a compressed file since we don't want to double
# compress.
self._compression = False
self.etag = f"{st.st_mtime_ns:x}-{st.st_size:x}" # type: ignore[assignment]
self.last_modified = file_mtime # type: ignore[assignment]
self.content_length = count
self._headers[hdrs.ACCEPT_RANGES] = "bytes"
if status == HTTPPartialContent.status_code:
real_start = start
assert real_start is not None
self._headers[hdrs.CONTENT_RANGE] = "bytes {}-{}/{}".format(
real_start, real_start + count - 1, file_size
)
# If we are sending 0 bytes calling sendfile() will throw a ValueError
if count == 0 or must_be_empty_body(request.method, status):
return await super().prepare(request)
# be aware that start could be None or int=0 here.
offset = start or 0
return await self._sendfile(request, fobj, offset, count)

View File

@@ -0,0 +1,216 @@
import datetime
import functools
import logging
import os
import re
import time as time_mod
from collections import namedtuple
from typing import Any, Callable, Dict, Iterable, List, Tuple # noqa
from .abc import AbstractAccessLogger
from .web_request import BaseRequest
from .web_response import StreamResponse
KeyMethod = namedtuple("KeyMethod", "key method")
class AccessLogger(AbstractAccessLogger):
"""Helper object to log access.
Usage:
log = logging.getLogger("spam")
log_format = "%a %{User-Agent}i"
access_logger = AccessLogger(log, log_format)
access_logger.log(request, response, time)
Format:
%% The percent sign
%a Remote IP-address (IP-address of proxy if using reverse proxy)
%t Time when the request was started to process
%P The process ID of the child that serviced the request
%r First line of request
%s Response status code
%b Size of response in bytes, including HTTP headers
%T Time taken to serve the request, in seconds
%Tf Time taken to serve the request, in seconds with floating fraction
in .06f format
%D Time taken to serve the request, in microseconds
%{FOO}i request.headers['FOO']
%{FOO}o response.headers['FOO']
%{FOO}e os.environ['FOO']
"""
LOG_FORMAT_MAP = {
"a": "remote_address",
"t": "request_start_time",
"P": "process_id",
"r": "first_request_line",
"s": "response_status",
"b": "response_size",
"T": "request_time",
"Tf": "request_time_frac",
"D": "request_time_micro",
"i": "request_header",
"o": "response_header",
}
LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"'
FORMAT_RE = re.compile(r"%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)")
CLEANUP_RE = re.compile(r"(%[^s])")
_FORMAT_CACHE: Dict[str, Tuple[str, List[KeyMethod]]] = {}
def __init__(self, logger: logging.Logger, log_format: str = LOG_FORMAT) -> None:
"""Initialise the logger.
logger is a logger object to be used for logging.
log_format is a string with apache compatible log format description.
"""
super().__init__(logger, log_format=log_format)
_compiled_format = AccessLogger._FORMAT_CACHE.get(log_format)
if not _compiled_format:
_compiled_format = self.compile_format(log_format)
AccessLogger._FORMAT_CACHE[log_format] = _compiled_format
self._log_format, self._methods = _compiled_format
def compile_format(self, log_format: str) -> Tuple[str, List[KeyMethod]]:
"""Translate log_format into form usable by modulo formatting
All known atoms will be replaced with %s
Also methods for formatting of those atoms will be added to
_methods in appropriate order
For example we have log_format = "%a %t"
This format will be translated to "%s %s"
Also contents of _methods will be
[self._format_a, self._format_t]
These method will be called and results will be passed
to translated string format.
Each _format_* method receive 'args' which is list of arguments
given to self.log
Exceptions are _format_e, _format_i and _format_o methods which
also receive key name (by functools.partial)
"""
# list of (key, method) tuples, we don't use an OrderedDict as users
# can repeat the same key more than once
methods = list()
for atom in self.FORMAT_RE.findall(log_format):
if atom[1] == "":
format_key1 = self.LOG_FORMAT_MAP[atom[0]]
m = getattr(AccessLogger, "_format_%s" % atom[0])
key_method = KeyMethod(format_key1, m)
else:
format_key2 = (self.LOG_FORMAT_MAP[atom[2]], atom[1])
m = getattr(AccessLogger, "_format_%s" % atom[2])
key_method = KeyMethod(format_key2, functools.partial(m, atom[1]))
methods.append(key_method)
log_format = self.FORMAT_RE.sub(r"%s", log_format)
log_format = self.CLEANUP_RE.sub(r"%\1", log_format)
return log_format, methods
@staticmethod
def _format_i(
key: str, request: BaseRequest, response: StreamResponse, time: float
) -> str:
if request is None:
return "(no headers)"
# suboptimal, make istr(key) once
return request.headers.get(key, "-")
@staticmethod
def _format_o(
key: str, request: BaseRequest, response: StreamResponse, time: float
) -> str:
# suboptimal, make istr(key) once
return response.headers.get(key, "-")
@staticmethod
def _format_a(request: BaseRequest, response: StreamResponse, time: float) -> str:
if request is None:
return "-"
ip = request.remote
return ip if ip is not None else "-"
@staticmethod
def _format_t(request: BaseRequest, response: StreamResponse, time: float) -> str:
tz = datetime.timezone(datetime.timedelta(seconds=-time_mod.timezone))
now = datetime.datetime.now(tz)
start_time = now - datetime.timedelta(seconds=time)
return start_time.strftime("[%d/%b/%Y:%H:%M:%S %z]")
@staticmethod
def _format_P(request: BaseRequest, response: StreamResponse, time: float) -> str:
return "<%s>" % os.getpid()
@staticmethod
def _format_r(request: BaseRequest, response: StreamResponse, time: float) -> str:
if request is None:
return "-"
return "{} {} HTTP/{}.{}".format(
request.method,
request.path_qs,
request.version.major,
request.version.minor,
)
@staticmethod
def _format_s(request: BaseRequest, response: StreamResponse, time: float) -> int:
return response.status
@staticmethod
def _format_b(request: BaseRequest, response: StreamResponse, time: float) -> int:
return response.body_length
@staticmethod
def _format_T(request: BaseRequest, response: StreamResponse, time: float) -> str:
return str(round(time))
@staticmethod
def _format_Tf(request: BaseRequest, response: StreamResponse, time: float) -> str:
return "%06f" % time
@staticmethod
def _format_D(request: BaseRequest, response: StreamResponse, time: float) -> str:
return str(round(time * 1000000))
def _format_line(
self, request: BaseRequest, response: StreamResponse, time: float
) -> Iterable[Tuple[str, Callable[[BaseRequest, StreamResponse, float], str]]]:
return [(key, method(request, response, time)) for key, method in self._methods]
@property
def enabled(self) -> bool:
"""Check if logger is enabled."""
# Avoid formatting the log line if it will not be emitted.
return self.logger.isEnabledFor(logging.INFO)
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
try:
fmt_info = self._format_line(request, response, time)
values = list()
extra = dict()
for key, value in fmt_info:
values.append(value)
if key.__class__ is str:
extra[key] = value
else:
k1, k2 = key # type: ignore[misc]
dct = extra.get(k1, {}) # type: ignore[var-annotated,has-type]
dct[k2] = value # type: ignore[index,has-type]
extra[k1] = dct # type: ignore[has-type,assignment]
self.logger.info(self._log_format % tuple(values), extra=extra)
except Exception:
self.logger.exception("Error in logging")

View File

@@ -0,0 +1,121 @@
import re
from typing import TYPE_CHECKING, Tuple, Type, TypeVar
from .typedefs import Handler, Middleware
from .web_exceptions import HTTPMove, HTTPPermanentRedirect
from .web_request import Request
from .web_response import StreamResponse
from .web_urldispatcher import SystemRoute
__all__ = (
"middleware",
"normalize_path_middleware",
)
if TYPE_CHECKING:
from .web_app import Application
_Func = TypeVar("_Func")
async def _check_request_resolves(request: Request, path: str) -> Tuple[bool, Request]:
alt_request = request.clone(rel_url=path)
match_info = await request.app.router.resolve(alt_request)
alt_request._match_info = match_info
if match_info.http_exception is None:
return True, alt_request
return False, request
def middleware(f: _Func) -> _Func:
f.__middleware_version__ = 1 # type: ignore[attr-defined]
return f
def normalize_path_middleware(
*,
append_slash: bool = True,
remove_slash: bool = False,
merge_slashes: bool = True,
redirect_class: Type[HTTPMove] = HTTPPermanentRedirect,
) -> Middleware:
"""Factory for producing a middleware that normalizes the path of a request.
Normalizing means:
- Add or remove a trailing slash to the path.
- Double slashes are replaced by one.
The middleware returns as soon as it finds a path that resolves
correctly. The order if both merge and append/remove are enabled is
1) merge slashes
2) append/remove slash
3) both merge slashes and append/remove slash.
If the path resolves with at least one of those conditions, it will
redirect to the new path.
Only one of `append_slash` and `remove_slash` can be enabled. If both
are `True` the factory will raise an assertion error
If `append_slash` is `True` the middleware will append a slash when
needed. If a resource is defined with trailing slash and the request
comes without it, it will append it automatically.
If `remove_slash` is `True`, `append_slash` must be `False`. When enabled
the middleware will remove trailing slashes and redirect if the resource
is defined
If merge_slashes is True, merge multiple consecutive slashes in the
path into one.
"""
correct_configuration = not (append_slash and remove_slash)
assert correct_configuration, "Cannot both remove and append slash"
@middleware
async def impl(request: Request, handler: Handler) -> StreamResponse:
if isinstance(request.match_info.route, SystemRoute):
paths_to_check = []
if "?" in request.raw_path:
path, query = request.raw_path.split("?", 1)
query = "?" + query
else:
query = ""
path = request.raw_path
if merge_slashes:
paths_to_check.append(re.sub("//+", "/", path))
if append_slash and not request.path.endswith("/"):
paths_to_check.append(path + "/")
if remove_slash and request.path.endswith("/"):
paths_to_check.append(path[:-1])
if merge_slashes and append_slash:
paths_to_check.append(re.sub("//+", "/", path + "/"))
if merge_slashes and remove_slash:
merged_slashes = re.sub("//+", "/", path)
paths_to_check.append(merged_slashes[:-1])
for path in paths_to_check:
path = re.sub("^//+", "/", path) # SECURITY: GHSA-v6wp-4m6f-gcjg
resolves, request = await _check_request_resolves(request, path)
if resolves:
raise redirect_class(request.raw_path + query)
return await handler(request)
return impl
def _fix_request_current_app(app: "Application") -> Middleware:
@middleware
async def impl(request: Request, handler: Handler) -> StreamResponse:
match_info = request.match_info
prev = match_info.current_app
match_info.current_app = app
try:
return await handler(request)
finally:
match_info.current_app = prev
return impl

View File

@@ -0,0 +1,792 @@
import asyncio
import asyncio.streams
import sys
import traceback
import warnings
from collections import deque
from contextlib import suppress
from html import escape as html_escape
from http import HTTPStatus
from logging import Logger
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Deque,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
import attr
import yarl
from propcache import under_cached_property
from .abc import AbstractAccessLogger, AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import ceil_timeout
from .http import (
HttpProcessingError,
HttpRequestParser,
HttpVersion10,
RawRequestMessage,
StreamWriter,
)
from .http_exceptions import BadHttpMethod
from .log import access_logger, server_logger
from .streams import EMPTY_PAYLOAD, StreamReader
from .tcp_helpers import tcp_keepalive
from .web_exceptions import HTTPException, HTTPInternalServerError
from .web_log import AccessLogger
from .web_request import BaseRequest
from .web_response import Response, StreamResponse
__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
if TYPE_CHECKING:
import ssl
from .web_server import Server
_RequestFactory = Callable[
[
RawRequestMessage,
StreamReader,
"RequestHandler",
AbstractStreamWriter,
"asyncio.Task[None]",
],
BaseRequest,
]
_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
ERROR = RawRequestMessage(
"UNKNOWN",
"/",
HttpVersion10,
{}, # type: ignore[arg-type]
{}, # type: ignore[arg-type]
True,
None,
False,
False,
yarl.URL("/"),
)
class RequestPayloadError(Exception):
"""Payload parsing error."""
class PayloadAccessError(Exception):
"""Payload was accessed after response was sent."""
_PAYLOAD_ACCESS_ERROR = PayloadAccessError()
@attr.s(auto_attribs=True, frozen=True, slots=True)
class _ErrInfo:
status: int
exc: BaseException
message: str
_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader]
class RequestHandler(BaseProtocol):
"""HTTP protocol implementation.
RequestHandler handles incoming HTTP request. It reads request line,
request headers and request payload and calls handle_request() method.
By default it always returns with 404 response.
RequestHandler handles errors in incoming request, like bad
status line, bad headers or incomplete payload. If any error occurs,
connection gets closed.
keepalive_timeout -- number of seconds before closing
keep-alive connection
tcp_keepalive -- TCP keep-alive is on, default is on
debug -- enable debug mode
logger -- custom logger object
access_log_class -- custom class for access_logger
access_log -- custom logging object
access_log_format -- access log format string
loop -- Optional event loop
max_line_size -- Optional maximum header line size
max_field_size -- Optional maximum header field size
max_headers -- Optional maximum header size
timeout_ceil_threshold -- Optional value to specify
threshold to ceil() timeout
values
"""
__slots__ = (
"_request_count",
"_keepalive",
"_manager",
"_request_handler",
"_request_factory",
"_tcp_keepalive",
"_next_keepalive_close_time",
"_keepalive_handle",
"_keepalive_timeout",
"_lingering_time",
"_messages",
"_message_tail",
"_handler_waiter",
"_waiter",
"_task_handler",
"_upgrade",
"_payload_parser",
"_request_parser",
"_reading_paused",
"logger",
"debug",
"access_log",
"access_logger",
"_close",
"_force_close",
"_current_request",
"_timeout_ceil_threshold",
"_request_in_progress",
"_logging_enabled",
"_cache",
)
def __init__(
self,
manager: "Server",
*,
loop: asyncio.AbstractEventLoop,
# Default should be high enough that it's likely longer than a reverse proxy.
keepalive_timeout: float = 3630,
tcp_keepalive: bool = True,
logger: Logger = server_logger,
access_log_class: Type[AbstractAccessLogger] = AccessLogger,
access_log: Logger = access_logger,
access_log_format: str = AccessLogger.LOG_FORMAT,
debug: bool = False,
max_line_size: int = 8190,
max_headers: int = 32768,
max_field_size: int = 8190,
lingering_time: float = 10.0,
read_bufsize: int = 2**16,
auto_decompress: bool = True,
timeout_ceil_threshold: float = 5,
):
super().__init__(loop)
# _request_count is the number of requests processed with the same connection.
self._request_count = 0
self._keepalive = False
self._current_request: Optional[BaseRequest] = None
self._manager: Optional[Server] = manager
self._request_handler: Optional[_RequestHandler] = manager.request_handler
self._request_factory: Optional[_RequestFactory] = manager.request_factory
self._tcp_keepalive = tcp_keepalive
# placeholder to be replaced on keepalive timeout setup
self._next_keepalive_close_time = 0.0
self._keepalive_handle: Optional[asyncio.Handle] = None
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)
self._messages: Deque[_MsgType] = deque()
self._message_tail = b""
self._waiter: Optional[asyncio.Future[None]] = None
self._handler_waiter: Optional[asyncio.Future[None]] = None
self._task_handler: Optional[asyncio.Task[None]] = None
self._upgrade = False
self._payload_parser: Any = None
self._request_parser: Optional[HttpRequestParser] = HttpRequestParser(
self,
loop,
read_bufsize,
max_line_size=max_line_size,
max_field_size=max_field_size,
max_headers=max_headers,
payload_exception=RequestPayloadError,
auto_decompress=auto_decompress,
)
self._timeout_ceil_threshold: float = 5
try:
self._timeout_ceil_threshold = float(timeout_ceil_threshold)
except (TypeError, ValueError):
pass
self.logger = logger
self.debug = debug
self.access_log = access_log
if access_log:
self.access_logger: Optional[AbstractAccessLogger] = access_log_class(
access_log, access_log_format
)
self._logging_enabled = self.access_logger.enabled
else:
self.access_logger = None
self._logging_enabled = False
self._close = False
self._force_close = False
self._request_in_progress = False
self._cache: dict[str, Any] = {}
def __repr__(self) -> str:
return "<{} {}>".format(
self.__class__.__name__,
"connected" if self.transport is not None else "disconnected",
)
@under_cached_property
def ssl_context(self) -> Optional["ssl.SSLContext"]:
"""Return SSLContext if available."""
return (
None
if self.transport is None
else self.transport.get_extra_info("sslcontext")
)
@under_cached_property
def peername(
self,
) -> Optional[Union[str, Tuple[str, int, int, int], Tuple[str, int]]]:
"""Return peername if available."""
return (
None
if self.transport is None
else self.transport.get_extra_info("peername")
)
@property
def keepalive_timeout(self) -> float:
return self._keepalive_timeout
async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
"""Do worker process exit preparations.
We need to clean up everything and stop accepting requests.
It is especially important for keep-alive connections.
"""
self._force_close = True
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
# Wait for graceful handler completion
if self._request_in_progress:
# The future is only created when we are shutting
# down while the handler is still processing a request
# to avoid creating a future for every request.
self._handler_waiter = self._loop.create_future()
try:
async with ceil_timeout(timeout):
await self._handler_waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._handler_waiter = None
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
# Then cancel handler and wait
try:
async with ceil_timeout(timeout):
if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())
if self._task_handler is not None and not self._task_handler.done():
await asyncio.shield(self._task_handler)
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
# force-close non-idle handler
if self._task_handler is not None:
self._task_handler.cancel()
self.force_close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
super().connection_made(transport)
real_transport = cast(asyncio.Transport, transport)
if self._tcp_keepalive:
tcp_keepalive(real_transport)
assert self._manager is not None
self._manager.connection_made(self, real_transport)
loop = self._loop
if sys.version_info >= (3, 12):
task = asyncio.Task(self.start(), loop=loop, eager_start=True)
else:
task = loop.create_task(self.start())
self._task_handler = task
def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._manager is None:
return
self._manager.connection_lost(self, exc)
# Grab value before setting _manager to None.
handler_cancellation = self._manager.handler_cancellation
self.force_close()
super().connection_lost(exc)
self._manager = None
self._request_factory = None
self._request_handler = None
self._request_parser = None
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()
if self._current_request is not None:
if exc is None:
exc = ConnectionResetError("Connection lost")
self._current_request._cancel(exc)
if handler_cancellation and self._task_handler is not None:
self._task_handler.cancel()
self._task_handler = None
if self._payload_parser is not None:
self._payload_parser.feed_eof()
self._payload_parser = None
def set_parser(self, parser: Any) -> None:
# Actual type is WebReader
assert self._payload_parser is None
self._payload_parser = parser
if self._message_tail:
self._payload_parser.feed_data(self._message_tail)
self._message_tail = b""
def eof_received(self) -> None:
pass
def data_received(self, data: bytes) -> None:
if self._force_close or self._close:
return
# parse http messages
messages: Sequence[_MsgType]
if self._payload_parser is None and not self._upgrade:
assert self._request_parser is not None
try:
messages, upgraded, tail = self._request_parser.feed_data(data)
except HttpProcessingError as exc:
messages = [
(_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
]
upgraded = False
tail = b""
for msg, payload in messages or ():
self._request_count += 1
self._messages.append((msg, payload))
waiter = self._waiter
if messages and waiter is not None and not waiter.done():
# don't set result twice
waiter.set_result(None)
self._upgrade = upgraded
if upgraded and tail:
self._message_tail = tail
# no parser, just store
elif self._payload_parser is None and self._upgrade and data:
self._message_tail += data
# feed payload
elif data:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self.close()
def keep_alive(self, val: bool) -> None:
"""Set keep-alive connection mode.
:param bool val: new state.
"""
self._keepalive = val
if self._keepalive_handle:
self._keepalive_handle.cancel()
self._keepalive_handle = None
def close(self) -> None:
"""Close connection.
Stop accepting new pipelining messages and close
connection when handlers done processing messages.
"""
self._close = True
if self._waiter:
self._waiter.cancel()
def force_close(self) -> None:
"""Forcefully close connection."""
self._force_close = True
if self._waiter:
self._waiter.cancel()
if self.transport is not None:
self.transport.close()
self.transport = None
def log_access(
self, request: BaseRequest, response: StreamResponse, time: Optional[float]
) -> None:
if self.access_logger is not None and self.access_logger.enabled:
if TYPE_CHECKING:
assert time is not None
self.access_logger.log(request, response, self._loop.time() - time)
def log_debug(self, *args: Any, **kw: Any) -> None:
if self.debug:
self.logger.debug(*args, **kw)
def log_exception(self, *args: Any, **kw: Any) -> None:
self.logger.exception(*args, **kw)
def _process_keepalive(self) -> None:
self._keepalive_handle = None
if self._force_close or not self._keepalive:
return
loop = self._loop
now = loop.time()
close_time = self._next_keepalive_close_time
if now < close_time:
# Keep alive close check fired too early, reschedule
self._keepalive_handle = loop.call_at(close_time, self._process_keepalive)
return
# handler in idle state
if self._waiter and not self._waiter.done():
self.force_close()
async def _handle_request(
self,
request: BaseRequest,
start_time: Optional[float],
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
) -> Tuple[StreamResponse, bool]:
self._request_in_progress = True
try:
try:
self._current_request = request
resp = await request_handler(request)
finally:
self._current_request = None
except HTTPException as exc:
resp = exc
resp, reset = await self.finish_response(request, resp, start_time)
except asyncio.CancelledError:
raise
except asyncio.TimeoutError as exc:
self.log_debug("Request handler timed out.", exc_info=exc)
resp = self.handle_error(request, 504)
resp, reset = await self.finish_response(request, resp, start_time)
except Exception as exc:
resp = self.handle_error(request, 500, exc)
resp, reset = await self.finish_response(request, resp, start_time)
else:
# Deprecation warning (See #2415)
if getattr(resp, "__http_exception__", False):
warnings.warn(
"returning HTTPException object is deprecated "
"(#2415) and will be removed, "
"please raise the exception instead",
DeprecationWarning,
)
resp, reset = await self.finish_response(request, resp, start_time)
finally:
self._request_in_progress = False
if self._handler_waiter is not None:
self._handler_waiter.set_result(None)
return resp, reset
async def start(self) -> None:
"""Process incoming request.
It reads request line, request headers and request payload, then
calls handle_request() method. Subclass has to override
handle_request(). start() handles various exceptions in request
or response handling. Connection is being closed always unless
keep_alive(True) specified.
"""
loop = self._loop
manager = self._manager
assert manager is not None
keepalive_timeout = self._keepalive_timeout
resp = None
assert self._request_factory is not None
assert self._request_handler is not None
while not self._force_close:
if not self._messages:
try:
# wait for next request
self._waiter = loop.create_future()
await self._waiter
finally:
self._waiter = None
message, payload = self._messages.popleft()
# time is only fetched if logging is enabled as otherwise
# its thrown away and never used.
start = loop.time() if self._logging_enabled else None
manager.requests_count += 1
writer = StreamWriter(self, loop)
if isinstance(message, _ErrInfo):
# make request_factory work
request_handler = self._make_error_handler(message)
message = ERROR
else:
request_handler = self._request_handler
# Important don't hold a reference to the current task
# as on traceback it will prevent the task from being
# collected and will cause a memory leak.
request = self._request_factory(
message,
payload,
self,
writer,
self._task_handler or asyncio.current_task(loop), # type: ignore[arg-type]
)
try:
# a new task is used for copy context vars (#3406)
coro = self._handle_request(request, start, request_handler)
if sys.version_info >= (3, 12):
task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
task = loop.create_task(coro)
try:
resp, reset = await task
except ConnectionError:
self.log_debug("Ignored premature client disconnection")
break
# Drop the processed task from asyncio.Task.all_tasks() early
del task
if reset:
self.log_debug("Ignored premature client disconnection 2")
break
# notify server about keep-alive
self._keepalive = bool(resp.keep_alive)
# check payload
if not payload.is_eof():
lingering_time = self._lingering_time
if not self._force_close and lingering_time:
self.log_debug(
"Start lingering close timer for %s sec.", lingering_time
)
now = loop.time()
end_t = now + lingering_time
try:
while not payload.is_eof() and now < end_t:
async with ceil_timeout(end_t - now):
# read and ignore
await payload.readany()
now = loop.time()
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task())
and t.cancelling()
):
raise
# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
self.log_debug("Uncompleted request.")
self.close()
payload.set_exception(_PAYLOAD_ACCESS_ERROR)
except asyncio.CancelledError:
self.log_debug("Ignored premature client disconnection")
self.force_close()
raise
except Exception as exc:
self.log_exception("Unhandled exception", exc_info=exc)
self.force_close()
except BaseException:
self.force_close()
raise
finally:
request._task = None # type: ignore[assignment] # Break reference cycle in case of exception
if self.transport is None and resp is not None:
self.log_debug("Ignored premature client disconnection.")
if self._keepalive and not self._close and not self._force_close:
# start keep-alive timer
close_time = loop.time() + keepalive_timeout
self._next_keepalive_close_time = close_time
if self._keepalive_handle is None:
self._keepalive_handle = loop.call_at(
close_time, self._process_keepalive
)
else:
break
# remove handler, close transport if no handlers left
if not self._force_close:
self._task_handler = None
if self.transport is not None:
self.transport.close()
async def finish_response(
self, request: BaseRequest, resp: StreamResponse, start_time: Optional[float]
) -> Tuple[StreamResponse, bool]:
"""Prepare the response and write_eof, then log access.
This has to
be called within the context of any exception so the access logger
can get exception information. Returns True if the client disconnects
prematurely.
"""
request._finish()
if self._request_parser is not None:
self._request_parser.set_upgraded(False)
self._upgrade = False
if self._message_tail:
self._request_parser.feed_data(self._message_tail)
self._message_tail = b""
try:
prepare_meth = resp.prepare
except AttributeError:
if resp is None:
self.log_exception("Missing return statement on request handler")
else:
self.log_exception(
"Web-handler should return a response instance, "
"got {!r}".format(resp)
)
exc = HTTPInternalServerError()
resp = Response(
status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
)
prepare_meth = resp.prepare
try:
await prepare_meth(request)
await resp.write_eof()
except ConnectionError:
self.log_access(request, resp, start_time)
return resp, True
self.log_access(request, resp, start_time)
return resp, False
def handle_error(
self,
request: BaseRequest,
status: int = 500,
exc: Optional[BaseException] = None,
message: Optional[str] = None,
) -> StreamResponse:
"""Handle errors.
Returns HTTP response with specific status code. Logs additional
information. It always closes current connection.
"""
if self._request_count == 1 and isinstance(exc, BadHttpMethod):
# BadHttpMethod is common when a client sends non-HTTP
# or encrypted traffic to an HTTP port. This is expected
# to happen when connected to the public internet so we log
# it at the debug level as to not fill logs with noise.
self.logger.debug(
"Error handling request from %s", request.remote, exc_info=exc
)
else:
self.log_exception(
"Error handling request from %s", request.remote, exc_info=exc
)
# some data already got sent, connection is broken
if request.writer.output_size > 0:
raise ConnectionError(
"Response is sent already, cannot send another response "
"with the error message"
)
ct = "text/plain"
if status == HTTPStatus.INTERNAL_SERVER_ERROR:
title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
tb = None
if self.debug:
with suppress(Exception):
tb = traceback.format_exc()
if "text/html" in request.headers.get("Accept", ""):
if tb:
tb = html_escape(tb)
msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
message = (
"<html><head>"
"<title>{title}</title>"
"</head><body>\n<h1>{title}</h1>"
"\n{msg}\n</body></html>\n"
).format(title=title, msg=msg)
ct = "text/html"
else:
if tb:
msg = tb
message = title + "\n\n" + msg
resp = Response(status=status, text=message, content_type=ct)
resp.force_close()
return resp
def _make_error_handler(
self, err_info: _ErrInfo
) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
async def handler(request: BaseRequest) -> StreamResponse:
return self.handle_error(
request, err_info.status, err_info.exc, err_info.message
)
return handler

View File

@@ -0,0 +1,914 @@
import asyncio
import datetime
import io
import re
import socket
import string
import tempfile
import types
import warnings
from http.cookies import SimpleCookie
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Final,
Iterator,
Mapping,
MutableMapping,
Optional,
Pattern,
Tuple,
Union,
cast,
)
from urllib.parse import parse_qsl
import attr
from multidict import (
CIMultiDict,
CIMultiDictProxy,
MultiDict,
MultiDictProxy,
MultiMapping,
)
from yarl import URL
from . import hdrs
from .abc import AbstractStreamWriter
from .helpers import (
_SENTINEL,
DEBUG,
ETAG_ANY,
LIST_QUOTED_ETAG_RE,
ChainMapProxy,
ETag,
HeadersMixin,
parse_http_date,
reify,
sentinel,
set_exception,
)
from .http_parser import RawRequestMessage
from .http_writer import HttpVersion
from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
DEFAULT_JSON_DECODER,
JSONDecoder,
LooseHeaders,
RawHeaders,
StrOrURL,
)
from .web_exceptions import HTTPRequestEntityTooLarge
from .web_response import StreamResponse
__all__ = ("BaseRequest", "FileField", "Request")
if TYPE_CHECKING:
from .web_app import Application
from .web_protocol import RequestHandler
from .web_urldispatcher import UrlMappingMatchInfo
@attr.s(auto_attribs=True, frozen=True, slots=True)
class FileField:
name: str
filename: str
file: io.BufferedReader
content_type: str
headers: CIMultiDictProxy[str]
_TCHAR: Final[str] = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-"
# '-' at the end to prevent interpretation as range in a char class
_TOKEN: Final[str] = rf"[{_TCHAR}]+"
_QDTEXT: Final[str] = r"[{}]".format(
r"".join(chr(c) for c in (0x09, 0x20, 0x21) + tuple(range(0x23, 0x7F)))
)
# qdtext includes 0x5C to escape 0x5D ('\]')
# qdtext excludes obs-text (because obsoleted, and encoding not specified)
_QUOTED_PAIR: Final[str] = r"\\[\t !-~]"
_QUOTED_STRING: Final[str] = r'"(?:{quoted_pair}|{qdtext})*"'.format(
qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR
)
_FORWARDED_PAIR: Final[str] = (
r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format(
token=_TOKEN, quoted_string=_QUOTED_STRING
)
)
_QUOTED_PAIR_REPLACE_RE: Final[Pattern[str]] = re.compile(r"\\([\t !-~])")
# same pattern as _QUOTED_PAIR but contains a capture group
_FORWARDED_PAIR_RE: Final[Pattern[str]] = re.compile(_FORWARDED_PAIR)
############################################################
# HTTP Request
############################################################
class BaseRequest(MutableMapping[str, Any], HeadersMixin):
POST_METHODS = {
hdrs.METH_PATCH,
hdrs.METH_POST,
hdrs.METH_PUT,
hdrs.METH_TRACE,
hdrs.METH_DELETE,
}
ATTRS = HeadersMixin.ATTRS | frozenset(
[
"_message",
"_protocol",
"_payload_writer",
"_payload",
"_headers",
"_method",
"_version",
"_rel_url",
"_post",
"_read_bytes",
"_state",
"_cache",
"_task",
"_client_max_size",
"_loop",
"_transport_sslcontext",
"_transport_peername",
]
)
_post: Optional[MultiDictProxy[Union[str, bytes, FileField]]] = None
_read_bytes: Optional[bytes] = None
def __init__(
self,
message: RawRequestMessage,
payload: StreamReader,
protocol: "RequestHandler",
payload_writer: AbstractStreamWriter,
task: "asyncio.Task[None]",
loop: asyncio.AbstractEventLoop,
*,
client_max_size: int = 1024**2,
state: Optional[Dict[str, Any]] = None,
scheme: Optional[str] = None,
host: Optional[str] = None,
remote: Optional[str] = None,
) -> None:
self._message = message
self._protocol = protocol
self._payload_writer = payload_writer
self._payload = payload
self._headers: CIMultiDictProxy[str] = message.headers
self._method = message.method
self._version = message.version
self._cache: Dict[str, Any] = {}
url = message.url
if url.absolute:
if scheme is not None:
url = url.with_scheme(scheme)
if host is not None:
url = url.with_host(host)
# absolute URL is given,
# override auto-calculating url, host, and scheme
# all other properties should be good
self._cache["url"] = url
self._cache["host"] = url.host
self._cache["scheme"] = url.scheme
self._rel_url = url.relative()
else:
self._rel_url = url
if scheme is not None:
self._cache["scheme"] = scheme
if host is not None:
self._cache["host"] = host
self._state = {} if state is None else state
self._task = task
self._client_max_size = client_max_size
self._loop = loop
self._transport_sslcontext = protocol.ssl_context
self._transport_peername = protocol.peername
if remote is not None:
self._cache["remote"] = remote
def clone(
self,
*,
method: Union[str, _SENTINEL] = sentinel,
rel_url: Union[StrOrURL, _SENTINEL] = sentinel,
headers: Union[LooseHeaders, _SENTINEL] = sentinel,
scheme: Union[str, _SENTINEL] = sentinel,
host: Union[str, _SENTINEL] = sentinel,
remote: Union[str, _SENTINEL] = sentinel,
client_max_size: Union[int, _SENTINEL] = sentinel,
) -> "BaseRequest":
"""Clone itself with replacement some attributes.
Creates and returns a new instance of Request object. If no parameters
are given, an exact copy is returned. If a parameter is not passed, it
will reuse the one from the current request object.
"""
if self._read_bytes:
raise RuntimeError("Cannot clone request after reading its content")
dct: Dict[str, Any] = {}
if method is not sentinel:
dct["method"] = method
if rel_url is not sentinel:
new_url: URL = URL(rel_url)
dct["url"] = new_url
dct["path"] = str(new_url)
if headers is not sentinel:
# a copy semantic
dct["headers"] = CIMultiDictProxy(CIMultiDict(headers))
dct["raw_headers"] = tuple(
(k.encode("utf-8"), v.encode("utf-8"))
for k, v in dct["headers"].items()
)
message = self._message._replace(**dct)
kwargs = {}
if scheme is not sentinel:
kwargs["scheme"] = scheme
if host is not sentinel:
kwargs["host"] = host
if remote is not sentinel:
kwargs["remote"] = remote
if client_max_size is sentinel:
client_max_size = self._client_max_size
return self.__class__(
message,
self._payload,
self._protocol,
self._payload_writer,
self._task,
self._loop,
client_max_size=client_max_size,
state=self._state.copy(),
**kwargs,
)
@property
def task(self) -> "asyncio.Task[None]":
return self._task
@property
def protocol(self) -> "RequestHandler":
return self._protocol
@property
def transport(self) -> Optional[asyncio.Transport]:
if self._protocol is None:
return None
return self._protocol.transport
@property
def writer(self) -> AbstractStreamWriter:
return self._payload_writer
@property
def client_max_size(self) -> int:
return self._client_max_size
@reify
def message(self) -> RawRequestMessage:
warnings.warn("Request.message is deprecated", DeprecationWarning, stacklevel=3)
return self._message
@reify
def rel_url(self) -> URL:
return self._rel_url
@reify
def loop(self) -> asyncio.AbstractEventLoop:
warnings.warn(
"request.loop property is deprecated", DeprecationWarning, stacklevel=2
)
return self._loop
# MutableMapping API
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
return iter(self._state)
########
@reify
def secure(self) -> bool:
"""A bool indicating if the request is handled with SSL."""
return self.scheme == "https"
@reify
def forwarded(self) -> Tuple[Mapping[str, str], ...]:
"""A tuple containing all parsed Forwarded header(s).
Makes an effort to parse Forwarded headers as specified by RFC 7239:
- It adds one (immutable) dictionary per Forwarded 'field-value', ie
per proxy. The element corresponds to the data in the Forwarded
field-value added by the first proxy encountered by the client. Each
subsequent item corresponds to those added by later proxies.
- It checks that every value has valid syntax in general as specified
in section 4: either a 'token' or a 'quoted-string'.
- It un-escapes found escape sequences.
- It does NOT validate 'by' and 'for' contents as specified in section
6.
- It does NOT validate 'host' contents (Host ABNF).
- It does NOT validate 'proto' contents for valid URI scheme names.
Returns a tuple containing one or more immutable dicts
"""
elems = []
for field_value in self._message.headers.getall(hdrs.FORWARDED, ()):
length = len(field_value)
pos = 0
need_separator = False
elem: Dict[str, str] = {}
elems.append(types.MappingProxyType(elem))
while 0 <= pos < length:
match = _FORWARDED_PAIR_RE.match(field_value, pos)
if match is not None: # got a valid forwarded-pair
if need_separator:
# bad syntax here, skip to next comma
pos = field_value.find(",", pos)
else:
name, value, port = match.groups()
if value[0] == '"':
# quoted string: remove quotes and unescape
value = _QUOTED_PAIR_REPLACE_RE.sub(r"\1", value[1:-1])
if port:
value += port
elem[name.lower()] = value
pos += len(match.group(0))
need_separator = True
elif field_value[pos] == ",": # next forwarded-element
need_separator = False
elem = {}
elems.append(types.MappingProxyType(elem))
pos += 1
elif field_value[pos] == ";": # next forwarded-pair
need_separator = False
pos += 1
elif field_value[pos] in " \t":
# Allow whitespace even between forwarded-pairs, though
# RFC 7239 doesn't. This simplifies code and is in line
# with Postel's law.
pos += 1
else:
# bad syntax here, skip to next comma
pos = field_value.find(",", pos)
return tuple(elems)
@reify
def scheme(self) -> str:
"""A string representing the scheme of the request.
Hostname is resolved in this order:
- overridden value by .clone(scheme=new_scheme) call.
- type of connection to peer: HTTPS if socket is SSL, HTTP otherwise.
'http' or 'https'.
"""
if self._transport_sslcontext:
return "https"
else:
return "http"
@reify
def method(self) -> str:
"""Read only property for getting HTTP method.
The value is upper-cased str like 'GET', 'POST', 'PUT' etc.
"""
return self._method
@reify
def version(self) -> HttpVersion:
"""Read only property for getting HTTP version of request.
Returns aiohttp.protocol.HttpVersion instance.
"""
return self._version
@reify
def host(self) -> str:
"""Hostname of the request.
Hostname is resolved in this order:
- overridden value by .clone(host=new_host) call.
- HOST HTTP header
- socket.getfqdn() value
For example, 'example.com' or 'localhost:8080'.
For historical reasons, the port number may be included.
"""
host = self._message.headers.get(hdrs.HOST)
if host is not None:
return host
return socket.getfqdn()
@reify
def remote(self) -> Optional[str]:
"""Remote IP of client initiated HTTP request.
The IP is resolved in this order:
- overridden value by .clone(remote=new_remote) call.
- peername of opened socket
"""
if self._transport_peername is None:
return None
if isinstance(self._transport_peername, (list, tuple)):
return str(self._transport_peername[0])
return str(self._transport_peername)
@reify
def url(self) -> URL:
"""The full URL of the request."""
# authority is used here because it may include the port number
# and we want yarl to parse it correctly
return URL.build(scheme=self.scheme, authority=self.host).join(self._rel_url)
@reify
def path(self) -> str:
"""The URL including *PATH INFO* without the host or scheme.
E.g., ``/app/blog``
"""
return self._rel_url.path
@reify
def path_qs(self) -> str:
"""The URL including PATH_INFO and the query string.
E.g, /app/blog?id=10
"""
return str(self._rel_url)
@reify
def raw_path(self) -> str:
"""The URL including raw *PATH INFO* without the host or scheme.
Warning, the path is unquoted and may contains non valid URL characters
E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters``
"""
return self._message.path
@reify
def query(self) -> "MultiMapping[str]":
"""A multidict with all the variables in the query string."""
return self._rel_url.query
@reify
def query_string(self) -> str:
"""The query string in the URL.
E.g., id=10
"""
return self._rel_url.query_string
@reify
def headers(self) -> CIMultiDictProxy[str]:
"""A case-insensitive multidict proxy with all headers."""
return self._headers
@reify
def raw_headers(self) -> RawHeaders:
"""A sequence of pairs for all headers."""
return self._message.raw_headers
@reify
def if_modified_since(self) -> Optional[datetime.datetime]:
"""The value of If-Modified-Since HTTP header, or None.
This header is represented as a `datetime` object.
"""
return parse_http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE))
@reify
def if_unmodified_since(self) -> Optional[datetime.datetime]:
"""The value of If-Unmodified-Since HTTP header, or None.
This header is represented as a `datetime` object.
"""
return parse_http_date(self.headers.get(hdrs.IF_UNMODIFIED_SINCE))
@staticmethod
def _etag_values(etag_header: str) -> Iterator[ETag]:
"""Extract `ETag` objects from raw header."""
if etag_header == ETAG_ANY:
yield ETag(
is_weak=False,
value=ETAG_ANY,
)
else:
for match in LIST_QUOTED_ETAG_RE.finditer(etag_header):
is_weak, value, garbage = match.group(2, 3, 4)
# Any symbol captured by 4th group means
# that the following sequence is invalid.
if garbage:
break
yield ETag(
is_weak=bool(is_weak),
value=value,
)
@classmethod
def _if_match_or_none_impl(
cls, header_value: Optional[str]
) -> Optional[Tuple[ETag, ...]]:
if not header_value:
return None
return tuple(cls._etag_values(header_value))
@reify
def if_match(self) -> Optional[Tuple[ETag, ...]]:
"""The value of If-Match HTTP header, or None.
This header is represented as a `tuple` of `ETag` objects.
"""
return self._if_match_or_none_impl(self.headers.get(hdrs.IF_MATCH))
@reify
def if_none_match(self) -> Optional[Tuple[ETag, ...]]:
"""The value of If-None-Match HTTP header, or None.
This header is represented as a `tuple` of `ETag` objects.
"""
return self._if_match_or_none_impl(self.headers.get(hdrs.IF_NONE_MATCH))
@reify
def if_range(self) -> Optional[datetime.datetime]:
"""The value of If-Range HTTP header, or None.
This header is represented as a `datetime` object.
"""
return parse_http_date(self.headers.get(hdrs.IF_RANGE))
@reify
def keep_alive(self) -> bool:
"""Is keepalive enabled by client?"""
return not self._message.should_close
@reify
def cookies(self) -> Mapping[str, str]:
"""Return request cookies.
A read-only dictionary-like object.
"""
raw = self.headers.get(hdrs.COOKIE, "")
parsed = SimpleCookie(raw)
return MappingProxyType({key: val.value for key, val in parsed.items()})
@reify
def http_range(self) -> slice:
"""The content of Range HTTP header.
Return a slice instance.
"""
rng = self._headers.get(hdrs.RANGE)
start, end = None, None
if rng is not None:
try:
pattern = r"^bytes=(\d*)-(\d*)$"
start, end = re.findall(pattern, rng)[0]
except IndexError: # pattern was not found in header
raise ValueError("range not in acceptable format")
end = int(end) if end else None
start = int(start) if start else None
if start is None and end is not None:
# end with no start is to return tail of content
start = -end
end = None
if start is not None and end is not None:
# end is inclusive in range header, exclusive for slice
end += 1
if start >= end:
raise ValueError("start cannot be after end")
if start is end is None: # No valid range supplied
raise ValueError("No start or end of range specified")
return slice(start, end, 1)
@reify
def content(self) -> StreamReader:
"""Return raw payload stream."""
return self._payload
@property
def has_body(self) -> bool:
"""Return True if request's HTTP BODY can be read, False otherwise."""
warnings.warn(
"Deprecated, use .can_read_body #2005", DeprecationWarning, stacklevel=2
)
return not self._payload.at_eof()
@property
def can_read_body(self) -> bool:
"""Return True if request's HTTP BODY can be read, False otherwise."""
return not self._payload.at_eof()
@reify
def body_exists(self) -> bool:
"""Return True if request has HTTP BODY, False otherwise."""
return type(self._payload) is not EmptyStreamReader
async def release(self) -> None:
"""Release request.
Eat unread part of HTTP BODY if present.
"""
while not self._payload.at_eof():
await self._payload.readany()
async def read(self) -> bytes:
"""Read request body if present.
Returns bytes object with full request content.
"""
if self._read_bytes is None:
body = bytearray()
while True:
chunk = await self._payload.readany()
body.extend(chunk)
if self._client_max_size:
body_size = len(body)
if body_size >= self._client_max_size:
raise HTTPRequestEntityTooLarge(
max_size=self._client_max_size, actual_size=body_size
)
if not chunk:
break
self._read_bytes = bytes(body)
return self._read_bytes
async def text(self) -> str:
"""Return BODY as text using encoding from .charset."""
bytes_body = await self.read()
encoding = self.charset or "utf-8"
return bytes_body.decode(encoding)
async def json(self, *, loads: JSONDecoder = DEFAULT_JSON_DECODER) -> Any:
"""Return BODY as JSON."""
body = await self.text()
return loads(body)
async def multipart(self) -> MultipartReader:
"""Return async iterator to process BODY as multipart."""
return MultipartReader(self._headers, self._payload)
async def post(self) -> "MultiDictProxy[Union[str, bytes, FileField]]":
"""Return POST parameters."""
if self._post is not None:
return self._post
if self._method not in self.POST_METHODS:
self._post = MultiDictProxy(MultiDict())
return self._post
content_type = self.content_type
if content_type not in (
"",
"application/x-www-form-urlencoded",
"multipart/form-data",
):
self._post = MultiDictProxy(MultiDict())
return self._post
out: MultiDict[Union[str, bytes, FileField]] = MultiDict()
if content_type == "multipart/form-data":
multipart = await self.multipart()
max_size = self._client_max_size
field = await multipart.next()
while field is not None:
size = 0
field_ct = field.headers.get(hdrs.CONTENT_TYPE)
if isinstance(field, BodyPartReader):
assert field.name is not None
# Note that according to RFC 7578, the Content-Type header
# is optional, even for files, so we can't assume it's
# present.
# https://tools.ietf.org/html/rfc7578#section-4.4
if field.filename:
# store file in temp file
tmp = await self._loop.run_in_executor(
None, tempfile.TemporaryFile
)
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
await self._loop.run_in_executor(None, tmp.write, chunk)
size += len(chunk)
if 0 < max_size < size:
await self._loop.run_in_executor(None, tmp.close)
raise HTTPRequestEntityTooLarge(
max_size=max_size, actual_size=size
)
chunk = await field.read_chunk(size=2**16)
await self._loop.run_in_executor(None, tmp.seek, 0)
if field_ct is None:
field_ct = "application/octet-stream"
ff = FileField(
field.name,
field.filename,
cast(io.BufferedReader, tmp),
field_ct,
field.headers,
)
out.add(field.name, ff)
else:
# deal with ordinary data
value = await field.read(decode=True)
if field_ct is None or field_ct.startswith("text/"):
charset = field.get_charset(default="utf-8")
out.add(field.name, value.decode(charset))
else:
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size, actual_size=size
)
else:
raise ValueError(
"To decode nested multipart you need to use custom reader",
)
field = await multipart.next()
else:
data = await self.read()
if data:
charset = self.charset or "utf-8"
out.extend(
parse_qsl(
data.rstrip().decode(charset),
keep_blank_values=True,
encoding=charset,
)
)
self._post = MultiDictProxy(out)
return self._post
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""Extra info from protocol transport"""
protocol = self._protocol
if protocol is None:
return default
transport = protocol.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def __repr__(self) -> str:
ascii_encodable_path = self.path.encode("ascii", "backslashreplace").decode(
"ascii"
)
return "<{} {} {} >".format(
self.__class__.__name__, self._method, ascii_encodable_path
)
def __eq__(self, other: object) -> bool:
return id(self) == id(other)
def __bool__(self) -> bool:
return True
async def _prepare_hook(self, response: StreamResponse) -> None:
return
def _cancel(self, exc: BaseException) -> None:
set_exception(self._payload, exc)
def _finish(self) -> None:
if self._post is None or self.content_type != "multipart/form-data":
return
# NOTE: Release file descriptors for the
# NOTE: `tempfile.Temporaryfile`-created `_io.BufferedRandom`
# NOTE: instances of files sent within multipart request body
# NOTE: via HTTP POST request.
for file_name, file_field_object in self._post.items():
if isinstance(file_field_object, FileField):
file_field_object.file.close()
class Request(BaseRequest):
ATTRS = BaseRequest.ATTRS | frozenset(["_match_info"])
_match_info: Optional["UrlMappingMatchInfo"] = None
if DEBUG:
def __setattr__(self, name: str, val: Any) -> None:
if name not in self.ATTRS:
warnings.warn(
"Setting custom {}.{} attribute "
"is discouraged".format(self.__class__.__name__, name),
DeprecationWarning,
stacklevel=2,
)
super().__setattr__(name, val)
def clone(
self,
*,
method: Union[str, _SENTINEL] = sentinel,
rel_url: Union[StrOrURL, _SENTINEL] = sentinel,
headers: Union[LooseHeaders, _SENTINEL] = sentinel,
scheme: Union[str, _SENTINEL] = sentinel,
host: Union[str, _SENTINEL] = sentinel,
remote: Union[str, _SENTINEL] = sentinel,
client_max_size: Union[int, _SENTINEL] = sentinel,
) -> "Request":
ret = super().clone(
method=method,
rel_url=rel_url,
headers=headers,
scheme=scheme,
host=host,
remote=remote,
client_max_size=client_max_size,
)
new_ret = cast(Request, ret)
new_ret._match_info = self._match_info
return new_ret
@reify
def match_info(self) -> "UrlMappingMatchInfo":
"""Result of route resolving."""
match_info = self._match_info
assert match_info is not None
return match_info
@property
def app(self) -> "Application":
"""Application instance."""
match_info = self._match_info
assert match_info is not None
return match_info.current_app
@property
def config_dict(self) -> ChainMapProxy:
match_info = self._match_info
assert match_info is not None
lst = match_info.apps
app = self.app
idx = lst.index(app)
sublist = list(reversed(lst[: idx + 1]))
return ChainMapProxy(sublist)
async def _prepare_hook(self, response: StreamResponse) -> None:
match_info = self._match_info
if match_info is None:
return
for app in match_info._apps:
if on_response_prepare := app.on_response_prepare:
await on_response_prepare.send(self, response)

View File

@@ -0,0 +1,838 @@
import asyncio
import collections.abc
import datetime
import enum
import json
import math
import time
import warnings
import zlib
from concurrent.futures import Executor
from http import HTTPStatus
from http.cookies import SimpleCookie
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
MutableMapping,
Optional,
Union,
cast,
)
from multidict import CIMultiDict, istr
from . import hdrs, payload
from .abc import AbstractStreamWriter
from .compression_utils import ZLibCompressor
from .helpers import (
ETAG_ANY,
QUOTED_ETAG_RE,
ETag,
HeadersMixin,
must_be_empty_body,
parse_http_date,
rfc822_formatted_time,
sentinel,
should_remove_content_length,
validate_etag_value,
)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11
from .payload import Payload
from .typedefs import JSONEncoder, LooseHeaders
REASON_PHRASES = {http_status.value: http_status.phrase for http_status in HTTPStatus}
LARGE_BODY_SIZE = 1024**2
__all__ = ("ContentCoding", "StreamResponse", "Response", "json_response")
if TYPE_CHECKING:
from .web_request import BaseRequest
BaseClass = MutableMapping[str, Any]
else:
BaseClass = collections.abc.MutableMapping
# TODO(py311): Convert to StrEnum for wider use
class ContentCoding(enum.Enum):
# The content codings that we have support for.
#
# Additional registered codings are listed at:
# https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding
deflate = "deflate"
gzip = "gzip"
identity = "identity"
CONTENT_CODINGS = {coding.value: coding for coding in ContentCoding}
############################################################
# HTTP Response classes
############################################################
class StreamResponse(BaseClass, HeadersMixin):
_body: Union[None, bytes, bytearray, Payload]
_length_check = True
_body = None
_keep_alive: Optional[bool] = None
_chunked: bool = False
_compression: bool = False
_compression_strategy: int = zlib.Z_DEFAULT_STRATEGY
_compression_force: Optional[ContentCoding] = None
_req: Optional["BaseRequest"] = None
_payload_writer: Optional[AbstractStreamWriter] = None
_eof_sent: bool = False
_must_be_empty_body: Optional[bool] = None
_body_length = 0
_cookies: Optional[SimpleCookie] = None
def __init__(
self,
*,
status: int = 200,
reason: Optional[str] = None,
headers: Optional[LooseHeaders] = None,
_real_headers: Optional[CIMultiDict[str]] = None,
) -> None:
"""Initialize a new stream response object.
_real_headers is an internal parameter used to pass a pre-populated
headers object. It is used by the `Response` class to avoid copying
the headers when creating a new response object. It is not intended
to be used by external code.
"""
self._state: Dict[str, Any] = {}
if _real_headers is not None:
self._headers = _real_headers
elif headers is not None:
self._headers: CIMultiDict[str] = CIMultiDict(headers)
else:
self._headers = CIMultiDict()
self._set_status(status, reason)
@property
def prepared(self) -> bool:
return self._eof_sent or self._payload_writer is not None
@property
def task(self) -> "Optional[asyncio.Task[None]]":
if self._req:
return self._req.task
else:
return None
@property
def status(self) -> int:
return self._status
@property
def chunked(self) -> bool:
return self._chunked
@property
def compression(self) -> bool:
return self._compression
@property
def reason(self) -> str:
return self._reason
def set_status(
self,
status: int,
reason: Optional[str] = None,
) -> None:
assert (
not self.prepared
), "Cannot change the response status code after the headers have been sent"
self._set_status(status, reason)
def _set_status(self, status: int, reason: Optional[str]) -> None:
self._status = int(status)
if reason is None:
reason = REASON_PHRASES.get(self._status, "")
elif "\n" in reason:
raise ValueError("Reason cannot contain \\n")
self._reason = reason
@property
def keep_alive(self) -> Optional[bool]:
return self._keep_alive
def force_close(self) -> None:
self._keep_alive = False
@property
def body_length(self) -> int:
return self._body_length
@property
def output_length(self) -> int:
warnings.warn("output_length is deprecated", DeprecationWarning)
assert self._payload_writer
return self._payload_writer.buffer_size
def enable_chunked_encoding(self, chunk_size: Optional[int] = None) -> None:
"""Enables automatic chunked transfer encoding."""
if hdrs.CONTENT_LENGTH in self._headers:
raise RuntimeError(
"You can't enable chunked encoding when a content length is set"
)
if chunk_size is not None:
warnings.warn("Chunk size is deprecated #1615", DeprecationWarning)
self._chunked = True
def enable_compression(
self,
force: Optional[Union[bool, ContentCoding]] = None,
strategy: int = zlib.Z_DEFAULT_STRATEGY,
) -> None:
"""Enables response compression encoding."""
# Backwards compatibility for when force was a bool <0.17.
if isinstance(force, bool):
force = ContentCoding.deflate if force else ContentCoding.identity
warnings.warn(
"Using boolean for force is deprecated #3318", DeprecationWarning
)
elif force is not None:
assert isinstance(
force, ContentCoding
), "force should one of None, bool or ContentEncoding"
self._compression = True
self._compression_force = force
self._compression_strategy = strategy
@property
def headers(self) -> "CIMultiDict[str]":
return self._headers
@property
def cookies(self) -> SimpleCookie:
if self._cookies is None:
self._cookies = SimpleCookie()
return self._cookies
def set_cookie(
self,
name: str,
value: str,
*,
expires: Optional[str] = None,
domain: Optional[str] = None,
max_age: Optional[Union[int, str]] = None,
path: str = "/",
secure: Optional[bool] = None,
httponly: Optional[bool] = None,
version: Optional[str] = None,
samesite: Optional[str] = None,
) -> None:
"""Set or update response cookie.
Sets new cookie or updates existent with new value.
Also updates only those params which are not None.
"""
if self._cookies is None:
self._cookies = SimpleCookie()
self._cookies[name] = value
c = self._cookies[name]
if expires is not None:
c["expires"] = expires
elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT":
del c["expires"]
if domain is not None:
c["domain"] = domain
if max_age is not None:
c["max-age"] = str(max_age)
elif "max-age" in c:
del c["max-age"]
c["path"] = path
if secure is not None:
c["secure"] = secure
if httponly is not None:
c["httponly"] = httponly
if version is not None:
c["version"] = version
if samesite is not None:
c["samesite"] = samesite
def del_cookie(
self,
name: str,
*,
domain: Optional[str] = None,
path: str = "/",
secure: Optional[bool] = None,
httponly: Optional[bool] = None,
samesite: Optional[str] = None,
) -> None:
"""Delete cookie.
Creates new empty expired cookie.
"""
# TODO: do we need domain/path here?
if self._cookies is not None:
self._cookies.pop(name, None)
self.set_cookie(
name,
"",
max_age=0,
expires="Thu, 01 Jan 1970 00:00:00 GMT",
domain=domain,
path=path,
secure=secure,
httponly=httponly,
samesite=samesite,
)
@property
def content_length(self) -> Optional[int]:
# Just a placeholder for adding setter
return super().content_length
@content_length.setter
def content_length(self, value: Optional[int]) -> None:
if value is not None:
value = int(value)
if self._chunked:
raise RuntimeError(
"You can't set content length when chunked encoding is enable"
)
self._headers[hdrs.CONTENT_LENGTH] = str(value)
else:
self._headers.pop(hdrs.CONTENT_LENGTH, None)
@property
def content_type(self) -> str:
# Just a placeholder for adding setter
return super().content_type
@content_type.setter
def content_type(self, value: str) -> None:
self.content_type # read header values if needed
self._content_type = str(value)
self._generate_content_type_header()
@property
def charset(self) -> Optional[str]:
# Just a placeholder for adding setter
return super().charset
@charset.setter
def charset(self, value: Optional[str]) -> None:
ctype = self.content_type # read header values if needed
if ctype == "application/octet-stream":
raise RuntimeError(
"Setting charset for application/octet-stream "
"doesn't make sense, setup content_type first"
)
assert self._content_dict is not None
if value is None:
self._content_dict.pop("charset", None)
else:
self._content_dict["charset"] = str(value).lower()
self._generate_content_type_header()
@property
def last_modified(self) -> Optional[datetime.datetime]:
"""The value of Last-Modified HTTP header, or None.
This header is represented as a `datetime` object.
"""
return parse_http_date(self._headers.get(hdrs.LAST_MODIFIED))
@last_modified.setter
def last_modified(
self, value: Optional[Union[int, float, datetime.datetime, str]]
) -> None:
if value is None:
self._headers.pop(hdrs.LAST_MODIFIED, None)
elif isinstance(value, (int, float)):
self._headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value))
)
elif isinstance(value, datetime.datetime):
self._headers[hdrs.LAST_MODIFIED] = time.strftime(
"%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple()
)
elif isinstance(value, str):
self._headers[hdrs.LAST_MODIFIED] = value
@property
def etag(self) -> Optional[ETag]:
quoted_value = self._headers.get(hdrs.ETAG)
if not quoted_value:
return None
elif quoted_value == ETAG_ANY:
return ETag(value=ETAG_ANY)
match = QUOTED_ETAG_RE.fullmatch(quoted_value)
if not match:
return None
is_weak, value = match.group(1, 2)
return ETag(
is_weak=bool(is_weak),
value=value,
)
@etag.setter
def etag(self, value: Optional[Union[ETag, str]]) -> None:
if value is None:
self._headers.pop(hdrs.ETAG, None)
elif (isinstance(value, str) and value == ETAG_ANY) or (
isinstance(value, ETag) and value.value == ETAG_ANY
):
self._headers[hdrs.ETAG] = ETAG_ANY
elif isinstance(value, str):
validate_etag_value(value)
self._headers[hdrs.ETAG] = f'"{value}"'
elif isinstance(value, ETag) and isinstance(value.value, str):
validate_etag_value(value.value)
hdr_value = f'W/"{value.value}"' if value.is_weak else f'"{value.value}"'
self._headers[hdrs.ETAG] = hdr_value
else:
raise ValueError(
f"Unsupported etag type: {type(value)}. "
f"etag must be str, ETag or None"
)
def _generate_content_type_header(
self, CONTENT_TYPE: istr = hdrs.CONTENT_TYPE
) -> None:
assert self._content_dict is not None
assert self._content_type is not None
params = "; ".join(f"{k}={v}" for k, v in self._content_dict.items())
if params:
ctype = self._content_type + "; " + params
else:
ctype = self._content_type
self._headers[CONTENT_TYPE] = ctype
async def _do_start_compression(self, coding: ContentCoding) -> None:
if coding is ContentCoding.identity:
return
assert self._payload_writer is not None
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._payload_writer.enable_compression(
coding.value, self._compression_strategy
)
# Compressed payload may have different content length,
# remove the header
self._headers.popall(hdrs.CONTENT_LENGTH, None)
async def _start_compression(self, request: "BaseRequest") -> None:
if self._compression_force:
await self._do_start_compression(self._compression_force)
return
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
for value, coding in CONTENT_CODINGS.items():
if value in accept_encoding:
await self._do_start_compression(coding)
return
async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
if self._eof_sent:
return None
if self._payload_writer is not None:
return self._payload_writer
self._must_be_empty_body = must_be_empty_body(request.method, self.status)
return await self._start(request)
async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
self._req = request
writer = self._payload_writer = request._payload_writer
await self._prepare_headers()
await request._prepare_hook(self)
await self._write_headers()
return writer
async def _prepare_headers(self) -> None:
request = self._req
assert request is not None
writer = self._payload_writer
assert writer is not None
keep_alive = self._keep_alive
if keep_alive is None:
keep_alive = request.keep_alive
self._keep_alive = keep_alive
version = request.version
headers = self._headers
if self._cookies:
for cookie in self._cookies.values():
value = cookie.output(header="")[1:]
headers.add(hdrs.SET_COOKIE, value)
if self._compression:
await self._start_compression(request)
if self._chunked:
if version != HttpVersion11:
raise RuntimeError(
"Using chunked encoding is forbidden "
"for HTTP/{0.major}.{0.minor}".format(request.version)
)
if not self._must_be_empty_body:
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = "chunked"
elif self._length_check: # Disabled for WebSockets
writer.length = self.content_length
if writer.length is None:
if version >= HttpVersion11:
if not self._must_be_empty_body:
writer.enable_chunking()
headers[hdrs.TRANSFER_ENCODING] = "chunked"
elif not self._must_be_empty_body:
keep_alive = False
# HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2
# HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4
if self._must_be_empty_body:
if hdrs.CONTENT_LENGTH in headers and should_remove_content_length(
request.method, self.status
):
del headers[hdrs.CONTENT_LENGTH]
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-10
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-13
if hdrs.TRANSFER_ENCODING in headers:
del headers[hdrs.TRANSFER_ENCODING]
elif (writer.length if self._length_check else self.content_length) != 0:
# https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream")
headers.setdefault(hdrs.DATE, rfc822_formatted_time())
headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE)
# connection header
if hdrs.CONNECTION not in headers:
if keep_alive:
if version == HttpVersion10:
headers[hdrs.CONNECTION] = "keep-alive"
elif version == HttpVersion11:
headers[hdrs.CONNECTION] = "close"
async def _write_headers(self) -> None:
request = self._req
assert request is not None
writer = self._payload_writer
assert writer is not None
# status line
version = request.version
status_line = f"HTTP/{version[0]}.{version[1]} {self._status} {self._reason}"
await writer.write_headers(status_line, self._headers)
async def write(self, data: Union[bytes, bytearray, memoryview]) -> None:
assert isinstance(
data, (bytes, bytearray, memoryview)
), "data argument must be byte-ish (%r)" % type(data)
if self._eof_sent:
raise RuntimeError("Cannot call write() after write_eof()")
if self._payload_writer is None:
raise RuntimeError("Cannot call write() before prepare()")
await self._payload_writer.write(data)
async def drain(self) -> None:
assert not self._eof_sent, "EOF has already been sent"
assert self._payload_writer is not None, "Response has not been started"
warnings.warn(
"drain method is deprecated, use await resp.write()",
DeprecationWarning,
stacklevel=2,
)
await self._payload_writer.drain()
async def write_eof(self, data: bytes = b"") -> None:
assert isinstance(
data, (bytes, bytearray, memoryview)
), "data argument must be byte-ish (%r)" % type(data)
if self._eof_sent:
return
assert self._payload_writer is not None, "Response has not been started"
await self._payload_writer.write_eof(data)
self._eof_sent = True
self._req = None
self._body_length = self._payload_writer.output_size
self._payload_writer = None
def __repr__(self) -> str:
if self._eof_sent:
info = "eof"
elif self.prepared:
assert self._req is not None
info = f"{self._req.method} {self._req.path} "
else:
info = "not prepared"
return f"<{self.__class__.__name__} {self.reason} {info}>"
def __getitem__(self, key: str) -> Any:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._state[key] = value
def __delitem__(self, key: str) -> None:
del self._state[key]
def __len__(self) -> int:
return len(self._state)
def __iter__(self) -> Iterator[str]:
return iter(self._state)
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, other: object) -> bool:
return self is other
class Response(StreamResponse):
_compressed_body: Optional[bytes] = None
def __init__(
self,
*,
body: Any = None,
status: int = 200,
reason: Optional[str] = None,
text: Optional[str] = None,
headers: Optional[LooseHeaders] = None,
content_type: Optional[str] = None,
charset: Optional[str] = None,
zlib_executor_size: Optional[int] = None,
zlib_executor: Optional[Executor] = None,
) -> None:
if body is not None and text is not None:
raise ValueError("body and text are not allowed together")
if headers is None:
real_headers: CIMultiDict[str] = CIMultiDict()
else:
real_headers = CIMultiDict(headers)
if content_type is not None and "charset" in content_type:
raise ValueError("charset must not be in content_type argument")
if text is not None:
if hdrs.CONTENT_TYPE in real_headers:
if content_type or charset:
raise ValueError(
"passing both Content-Type header and "
"content_type or charset params "
"is forbidden"
)
else:
# fast path for filling headers
if not isinstance(text, str):
raise TypeError("text argument must be str (%r)" % type(text))
if content_type is None:
content_type = "text/plain"
if charset is None:
charset = "utf-8"
real_headers[hdrs.CONTENT_TYPE] = content_type + "; charset=" + charset
body = text.encode(charset)
text = None
elif hdrs.CONTENT_TYPE in real_headers:
if content_type is not None or charset is not None:
raise ValueError(
"passing both Content-Type header and "
"content_type or charset params "
"is forbidden"
)
elif content_type is not None:
if charset is not None:
content_type += "; charset=" + charset
real_headers[hdrs.CONTENT_TYPE] = content_type
super().__init__(status=status, reason=reason, _real_headers=real_headers)
if text is not None:
self.text = text
else:
self.body = body
self._zlib_executor_size = zlib_executor_size
self._zlib_executor = zlib_executor
@property
def body(self) -> Optional[Union[bytes, Payload]]:
return self._body
@body.setter
def body(self, body: Any) -> None:
if body is None:
self._body = None
elif isinstance(body, (bytes, bytearray)):
self._body = body
else:
try:
self._body = body = payload.PAYLOAD_REGISTRY.get(body)
except payload.LookupError:
raise ValueError("Unsupported body type %r" % type(body))
headers = self._headers
# set content-type
if hdrs.CONTENT_TYPE not in headers:
headers[hdrs.CONTENT_TYPE] = body.content_type
# copy payload headers
if body.headers:
for key, value in body.headers.items():
if key not in headers:
headers[key] = value
self._compressed_body = None
@property
def text(self) -> Optional[str]:
if self._body is None:
return None
return self._body.decode(self.charset or "utf-8")
@text.setter
def text(self, text: str) -> None:
assert text is None or isinstance(
text, str
), "text argument must be str (%r)" % type(text)
if self.content_type == "application/octet-stream":
self.content_type = "text/plain"
if self.charset is None:
self.charset = "utf-8"
self._body = text.encode(self.charset)
self._compressed_body = None
@property
def content_length(self) -> Optional[int]:
if self._chunked:
return None
if hdrs.CONTENT_LENGTH in self._headers:
return int(self._headers[hdrs.CONTENT_LENGTH])
if self._compressed_body is not None:
# Return length of the compressed body
return len(self._compressed_body)
elif isinstance(self._body, Payload):
# A payload without content length, or a compressed payload
return None
elif self._body is not None:
return len(self._body)
else:
return 0
@content_length.setter
def content_length(self, value: Optional[int]) -> None:
raise RuntimeError("Content length is set automatically")
async def write_eof(self, data: bytes = b"") -> None:
if self._eof_sent:
return
if self._compressed_body is None:
body: Optional[Union[bytes, Payload]] = self._body
else:
body = self._compressed_body
assert not data, f"data arg is not supported, got {data!r}"
assert self._req is not None
assert self._payload_writer is not None
if body is None or self._must_be_empty_body:
await super().write_eof()
elif isinstance(self._body, Payload):
await self._body.write(self._payload_writer)
await super().write_eof()
else:
await super().write_eof(cast(bytes, body))
async def _start(self, request: "BaseRequest") -> AbstractStreamWriter:
if hdrs.CONTENT_LENGTH in self._headers:
if should_remove_content_length(request.method, self.status):
del self._headers[hdrs.CONTENT_LENGTH]
elif not self._chunked:
if isinstance(self._body, Payload):
if self._body.size is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(self._body.size)
else:
body_len = len(self._body) if self._body else "0"
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-7
if body_len != "0" or (
self.status != 304 and request.method not in hdrs.METH_HEAD_ALL
):
self._headers[hdrs.CONTENT_LENGTH] = str(body_len)
return await super()._start(request)
async def _do_start_compression(self, coding: ContentCoding) -> None:
if self._chunked or isinstance(self._body, Payload):
return await super()._do_start_compression(coding)
if coding is ContentCoding.identity:
return
# Instead of using _payload_writer.enable_compression,
# compress the whole body
compressor = ZLibCompressor(
encoding=coding.value,
max_sync_chunk_size=self._zlib_executor_size,
executor=self._zlib_executor,
)
assert self._body is not None
if self._zlib_executor_size is None and len(self._body) > LARGE_BODY_SIZE:
warnings.warn(
"Synchronous compression of large response bodies "
f"({len(self._body)} bytes) might block the async event loop. "
"Consider providing a custom value to zlib_executor_size/"
"zlib_executor response properties or disabling compression on it."
)
self._compressed_body = (
await compressor.compress(self._body) + compressor.flush()
)
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._compressed_body))
def json_response(
data: Any = sentinel,
*,
text: Optional[str] = None,
body: Optional[bytes] = None,
status: int = 200,
reason: Optional[str] = None,
headers: Optional[LooseHeaders] = None,
content_type: str = "application/json",
dumps: JSONEncoder = json.dumps,
) -> Response:
if data is not sentinel:
if text or body:
raise ValueError("only one of data, text, or body should be specified")
else:
text = dumps(data)
return Response(
text=text,
body=body,
status=status,
reason=reason,
headers=headers,
content_type=content_type,
)

View File

@@ -0,0 +1,214 @@
import abc
import os # noqa
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
overload,
)
import attr
from . import hdrs
from .abc import AbstractView
from .typedefs import Handler, PathLike
if TYPE_CHECKING:
from .web_request import Request
from .web_response import StreamResponse
from .web_urldispatcher import AbstractRoute, UrlDispatcher
else:
Request = StreamResponse = UrlDispatcher = AbstractRoute = None
__all__ = (
"AbstractRouteDef",
"RouteDef",
"StaticDef",
"RouteTableDef",
"head",
"options",
"get",
"post",
"patch",
"put",
"delete",
"route",
"view",
"static",
)
class AbstractRouteDef(abc.ABC):
@abc.abstractmethod
def register(self, router: UrlDispatcher) -> List[AbstractRoute]:
pass # pragma: no cover
_HandlerType = Union[Type[AbstractView], Handler]
@attr.s(auto_attribs=True, frozen=True, repr=False, slots=True)
class RouteDef(AbstractRouteDef):
method: str
path: str
handler: _HandlerType
kwargs: Dict[str, Any]
def __repr__(self) -> str:
info = []
for name, value in sorted(self.kwargs.items()):
info.append(f", {name}={value!r}")
return "<RouteDef {method} {path} -> {handler.__name__!r}{info}>".format(
method=self.method, path=self.path, handler=self.handler, info="".join(info)
)
def register(self, router: UrlDispatcher) -> List[AbstractRoute]:
if self.method in hdrs.METH_ALL:
reg = getattr(router, "add_" + self.method.lower())
return [reg(self.path, self.handler, **self.kwargs)]
else:
return [
router.add_route(self.method, self.path, self.handler, **self.kwargs)
]
@attr.s(auto_attribs=True, frozen=True, repr=False, slots=True)
class StaticDef(AbstractRouteDef):
prefix: str
path: PathLike
kwargs: Dict[str, Any]
def __repr__(self) -> str:
info = []
for name, value in sorted(self.kwargs.items()):
info.append(f", {name}={value!r}")
return "<StaticDef {prefix} -> {path}{info}>".format(
prefix=self.prefix, path=self.path, info="".join(info)
)
def register(self, router: UrlDispatcher) -> List[AbstractRoute]:
resource = router.add_static(self.prefix, self.path, **self.kwargs)
routes = resource.get_info().get("routes", {})
return list(routes.values())
def route(method: str, path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return RouteDef(method, path, handler, kwargs)
def head(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_HEAD, path, handler, **kwargs)
def options(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_OPTIONS, path, handler, **kwargs)
def get(
path: str,
handler: _HandlerType,
*,
name: Optional[str] = None,
allow_head: bool = True,
**kwargs: Any,
) -> RouteDef:
return route(
hdrs.METH_GET, path, handler, name=name, allow_head=allow_head, **kwargs
)
def post(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_POST, path, handler, **kwargs)
def put(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_PUT, path, handler, **kwargs)
def patch(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_PATCH, path, handler, **kwargs)
def delete(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef:
return route(hdrs.METH_DELETE, path, handler, **kwargs)
def view(path: str, handler: Type[AbstractView], **kwargs: Any) -> RouteDef:
return route(hdrs.METH_ANY, path, handler, **kwargs)
def static(prefix: str, path: PathLike, **kwargs: Any) -> StaticDef:
return StaticDef(prefix, path, kwargs)
_Deco = Callable[[_HandlerType], _HandlerType]
class RouteTableDef(Sequence[AbstractRouteDef]):
"""Route definition table"""
def __init__(self) -> None:
self._items: List[AbstractRouteDef] = []
def __repr__(self) -> str:
return f"<RouteTableDef count={len(self._items)}>"
@overload
def __getitem__(self, index: int) -> AbstractRouteDef: ...
@overload
def __getitem__(self, index: slice) -> List[AbstractRouteDef]: ...
def __getitem__(self, index): # type: ignore[no-untyped-def]
return self._items[index]
def __iter__(self) -> Iterator[AbstractRouteDef]:
return iter(self._items)
def __len__(self) -> int:
return len(self._items)
def __contains__(self, item: object) -> bool:
return item in self._items
def route(self, method: str, path: str, **kwargs: Any) -> _Deco:
def inner(handler: _HandlerType) -> _HandlerType:
self._items.append(RouteDef(method, path, handler, kwargs))
return handler
return inner
def head(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_HEAD, path, **kwargs)
def get(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_GET, path, **kwargs)
def post(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_POST, path, **kwargs)
def put(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_PUT, path, **kwargs)
def patch(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_PATCH, path, **kwargs)
def delete(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_DELETE, path, **kwargs)
def options(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_OPTIONS, path, **kwargs)
def view(self, path: str, **kwargs: Any) -> _Deco:
return self.route(hdrs.METH_ANY, path, **kwargs)
def static(self, prefix: str, path: PathLike, **kwargs: Any) -> None:
self._items.append(StaticDef(prefix, path, kwargs))

View File

@@ -0,0 +1,399 @@
import asyncio
import signal
import socket
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Set
from yarl import URL
from .typedefs import PathLike
from .web_app import Application
from .web_server import Server
if TYPE_CHECKING:
from ssl import SSLContext
else:
try:
from ssl import SSLContext
except ImportError: # pragma: no cover
SSLContext = object # type: ignore[misc,assignment]
__all__ = (
"BaseSite",
"TCPSite",
"UnixSite",
"NamedPipeSite",
"SockSite",
"BaseRunner",
"AppRunner",
"ServerRunner",
"GracefulExit",
)
class GracefulExit(SystemExit):
code = 1
def _raise_graceful_exit() -> None:
raise GracefulExit()
class BaseSite(ABC):
__slots__ = ("_runner", "_ssl_context", "_backlog", "_server")
def __init__(
self,
runner: "BaseRunner",
*,
shutdown_timeout: float = 60.0,
ssl_context: Optional[SSLContext] = None,
backlog: int = 128,
) -> None:
if runner.server is None:
raise RuntimeError("Call runner.setup() before making a site")
if shutdown_timeout != 60.0:
msg = "shutdown_timeout should be set on BaseRunner"
warnings.warn(msg, DeprecationWarning, stacklevel=2)
runner._shutdown_timeout = shutdown_timeout
self._runner = runner
self._ssl_context = ssl_context
self._backlog = backlog
self._server: Optional[asyncio.AbstractServer] = None
@property
@abstractmethod
def name(self) -> str:
pass # pragma: no cover
@abstractmethod
async def start(self) -> None:
self._runner._reg_site(self)
async def stop(self) -> None:
self._runner._check_site(self)
if self._server is not None: # Maybe not started yet
self._server.close()
self._runner._unreg_site(self)
class TCPSite(BaseSite):
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
def __init__(
self,
runner: "BaseRunner",
host: Optional[str] = None,
port: Optional[int] = None,
*,
shutdown_timeout: float = 60.0,
ssl_context: Optional[SSLContext] = None,
backlog: int = 128,
reuse_address: Optional[bool] = None,
reuse_port: Optional[bool] = None,
) -> None:
super().__init__(
runner,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
self._host = host
if port is None:
port = 8443 if self._ssl_context else 8080
self._port = port
self._reuse_address = reuse_address
self._reuse_port = reuse_port
@property
def name(self) -> str:
scheme = "https" if self._ssl_context else "http"
host = "0.0.0.0" if not self._host else self._host
return str(URL.build(scheme=scheme, host=host, port=self._port))
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server(
server,
self._host,
self._port,
ssl=self._ssl_context,
backlog=self._backlog,
reuse_address=self._reuse_address,
reuse_port=self._reuse_port,
)
class UnixSite(BaseSite):
__slots__ = ("_path",)
def __init__(
self,
runner: "BaseRunner",
path: PathLike,
*,
shutdown_timeout: float = 60.0,
ssl_context: Optional[SSLContext] = None,
backlog: int = 128,
) -> None:
super().__init__(
runner,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
self._path = path
@property
def name(self) -> str:
scheme = "https" if self._ssl_context else "http"
return f"{scheme}://unix:{self._path}:"
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_unix_server(
server,
self._path,
ssl=self._ssl_context,
backlog=self._backlog,
)
class NamedPipeSite(BaseSite):
__slots__ = ("_path",)
def __init__(
self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0
) -> None:
loop = asyncio.get_event_loop()
if not isinstance(
loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
):
raise RuntimeError(
"Named Pipes only available in proactor loop under windows"
)
super().__init__(runner, shutdown_timeout=shutdown_timeout)
self._path = path
@property
def name(self) -> str:
return self._path
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
_server = await loop.start_serving_pipe( # type: ignore[attr-defined]
server, self._path
)
self._server = _server[0]
class SockSite(BaseSite):
__slots__ = ("_sock", "_name")
def __init__(
self,
runner: "BaseRunner",
sock: socket.socket,
*,
shutdown_timeout: float = 60.0,
ssl_context: Optional[SSLContext] = None,
backlog: int = 128,
) -> None:
super().__init__(
runner,
shutdown_timeout=shutdown_timeout,
ssl_context=ssl_context,
backlog=backlog,
)
self._sock = sock
scheme = "https" if self._ssl_context else "http"
if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
name = f"{scheme}://unix:{sock.getsockname()}:"
else:
host, port = sock.getsockname()[:2]
name = str(URL.build(scheme=scheme, host=host, port=port))
self._name = name
@property
def name(self) -> str:
return self._name
async def start(self) -> None:
await super().start()
loop = asyncio.get_event_loop()
server = self._runner.server
assert server is not None
self._server = await loop.create_server(
server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog
)
class BaseRunner(ABC):
__slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")
def __init__(
self,
*,
handle_signals: bool = False,
shutdown_timeout: float = 60.0,
**kwargs: Any,
) -> None:
self._handle_signals = handle_signals
self._kwargs = kwargs
self._server: Optional[Server] = None
self._sites: List[BaseSite] = []
self._shutdown_timeout = shutdown_timeout
@property
def server(self) -> Optional[Server]:
return self._server
@property
def addresses(self) -> List[Any]:
ret: List[Any] = []
for site in self._sites:
server = site._server
if server is not None:
sockets = server.sockets # type: ignore[attr-defined]
if sockets is not None:
for sock in sockets:
ret.append(sock.getsockname())
return ret
@property
def sites(self) -> Set[BaseSite]:
return set(self._sites)
async def setup(self) -> None:
loop = asyncio.get_event_loop()
if self._handle_signals:
try:
loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
except NotImplementedError: # pragma: no cover
# add_signal_handler is not implemented on Windows
pass
self._server = await self._make_server()
@abstractmethod
async def shutdown(self) -> None:
"""Call any shutdown hooks to help server close gracefully."""
async def cleanup(self) -> None:
# The loop over sites is intentional, an exception on gather()
# leaves self._sites in unpredictable state.
# The loop guaranties that a site is either deleted on success or
# still present on failure
for site in list(self._sites):
await site.stop()
if self._server: # If setup succeeded
# Yield to event loop to ensure incoming requests prior to stopping the sites
# have all started to be handled before we proceed to close idle connections.
await asyncio.sleep(0)
self._server.pre_shutdown()
await self.shutdown()
await self._server.shutdown(self._shutdown_timeout)
await self._cleanup_server()
self._server = None
if self._handle_signals:
loop = asyncio.get_running_loop()
try:
loop.remove_signal_handler(signal.SIGINT)
loop.remove_signal_handler(signal.SIGTERM)
except NotImplementedError: # pragma: no cover
# remove_signal_handler is not implemented on Windows
pass
@abstractmethod
async def _make_server(self) -> Server:
pass # pragma: no cover
@abstractmethod
async def _cleanup_server(self) -> None:
pass # pragma: no cover
def _reg_site(self, site: BaseSite) -> None:
if site in self._sites:
raise RuntimeError(f"Site {site} is already registered in runner {self}")
self._sites.append(site)
def _check_site(self, site: BaseSite) -> None:
if site not in self._sites:
raise RuntimeError(f"Site {site} is not registered in runner {self}")
def _unreg_site(self, site: BaseSite) -> None:
if site not in self._sites:
raise RuntimeError(f"Site {site} is not registered in runner {self}")
self._sites.remove(site)
class ServerRunner(BaseRunner):
"""Low-level web server runner"""
__slots__ = ("_web_server",)
def __init__(
self, web_server: Server, *, handle_signals: bool = False, **kwargs: Any
) -> None:
super().__init__(handle_signals=handle_signals, **kwargs)
self._web_server = web_server
async def shutdown(self) -> None:
pass
async def _make_server(self) -> Server:
return self._web_server
async def _cleanup_server(self) -> None:
pass
class AppRunner(BaseRunner):
"""Web Application runner"""
__slots__ = ("_app",)
def __init__(
self, app: Application, *, handle_signals: bool = False, **kwargs: Any
) -> None:
super().__init__(handle_signals=handle_signals, **kwargs)
if not isinstance(app, Application):
raise TypeError(
"The first argument should be web.Application "
"instance, got {!r}".format(app)
)
self._app = app
@property
def app(self) -> Application:
return self._app
async def shutdown(self) -> None:
await self._app.shutdown()
async def _make_server(self) -> Server:
loop = asyncio.get_event_loop()
self._app._set_loop(loop)
self._app.on_startup.freeze()
await self._app.startup()
self._app.freeze()
return self._app._make_handler(loop=loop, **self._kwargs)
async def _cleanup_server(self) -> None:
await self._app.cleanup()

View File

@@ -0,0 +1,84 @@
"""Low level HTTP server."""
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa
from .abc import AbstractStreamWriter
from .http_parser import RawRequestMessage
from .streams import StreamReader
from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler
from .web_request import BaseRequest
__all__ = ("Server",)
class Server:
def __init__(
self,
handler: _RequestHandler,
*,
request_factory: Optional[_RequestFactory] = None,
handler_cancellation: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs: Any,
) -> None:
self._loop = loop or asyncio.get_running_loop()
self._connections: Dict[RequestHandler, asyncio.Transport] = {}
self._kwargs = kwargs
# requests_count is the number of requests being processed by the server
# for the lifetime of the server.
self.requests_count = 0
self.request_handler = handler
self.request_factory = request_factory or self._make_request
self.handler_cancellation = handler_cancellation
@property
def connections(self) -> List[RequestHandler]:
return list(self._connections.keys())
def connection_made(
self, handler: RequestHandler, transport: asyncio.Transport
) -> None:
self._connections[handler] = transport
def connection_lost(
self, handler: RequestHandler, exc: Optional[BaseException] = None
) -> None:
if handler in self._connections:
if handler._task_handler:
handler._task_handler.add_done_callback(
lambda f: self._connections.pop(handler, None)
)
else:
del self._connections[handler]
def _make_request(
self,
message: RawRequestMessage,
payload: StreamReader,
protocol: RequestHandler,
writer: AbstractStreamWriter,
task: "asyncio.Task[None]",
) -> BaseRequest:
return BaseRequest(message, payload, protocol, writer, task, self._loop)
def pre_shutdown(self) -> None:
for conn in self._connections:
conn.close()
async def shutdown(self, timeout: Optional[float] = None) -> None:
coros = (conn.shutdown(timeout) for conn in self._connections)
await asyncio.gather(*coros)
self._connections.clear()
def __call__(self) -> RequestHandler:
try:
return RequestHandler(self, loop=self._loop, **self._kwargs)
except TypeError:
# Failsafe creation: remove all custom handler_args
kwargs = {
k: v
for k, v in self._kwargs.items()
if k in ["debug", "access_log_class"]
}
return RequestHandler(self, loop=self._loop, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,627 @@
import asyncio
import base64
import binascii
import hashlib
import json
import sys
from typing import Any, Final, Iterable, Optional, Tuple, Union, cast
import attr
from multidict import CIMultiDict
from . import hdrs
from ._websocket.reader import WebSocketDataQueue
from ._websocket.writer import DEFAULT_LIMIT
from .abc import AbstractStreamWriter
from .client_exceptions import WSMessageTypeError
from .helpers import calculate_timeout_when, set_exception, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WS_KEY,
WebSocketError,
WebSocketReader,
WebSocketWriter,
WSCloseCode,
WSMessage,
WSMsgType as WSMsgType,
ws_ext_gen,
ws_ext_parse,
)
from .http_websocket import _INTERNAL_RECEIVE_TYPES
from .log import ws_logger
from .streams import EofStream
from .typedefs import JSONDecoder, JSONEncoder
from .web_exceptions import HTTPBadRequest, HTTPException
from .web_request import BaseRequest
from .web_response import StreamResponse
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
__all__ = (
"WebSocketResponse",
"WebSocketReady",
"WSMsgType",
)
THRESHOLD_CONNLOST_ACCESS: Final[int] = 5
@attr.s(auto_attribs=True, frozen=True, slots=True)
class WebSocketReady:
ok: bool
protocol: Optional[str]
def __bool__(self) -> bool:
return self.ok
class WebSocketResponse(StreamResponse):
_length_check: bool = False
_ws_protocol: Optional[str] = None
_writer: Optional[WebSocketWriter] = None
_reader: Optional[WebSocketDataQueue] = None
_closed: bool = False
_closing: bool = False
_conn_lost: int = 0
_close_code: Optional[int] = None
_loop: Optional[asyncio.AbstractEventLoop] = None
_waiting: bool = False
_close_wait: Optional[asyncio.Future[None]] = None
_exception: Optional[BaseException] = None
_heartbeat_when: float = 0.0
_heartbeat_cb: Optional[asyncio.TimerHandle] = None
_pong_response_cb: Optional[asyncio.TimerHandle] = None
_ping_task: Optional[asyncio.Task[None]] = None
def __init__(
self,
*,
timeout: float = 10.0,
receive_timeout: Optional[float] = None,
autoclose: bool = True,
autoping: bool = True,
heartbeat: Optional[float] = None,
protocols: Iterable[str] = (),
compress: bool = True,
max_msg_size: int = 4 * 1024 * 1024,
writer_limit: int = DEFAULT_LIMIT,
) -> None:
super().__init__(status=101)
self._protocols = protocols
self._timeout = timeout
self._receive_timeout = receive_timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._compress: Union[bool, int] = compress
self._max_msg_size = max_msg_size
self._writer_limit = writer_limit
def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
)
loop = self._loop
assert loop is not None
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
def _send_heartbeat(self) -> None:
self._heartbeat_cb = None
loop = self._loop
assert loop is not None and self._writer is not None
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return
req = self._req
timeout_ceil_threshold = (
req._protocol._timeout_ceil_threshold if req is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
coro = self._writer.send_frame(b"", WSMsgType.PING)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
ping_task = loop.create_task(coro)
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None
def _pong_not_received(self) -> None:
if self._req is not None and self._req.transport is not None:
self._handle_ping_pong_exception(
asyncio.TimeoutError(
f"No PONG received after {self._pong_heartbeat} seconds"
)
)
def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
self._exception = exc
if self._waiting and not self._closing and self._reader is not None:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()
async def prepare(self, request: BaseRequest) -> AbstractStreamWriter:
# make pre-check to don't hide it by do_handshake() exceptions
if self._payload_writer is not None:
return self._payload_writer
protocol, writer = self._pre_start(request)
payload_writer = await super().prepare(request)
assert payload_writer is not None
self._post_start(request, protocol, writer)
await payload_writer.drain()
return payload_writer
def _handshake(
self, request: BaseRequest
) -> Tuple["CIMultiDict[str]", Optional[str], int, bool]:
headers = request.headers
if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip():
raise HTTPBadRequest(
text=(
"No WebSocket UPGRADE hdr: {}\n Can "
'"Upgrade" only to "WebSocket".'
).format(headers.get(hdrs.UPGRADE))
)
if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower():
raise HTTPBadRequest(
text="No CONNECTION upgrade hdr: {}".format(
headers.get(hdrs.CONNECTION)
)
)
# find common sub-protocol between client and server
protocol: Optional[str] = None
if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
req_protocols = [
str(proto.strip())
for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",")
]
for proto in req_protocols:
if proto in self._protocols:
protocol = proto
break
else:
# No overlap found: Return no protocol as per spec
ws_logger.warning(
"%s: Client protocols %r dont overlap server-known ones %r",
request.remote,
req_protocols,
self._protocols,
)
# check supported version
version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "")
if version not in ("13", "8", "7"):
raise HTTPBadRequest(text=f"Unsupported version: {version}")
# check client handshake for validity
key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
try:
if not key or len(base64.b64decode(key)) != 16:
raise HTTPBadRequest(text=f"Handshake error: {key!r}")
except binascii.Error:
raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None
accept_val = base64.b64encode(
hashlib.sha1(key.encode() + WS_KEY).digest()
).decode()
response_headers = CIMultiDict(
{
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: accept_val,
}
)
notakeover = False
compress = 0
if self._compress:
extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
# Server side always get return with no exception.
# If something happened, just drop compress extension
compress, notakeover = ws_ext_parse(extensions, isserver=True)
if compress:
enabledext = ws_ext_gen(
compress=compress, isserver=True, server_notakeover=notakeover
)
response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
if protocol:
response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
return (
response_headers,
protocol,
compress,
notakeover,
)
def _pre_start(self, request: BaseRequest) -> Tuple[Optional[str], WebSocketWriter]:
self._loop = request._loop
headers, protocol, compress, notakeover = self._handshake(request)
self.set_status(101)
self.headers.update(headers)
self.force_close()
self._compress = compress
transport = request._protocol.transport
assert transport is not None
writer = WebSocketWriter(
request._protocol,
transport,
compress=compress,
notakeover=notakeover,
limit=self._writer_limit,
)
return protocol, writer
def _post_start(
self, request: BaseRequest, protocol: Optional[str], writer: WebSocketWriter
) -> None:
self._ws_protocol = protocol
self._writer = writer
self._reset_heartbeat()
loop = self._loop
assert loop is not None
self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
request.protocol.set_parser(
WebSocketReader(
self._reader, self._max_msg_size, compress=bool(self._compress)
)
)
# disable HTTP keepalive for WebSocket
request.protocol.keep_alive(False)
def can_prepare(self, request: BaseRequest) -> WebSocketReady:
if self._writer is not None:
raise RuntimeError("Already started")
try:
_, protocol, _, _ = self._handshake(request)
except HTTPException:
return WebSocketReady(False, None)
else:
return WebSocketReady(True, protocol)
@property
def closed(self) -> bool:
return self._closed
@property
def close_code(self) -> Optional[int]:
return self._close_code
@property
def ws_protocol(self) -> Optional[str]:
return self._ws_protocol
@property
def compress(self) -> Union[int, bool]:
return self._compress
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""Get optional transport information.
If no value associated with ``name`` is found, ``default`` is returned.
"""
writer = self._writer
if writer is None:
return default
transport = writer.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def exception(self) -> Optional[BaseException]:
return self._exception
async def ping(self, message: bytes = b"") -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
await self._writer.send_frame(message, WSMsgType.PING)
async def pong(self, message: bytes = b"") -> None:
# unsolicited pong
if self._writer is None:
raise RuntimeError("Call .prepare() first")
await self._writer.send_frame(message, WSMsgType.PONG)
async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket."""
if self._writer is None:
raise RuntimeError("Call .prepare() first")
await self._writer.send_frame(message, opcode, compress)
async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)
async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
async def send_json(
self,
data: Any,
compress: Optional[int] = None,
*,
dumps: JSONEncoder = json.dumps,
) -> None:
await self.send_str(dumps(data), compress=compress)
async def write_eof(self) -> None: # type: ignore[override]
if self._eof_sent:
return
if self._payload_writer is None:
raise RuntimeError("Response has not been started")
await self.close()
self._eof_sent = True
async def close(
self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True
) -> bool:
"""Close websocket connection."""
if self._writer is None:
raise RuntimeError("Call .prepare() first")
if self._closed:
return False
self._set_closed()
try:
await self._writer.close(code, message)
writer = self._payload_writer
assert writer is not None
if drain:
await writer.drain()
except (asyncio.CancelledError, asyncio.TimeoutError):
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True
reader = self._reader
assert reader is not None
# we need to break `receive()` cycle before we can call
# `reader.read()` as `close()` may be called from different task
if self._waiting:
assert self._loop is not None
assert self._close_wait is None
self._close_wait = self._loop.create_future()
reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._close_wait
if self._closing:
self._close_transport()
return True
try:
async with async_timeout.timeout(self._timeout):
while True:
msg = await reader.read()
if msg.type is WSMsgType.CLOSE:
self._set_code_close_transport(msg.data)
return True
except asyncio.CancelledError:
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
raise
except Exception as exc:
self._exception = exc
self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE)
return True
def _set_closing(self, code: WSCloseCode) -> None:
"""Set the close code and mark the connection as closing."""
self._closing = True
self._close_code = code
self._cancel_heartbeat()
def _set_code_close_transport(self, code: WSCloseCode) -> None:
"""Set the close code and close the transport."""
self._close_code = code
self._close_transport()
def _close_transport(self) -> None:
"""Close the transport."""
if self._req is not None and self._req.transport is not None:
self._req.transport.close()
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
if self._reader is None:
raise RuntimeError("Call .prepare() first")
receive_timeout = timeout or self._receive_timeout
while True:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
self._conn_lost += 1
if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
raise RuntimeError("WebSocket connection is closed.")
return WS_CLOSED_MESSAGE
elif self._closing:
return WS_CLOSING_MESSAGE
try:
self._waiting = True
try:
if receive_timeout:
# Entering the context manager and creating
# Timeout() object can take almost 50% of the
# run time in this loop so we avoid it if
# there is no read timeout.
async with async_timeout.timeout(receive_timeout):
msg = await self._reader.read()
else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except asyncio.TimeoutError:
raise
except EofStream:
self._close_code = WSCloseCode.OK
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._set_closing(WSCloseCode.ABNORMAL_CLOSURE)
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type not in _INTERNAL_RECEIVE_TYPES:
# If its not a close/closing/ping/pong message
# we can return it immediately
return msg
if msg.type is WSMsgType.CLOSE:
self._set_closing(msg.data)
# Could be closed while awaiting reader.
if not self._closed and self._autoclose:
# The client is likely going to close the
# connection out from under us so we do not
# want to drain any pending writes as it will
# likely result writing to a broken pipe.
await self.close(drain=False)
elif msg.type is WSMsgType.CLOSING:
self._set_closing(WSCloseCode.OK)
elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return cast(str, msg.data)
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
)
return cast(bytes, msg.data)
async def receive_json(
self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None
) -> Any:
data = await self.receive_str(timeout=timeout)
return loads(data)
async def write(self, data: bytes) -> None:
raise RuntimeError("Cannot call .write() for websocket")
def __aiter__(self) -> "WebSocketResponse":
return self
async def __anext__(self) -> WSMessage:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
return msg
def _cancel(self, exc: BaseException) -> None:
# web_protocol calls this from connection_lost
# or when the server is shutting down.
self._closing = True
self._cancel_heartbeat()
if self._reader is not None:
set_exception(self._reader, exc)

View File

@@ -0,0 +1,255 @@
"""Async gunicorn worker for aiohttp.web"""
import asyncio
import inspect
import os
import re
import signal
import sys
from types import FrameType
from typing import TYPE_CHECKING, Any, Optional
from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
from gunicorn.workers import base
from aiohttp import web
from .helpers import set_result
from .web_app import Application
from .web_log import AccessLogger
if TYPE_CHECKING:
import ssl
SSLContext = ssl.SSLContext
else:
try:
import ssl
SSLContext = ssl.SSLContext
except ImportError: # pragma: no cover
ssl = None # type: ignore[assignment]
SSLContext = object # type: ignore[misc,assignment]
__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker")
class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported]
DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover
super().__init__(*args, **kw)
self._task: Optional[asyncio.Task[None]] = None
self.exit_code = 0
self._notify_waiter: Optional[asyncio.Future[bool]] = None
def init_process(self) -> None:
# create new event_loop after fork
asyncio.get_event_loop().close()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().init_process()
def run(self) -> None:
self._task = self.loop.create_task(self._run())
try: # ignore all finalization problems
self.loop.run_until_complete(self._task)
except Exception:
self.log.exception("Exception in gunicorn worker")
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
sys.exit(self.exit_code)
async def _run(self) -> None:
runner = None
if isinstance(self.wsgi, Application):
app = self.wsgi
elif inspect.iscoroutinefunction(self.wsgi) or (
sys.version_info < (3, 14) and asyncio.iscoroutinefunction(self.wsgi)
):
wsgi = await self.wsgi()
if isinstance(wsgi, web.AppRunner):
runner = wsgi
app = runner.app
else:
app = wsgi
else:
raise RuntimeError(
"wsgi app should be either Application or "
"async function returning Application, got {}".format(self.wsgi)
)
if runner is None:
access_log = self.log.access_log if self.cfg.accesslog else None
runner = web.AppRunner(
app,
logger=self.log,
keepalive_timeout=self.cfg.keepalive,
access_log=access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format
),
shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
)
await runner.setup()
ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
runner = runner
assert runner is not None
server = runner.server
assert server is not None
for sock in self.sockets:
site = web.SockSite(
runner,
sock,
ssl_context=ctx,
)
await site.start()
# If our parent changed then we shut down.
pid = os.getpid()
try:
while self.alive: # type: ignore[has-type]
self.notify()
cnt = server.requests_count
if self.max_requests and cnt > self.max_requests:
self.alive = False
self.log.info("Max requests, shutting down: %s", self)
elif pid == os.getpid() and self.ppid != os.getppid():
self.alive = False
self.log.info("Parent changed, shutting down: %s", self)
else:
await self._wait_next_notify()
except BaseException:
pass
await runner.cleanup()
def _wait_next_notify(self) -> "asyncio.Future[bool]":
self._notify_waiter_done()
loop = self.loop
assert loop is not None
self._notify_waiter = waiter = loop.create_future()
self.loop.call_later(1.0, self._notify_waiter_done, waiter)
return waiter
def _notify_waiter_done(
self, waiter: Optional["asyncio.Future[bool]"] = None
) -> None:
if waiter is None:
waiter = self._notify_waiter
if waiter is not None:
set_result(waiter, True)
if waiter is self._notify_waiter:
self._notify_waiter = None
def init_signals(self) -> None:
# Set up signals through the event loop API.
self.loop.add_signal_handler(
signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None
)
self.loop.add_signal_handler(
signal.SIGTERM, self.handle_exit, signal.SIGTERM, None
)
self.loop.add_signal_handler(
signal.SIGINT, self.handle_quit, signal.SIGINT, None
)
self.loop.add_signal_handler(
signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None
)
self.loop.add_signal_handler(
signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None
)
self.loop.add_signal_handler(
signal.SIGABRT, self.handle_abort, signal.SIGABRT, None
)
# Don't let SIGTERM and SIGUSR1 disturb active requests
# by interrupting system calls
signal.siginterrupt(signal.SIGTERM, False)
signal.siginterrupt(signal.SIGUSR1, False)
# Reset signals so Gunicorn doesn't swallow subprocess return codes
# See: https://github.com/aio-libs/aiohttp/issues/6130
def handle_quit(self, sig: int, frame: Optional[FrameType]) -> None:
self.alive = False
# worker_int callback
self.cfg.worker_int(self)
# wakeup closing process
self._notify_waiter_done()
def handle_abort(self, sig: int, frame: Optional[FrameType]) -> None:
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)
@staticmethod
def _create_ssl_context(cfg: Any) -> "SSLContext":
"""Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__ for more details.
"""
if ssl is None: # pragma: no cover
raise RuntimeError("SSL is not supported.")
ctx = ssl.SSLContext(cfg.ssl_version)
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
ctx.verify_mode = cfg.cert_reqs
if cfg.ca_certs:
ctx.load_verify_locations(cfg.ca_certs)
if cfg.ciphers:
ctx.set_ciphers(cfg.ciphers)
return ctx
def _get_valid_log_format(self, source_format: str) -> str:
if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
return self.DEFAULT_AIOHTTP_LOG_FORMAT
elif re.search(r"%\([^\)]+\)", source_format):
raise ValueError(
"Gunicorn's style options in form of `%(name)s` are not "
"supported for the log formatting. Please use aiohttp's "
"format specification to configure access log formatting: "
"http://docs.aiohttp.org/en/stable/logging.html"
"#format-specification"
)
else:
return source_format
class GunicornUVLoopWebWorker(GunicornWebWorker):
def init_process(self) -> None:
import uvloop
# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
# Setup uvloop policy, so that every
# asyncio.get_event_loop() will create an instance
# of uvloop event loop.
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
super().init_process()