This repository has been archived on 2023-07-16. You can view files and clone it, but cannot push or open issues or pull requests.
hackaday.io-spambot-hunter/hadsh/db/db.py

153 lines
4.1 KiB
Python

from psycopg2 import connect
from tornado.ioloop import IOLoop
from tornado.gen import coroutine, Return
from concurrent.futures import Future
import threading
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
class Database(object):
def __init__(self, db_uri, **kwargs):
"""
Parse the database URI and keyword arguments.
"""
parsed_uri = urlparse(db_uri)
if '@' in parsed_uri.netloc:
(user_password, host_port) = \
parsed_uri.netloc.split('@', 1)
else:
user_password = None
host_port = parsed_uri.netloc
if user_password and (':' in user_password):
(user, password) = user_password.split(':', 1)
else:
user = user_password or None
password = None
if host_port:
# Beware of IPv6 literals
try:
end_literal = host_port.index(']')
except ValueError:
end_literal = None
if end_literal:
assert host_port[0] == '['
host = host_port[0:end_literal+1]
port = host_port[end_literal+1:]
if port.startswith(':'):
port = int(port[1:])
else:
port = None
else:
# IPv4 or hostname
if ':' in host_port:
(host, port) = host_port.split(':', 1)
port = int(port)
else:
host = host_port
port = None
else:
host = None
port = None
self._log = kwargs.pop('log', None)
self._db_args = dict(
dbname=parsed_uri.path[1:],
user=user, password=password,
host=host, port=port, **kwargs)
self._conn_ioloop = None
self._conn_thread = None
self._conn = None
@coroutine
def connect(self):
"""
Connect to the server
"""
assert self._conn is None
assert self._conn_ioloop is None
future = Future()
io_loop = IOLoop()
thread = threading.Thread(
target=io_loop.start,
name='DatabaseThread')
thread.start()
def _connect():
try:
conn = connect(**self._db_args)
self._conn = conn
self._conn_ioloop = io_loop
self._conn_thread = thread
future.set_result(None)
except Exception as ex:
io_loop.stop()
future.set_exception(ex)
io_loop.add_callback(_connect)
yield future
def close(self):
if self._conn is None:
return
def _close():
self._conn.close()
self._conn_ioloop.stop()
self._conn_ioloop.add_callback(_close)
self._conn_thread.join()
self._conn = None
self._conn_ioloop = None
self._conn_thread = None
@coroutine
def query(self, sql, *args, commit=False):
if self._conn is None:
yield self.connect()
assert self._conn is not None
assert self._conn_ioloop is not None
future = Future()
def _query():
try:
with self._conn:
with self._conn.cursor() as cur:
cur = self._conn.cursor()
cur.execute(sql, args)
if cur.description:
res = cur.fetchall()
else:
res = None
if commit:
self._conn.commit()
future.set_result(res)
except Exception as ex:
future.set_exception(ex)
self._conn_ioloop.add_callback(_query)
try:
result = yield future
except:
if self._log:
self._log.exception('Failed SQL query:\n%s\nARGS: %s',
sql, args)
raise
raise Return(result)