postgresql.py

#
#

Copyright 2014-2020 Álvaro Justen https://github.com/turicas/rows/

#

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

#

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.

#

You should have received a copy of the GNU Lesser General Public License along with this program. If not, see http://www.gnu.org/licenses/.

from __future__ import unicode_literals

import string

import six
from psycopg2 import connect as pgconnect

import rows.fields as fields
from rows.plugins.utils import (
    create_table,
    ipartition,
    make_unique_name,
    prepare_to_export,
)
from rows.utils import Source

SQL_TABLE_NAMES = """
    SELECT
        tablename
    FROM pg_tables
    WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
"""
SQL_CREATE_TABLE = "CREATE TABLE IF NOT EXISTS " '"{table_name}" ({field_types})'
SQL_SELECT_ALL = 'SELECT * FROM "{table_name}"'
SQL_INSERT = 'INSERT INTO "{table_name}" ({field_names}) ' "VALUES ({placeholders})"
SQL_TYPES = {
    fields.BinaryField: "BYTEA",
    fields.BoolField: "BOOLEAN",
    fields.DateField: "DATE",
    fields.DatetimeField: "TIMESTAMP(0) WITHOUT TIME ZONE",
    fields.DecimalField: "NUMERIC",
    fields.FloatField: "REAL",
    fields.IntegerField: "INTEGER",
    fields.JSONField: "JSONB",
    fields.PercentField: "REAL",
    fields.TextField: "TEXT",
    fields.UUIDField: "UUID",
}
DEFAULT_TYPE = "BYTEA"
#

TODO: unify this and rows.utils.POSTGRESQL_TYPES

#
def _python_to_postgresql(field_types):
#
    def convert_value(field_type, value):
        if field_type in (
            fields.BinaryField,
            fields.BoolField,
            fields.DateField,
            fields.DatetimeField,
            fields.DecimalField,
            fields.FloatField,
            fields.IntegerField,
            fields.PercentField,
            fields.TextField,
            fields.JSONField,
        ):
            return value

        else:  # don't know this field
            return field_type.serialize(value)
#
    def convert_row(row):
        return [
            convert_value(field_type, value)
            for field_type, value in zip(field_types, row)
        ]

    return convert_row
#
def get_source(connection_or_uri):

    if isinstance(connection_or_uri, (six.binary_type, six.text_type)):
        connection = pgconnect(connection_or_uri)
        uri = connection_or_uri
        input_is_uri = should_close = True
    else:  # already a connection
        connection = connection_or_uri
        uri = None
        input_is_uri = should_close = False
#

TODO: may improve Source for non-fobj cases (when open() is not needed)

    source = Source.from_file(connection, plugin_name="postgresql", mode=None, is_file=False, local=False, should_close=should_close)
    source.uri = uri if input_is_uri else None

    return source
#

Verify if a given table name is valid for rows

def _valid_table_name(name):
#

Rules: - Should start with a letter or ‘_’ - Letters can be capitalized or not - Accepts letters, numbers and _

    if name[0] not in "_" + string.ascii_letters or not set(name).issubset(
        "_" + string.ascii_letters + string.digits
    ):
        return False

    else:
        return True
#
def import_from_postgresql(
    connection_or_uri,
    table_name="table1",
    query=None,
    query_args=None,
    close_connection=None,
    *args,
    **kwargs
):

    if query is None:
        if not _valid_table_name(table_name):
            raise ValueError("Invalid table name: {}".format(table_name))

        query = SQL_SELECT_ALL.format(table_name=table_name)

    if query_args is None:
        query_args = tuple()

    source = get_source(connection_or_uri)
    connection = source.fobj

    cursor = connection.cursor()
    cursor.execute(query, query_args)
    table_rows = list(cursor.fetchall())  # TODO: make it lazy
    header = [six.text_type(info[0]) for info in cursor.description]
    cursor.close()
    connection.commit()  # WHY?

    meta = {"imported_from": "postgresql", "source": source}
    if close_connection or (close_connection is None and source.should_close):
        connection.close()
    return create_table([header] + table_rows, meta=meta, *args, **kwargs)
#
def export_to_postgresql(
    table,
    connection_or_uri,
    table_name=None,
    table_name_format="table{index}",
    batch_size=100,
    close_connection=None,
    *args,
    **kwargs
):
#

TODO: should add transaction support?

    if table_name is not None and not _valid_table_name(table_name):
        raise ValueError("Invalid table name: {}".format(table_name))

    source = get_source(connection_or_uri)
    connection = source.fobj
    cursor = connection.cursor()
    if table_name is None:
        cursor.execute(SQL_TABLE_NAMES)
        table_names = [item[0] for item in cursor.fetchall()]
        table_name = make_unique_name(
            table.name,
            existing_names=table_names,
            name_format=table_name_format,
            start=1,
        )

    prepared_table = prepare_to_export(table, *args, **kwargs)
#

TODO: use same code/logic of CREATE TABLE as rows.utils.pg_create_table_sql

    field_names = next(prepared_table)
    field_types = list(map(table.fields.get, field_names))
    columns = [
        "{} {}".format(field_name, SQL_TYPES.get(field_type, DEFAULT_TYPE))
        for field_name, field_type in zip(field_names, field_types)
    ]
    cursor.execute(
        SQL_CREATE_TABLE.format(table_name=table_name, field_types=", ".join(columns))
    )

    insert_sql = SQL_INSERT.format(
        table_name=table_name,
        field_names=", ".join(field_names),
        placeholders=", ".join("%s" for _ in field_names),
    )
    _convert_row = _python_to_postgresql(field_types)
    for batch in ipartition(prepared_table, batch_size):
        cursor.executemany(insert_sql, map(_convert_row, batch))

    connection.commit()
    cursor.close()
    if close_connection or (close_connection is None and source.should_close):
        connection.close()
    return connection, table_name