############################################################################## # # Copyright (c) 2001, 2002 Zope Foundation and Contributors. # All Rights Reserved. # # This software is subject to the provisions of the Zope Public License, # Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. # THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED # WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED # WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS # FOR A PARTICULAR PURPOSE. # ############################################################################## from collections import deque import socket import sys import threading import time from .buffers import ReadOnlyFileBasedBuffer from .utilities import build_http_date, logger, queue_logger rename_headers = { # or keep them without the HTTP_ prefix added "CONTENT_LENGTH": "CONTENT_LENGTH", "CONTENT_TYPE": "CONTENT_TYPE", } hop_by_hop = frozenset( ( "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", ) ) class ThreadedTaskDispatcher: """A Task Dispatcher that creates a thread for each task.""" stop_count = 0 # Number of threads that will stop soon. active_count = 0 # Number of currently active threads logger = logger queue_logger = queue_logger def __init__(self): self.threads = set() self.queue = deque() self.lock = threading.Lock() self.queue_cv = threading.Condition(self.lock) self.thread_exit_cv = threading.Condition(self.lock) def start_new_thread(self, target, thread_no): t = threading.Thread( target=target, name="waitress-{}".format(thread_no), args=(thread_no,) ) t.daemon = True t.start() def handler_thread(self, thread_no): while True: with self.lock: while not self.queue and self.stop_count == 0: # Mark ourselves as idle before waiting to be # woken up, then we will once again be active self.active_count -= 1 self.queue_cv.wait() self.active_count += 1 if self.stop_count > 0: self.active_count -= 1 self.stop_count -= 1 self.threads.discard(thread_no) self.thread_exit_cv.notify() break task = self.queue.popleft() try: task.service() except BaseException: self.logger.exception("Exception when servicing %r", task) def set_thread_count(self, count): with self.lock: threads = self.threads thread_no = 0 running = len(threads) - self.stop_count while running < count: # Start threads. while thread_no in threads: thread_no = thread_no + 1 threads.add(thread_no) running += 1 self.start_new_thread(self.handler_thread, thread_no) self.active_count += 1 thread_no = thread_no + 1 if running > count: # Stop threads. self.stop_count += running - count self.queue_cv.notify_all() def add_task(self, task): with self.lock: self.queue.append(task) self.queue_cv.notify() queue_size = len(self.queue) idle_threads = len(self.threads) - self.stop_count - self.active_count if queue_size > idle_threads: self.queue_logger.warning( "Task queue depth is %d", queue_size - idle_threads ) def shutdown(self, cancel_pending=True, timeout=5): self.set_thread_count(0) # Ensure the threads shut down. threads = self.threads expiration = time.time() + timeout with self.lock: while threads: if time.time() >= expiration: self.logger.warning("%d thread(s) still running", len(threads)) break self.thread_exit_cv.wait(0.1) if cancel_pending: # Cancel remaining tasks. queue = self.queue if len(queue) > 0: self.logger.warning("Canceling %d pending task(s)", len(queue)) while queue: task = queue.popleft() task.cancel() self.queue_cv.notify_all() return True return False class Task: close_on_finish = False status = "200 OK" wrote_header = False start_time = 0 content_length = None content_bytes_written = 0 logged_write_excess = False logged_write_no_body = False complete = False chunked_response = False logger = logger def __init__(self, channel, request): self.channel = channel self.request = request self.response_headers = [] version = request.version if version not in ("1.0", "1.1"): # fall back to a version we support. version = "1.0" self.version = version def service(self): try: self.start() self.execute() self.finish() except OSError: self.close_on_finish = True if self.channel.adj.log_socket_errors: raise @property def has_body(self): return not ( self.status.startswith("1") or self.status.startswith("204") or self.status.startswith("304") ) def build_response_header(self): version = self.version # Figure out whether the connection should be closed. connection = self.request.headers.get("CONNECTION", "").lower() response_headers = [] content_length_header = None date_header = None server_header = None connection_close_header = None for (headername, headerval) in self.response_headers: headername = "-".join([x.capitalize() for x in headername.split("-")]) if headername == "Content-Length": if self.has_body: content_length_header = headerval else: continue # pragma: no cover if headername == "Date": date_header = headerval if headername == "Server": server_header = headerval if headername == "Connection": connection_close_header = headerval.lower() # replace with properly capitalized version response_headers.append((headername, headerval)) if ( content_length_header is None and self.content_length is not None and self.has_body ): content_length_header = str(self.content_length) response_headers.append(("Content-Length", content_length_header)) def close_on_finish(): if connection_close_header is None: response_headers.append(("Connection", "close")) self.close_on_finish = True if version == "1.0": if connection == "keep-alive": if not content_length_header: close_on_finish() else: response_headers.append(("Connection", "Keep-Alive")) else: close_on_finish() elif version == "1.1": if connection == "close": close_on_finish() if not content_length_header: # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length # for any response with a status code of 1xx, 204 or 304. if self.has_body: response_headers.append(("Transfer-Encoding", "chunked")) self.chunked_response = True if not self.close_on_finish: close_on_finish() # under HTTP 1.1 keep-alive is default, no need to set the header else: raise AssertionError("neither HTTP/1.0 or HTTP/1.1") # Set the Server and Date field, if not yet specified. This is needed # if the server is used as a proxy. ident = self.channel.server.adj.ident if not server_header: if ident: response_headers.append(("Server", ident)) else: response_headers.append(("Via", ident or "waitress")) if not date_header: response_headers.append(("Date", build_http_date(self.start_time))) self.response_headers = response_headers first_line = "HTTP/%s %s" % (self.version, self.status) # NB: sorting headers needs to preserve same-named-header order # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; # rely on stable sort to keep relative position of same-named headers next_lines = [ "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) ] lines = [first_line] + next_lines res = "%s\r\n\r\n" % "\r\n".join(lines) return res.encode("latin-1") def remove_content_length_header(self): response_headers = [] for header_name, header_value in self.response_headers: if header_name.lower() == "content-length": continue # pragma: nocover response_headers.append((header_name, header_value)) self.response_headers = response_headers def start(self): self.start_time = time.time() def finish(self): if not self.wrote_header: self.write(b"") if self.chunked_response: # not self.write, it will chunk it! self.channel.write_soon(b"0\r\n\r\n") def write(self, data): if not self.complete: raise RuntimeError("start_response was not called before body written") channel = self.channel if not self.wrote_header: rh = self.build_response_header() channel.write_soon(rh) self.wrote_header = True if data and self.has_body: towrite = data cl = self.content_length if self.chunked_response: # use chunked encoding response towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n" towrite += data + b"\r\n" elif cl is not None: towrite = data[: cl - self.content_bytes_written] self.content_bytes_written += len(towrite) if towrite != data and not self.logged_write_excess: self.logger.warning( "application-written content exceeded the number of " "bytes specified by Content-Length header (%s)" % cl ) self.logged_write_excess = True if towrite: channel.write_soon(towrite) elif data: # Cheat, and tell the application we have written all of the bytes, # even though the response shouldn't have a body and we are # ignoring it entirely. self.content_bytes_written += len(data) if not self.logged_write_no_body: self.logger.warning( "application-written content was ignored due to HTTP " "response that may not contain a message-body: (%s)" % self.status ) self.logged_write_no_body = True class ErrorTask(Task): """An error task produces an error response""" complete = True def execute(self): e = self.request.error status, headers, body = e.to_response() self.status = status self.response_headers.extend(headers) # We need to explicitly tell the remote client we are closing the # connection, because self.close_on_finish is set, and we are going to # slam the door in the clients face. self.response_headers.append(("Connection", "close")) self.close_on_finish = True self.content_length = len(body) self.write(body.encode("latin-1")) class WSGITask(Task): """A WSGI task produces a response from a WSGI application.""" environ = None def execute(self): environ = self.get_environment() def start_response(status, headers, exc_info=None): if self.complete and not exc_info: raise AssertionError( "start_response called a second time without providing exc_info." ) if exc_info: try: if self.wrote_header: # higher levels will catch and handle raised exception: # 1. "service" method in task.py # 2. "service" method in channel.py # 3. "handler_thread" method in task.py raise exc_info[1] else: # As per WSGI spec existing headers must be cleared self.response_headers = [] finally: exc_info = None self.complete = True if not status.__class__ is str: raise AssertionError("status %s is not a string" % status) if "\n" in status or "\r" in status: raise ValueError( "carriage return/line feed character present in status" ) self.status = status # Prepare the headers for output for k, v in headers: if not k.__class__ is str: raise AssertionError( "Header name %r is not a string in %r" % (k, (k, v)) ) if not v.__class__ is str: raise AssertionError( "Header value %r is not a string in %r" % (v, (k, v)) ) if "\n" in v or "\r" in v: raise ValueError( "carriage return/line feed character present in header value" ) if "\n" in k or "\r" in k: raise ValueError( "carriage return/line feed character present in header name" ) kl = k.lower() if kl == "content-length": self.content_length = int(v) elif kl in hop_by_hop: raise AssertionError( '%s is a "hop-by-hop" header; it cannot be used by ' "a WSGI application (see PEP 3333)" % k ) self.response_headers.extend(headers) # Return a method used to write the response data. return self.write # Call the application to handle the request and write a response app_iter = self.channel.server.application(environ, start_response) can_close_app_iter = True try: if app_iter.__class__ is ReadOnlyFileBasedBuffer: cl = self.content_length size = app_iter.prepare(cl) if size: if cl != size: if cl is not None: self.remove_content_length_header() self.content_length = size self.write(b"") # generate headers # if the write_soon below succeeds then the channel will # take over closing the underlying file via the channel's # _flush_some or handle_close so we intentionally avoid # calling close in the finally block self.channel.write_soon(app_iter) can_close_app_iter = False return first_chunk_len = None for chunk in app_iter: if first_chunk_len is None: first_chunk_len = len(chunk) # Set a Content-Length header if one is not supplied. # start_response may not have been called until first # iteration as per PEP, so we must reinterrogate # self.content_length here if self.content_length is None: app_iter_len = None if hasattr(app_iter, "__len__"): app_iter_len = len(app_iter) if app_iter_len == 1: self.content_length = first_chunk_len # transmit headers only after first iteration of the iterable # that returns a non-empty bytestring (PEP 3333) if chunk: self.write(chunk) cl = self.content_length if cl is not None: if self.content_bytes_written != cl: # close the connection so the client isn't sitting around # waiting for more data when there are too few bytes # to service content-length self.close_on_finish = True if self.request.command != "HEAD": self.logger.warning( "application returned too few bytes (%s) " "for specified Content-Length (%s) via app_iter" % (self.content_bytes_written, cl), ) finally: if can_close_app_iter and hasattr(app_iter, "close"): app_iter.close() def get_environment(self): """Returns a WSGI environment.""" environ = self.environ if environ is not None: # Return the cached copy. return environ request = self.request path = request.path channel = self.channel server = channel.server url_prefix = server.adj.url_prefix if path.startswith("/"): # strip extra slashes at the beginning of a path that starts # with any number of slashes path = "/" + path.lstrip("/") if url_prefix: # NB: url_prefix is guaranteed by the configuration machinery to # be either the empty string or a string that starts with a single # slash and ends without any slashes if path == url_prefix: # if the path is the same as the url prefix, the SCRIPT_NAME # should be the url_prefix and PATH_INFO should be empty path = "" else: # if the path starts with the url prefix plus a slash, # the SCRIPT_NAME should be the url_prefix and PATH_INFO should # the value of path from the slash until its end url_prefix_with_trailing_slash = url_prefix + "/" if path.startswith(url_prefix_with_trailing_slash): path = path[len(url_prefix) :] environ = { "REMOTE_ADDR": channel.addr[0], # Nah, we aren't actually going to look up the reverse DNS for # REMOTE_ADDR, but we will happily set this environment variable # for the WSGI application. Spec says we can just set this to # REMOTE_ADDR, so we do. "REMOTE_HOST": channel.addr[0], # try and set the REMOTE_PORT to something useful, but maybe None "REMOTE_PORT": str(channel.addr[1]), "REQUEST_METHOD": request.command.upper(), "SERVER_PORT": str(server.effective_port), "SERVER_NAME": server.server_name, "SERVER_SOFTWARE": server.adj.ident, "SERVER_PROTOCOL": "HTTP/%s" % self.version, "SCRIPT_NAME": url_prefix, "PATH_INFO": path, "QUERY_STRING": request.query, "wsgi.url_scheme": request.url_scheme, # the following environment variables are required by the WSGI spec "wsgi.version": (1, 0), # apps should use the logging module "wsgi.errors": sys.stderr, "wsgi.multithread": True, "wsgi.multiprocess": False, "wsgi.run_once": False, "wsgi.input": request.get_body_stream(), "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, "wsgi.input_terminated": True, # wsgi.input is EOF terminated } for key, value in dict(request.headers).items(): value = value.strip() mykey = rename_headers.get(key, None) if mykey is None: mykey = "HTTP_" + key if mykey not in environ: environ[mykey] = value # Insert a callable into the environment that allows the application to # check if the client disconnected. Only works with # channel_request_lookahead larger than 0. environ["waitress.client_disconnected"] = self.channel.check_client_disconnected # cache the environ for this request self.environ = environ return environ