You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
bazarr/libs/sqlalchemy/dialects/mysql/aiomysql.py

333 lines
9.7 KiB

# dialects/mysql/aiomysql.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""
.. dialect:: mysql+aiomysql
:name: aiomysql
:dbapi: aiomysql
:connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
:url: https://github.com/aio-libs/aiomysql
The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
Using a special asyncio mediation layer, the aiomysql dialect is usable
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
extension package.
This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
""" # noqa
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...engine import AdaptedConnection
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
class AsyncAdapt_aiomysql_cursor:
# TODO: base on connectors/asyncio.py
# see #10415
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
# see https://github.com/aio-libs/aiomysql/issues/543
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._execute_mutex:
result = await self._cursor.execute(operation, parameters)
if not self.server_side:
# aiomysql has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._execute_mutex:
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
server_side = True
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
self._cursor = self.await_(cursor.__aenter__())
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_aiomysql_connection(AdaptedConnection):
# TODO: base on connectors/asyncio.py
# see #10415
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
def ping(self, reconnect):
return self.await_(self._connection.ping(reconnect))
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_aiomysql_ss_cursor(self)
else:
return AsyncAdapt_aiomysql_cursor(self)
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def terminate(self):
# it's not awaitable.
self._connection.close()
def close(self) -> None:
self.await_(self._connection.ensure_closed())
class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
# TODO: base on connectors/asyncio.py
# see #10415
__slots__ = ()
await_ = staticmethod(await_fallback)
class AsyncAdapt_aiomysql_dbapi:
def __init__(self, aiomysql, pymysql):
self.aiomysql = aiomysql
self.pymysql = pymysql
self.paramstyle = "format"
self._init_dbapi_attributes()
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
"InterfaceError",
"DataError",
"DatabaseError",
"OperationalError",
"InterfaceError",
"IntegrityError",
"ProgrammingError",
"InternalError",
"NotSupportedError",
):
setattr(self, name, getattr(self.aiomysql, name))
for name in (
"NUMBER",
"STRING",
"DATETIME",
"BINARY",
"TIMESTAMP",
"Binary",
):
setattr(self, name, getattr(self.pymysql, name))
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
if util.asbool(async_fallback):
return AsyncAdaptFallback_aiomysql_connection(
self,
await_fallback(creator_fn(*arg, **kw)),
)
else:
return AsyncAdapt_aiomysql_connection(
self,
await_only(creator_fn(*arg, **kw)),
)
def _init_cursors_subclasses(self):
# suppress unconditional warning emitted by aiomysql
class Cursor(self.aiomysql.Cursor):
async def _show_warnings(self, conn):
pass
class SSCursor(self.aiomysql.SSCursor):
async def _show_warnings(self, conn):
pass
return Cursor, SSCursor
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
driver = "aiomysql"
supports_statement_cache = True
supports_server_side_cursors = True
_sscursor = AsyncAdapt_aiomysql_ss_cursor
is_async = True
has_terminate = True
@classmethod
def import_dbapi(cls):
return AsyncAdapt_aiomysql_dbapi(
__import__("aiomysql"), __import__("pymysql")
)
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def do_terminate(self, dbapi_connection) -> None:
dbapi_connection.terminate()
def create_connect_args(self, url):
return super().create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(self, e, connection, cursor):
if super().is_disconnect(e, connection, cursor):
return True
else:
str_e = str(e).lower()
return "not connected" in str_e
def _found_rows_client_flag(self):
from pymysql.constants import CLIENT
return CLIENT.FOUND_ROWS
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_aiomysql