import re
from .._compat import PY2, with_metaclass, iterkeys, to_unicode
from .._globals import IDENTITY, THREAD_LOCAL
from ..drivers import psycopg2_adapt
from .base import SQLAdapter
from . import AdapterMeta, adapters, with_connection, with_connection_or_raise
[docs]class PostgreMeta(AdapterMeta):
def __call__(cls, *args, **kwargs):
if cls != Postgre:
return AdapterMeta.__call__(cls, *args, **kwargs)
available_drivers = [
driver for driver in cls.drivers
if driver in iterkeys(kwargs['db']._drivers_available)]
uri_items = kwargs['uri'].split('://', 1)[0].split(':')
uri_driver = uri_items[1] if len(uri_items) > 1 else None
if uri_driver and uri_driver in available_drivers:
driver = uri_driver
else:
driver = available_drivers[0] if available_drivers else \
cls.drivers[0]
if driver == 'psycopg2':
cls = PostgrePsyco
else:
cls = PostgrePG8000
return AdapterMeta.__call__(cls, *args, **kwargs)
@adapters.register_for('postgres')
[docs]class Postgre(with_metaclass(PostgreMeta, SQLAdapter)):
dbengine = 'postgres'
drivers = ('psycopg2', 'pg8000')
support_distributed_transaction = True
REGEX_URI = re.compile(
'^(?P<user>[^:@]+)(\:(?P<password>[^@]*))?@(?P<host>\[[^/]+\]|' +
'[^\:@]+)(\:(?P<port>[0-9]+))?/(?P<db>[^\?]+)' +
'(\?sslmode=(?P<sslmode>.+))?$')
def __init__(self, db, uri, pool_size=0, folder=None, db_codec='UTF-8',
credential_decoder=IDENTITY, driver_args={},
adapter_args={}, do_connect=True, srid=4326,
after_connection=None):
self.srid = srid
super(Postgre, self).__init__(
db, uri, pool_size, folder, db_codec, credential_decoder,
driver_args, adapter_args, do_connect, after_connection)
def _initialize_(self, do_connect):
super(Postgre, self)._initialize_(do_connect)
ruri = self.uri.split('://', 1)[1]
m = self.REGEX_URI.match(ruri)
if not m:
raise SyntaxError("Invalid URI string in DAL")
user = self.credential_decoder(m.group('user'))
if not user:
raise SyntaxError('User required')
password = self.credential_decoder(m.group('password'))
if not password:
password = ''
host = m.group('host')
if not host:
raise SyntaxError('Host name required')
db = m.group('db')
if not db:
raise SyntaxError('Database name required')
port = m.group('port') or '5432'
sslmode = m.group('sslmode')
self.driver_args['database'] = db
self.driver_args['user'] = user
self.driver_args['host'] = host
self.driver_args['port'] = int(port)
self.driver_args['password'] = password
if sslmode:
self.driver_args['sslmode'] = sslmode
# choose diver according uri
if self.driver:
self.__version__ = "%s %s" % (self.driver.__name__,
self.driver.__version__)
else:
self.__version__ = None
THREAD_LOCAL._pydal_last_insert_ = None
def _get_json_dialect(self):
from ..dialects.postgre import PostgreDialectJSON
return PostgreDialectJSON
def _get_json_parser(self):
from ..parsers.postgre import PostgreAutoJSONParser
return PostgreAutoJSONParser
@property
def _last_insert(self):
return THREAD_LOCAL._pydal_last_insert_
@_last_insert.setter
def _last_insert(self, value):
THREAD_LOCAL._pydal_last_insert_ = value
[docs] def connector(self):
return self.driver.connect(**self.driver_args)
[docs] def after_connection(self):
self.execute("SET CLIENT_ENCODING TO 'UTF8'")
self.execute("SET standard_conforming_strings=on;")
self._config_json()
[docs] def lastrowid(self, table=None):
if self._last_insert:
return int(self.cursor.fetchone()[0])
self.execute("select lastval()")
return int(self.cursor.fetchone()[0])
def _insert(self, table, fields):
self._last_insert = None
if fields:
retval = None
if hasattr(table, '_id'):
self._last_insert = (table._id, 1)
retval = table._id.name
return self.dialect.insert(
table.sqlsafe,
','.join(el[0].sqlsafe_name for el in fields),
','.join(self.expand(v, f.type) for f, v in fields),
retval)
return self.dialect.insert_empty(table.sqlsafe)
@with_connection
def prepare(self, key):
self.execute("PREPARE TRANSACTION '%s';" % key)
@with_connection
def commit_prepared(self, key):
self.execute("COMMIT PREPARED '%s';" % key)
@with_connection
def rollback_prepared(self, key):
self.execute("ROLLBACK PREPARED '%s';" % key)
@adapters.register_for('postgres:psycopg2')
[docs]class PostgrePsyco(Postgre):
drivers = ('psycopg2',)
def _config_json(self):
use_json = self.driver.__version__ >= "2.0.12" and \
self.connection.server_version >= 90200
if use_json:
self.dialect = self._get_json_dialect()(self)
if self.driver.__version__ >= '2.5.0':
self.parser = self._get_json_parser()(self)
[docs] def adapt(self, obj):
adapted = psycopg2_adapt(obj)
# deal with new relic Connection Wrapper (newrelic>=2.10.0.8)
cxn = getattr(self.connection, '__wrapped__', self.connection)
adapted.prepare(cxn)
rv = adapted.getquoted()
if not PY2:
if isinstance(rv, bytes):
return rv.decode('utf-8')
return rv
@adapters.register_for('postgres:pg8000')
[docs]class PostgrePG8000(Postgre):
drivers = ('pg8000',)
def _config_json(self):
if self.connection._server_version >= "9.2.0":
self.dialect = self._get_json_dialect()(self)
if self.driver.__version__ >= '1.10.2':
self.parser = self._get_json_parser()(self)
[docs] def adapt(self, obj):
return "'%s'" % obj.replace("%", "%%").replace("'", "''")
@with_connection_or_raise
def execute(self, *args, **kwargs):
if PY2:
args = list(args)
args[0] = to_unicode(args[0])
return super(PostgrePG8000, self).execute(*args, **kwargs)
@adapters.register_for('postgres2')
[docs]class PostgreNew(Postgre):
def _get_json_dialect(self):
from ..dialects.postgre import PostgreDialectArraysJSON
return PostgreDialectArraysJSON
def _get_json_parser(self):
from ..parsers.postgre import PostgreNewAutoJSONParser
return PostgreNewAutoJSONParser
@adapters.register_for('postgres2:psycopg2')
[docs]class PostgrePsycoNew(PostgrePsyco):
pass
@adapters.register_for('postgres2:pg8000')
[docs]class PostgrePG8000New(PostgrePG8000):
pass
@adapters.register_for('postgres3')
[docs]class PostgreBoolean(PostgreNew):
def _get_json_dialect(self):
from ..dialects.postgre import PostgreDialectBooleanJSON
return PostgreDialectBooleanJSON
def _get_json_parser(self):
from ..parsers.postgre import PostgreBooleanAutoJSONParser
return PostgreBooleanAutoJSONParser
@adapters.register_for('postgres3:psycopg2')
[docs]class PostgrePsycoBoolean(PostgrePsycoNew):
pass
@adapters.register_for('postgres3:pg8000')
[docs]class PostgrePG8000Boolean(PostgrePG8000New):
pass
@adapters.register_for('jdbc:postgres')
[docs]class JDBCPostgre(Postgre):
drivers = ('zxJDBC',)
REGEX_URI = re.compile(
'^(?P<user>[^:@]+)(\:(?P<password>[^@]*))?@(?P<host>\[[^/]+\]|' +
'[^\:/]+)(\:(?P<port>[0-9]+))?/(?P<db>.+)$')
def _initialize_(self, do_connect):
super(Postgre, self)._initialize_(do_connect)
ruri = self.uri.split('://', 1)[1]
m = self.REGEX_URI.match(ruri)
if not m:
raise SyntaxError("Invalid URI string in DAL")
user = self.credential_decoder(m.group('user'))
if not user:
raise SyntaxError('User required')
password = self.credential_decoder(m.group('password'))
if not password:
password = ''
host = m.group('host')
if not host:
raise SyntaxError('Host name required')
db = m.group('db')
if not db:
raise SyntaxError('Database name required')
port = m.group('port') or '5432'
self.dsn = (
'jdbc:postgresql://%s:%s/%s' % (host, port, db), user, password)
# choose diver according uri
if self.driver:
self.__version__ = "%s %s" % (self.driver.__name__,
self.driver.__version__)
else:
self.__version__ = None
THREAD_LOCAL._pydal_last_insert_ = None
[docs] def connector(self):
return self.driver.connect(*self.dsn, **self.driver_args)
[docs] def after_connection(self):
self.connection.set_client_encoding('UTF8')
self.execute('BEGIN;')
self.execute("SET CLIENT_ENCODING TO 'UNICODE';")
self._config_json()
def _config_json(self):
use_json = self.connection.dbversion >= "9.2.0"
if use_json:
self.dialect = self._get_json_dialect()(self)