You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
bazarr/libs/websocket/_handshake.py

203 lines
6.4 KiB

6 years ago
"""
_handshake.py
6 years ago
websocket - WebSocket client library for Python
Copyright 2023 engn33r
6 years ago
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6 years ago
http://www.apache.org/licenses/LICENSE-2.0
6 years ago
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
6 years ago
"""
import hashlib
import hmac
import os
from base64 import encodebytes as base64encode
from http import HTTPStatus
6 years ago
from ._cookiejar import SimpleCookieJar
from ._exceptions import *
from ._http import *
from ._logging import *
from ._socket import *
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
6 years ago
# websocket supported version.
VERSION = 13
SUPPORTED_REDIRECT_STATUSES = (
HTTPStatus.MOVED_PERMANENTLY,
HTTPStatus.FOUND,
HTTPStatus.SEE_OTHER,
HTTPStatus.TEMPORARY_REDIRECT,
HTTPStatus.PERMANENT_REDIRECT,
)
SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
6 years ago
CookieJar = SimpleCookieJar()
class handshake_response:
def __init__(self, status: int, headers: dict, subprotocol):
6 years ago
self.status = status
self.headers = headers
self.subprotocol = subprotocol
CookieJar.add(headers.get("set-cookie"))
def handshake(
sock, url: str, hostname: str, port: int, resource: str, **options
) -> handshake_response:
headers, key = _get_handshake_headers(resource, url, hostname, port, options)
6 years ago
header_str = "\r\n".join(headers)
send(sock, header_str)
dump("request header", header_str)
status, resp = _get_resp_headers(sock)
if status in SUPPORTED_REDIRECT_STATUSES:
return handshake_response(status, resp, None)
6 years ago
success, subproto = _validate(resp, key, options.get("subprotocols"))
if not success:
raise WebSocketException("Invalid WebSocket Header")
return handshake_response(status, resp, subproto)
def _pack_hostname(hostname: str) -> str:
# IPv6 address
if ":" in hostname:
return f"[{hostname}]"
return hostname
6 years ago
def _get_handshake_headers(
resource: str, url: str, host: str, port: int, options: dict
) -> tuple:
headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"]
if port in [80, 443]:
hostport = _pack_hostname(host)
6 years ago
else:
hostport = f"{_pack_hostname(host)}:{port}"
if options.get("host"):
headers.append(f'Host: {options["host"]}')
6 years ago
else:
headers.append(f"Host: {hostport}")
6 years ago
# scheme indicates whether http or https is used in Origin
# The same approach is used in parse_url of _url.py to set default port
scheme, url = url.split(":", 1)
if not options.get("suppress_origin"):
if "origin" in options and options["origin"] is not None:
headers.append(f'Origin: {options["origin"]}')
elif scheme == "wss":
headers.append(f"Origin: https://{hostport}")
else:
headers.append(f"Origin: http://{hostport}")
6 years ago
key = _create_sec_websocket_key()
# Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]:
headers.append(f"Sec-WebSocket-Key: {key}")
else:
key = options["header"]["Sec-WebSocket-Key"]
if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]:
headers.append(f"Sec-WebSocket-Version: {VERSION}")
if not options.get("connection"):
headers.append("Connection: Upgrade")
else:
headers.append(options["connection"])
6 years ago
if subprotocols := options.get("subprotocols"):
headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}')
6 years ago
if header := options.get("header"):
6 years ago
if isinstance(header, dict):
header = [": ".join([k, v]) for k, v in header.items() if v is not None]
6 years ago
headers.extend(header)
server_cookie = CookieJar.get(host)
client_cookie = options.get("cookie", None)
if cookie := "; ".join(filter(None, [server_cookie, client_cookie])):
headers.append(f"Cookie: {cookie}")
6 years ago
headers.extend(("", ""))
6 years ago
return headers, key
def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple:
status, resp_headers, status_message = read_headers(sock)
if status not in success_statuses:
content_len = resp_headers.get("content-length")
if content_len:
response_body = sock.recv(
int(content_len)
) # read the body of the HTTP error message response and include it in the exception
else:
response_body = None
raise WebSocketBadStatusException(
f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}",
status,
status_message,
resp_headers,
response_body,
)
6 years ago
return status, resp_headers
6 years ago
_HEADERS_TO_CHECK = {
"upgrade": "websocket",
"connection": "upgrade",
}
def _validate(headers, key: str, subprotocols) -> tuple:
6 years ago
subproto = None
for k, v in _HEADERS_TO_CHECK.items():
r = headers.get(k, None)
if not r:
return False, None
r = [x.strip().lower() for x in r.split(",")]
if v not in r:
6 years ago
return False, None
if subprotocols:
subproto = headers.get("sec-websocket-protocol", None)
if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
error(f"Invalid subprotocol: {subprotocols}")
6 years ago
return False, None
subproto = subproto.lower()
6 years ago
result = headers.get("sec-websocket-accept", None)
if not result:
return False, None
result = result.lower()
if isinstance(result, str):
result = result.encode("utf-8")
6 years ago
value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8")
6 years ago
hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
if hmac.compare_digest(hashed, result):
6 years ago
return True, subproto
else:
return False, None
def _create_sec_websocket_key() -> str:
6 years ago
randomness = os.urandom(16)
return base64encode(randomness).decode("utf-8").strip()