You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
276 lines
8.8 KiB
276 lines
8.8 KiB
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
|
|
|
"""asyncio library query support"""
|
|
|
|
import asyncio
|
|
import socket
|
|
import sys
|
|
|
|
import dns._asyncbackend
|
|
import dns._features
|
|
import dns.exception
|
|
import dns.inet
|
|
|
|
_is_win32 = sys.platform == "win32"
|
|
|
|
|
|
def _get_running_loop():
|
|
try:
|
|
return asyncio.get_running_loop()
|
|
except AttributeError: # pragma: no cover
|
|
return asyncio.get_event_loop()
|
|
|
|
|
|
class _DatagramProtocol:
|
|
def __init__(self):
|
|
self.transport = None
|
|
self.recvfrom = None
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
|
|
def datagram_received(self, data, addr):
|
|
if self.recvfrom and not self.recvfrom.done():
|
|
self.recvfrom.set_result((data, addr))
|
|
|
|
def error_received(self, exc): # pragma: no cover
|
|
if self.recvfrom and not self.recvfrom.done():
|
|
self.recvfrom.set_exception(exc)
|
|
|
|
def connection_lost(self, exc):
|
|
if self.recvfrom and not self.recvfrom.done():
|
|
if exc is None:
|
|
# EOF we triggered. Is there a better way to do this?
|
|
try:
|
|
raise EOFError
|
|
except EOFError as e:
|
|
self.recvfrom.set_exception(e)
|
|
else:
|
|
self.recvfrom.set_exception(exc)
|
|
|
|
def close(self):
|
|
self.transport.close()
|
|
|
|
|
|
async def _maybe_wait_for(awaitable, timeout):
|
|
if timeout is not None:
|
|
try:
|
|
return await asyncio.wait_for(awaitable, timeout)
|
|
except asyncio.TimeoutError:
|
|
raise dns.exception.Timeout(timeout=timeout)
|
|
else:
|
|
return await awaitable
|
|
|
|
|
|
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
|
def __init__(self, family, transport, protocol):
|
|
super().__init__(family)
|
|
self.transport = transport
|
|
self.protocol = protocol
|
|
|
|
async def sendto(self, what, destination, timeout): # pragma: no cover
|
|
# no timeout for asyncio sendto
|
|
self.transport.sendto(what, destination)
|
|
return len(what)
|
|
|
|
async def recvfrom(self, size, timeout):
|
|
# ignore size as there's no way I know to tell protocol about it
|
|
done = _get_running_loop().create_future()
|
|
try:
|
|
assert self.protocol.recvfrom is None
|
|
self.protocol.recvfrom = done
|
|
await _maybe_wait_for(done, timeout)
|
|
return done.result()
|
|
finally:
|
|
self.protocol.recvfrom = None
|
|
|
|
async def close(self):
|
|
self.protocol.close()
|
|
|
|
async def getpeername(self):
|
|
return self.transport.get_extra_info("peername")
|
|
|
|
async def getsockname(self):
|
|
return self.transport.get_extra_info("sockname")
|
|
|
|
async def getpeercert(self, timeout):
|
|
raise NotImplementedError
|
|
|
|
|
|
class StreamSocket(dns._asyncbackend.StreamSocket):
|
|
def __init__(self, af, reader, writer):
|
|
self.family = af
|
|
self.reader = reader
|
|
self.writer = writer
|
|
|
|
async def sendall(self, what, timeout):
|
|
self.writer.write(what)
|
|
return await _maybe_wait_for(self.writer.drain(), timeout)
|
|
|
|
async def recv(self, size, timeout):
|
|
return await _maybe_wait_for(self.reader.read(size), timeout)
|
|
|
|
async def close(self):
|
|
self.writer.close()
|
|
|
|
async def getpeername(self):
|
|
return self.writer.get_extra_info("peername")
|
|
|
|
async def getsockname(self):
|
|
return self.writer.get_extra_info("sockname")
|
|
|
|
async def getpeercert(self, timeout):
|
|
return self.writer.get_extra_info("peercert")
|
|
|
|
|
|
if dns._features.have("doh"):
|
|
import anyio
|
|
import httpcore
|
|
import httpcore._backends.anyio
|
|
import httpx
|
|
|
|
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
|
|
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream
|
|
|
|
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
|
|
|
|
class _NetworkBackend(_CoreAsyncNetworkBackend):
|
|
def __init__(self, resolver, local_port, bootstrap_address, family):
|
|
super().__init__()
|
|
self._local_port = local_port
|
|
self._resolver = resolver
|
|
self._bootstrap_address = bootstrap_address
|
|
self._family = family
|
|
if local_port != 0:
|
|
raise NotImplementedError(
|
|
"the asyncio transport for HTTPX cannot set the local port"
|
|
)
|
|
|
|
async def connect_tcp(
|
|
self, host, port, timeout, local_address, socket_options=None
|
|
): # pylint: disable=signature-differs
|
|
addresses = []
|
|
_, expiration = _compute_times(timeout)
|
|
if dns.inet.is_address(host):
|
|
addresses.append(host)
|
|
elif self._bootstrap_address is not None:
|
|
addresses.append(self._bootstrap_address)
|
|
else:
|
|
timeout = _remaining(expiration)
|
|
family = self._family
|
|
if local_address:
|
|
family = dns.inet.af_for_address(local_address)
|
|
answers = await self._resolver.resolve_name(
|
|
host, family=family, lifetime=timeout
|
|
)
|
|
addresses = answers.addresses()
|
|
for address in addresses:
|
|
try:
|
|
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
|
|
timeout = _remaining(attempt_expiration)
|
|
with anyio.fail_after(timeout):
|
|
stream = await anyio.connect_tcp(
|
|
remote_host=address,
|
|
remote_port=port,
|
|
local_host=local_address,
|
|
)
|
|
return _CoreAnyIOStream(stream)
|
|
except Exception:
|
|
pass
|
|
raise httpcore.ConnectError
|
|
|
|
async def connect_unix_socket(
|
|
self, path, timeout, socket_options=None
|
|
): # pylint: disable=signature-differs
|
|
raise NotImplementedError
|
|
|
|
async def sleep(self, seconds): # pylint: disable=signature-differs
|
|
await anyio.sleep(seconds)
|
|
|
|
class _HTTPTransport(httpx.AsyncHTTPTransport):
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
local_port=0,
|
|
bootstrap_address=None,
|
|
resolver=None,
|
|
family=socket.AF_UNSPEC,
|
|
**kwargs,
|
|
):
|
|
if resolver is None:
|
|
# pylint: disable=import-outside-toplevel,redefined-outer-name
|
|
import dns.asyncresolver
|
|
|
|
resolver = dns.asyncresolver.Resolver()
|
|
super().__init__(*args, **kwargs)
|
|
self._pool._network_backend = _NetworkBackend(
|
|
resolver, local_port, bootstrap_address, family
|
|
)
|
|
|
|
else:
|
|
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
|
|
|
|
|
|
class Backend(dns._asyncbackend.Backend):
|
|
def name(self):
|
|
return "asyncio"
|
|
|
|
async def make_socket(
|
|
self,
|
|
af,
|
|
socktype,
|
|
proto=0,
|
|
source=None,
|
|
destination=None,
|
|
timeout=None,
|
|
ssl_context=None,
|
|
server_hostname=None,
|
|
):
|
|
loop = _get_running_loop()
|
|
if socktype == socket.SOCK_DGRAM:
|
|
if _is_win32 and source is None:
|
|
# Win32 wants explicit binding before recvfrom(). This is the
|
|
# proper fix for [#637].
|
|
source = (dns.inet.any_for_af(af), 0)
|
|
transport, protocol = await loop.create_datagram_endpoint(
|
|
_DatagramProtocol,
|
|
source,
|
|
family=af,
|
|
proto=proto,
|
|
remote_addr=destination,
|
|
)
|
|
return DatagramSocket(af, transport, protocol)
|
|
elif socktype == socket.SOCK_STREAM:
|
|
if destination is None:
|
|
# This shouldn't happen, but we check to make code analysis software
|
|
# happier.
|
|
raise ValueError("destination required for stream sockets")
|
|
(r, w) = await _maybe_wait_for(
|
|
asyncio.open_connection(
|
|
destination[0],
|
|
destination[1],
|
|
ssl=ssl_context,
|
|
family=af,
|
|
proto=proto,
|
|
local_addr=source,
|
|
server_hostname=server_hostname,
|
|
),
|
|
timeout,
|
|
)
|
|
return StreamSocket(af, r, w)
|
|
raise NotImplementedError(
|
|
"unsupported socket " + f"type {socktype}"
|
|
) # pragma: no cover
|
|
|
|
async def sleep(self, interval):
|
|
await asyncio.sleep(interval)
|
|
|
|
def datagram_connection_required(self):
|
|
return False
|
|
|
|
def get_transport_class(self):
|
|
return _HTTPTransport
|
|
|
|
async def wait_for(self, awaitable, timeout):
|
|
return await _maybe_wait_for(awaitable, timeout)
|