You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
bazarr/libs/werkzeug/middleware/http_proxy.py

236 lines
7.6 KiB

5 years ago
"""
Basic HTTP Proxy
================
.. autoclass:: ProxyMiddleware
:copyright: 2007 Pallets
:license: BSD-3-Clause
"""
from __future__ import annotations
import typing as t
from http import client
from urllib.parse import quote
from urllib.parse import urlsplit
5 years ago
from ..datastructures import EnvironHeaders
from ..http import is_hop_by_hop_header
from ..wsgi import get_input_stream
if t.TYPE_CHECKING:
from _typeshed.wsgi import StartResponse
from _typeshed.wsgi import WSGIApplication
from _typeshed.wsgi import WSGIEnvironment
5 years ago
class ProxyMiddleware:
5 years ago
"""Proxy requests under a path to an external server, routing other
requests to the app.
This middleware can only proxy HTTP requests, as HTTP is the only
5 years ago
protocol handled by the WSGI server. Other protocols, such as
WebSocket requests, cannot be proxied at this layer. This should
only be used for development, in production a real proxy server
5 years ago
should be used.
The middleware takes a dict mapping a path prefix to a dict
5 years ago
describing the host to be proxied to::
app = ProxyMiddleware(app, {
"/static/": {
"target": "http://127.0.0.1:5001/",
}
})
Each host has the following options:
``target``:
The target URL to dispatch to. This is required.
``remove_prefix``:
Whether to remove the prefix from the URL before dispatching it
to the target. The default is ``False``.
``host``:
``"<auto>"`` (default):
The host header is automatically rewritten to the URL of the
target.
``None``:
The host header is unmodified from the client request.
Any other value:
The host header is overwritten with the value.
``headers``:
A dictionary of headers to be sent with the request to the
target. The default is ``{}``.
``ssl_context``:
A :class:`ssl.SSLContext` defining how to verify requests if the
target is HTTPS. The default is ``None``.
In the example above, everything under ``"/static/"`` is proxied to
the server on port 5001. The host header is rewritten to the target,
and the ``"/static/"`` prefix is removed from the URLs.
:param app: The WSGI application to wrap.
:param targets: Proxy target configurations. See description above.
:param chunk_size: Size of chunks to read from input stream and
write to target.
:param timeout: Seconds before an operation to a target fails.
.. versionadded:: 0.14
"""
def __init__(
self,
app: WSGIApplication,
targets: t.Mapping[str, dict[str, t.Any]],
chunk_size: int = 2 << 13,
timeout: int = 10,
) -> None:
def _set_defaults(opts: dict[str, t.Any]) -> dict[str, t.Any]:
5 years ago
opts.setdefault("remove_prefix", False)
opts.setdefault("host", "<auto>")
opts.setdefault("headers", {})
opts.setdefault("ssl_context", None)
return opts
self.app = app
self.targets = {
f"/{k.strip('/')}/": _set_defaults(v) for k, v in targets.items()
}
5 years ago
self.chunk_size = chunk_size
self.timeout = timeout
def proxy_to(
self, opts: dict[str, t.Any], path: str, prefix: str
) -> WSGIApplication:
target = urlsplit(opts["target"])
# socket can handle unicode host, but header must be ascii
host = target.hostname.encode("idna").decode("ascii")
5 years ago
def application(
environ: WSGIEnvironment, start_response: StartResponse
) -> t.Iterable[bytes]:
5 years ago
headers = list(EnvironHeaders(environ).items())
headers[:] = [
(k, v)
for k, v in headers
if not is_hop_by_hop_header(k)
and k.lower() not in ("content-length", "host")
]
headers.append(("Connection", "close"))
if opts["host"] == "<auto>":
headers.append(("Host", host))
5 years ago
elif opts["host"] is None:
headers.append(("Host", environ["HTTP_HOST"]))
else:
headers.append(("Host", opts["host"]))
headers.extend(opts["headers"].items())
remote_path = path
if opts["remove_prefix"]:
remote_path = remote_path[len(prefix) :].lstrip("/")
remote_path = f"{target.path.rstrip('/')}/{remote_path}"
5 years ago
content_length = environ.get("CONTENT_LENGTH")
chunked = False
if content_length not in ("", None):
headers.append(("Content-Length", content_length)) # type: ignore
5 years ago
elif content_length is not None:
headers.append(("Transfer-Encoding", "chunked"))
chunked = True
try:
if target.scheme == "http":
con = client.HTTPConnection(
host, target.port or 80, timeout=self.timeout
5 years ago
)
elif target.scheme == "https":
con = client.HTTPSConnection(
host,
5 years ago
target.port or 443,
timeout=self.timeout,
context=opts["ssl_context"],
)
else:
raise RuntimeError(
"Target scheme must be 'http' or 'https', got"
f" {target.scheme!r}."
5 years ago
)
con.connect()
# safe = https://url.spec.whatwg.org/#url-path-segment-string
# as well as percent for things that are already quoted
remote_url = quote(remote_path, safe="!$&'()*+,/:;=@%")
5 years ago
querystring = environ["QUERY_STRING"]
if querystring:
remote_url = f"{remote_url}?{querystring}"
5 years ago
con.putrequest(environ["REQUEST_METHOD"], remote_url, skip_host=True)
for k, v in headers:
if k.lower() == "connection":
v = "close"
con.putheader(k, v)
con.endheaders()
stream = get_input_stream(environ)
while True:
5 years ago
data = stream.read(self.chunk_size)
if not data:
break
if chunked:
con.send(b"%x\r\n%s\r\n" % (len(data), data))
else:
con.send(data)
resp = con.getresponse()
except OSError:
5 years ago
from ..exceptions import BadGateway
return BadGateway()(environ, start_response)
start_response(
f"{resp.status} {resp.reason}",
5 years ago
[
(k.title(), v)
for k, v in resp.getheaders()
if not is_hop_by_hop_header(k)
],
)
def read() -> t.Iterator[bytes]:
while True:
5 years ago
try:
data = resp.read(self.chunk_size)
except OSError:
5 years ago
break
if not data:
break
yield data
return read()
return application
def __call__(
self, environ: WSGIEnvironment, start_response: StartResponse
) -> t.Iterable[bytes]:
5 years ago
path = environ["PATH_INFO"]
app = self.app
for prefix, opts in self.targets.items():
if path.startswith(prefix):
app = self.proxy_to(opts, path, prefix)
break
return app(environ, start_response)