##############################################################################
#
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################

import os
import os.path
import socket
import time

from waitress import trigger
from waitress.adjustments import Adjustments
from waitress.channel import HTTPChannel
from waitress.task import ThreadedTaskDispatcher
from waitress.utilities import cleanup_unix_socket

from waitress.compat import (
    IPPROTO_IPV6,
    IPV6_V6ONLY,
)
from . import wasyncore
from .proxy_headers import proxy_headers_middleware


def create_server(
    application,
    map=None,
    _start=True,  # test shim
    _sock=None,  # test shim
    _dispatcher=None,  # test shim
    **kw  # adjustments
):
    """
    if __name__ == '__main__':
        server = create_server(app)
        server.run()
    """
    if application is None:
        raise ValueError(
            'The "app" passed to ``create_server`` was ``None``.  You forgot '
            "to return a WSGI app within your application."
        )
    adj = Adjustments(**kw)

    if map is None:  # pragma: nocover
        map = {}

    dispatcher = _dispatcher
    if dispatcher is None:
        dispatcher = ThreadedTaskDispatcher()
        dispatcher.set_thread_count(adj.threads)

    if adj.unix_socket and hasattr(socket, "AF_UNIX"):
        sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None)
        return UnixWSGIServer(
            application,
            map,
            _start,
            _sock,
            dispatcher=dispatcher,
            adj=adj,
            sockinfo=sockinfo,
        )

    effective_listen = []
    last_serv = None
    if not adj.sockets:
        for sockinfo in adj.listen:
            # When TcpWSGIServer is called, it registers itself in the map. This
            # side-effect is all we need it for, so we don't store a reference to
            # or return it to the user.
            last_serv = TcpWSGIServer(
                application,
                map,
                _start,
                _sock,
                dispatcher=dispatcher,
                adj=adj,
                sockinfo=sockinfo,
            )
            effective_listen.append(
                (last_serv.effective_host, last_serv.effective_port)
            )

    for sock in adj.sockets:
        sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname())
        if sock.family == socket.AF_INET or sock.family == socket.AF_INET6:
            last_serv = TcpWSGIServer(
                application,
                map,
                _start,
                sock,
                dispatcher=dispatcher,
                adj=adj,
                bind_socket=False,
                sockinfo=sockinfo,
            )
            effective_listen.append(
                (last_serv.effective_host, last_serv.effective_port)
            )
        elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
            last_serv = UnixWSGIServer(
                application,
                map,
                _start,
                sock,
                dispatcher=dispatcher,
                adj=adj,
                bind_socket=False,
                sockinfo=sockinfo,
            )
            effective_listen.append(
                (last_serv.effective_host, last_serv.effective_port)
            )

    # We are running a single server, so we can just return the last server,
    # saves us from having to create one more object
    if len(effective_listen) == 1:
        # In this case we have no need to use a MultiSocketServer
        return last_serv

    # Return a class that has a utility function to print out the sockets it's
    # listening on, and has a .run() function. All of the TcpWSGIServers
    # registered themselves in the map above.
    return MultiSocketServer(map, adj, effective_listen, dispatcher)


# This class is only ever used if we have multiple listen sockets. It allows
# the serve() API to call .run() which starts the wasyncore loop, and catches
# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down.
class MultiSocketServer(object):
    asyncore = wasyncore  # test shim

    def __init__(
        self, map=None, adj=None, effective_listen=None, dispatcher=None,
    ):
        self.adj = adj
        self.map = map
        self.effective_listen = effective_listen
        self.task_dispatcher = dispatcher

    def print_listen(self, format_str):  # pragma: nocover
        for l in self.effective_listen:
            l = list(l)

            if ":" in l[0]:
                l[0] = "[{}]".format(l[0])

            print(format_str.format(*l))

    def run(self):
        try:
            self.asyncore.loop(
                timeout=self.adj.asyncore_loop_timeout,
                map=self.map,
                use_poll=self.adj.asyncore_use_poll,
            )
        except (SystemExit, KeyboardInterrupt):
            self.close()

    def close(self):
        self.task_dispatcher.shutdown()
        wasyncore.close_all(self.map)


class BaseWSGIServer(wasyncore.dispatcher, object):

    channel_class = HTTPChannel
    next_channel_cleanup = 0
    socketmod = socket  # test shim
    asyncore = wasyncore  # test shim

    def __init__(
        self,
        application,
        map=None,
        _start=True,  # test shim
        _sock=None,  # test shim
        dispatcher=None,  # dispatcher
        adj=None,  # adjustments
        sockinfo=None,  # opaque object
        bind_socket=True,
        **kw
    ):
        if adj is None:
            adj = Adjustments(**kw)

        if adj.trusted_proxy or adj.clear_untrusted_proxy_headers:
            # wrap the application to deal with proxy headers
            # we wrap it here because webtest subclasses the TcpWSGIServer
            # directly and thus doesn't run any code that's in create_server
            application = proxy_headers_middleware(
                application,
                trusted_proxy=adj.trusted_proxy,
                trusted_proxy_count=adj.trusted_proxy_count,
                trusted_proxy_headers=adj.trusted_proxy_headers,
                clear_untrusted=adj.clear_untrusted_proxy_headers,
                log_untrusted=adj.log_untrusted_proxy_headers,
                logger=self.logger,
            )

        if map is None:
            # use a nonglobal socket map by default to hopefully prevent
            # conflicts with apps and libs that use the wasyncore global socket
            # map ala https://github.com/Pylons/waitress/issues/63
            map = {}
        if sockinfo is None:
            sockinfo = adj.listen[0]

        self.sockinfo = sockinfo
        self.family = sockinfo[0]
        self.socktype = sockinfo[1]
        self.application = application
        self.adj = adj
        self.trigger = trigger.trigger(map)
        if dispatcher is None:
            dispatcher = ThreadedTaskDispatcher()
            dispatcher.set_thread_count(self.adj.threads)

        self.task_dispatcher = dispatcher
        self.asyncore.dispatcher.__init__(self, _sock, map=map)
        if _sock is None:
            self.create_socket(self.family, self.socktype)
            if self.family == socket.AF_INET6:  # pragma: nocover
                self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)

        self.set_reuse_addr()

        if bind_socket:
            self.bind_server_socket()

        self.effective_host, self.effective_port = self.getsockname()
        self.server_name = self.get_server_name(self.effective_host)
        self.active_channels = {}
        if _start:
            self.accept_connections()

    def bind_server_socket(self):
        raise NotImplementedError  # pragma: no cover

    def get_server_name(self, ip):
        """Given an IP or hostname, try to determine the server name."""

        if not ip:
            raise ValueError("Requires an IP to get the server name")

        server_name = str(ip)

        # If we are bound to all IP's, just return the current hostname, only
        # fall-back to "localhost" if we fail to get the hostname
        if server_name == "0.0.0.0" or server_name == "::":
            try:
                return str(self.socketmod.gethostname())
            except (socket.error, UnicodeDecodeError):  # pragma: no cover
                # We also deal with UnicodeDecodeError in case of Windows with
                # non-ascii hostname
                return "localhost"

        # Now let's try and convert the IP address to a proper hostname
        try:
            server_name = self.socketmod.gethostbyaddr(server_name)[0]
        except (socket.error, UnicodeDecodeError):  # pragma: no cover
            # We also deal with UnicodeDecodeError in case of Windows with
            # non-ascii hostname
            pass

        # If it contains an IPv6 literal, make sure to surround it with
        # brackets
        if ":" in server_name and "[" not in server_name:
            server_name = "[{}]".format(server_name)

        return server_name

    def getsockname(self):
        raise NotImplementedError  # pragma: no cover

    def accept_connections(self):
        self.accepting = True
        self.socket.listen(self.adj.backlog)  # Get around asyncore NT limit

    def add_task(self, task):
        self.task_dispatcher.add_task(task)

    def readable(self):
        now = time.time()
        if now >= self.next_channel_cleanup:
            self.next_channel_cleanup = now + self.adj.cleanup_interval
            self.maintenance(now)
        return self.accepting and len(self._map) < self.adj.connection_limit

    def writable(self):
        return False

    def handle_read(self):
        pass

    def handle_connect(self):
        pass

    def handle_accept(self):
        try:
            v = self.accept()
            if v is None:
                return
            conn, addr = v
        except socket.error:
            # Linux: On rare occasions we get a bogus socket back from
            # accept.  socketmodule.c:makesockaddr complains that the
            # address family is unknown.  We don't want the whole server
            # to shut down because of this.
            if self.adj.log_socket_errors:
                self.logger.warning("server accept() threw an exception", exc_info=True)
            return
        self.set_socket_options(conn)
        addr = self.fix_addr(addr)
        self.channel_class(self, conn, addr, self.adj, map=self._map)

    def run(self):
        try:
            self.asyncore.loop(
                timeout=self.adj.asyncore_loop_timeout,
                map=self._map,
                use_poll=self.adj.asyncore_use_poll,
            )
        except (SystemExit, KeyboardInterrupt):
            self.task_dispatcher.shutdown()

    def pull_trigger(self):
        self.trigger.pull_trigger()

    def set_socket_options(self, conn):
        pass

    def fix_addr(self, addr):
        return addr

    def maintenance(self, now):
        """
        Closes channels that have not had any activity in a while.

        The timeout is configured through adj.channel_timeout (seconds).
        """
        cutoff = now - self.adj.channel_timeout
        for channel in self.active_channels.values():
            if (not channel.requests) and channel.last_activity < cutoff:
                channel.will_close = True

    def print_listen(self, format_str):  # pragma: nocover
        print(format_str.format(self.effective_host, self.effective_port))

    def close(self):
        self.trigger.close()
        return wasyncore.dispatcher.close(self)


class TcpWSGIServer(BaseWSGIServer):
    def bind_server_socket(self):
        (_, _, _, sockaddr) = self.sockinfo
        self.bind(sockaddr)

    def getsockname(self):
        try:
            return self.socketmod.getnameinfo(
                self.socket.getsockname(), self.socketmod.NI_NUMERICSERV
            )
        except:  # pragma: no cover
            # This only happens on Linux because a DNS issue is considered a
            # temporary failure that will raise (even when NI_NAMEREQD is not
            # set). Instead we try again, but this time we just ask for the
            # numerichost and the numericserv (port) and return those. It is
            # better than nothing.
            return self.socketmod.getnameinfo(
                self.socket.getsockname(),
                self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV,
            )

    def set_socket_options(self, conn):
        for (level, optname, value) in self.adj.socket_options:
            conn.setsockopt(level, optname, value)


if hasattr(socket, "AF_UNIX"):

    class UnixWSGIServer(BaseWSGIServer):
        def __init__(
            self,
            application,
            map=None,
            _start=True,  # test shim
            _sock=None,  # test shim
            dispatcher=None,  # dispatcher
            adj=None,  # adjustments
            sockinfo=None,  # opaque object
            **kw
        ):
            if sockinfo is None:
                sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None)

            super(UnixWSGIServer, self).__init__(
                application,
                map=map,
                _start=_start,
                _sock=_sock,
                dispatcher=dispatcher,
                adj=adj,
                sockinfo=sockinfo,
                **kw
            )

        def bind_server_socket(self):
            cleanup_unix_socket(self.adj.unix_socket)
            self.bind(self.adj.unix_socket)
            if os.path.exists(self.adj.unix_socket):
                os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms)

        def getsockname(self):
            return ("unix", self.socket.getsockname())

        def fix_addr(self, addr):
            return ("localhost", None)

        def get_server_name(self, ip):
            return "localhost"


# Compatibility alias.
WSGIServer = TcpWSGIServer