Source code for pydal.adapters.postgres

import re
from .._compat import PY2, with_metaclass, iterkeys, to_unicode
from .._globals import IDENTITY, THREAD_LOCAL
from ..drivers import psycopg2_adapt
from .base import SQLAdapter
from . import AdapterMeta, adapters, with_connection, with_connection_or_raise


[docs]class PostgreMeta(AdapterMeta): def __call__(cls, *args, **kwargs): if cls != Postgre: return AdapterMeta.__call__(cls, *args, **kwargs) available_drivers = [ driver for driver in cls.drivers if driver in iterkeys(kwargs['db']._drivers_available)] uri_items = kwargs['uri'].split('://', 1)[0].split(':') uri_driver = uri_items[1] if len(uri_items) > 1 else None if uri_driver and uri_driver in available_drivers: driver = uri_driver else: driver = available_drivers[0] if available_drivers else \ cls.drivers[0] if driver == 'psycopg2': cls = PostgrePsyco else: cls = PostgrePG8000 return AdapterMeta.__call__(cls, *args, **kwargs)
@adapters.register_for('postgres')
[docs]class Postgre(with_metaclass(PostgreMeta, SQLAdapter)): dbengine = 'postgres' drivers = ('psycopg2', 'pg8000') support_distributed_transaction = True REGEX_URI = re.compile( '^(?P<user>[^:@]+)(\:(?P<password>[^@]*))?@(?P<host>\[[^/]+\]|' + '[^\:@]+)(\:(?P<port>[0-9]+))?/(?P<db>[^\?]+)' + '(\?sslmode=(?P<sslmode>.+))?$') def __init__(self, db, uri, pool_size=0, folder=None, db_codec='UTF-8', credential_decoder=IDENTITY, driver_args={}, adapter_args={}, do_connect=True, srid=4326, after_connection=None): self.srid = srid super(Postgre, self).__init__( db, uri, pool_size, folder, db_codec, credential_decoder, driver_args, adapter_args, do_connect, after_connection) def _initialize_(self, do_connect): super(Postgre, self)._initialize_(do_connect) ruri = self.uri.split('://', 1)[1] m = self.REGEX_URI.match(ruri) if not m: raise SyntaxError("Invalid URI string in DAL") user = self.credential_decoder(m.group('user')) if not user: raise SyntaxError('User required') password = self.credential_decoder(m.group('password')) if not password: password = '' host = m.group('host') if not host: raise SyntaxError('Host name required') db = m.group('db') if not db: raise SyntaxError('Database name required') port = m.group('port') or '5432' sslmode = m.group('sslmode') self.driver_args['database'] = db self.driver_args['user'] = user self.driver_args['host'] = host self.driver_args['port'] = int(port) self.driver_args['password'] = password if sslmode: self.driver_args['sslmode'] = sslmode # choose diver according uri if self.driver: self.__version__ = "%s %s" % (self.driver.__name__, self.driver.__version__) else: self.__version__ = None THREAD_LOCAL._pydal_last_insert_ = None def _get_json_dialect(self): from ..dialects.postgre import PostgreDialectJSON return PostgreDialectJSON def _get_json_parser(self): from ..parsers.postgre import PostgreAutoJSONParser return PostgreAutoJSONParser @property def _last_insert(self): return THREAD_LOCAL._pydal_last_insert_ @_last_insert.setter def _last_insert(self, value): THREAD_LOCAL._pydal_last_insert_ = value
[docs] def connector(self): return self.driver.connect(**self.driver_args)
[docs] def after_connection(self): self.execute("SET CLIENT_ENCODING TO 'UTF8'") self.execute("SET standard_conforming_strings=on;") self._config_json()
[docs] def lastrowid(self, table=None): if self._last_insert: return int(self.cursor.fetchone()[0]) self.execute("select lastval()") return int(self.cursor.fetchone()[0])
def _insert(self, table, fields): self._last_insert = None if fields: retval = None if hasattr(table, '_id'): self._last_insert = (table._id, 1) retval = table._id.name return self.dialect.insert( table.sqlsafe, ','.join(el[0].sqlsafe_name for el in fields), ','.join(self.expand(v, f.type) for f, v in fields), retval) return self.dialect.insert_empty(table.sqlsafe) @with_connection def prepare(self, key): self.execute("PREPARE TRANSACTION '%s';" % key) @with_connection def commit_prepared(self, key): self.execute("COMMIT PREPARED '%s';" % key) @with_connection def rollback_prepared(self, key): self.execute("ROLLBACK PREPARED '%s';" % key)
@adapters.register_for('postgres:psycopg2')
[docs]class PostgrePsyco(Postgre): drivers = ('psycopg2',) def _config_json(self): use_json = self.driver.__version__ >= "2.0.12" and \ self.connection.server_version >= 90200 if use_json: self.dialect = self._get_json_dialect()(self) if self.driver.__version__ >= '2.5.0': self.parser = self._get_json_parser()(self)
[docs] def adapt(self, obj): adapted = psycopg2_adapt(obj) # deal with new relic Connection Wrapper (newrelic>=2.10.0.8) cxn = getattr(self.connection, '__wrapped__', self.connection) adapted.prepare(cxn) rv = adapted.getquoted() if not PY2: if isinstance(rv, bytes): return rv.decode('utf-8') return rv
@adapters.register_for('postgres:pg8000')
[docs]class PostgrePG8000(Postgre): drivers = ('pg8000',) def _config_json(self): if self.connection._server_version >= "9.2.0": self.dialect = self._get_json_dialect()(self) if self.driver.__version__ >= '1.10.2': self.parser = self._get_json_parser()(self)
[docs] def adapt(self, obj): return "'%s'" % obj.replace("%", "%%").replace("'", "''")
@with_connection_or_raise def execute(self, *args, **kwargs): if PY2: args = list(args) args[0] = to_unicode(args[0]) return super(PostgrePG8000, self).execute(*args, **kwargs)
@adapters.register_for('postgres2')
[docs]class PostgreNew(Postgre): def _get_json_dialect(self): from ..dialects.postgre import PostgreDialectArraysJSON return PostgreDialectArraysJSON def _get_json_parser(self): from ..parsers.postgre import PostgreNewAutoJSONParser return PostgreNewAutoJSONParser
@adapters.register_for('postgres2:psycopg2')
[docs]class PostgrePsycoNew(PostgrePsyco): pass
@adapters.register_for('postgres2:pg8000')
[docs]class PostgrePG8000New(PostgrePG8000): pass
@adapters.register_for('jdbc:postgres')
[docs]class JDBCPostgre(Postgre): drivers = ('zxJDBC',) REGEX_URI = re.compile( '^(?P<user>[^:@]+)(\:(?P<password>[^@]*))?@(?P<host>\[[^/]+\]|' + '[^\:/]+)(\:(?P<port>[0-9]+))?/(?P<db>.+)$') def _initialize_(self, do_connect): super(Postgre, self)._initialize_(do_connect) ruri = self.uri.split('://', 1)[1] m = self.REGEX_URI.match(ruri) if not m: raise SyntaxError("Invalid URI string in DAL") user = self.credential_decoder(m.group('user')) if not user: raise SyntaxError('User required') password = self.credential_decoder(m.group('password')) if not password: password = '' host = m.group('host') if not host: raise SyntaxError('Host name required') db = m.group('db') if not db: raise SyntaxError('Database name required') port = m.group('port') or '5432' self.dsn = ( 'jdbc:postgresql://%s:%s/%s' % (host, port, db), user, password) # choose diver according uri if self.driver: self.__version__ = "%s %s" % (self.driver.__name__, self.driver.__version__) else: self.__version__ = None THREAD_LOCAL._pydal_last_insert_ = None
[docs] def connector(self): return self.driver.connect(*self.dsn, **self.driver_args)
[docs] def after_connection(self): self.connection.set_client_encoding('UTF8') self.execute('BEGIN;') self.execute("SET CLIENT_ENCODING TO 'UNICODE';") self._config_json()
def _config_json(self): use_json = self.connection.dbversion >= "9.2.0" if use_json: self.dialect = self._get_json_dialect()(self)