utils.py

#
#

Copyright 2014-2020 Álvaro Justen https://github.com/turicas/rows/

#

This program is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.

#

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.

#

You should have received a copy of the GNU Lesser General Public License along with this program. If not, see http://www.gnu.org/licenses/.

from __future__ import unicode_literals

import cgi
import csv
import gzip
import io
import itertools
import mimetypes
import os
import re
import shlex
import sqlite3
import subprocess
import tempfile
from collections import OrderedDict
from dataclasses import dataclass
from itertools import islice
from pathlib import Path
from textwrap import dedent

import six

try:
    import requests
except ImportError:
    requests = None
try:
    from tqdm import tqdm
except ImportError:
    tqdm = None

import rows
from rows.plugins.utils import make_header, slug

try:
    import lzma
except ImportError:
    lzma = None
try:
    import bz2
except ImportError:
    bz2 = None

try:
    from urlparse import urlparse  # Python 2
except ImportError:
    from urllib.parse import urlparse  # Python 3

try:
    import magic
except (ImportError, TypeError):
    magic = None
else:
    if not hasattr(magic, "detect_from_content"):
#

This is not the file-magic library

        magic = None

if requests:
    chardet = requests.compat.chardet
else:
    chardet = None
try:
    import urllib3
except ImportError:
    from requests.packages import urllib3
else:
    try:
        urllib3.disable_warnings()
    except AttributeError:
#

old versions of urllib3 or requests

        pass
#

TODO: should get this information from the plugins

COMPRESSED_EXTENSIONS = ("gz", "xz", "bz2")
TEXT_PLAIN = {
    "txt": "text/txt",
    "text": "text/txt",
    "csv": "text/csv",
    "json": "application/json",
}
OCTET_STREAM = {
    "microsoft ooxml": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    "par archive data": "application/parquet",
}
FILE_EXTENSIONS = {
    "csv": "text/csv",
    "db": "application/x-sqlite3",
    "htm": "text/html",
    "html": "text/html",
    "json": "application/json",
    "ods": "application/vnd.oasis.opendocument.spreadsheet",
    "parquet": "application/parquet",
    "sqlite": "application/x-sqlite3",
    "text": "text/txt",
    "tsv": "text/csv",
    "txt": "text/txt",
    "xls": "application/vnd.ms-excel",
    "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
    "pdf": "application/pdf",
}
MIME_TYPE_TO_PLUGIN_NAME = {
    "application/json": "json",
    "application/parquet": "parquet",
    "application/vnd.ms-excel": "xls",
    "application/vnd.oasis.opendocument.spreadsheet": "ods",
    "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
    "application/x-sqlite3": "sqlite",
    "text/csv": "csv",
    "text/html": "html",
    "text/txt": "txt",
    "application/pdf": "pdf",
}
regexp_sizes = re.compile("([0-9,.]+ [a-zA-Z]+B)")
MULTIPLIERS = {"B": 1, "KiB": 1024, "MiB": 1024 ** 2, "GiB": 1024 ** 3}
POSTGRESQL_TYPES = {
    rows.fields.BinaryField: "BYTEA",
    rows.fields.BoolField: "BOOLEAN",
    rows.fields.DateField: "DATE",
    rows.fields.DatetimeField: "TIMESTAMP(0) WITHOUT TIME ZONE",
    rows.fields.DecimalField: "NUMERIC",
    rows.fields.FloatField: "REAL",
    rows.fields.IntegerField: "BIGINT",  # TODO: detect when it's really needed
    rows.fields.JSONField: "JSONB",
    rows.fields.PercentField: "REAL",
    rows.fields.TextField: "TEXT",
    rows.fields.UUIDField: "UUID",
}
DEFAULT_POSTGRESQL_TYPE = "BYTEA"
SQL_CREATE_TABLE = "CREATE {pre_table}TABLE{post_table} " '"{table_name}" ({field_types})'
#
class ProgressBar:
#
    def __init__(self, prefix, pre_prefix="", total=None, unit=" rows"):
        self.prefix = prefix
        self.progress = tqdm(
            desc=pre_prefix, total=total, unit=unit, unit_scale=True, dynamic_ncols=True
        )
        self.started = False
#
    def __enter__(self):
        return self
#
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()
#

Create a Source from a filename or fobj

    @property
    def description(self):
        return self.progress.desc

    @description.setter
    def description(self, value):
        self.progress.desc = value
        self.progress.refresh()

    @property
    def total(self):
        return self.progress.total

    @total.setter
    def total(self, value):
        self.progress.total = value
        self.progress.refresh()

    def update(self, last_done=1, total_done=None):
        if not last_done and not total_done:
            raise ValueError("Either last_done or total_done must be specified")

        if not self.started:
            self.started = True
            self.progress.desc = self.prefix
            self.progress.unpause()

        if last_done:
            self.progress.n += last_done
        else:
            self.progress.n = total_done
        self.progress.refresh()

    def close(self):
        self.progress.close()


@dataclass
class Source(object):
    "Define a source to import a `rows.Table`"

    uri: (str, Path)
    plugin_name: str
    encoding: str
    fobj: object = None
    compressed: bool = None
    should_delete: bool = False
    should_close: bool = False
    is_file: bool = None
    local: bool = None

    @classmethod
    def from_file(
        cls,
        filename_or_fobj,
        plugin_name=None,
        encoding=None,
        mode="rb",
        compressed=None,
        should_delete=False,
        should_close=None,
        is_file=True,
        local=True,
    ):
#
        if isinstance(filename_or_fobj, Source):
            return filename_or_fobj

        elif isinstance(filename_or_fobj, (six.binary_type, six.text_type, Path)):
            fobj = open_compressed(filename_or_fobj, mode=mode)
            filename = filename_or_fobj
            should_close = True if should_close is None else should_close

        else:  # Don't know exactly what is, assume file-like object
            fobj = filename_or_fobj
            filename = getattr(fobj, "name", None)
            if not isinstance(
                filename, (six.binary_type, six.text_type)
            ):  # BytesIO object
                filename = None
            should_close = False if should_close is None else should_close

        if is_file and local and filename and not isinstance(filename, Path):
            filename = Path(filename)

        return Source(
            compressed=compressed,
            encoding=encoding,
            fobj=fobj,
            is_file=is_file,
            local=local,
            plugin_name=plugin_name,
            should_close=should_close,
            should_delete=should_delete,
            uri=filename,
        )
#
def plugin_name_by_uri(uri):
    "Return the plugin name based on the URI"
#

TODO: parse URIs like ‘sqlite://’ also TODO: integrate this function with detect_source

    parsed = urlparse(uri)
    if parsed.scheme:
        if parsed.scheme == "sqlite":
            return "sqlite"
        elif parsed.scheme == "postgres":
            return "postgresql"

    basename = os.path.basename(parsed.path)
    if not basename.strip():
        raise RuntimeError("Could not identify file format.")

    extension = basename.split(".")[-1].lower()
    if extension in COMPRESSED_EXTENSIONS:
        extension = basename.split(".")[-2].lower()

    plugin_name = extension
    if extension in FILE_EXTENSIONS:
        plugin_name = MIME_TYPE_TO_PLUGIN_NAME[FILE_EXTENSIONS[plugin_name]]

    return plugin_name
#
def extension_by_source(source, mime_type):
    "Return the file extension used by this plugin"
#

TODO: should get this information from the plugin

    extension = source.plugin_name
    if extension:
        return extension

    if mime_type:
        return mime_type.split("/")[-1]
#
def normalize_mime_type(mime_type, mime_name, file_extension):

    file_extension = file_extension.lower() if file_extension else ""
    mime_name = mime_name.lower() if mime_name else ""
    mime_type = mime_type.lower() if mime_type else ""

    if mime_type == "text/plain" and file_extension in TEXT_PLAIN:
        return TEXT_PLAIN[file_extension]

    elif mime_type == "application/octet-stream" and mime_name in OCTET_STREAM:
        return OCTET_STREAM[mime_name]

    elif file_extension in FILE_EXTENSIONS:
        return FILE_EXTENSIONS[file_extension]

    else:
        return mime_type
#
def plugin_name_by_mime_type(mime_type, mime_name, file_extension):
    "Return the plugin name based on the MIME type"

    return MIME_TYPE_TO_PLUGIN_NAME.get(
        normalize_mime_type(mime_type, mime_name, file_extension), None
    )
#
def detect_local_source(path, content, mime_type=None, encoding=None):
#

TODO: may add sample_size

    filename = os.path.basename(path)
    parts = filename.split(".")
    extension = parts[-1].lower() if len(parts) > 1 else None
    if extension in COMPRESSED_EXTENSIONS:
        extension = parts[-2].lower() if len(parts) > 2 else None

    if magic is not None:
        detected = magic.detect_from_content(content)
        encoding = detected.encoding or encoding
        mime_name = detected.name
        mime_type = detected.mime_type or mime_type

    else:
        if chardet and not encoding:
            encoding = chardet.detect(content)["encoding"] or encoding
        mime_name = None
        mime_type = mime_type or mimetypes.guess_type(filename)[0]

    plugin_name = plugin_name_by_mime_type(mime_type, mime_name, extension)
    if encoding == "binary":
        encoding = None

    return Source(uri=path, plugin_name=plugin_name, encoding=encoding)
#
def local_file(path, sample_size=1048576):
#

TODO: may change sample_size

    if path.split(".")[-1].lower() in COMPRESSED_EXTENSIONS:
        compressed = True
        fobj = open_compressed(path, mode="rb")
        content = fobj.read(sample_size)
        fobj.close()
    else:
        compressed = False
        with open(path, "rb") as fobj:
            content = fobj.read(sample_size)

    source = detect_local_source(path, content, mime_type=None, encoding=None)

    return Source(
        uri=path,
        plugin_name=source.plugin_name,
        encoding=source.encoding,
        compressed=compressed,
        should_delete=False,
        is_file=True,
        local=True,
    )
#
def download_file(
    uri,
    filename=None,
    verify_ssl=True,
    timeout=5,
    progress=False,
    detect=False,
    chunk_size=8192,
    sample_size=1048576,
):
#

TODO: add ability to continue download

    response = requests.get(
        uri,
        verify=verify_ssl,
        timeout=timeout,
        stream=True,
        headers={"user-agent": "rows-{}".format(rows.__version__)},
    )
    if not response.ok:
        raise RuntimeError("HTTP response: {}".format(response.status_code))
#

Get data from headers (if available) to help plugin + encoding detection

    real_filename, encoding, mime_type = uri, None, None
    headers = response.headers
    if "content-type" in headers:
        mime_type, options = cgi.parse_header(headers["content-type"])
        encoding = options.get("charset", encoding)
    if "content-disposition" in headers:
        _, options = cgi.parse_header(headers["content-disposition"])
        real_filename = options.get("filename", real_filename)

    if progress:
        total = response.headers.get("content-length", None)
        total = int(total) if total else None
        progress_bar = ProgressBar(prefix="Downloading file", total=total, unit="bytes")
    if filename is None:
        tmp = tempfile.NamedTemporaryFile(delete=False)
        fobj = open_compressed(tmp.name, mode="wb")
    else:
        fobj = open_compressed(filename, mode="wb")
    sample_data = b""
    for data in response.iter_content(chunk_size=chunk_size):
        fobj.write(data)
        if detect and len(sample_data) <= sample_size:
            sample_data += data
        if progress:
            progress_bar.update(len(data))
    fobj.close()
    if progress:
        progress_bar.close()
#

Detect file type and rename temporary file to have the correct extension

    if detect:
#

TODO: check if will work for compressed files

        source = detect_local_source(real_filename, sample_data, mime_type, encoding)
        extension = extension_by_source(source, mime_type)
        plugin_name = source.plugin_name
        encoding = source.encoding
    else:
        extension, plugin_name, encoding = None, None, None
        if mime_type:
            extension = mime_type.split("/")[-1]

    if filename is None:
        filename = tmp.name
        if extension:
            filename += "." + extension
        os.rename(tmp.name, filename)

    return Source(
        uri=filename,
        plugin_name=plugin_name,
        encoding=encoding,
        should_delete=True,
        is_file=True,
        local=False,
    )
#

Return a rows.Source with information for a given URI

def detect_source(uri, verify_ssl, progress, timeout=5):
#

If URI starts with “http” or “https” the file will be downloaded.

This function should only be used if the URI already exists because it’s going to download/open the file to detect its encoding and MIME type.

#

TODO: should also supporte other schemes, like file://, sqlite:// etc.

    if uri.lower().startswith("http://") or uri.lower().startswith("https://"):
        return download_file(
            uri, verify_ssl=verify_ssl, timeout=timeout, progress=progress, detect=True
        )

    elif uri.startswith("postgres://"):
        return Source(
            should_delete=False,
            encoding=None,
            plugin_name="postgresql",
            uri=uri,
            is_file=False,
            local=None,
        )
    else:
        return local_file(uri)
#
def import_from_source(source, default_encoding, *args, **kwargs):
    "Import data described in a `rows.Source` into a `rows.Table`"
#

TODO: test open_compressed

    plugin_name = source.plugin_name
    kwargs["encoding"] = (
        kwargs.get("encoding", None) or source.encoding or default_encoding
    )

    try:
        import_function = getattr(rows, "import_from_{}".format(plugin_name))
    except AttributeError:
        raise ValueError('Plugin (import) "{}" not found'.format(plugin_name))

    table = import_function(source.uri, *args, **kwargs)

    return table
#
def import_from_uri(
    uri, default_encoding="utf-8", verify_ssl=True, progress=False, *args, **kwargs
):
    "Given an URI, detects plugin and encoding and imports into a `rows.Table`"
#

TODO: support ‘-‘ also TODO: (optimization) if kwargs.get('encoding', None) is not None we can skip encoding detection.

    source = detect_source(uri, verify_ssl=verify_ssl, progress=progress)
    return import_from_source(source, default_encoding, *args, **kwargs)
#
def export_to_uri(table, uri, *args, **kwargs):
    "Given a `rows.Table` and an URI, detects plugin (from URI) and exports"
#

TODO: support ‘-‘ also

    plugin_name = plugin_name_by_uri(uri)

    try:
        export_function = getattr(rows, "export_to_{}".format(plugin_name))
    except AttributeError:
        raise ValueError('Plugin (export) "{}" not found'.format(plugin_name))

    return export_function(table, uri, *args, **kwargs)
#

Return a text-based file object from a filename, even if compressed

def open_compressed(
    filename,
    mode="r",
    buffering=-1,
    encoding=None,
    errors=None,
    newline=None,
    closefd=True,
    opener=None,
):
#

NOTE: if the file is compressed, options like buffering are valid to the compressed file-object (not the uncompressed file-object returned).

    binary_mode = "b" in mode
    if not binary_mode and "t" not in mode:
#

For some reason, passing only mode=’r’ to bzip2 is equivalent to ‘rb’, not ‘rt’, so we force it here.

        mode += "t"
    if binary_mode and encoding:
        raise ValueError("encoding should not be specified in binary mode")

    extension = str(filename).split(".")[-1].lower()
    mode_binary = mode.replace("t", "b")
    get_fobj_binary = lambda: open(
        filename,
        mode=mode_binary,
        buffering=buffering,
        errors=errors,
        newline=newline,
        closefd=closefd,
        opener=opener,
    )
    get_fobj_text = lambda: open(
        filename,
        mode=mode,
        buffering=buffering,
        encoding=encoding,
        errors=errors,
        newline=newline,
        closefd=closefd,
        opener=opener,
    )
    known_extensions = ("xz", "gz", "bz2")

    if extension not in known_extensions:  # No compression
        if binary_mode:
            return get_fobj_binary()
        else:
            return get_fobj_text()

    elif extension == "xz":
        if lzma is None:
            raise RuntimeError("lzma support is not installed")
        fobj_binary = lzma.LZMAFile(get_fobj_binary(), mode=mode_binary)

    elif extension == "gz":
        fobj_binary = gzip.GzipFile(fileobj=get_fobj_binary())

    elif extension == "bz2":
        if bz2 is None:
            raise RuntimeError("bzip2 support is not installed")
        fobj_binary = bz2.BZ2File(get_fobj_binary(), mode=mode_binary)

    if binary_mode:
        return fobj_binary
    else:
        return io.TextIOWrapper(fobj_binary, encoding=encoding)
#
def csv_to_sqlite(
    input_filename,
    output_filename,
    samples=None,
    dialect=None,
    batch_size=10000,
    encoding="utf-8",
    callback=None,
    force_types=None,
    chunk_size=8388608,
    table_name="table1",
    schema=None,
):
    "Export a CSV file to SQLite, based on field type detection from samples"
#

TODO: automatically detect encoding if encoding == None TODO: should be able to specify fields TODO: if schema is provided and the names are in uppercase, this function will fail

    if dialect is None:  # Get a sample to detect dialect
        fobj = open_compressed(input_filename, mode="rb")
        sample = fobj.read(chunk_size)
        fobj.close()
        dialect = rows.plugins.csv.discover_dialect(sample, encoding=encoding)
    elif isinstance(dialect, six.text_type):
        dialect = csv.get_dialect(dialect)

    if schema is None:  # Identify data types
        fobj = open_compressed(input_filename, encoding=encoding)
        data = list(islice(csv.DictReader(fobj, dialect=dialect), samples))
        fobj.close()
        schema = rows.import_from_dicts(data).fields
        if force_types is not None:
            schema.update(force_types)
#

Create lazy table object to be converted TODO: this lazyness feature will be incorported into the library soon so we can call here rows.import_from_csv instead of csv.reader.

    fobj = open_compressed(input_filename, encoding=encoding)
    csv_reader = csv.reader(fobj, dialect=dialect)
    header = make_header(next(csv_reader))  # skip header
    table = rows.Table(fields=OrderedDict([(field, schema[field]) for field in header]))
    table._rows = csv_reader
#

Export to SQLite

    result = rows.export_to_sqlite(
        table,
        output_filename,
        table_name=table_name,
        batch_size=batch_size,
        callback=callback,
    )
    fobj.close()
    return result
#

Export a table inside a SQLite database to CSV

def sqlite_to_csv(
    input_filename,
    table_name,
    output_filename,
    dialect=csv.excel,
    batch_size=10000,
    encoding="utf-8",
    callback=None,
    query=None,
):
#
#

TODO: should be able to specify fields TODO: should be able to specify custom query

    if isinstance(dialect, six.text_type):
        dialect = csv.get_dialect(dialect)

    if query is None:
        query = "SELECT * FROM {}".format(table_name)
    connection = sqlite3.Connection(input_filename)
    cursor = connection.cursor()
    result = cursor.execute(query)
    header = [item[0] for item in cursor.description]
    fobj = open_compressed(output_filename, mode="w", encoding=encoding)
    writer = csv.writer(fobj, dialect=dialect)
    writer.writerow(header)
    total_written = 0
    for batch in rows.plugins.utils.ipartition(result, batch_size):
        writer.writerows(batch)
        written = len(batch)
        total_written += written
        if callback:
            callback(written, total_written)
    fobj.close()
#

Lazy CSV dict writer, with compressed output option

class CsvLazyDictWriter:
#

This class is almost the same as csv.DictWriter with the following differences:

  • You don’t need to pass fieldnames (it’s extracted on the first .writerow call);
  • You can pass either a filename or a fobj (like sys.stdout);
  • If passing a filename, it can end with .gz, .xz or .bz2 and the output file will be automatically compressed.
#
    def __init__(self, filename_or_fobj, encoding="utf-8", *args, **kwargs):
        self.writer = None
        self.filename_or_fobj = filename_or_fobj
        self.encoding = encoding
        self._fobj = None
        self.writer_args = args
        self.writer_kwargs = kwargs
        self.writer_kwargs["lineterminator"] = kwargs.get("lineterminator", "\n")
#

TODO: check if it should be the same in other OSes

#
    def __enter__(self):
        return self
#
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()
#

Execute a command and return its output

    @property
    def fobj(self):
        if self._fobj is None:
            if getattr(self.filename_or_fobj, "read", None) is not None:
                self._fobj = self.filename_or_fobj
            else:
                self._fobj = open_compressed(
                    self.filename_or_fobj, mode="w", encoding=self.encoding
                )

        return self._fobj

    def writerow(self, row):
        if self.writer is None:
            self.writer = csv.DictWriter(
                self.fobj,
                fieldnames=list(row.keys()),
                *self.writer_args,
                **self.writer_kwargs
            )
            self.writer.writeheader()

        self.writerow = self.writer.writerow
        return self.writerow(row)

    def __del__(self):
        self.close()

    def close(self):
        if self._fobj and not self._fobj.closed:
            self._fobj.close()


def execute_command(command):
#
    command = shlex.split(command)
    try:
        process = subprocess.Popen(
            command,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
    except FileNotFoundError:
        raise RuntimeError("Command not found: {}".format(repr(command)))
    process.wait()
#

TODO: may use another codec to decode

    if process.returncode > 0:
        stderr = process.stderr.read().decode("utf-8")
        raise ValueError("Error executing command: {}".format(repr(stderr)))
    data = process.stdout.read().decode("utf-8")
    process.stdin.close()
    process.stdout.close()
    process.stderr.close()
    process.wait()
    return data
#

Return the uncompressed size for a file by executing commands

def uncompressed_size(filename):
#

Note: due to a limitation in gzip format, uncompressed files greather than 4GiB will have a wrong value.

    quoted_filename = shlex.quote(filename)
#

TODO: get filetype from file-magic, if available

    if str(filename).lower().endswith(".xz"):
        output = execute_command('xz --list "{}"'.format(quoted_filename))
        compressed, uncompressed = regexp_sizes.findall(output)
        value, unit = uncompressed.split()
        value = float(value.replace(",", ""))
        return int(value * MULTIPLIERS[unit])

    elif str(filename).lower().endswith(".gz"):
#

XXX: gzip only uses 32 bits to store uncompressed size, so if the uncompressed size is greater than 4GiB, the value returned will be incorrect.

        output = execute_command('gzip --list "{}"'.format(quoted_filename))
        lines = [line.split() for line in output.splitlines()]
        header, data = lines[0], lines[1]
        gzip_data = dict(zip(header, data))
        return int(gzip_data["uncompressed"])

    else:
        raise ValueError('Unrecognized file type for "{}".'.format(filename))
#

TODO: move all PostgreSQL-related utils to rows/plugins/postgresql.py

#
def get_psql_command(
    command,
    user=None,
    password=None,
    host=None,
    port=None,
    database_name=None,
    database_uri=None,
):

    if database_uri is None:
        if None in (user, password, host, port, database_name):
            raise ValueError(
                "Need to specify either `database_uri` or the complete information"
            )

        database_uri = "postgres://{user}:{password}@{host}:{port}/{name}".format(
            user=user, password=password, host=host, port=port, name=database_name
        )

    return "psql -c {} {}".format(shlex.quote(command), shlex.quote(database_uri))
#
def get_psql_copy_command(
    table_name_or_query,
    header,
    encoding="utf-8",
    user=None,
    password=None,
    host=None,
    port=None,
    database_name=None,
    database_uri=None,
    is_query=False,
    dialect=csv.excel,
    direction="FROM",
):

    direction = direction.upper()
    if direction not in ("FROM", "TO"):
        raise ValueError('`direction` must be "FROM" or "TO"')

    if not is_query:  # Table name
        source = table_name_or_query
    else:
        source = "(" + table_name_or_query + ")"
    if header is None:
        header = ""
    else:
        header = ", ".join(slug(field_name) for field_name in header)
        header = "({header}) ".format(header=header)
    copy = (
        r"\copy {source} {header}{direction} STDIN WITH("
        "DELIMITER '{delimiter}', "
        "QUOTE '{quote}', "
    )
    if direction == "FROM":
        copy += "FORCE_NULL {header}, "
    copy += "ENCODING '{encoding}', " "FORMAT CSV, HEADER);"

    copy_command = copy.format(
        source=source,
        header=header,
        direction=direction,
        delimiter=dialect.delimiter.replace("'", "''"),
        quote=dialect.quotechar.replace("'", "''"),
        encoding=encoding,
    )

    return get_psql_command(
        copy_command,
        user=user,
        password=password,
        host=host,
        port=port,
        database_name=database_name,
        database_uri=database_uri,
    )
#
def pg_create_table_sql(schema, table_name, unlogged=False):
    field_names = list(schema.keys())
    field_types = list(schema.values())

    columns = [
        "{} {}".format(name, POSTGRESQL_TYPES.get(type_, DEFAULT_POSTGRESQL_TYPE))
        for name, type_ in zip(field_names, field_types)
    ]
    return SQL_CREATE_TABLE.format(
        pre_table="" if not unlogged else "UNLOGGED ",
        post_table=" IF NOT EXISTS",
        table_name=table_name, field_types=", ".join(columns),
    )
#
def pg_execute_psql(database_uri, sql):
    return execute_command(
        get_psql_command(sql, database_uri=database_uri)
    )
#

Import data from CSV into PostgreSQL using the fastest method

def pgimport(
    filename,
    database_uri,
    table_name,
    encoding=None,
    dialect=None,
    create_table=True,
    schema=None,
    callback=None,
    timeout=0.1,
    chunk_size=8388608,
    max_samples=10000,
    unlogged=False,
):
#

Required: psql command

#

TODO: add option to run parallel COPY processes TODO: add logging to the process TODO: detect when error ocurred and interrupt the process immediatly

    if encoding is None:
        fobj = open_compressed(filename, mode="rb")
        sample_bytes = fobj.read(chunk_size)
        fobj.close()
        source = detect_local_source(filename, sample_bytes)
        encoding = source.encoding

    pg_encoding = encoding
    if pg_encoding in ("us-ascii", "ascii"):
#

TODO: convert all possible encodings

        pg_encoding = "SQL_ASCII"

    fobj = open_compressed(filename, mode="r", encoding=encoding)
    sample = fobj.read(chunk_size)
    fobj.close()

    if dialect is None:  # Detect dialect
        dialect = rows.plugins.csv.discover_dialect(
            sample.encode(encoding), encoding=encoding
        )
    elif isinstance(dialect, six.text_type):
        dialect = csv.get_dialect(dialect)
#

TODO: add else to check if dialect is instace of correct class

    reader = csv.reader(io.StringIO(sample), dialect=dialect)
    csv_field_names = [slug(field_name) for field_name in next(reader)]
    if schema is None:
        field_names = csv_field_names
    else:
        field_names = list(schema.keys())
        if not set(csv_field_names).issubset(set(field_names)):
            raise ValueError('CSV field names are not a subset of schema field names')

    if create_table:
#

If we need to create the table, it creates based on schema (automatically identified or forced), not on CSV directly (field order will be schema’s field order).

        if schema is None:
            schema = rows.fields.detect_types(
                csv_field_names,
                itertools.islice(reader, max_samples)
            )
        create_table_sql = pg_create_table_sql(schema, table_name, unlogged=unlogged)
        pg_execute_psql(database_uri, create_table_sql)
#

Prepare the psql command to be executed based on collected metadata

    command = get_psql_copy_command(
        database_uri=database_uri,
        dialect=dialect,
        direction="FROM",
        encoding=pg_encoding,
        header=csv_field_names,
        table_name_or_query=table_name,
        is_query=False,
    )
    rows_imported, error = 0, None
    fobj = open_compressed(filename, mode="rb")
    try:
        process = subprocess.Popen(
            shlex.split(command),
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        data = fobj.read(chunk_size)
        total_written = 0
        while data != b"":
            written = process.stdin.write(data)
            total_written += written
            if callback:
                callback(written, total_written)
            data = fobj.read(chunk_size)
        stdout, stderr = process.communicate()
        if stderr != b"":
            for line in stderr.splitlines():
                if line.startswith(b"NOTICE:"):
                    continue
                else:
                    raise RuntimeError(stderr.decode("utf-8"))
        rows_imported = int(stdout.replace(b"COPY ", b"").strip())

    except FileNotFoundError:
        fobj.close()
        raise RuntimeError("Command `psql` not found")

    except BrokenPipeError:
        fobj.close()
        raise RuntimeError(process.stderr.read().decode("utf-8"))

    else:
        fobj.close()
        return {"bytes_written": total_written, "rows_imported": rows_imported}
#

Export data from PostgreSQL into a CSV file using the fastest method

def pgexport(
    database_uri,
    table_name_or_query,
    filename,
    encoding="utf-8",
    dialect=csv.excel,
    callback=None,
    is_query=False,
    timeout=0.1,
    chunk_size=8388608,
):
#

Required: psql command

#

TODO: add logging to the process

    if isinstance(dialect, six.text_type):
        dialect = csv.get_dialect(dialect)
#

Prepare the psql command to be executed to export data

    command = get_psql_copy_command(
        database_uri=database_uri,
        direction="TO",
        encoding=encoding,
        header=None,  # Needed when direction = 'TO'
        table_name_or_query=table_name_or_query,
        is_query=is_query,
        dialect=dialect,
    )
    fobj = open_compressed(filename, mode="wb")
    try:
        process = subprocess.Popen(
            shlex.split(command),
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        total_written = 0
        data = process.stdout.read(chunk_size)
        while data != b"":
            written = fobj.write(data)
            total_written += written
            if callback:
                callback(written, total_written)
            data = process.stdout.read(chunk_size)
        stdout, stderr = process.communicate()
        if stderr != b"":
            raise RuntimeError(stderr.decode("utf-8"))

    except FileNotFoundError:
        fobj.close()
        raise RuntimeError("Command `psql` not found")

    except BrokenPipeError:
        fobj.close()
        raise RuntimeError(process.stderr.read().decode("utf-8"))

    else:
        fobj.close()
        return {"bytes_written": total_written}
#

Generate table schema for a specific output format and write

def generate_schema(table, export_fields, output_format):
#

Current supported output formats: ‘txt’, ‘sql’ and ‘django’. The table name and all fields names pass for a slugifying process (table name is taken from file name).

    if output_format in ("csv", "txt"):
        from rows import plugins

        data = [
            {
                "field_name": fieldname,
                "field_type": fieldtype.__name__.replace("Field", "").lower(),
            }
            for fieldname, fieldtype in table.fields.items()
            if fieldname in export_fields
        ]
        table = plugins.dicts.import_from_dicts(
            data, import_fields=["field_name", "field_type"]
        )
        if output_format == "txt":
            return plugins.txt.export_to_txt(table)
        elif output_format == "csv":
            return plugins.csv.export_to_csv(table).decode("utf-8")

    elif output_format == "sql":
#

TODO: may use dict from rows.plugins.sqlite or postgresql

        sql_fields = {
            rows.fields.BinaryField: "BLOB",
            rows.fields.BoolField: "BOOL",
            rows.fields.IntegerField: "INT",
            rows.fields.FloatField: "FLOAT",
            rows.fields.PercentField: "FLOAT",
            rows.fields.DateField: "DATE",
            rows.fields.DatetimeField: "DATETIME",
            rows.fields.TextField: "TEXT",
            rows.fields.DecimalField: "FLOAT",
            rows.fields.EmailField: "TEXT",
            rows.fields.JSONField: "TEXT",
        }
        fields = [
            "    {} {}".format(field_name, sql_fields[field_type])
            for field_name, field_type in table.fields.items()
            if field_name in export_fields
        ]
        sql = (
            dedent(
#

CREATE TABLE IF NOT EXISTS {name} ( {fields} );

            )
            .strip()
            .format(name=table.name, fields=",\n".join(fields))
            + "\n"
        )
        return sql

    elif output_format == "django":
        django_fields = {
            rows.fields.BinaryField: "BinaryField",
            rows.fields.BoolField: "BooleanField",
            rows.fields.IntegerField: "IntegerField",
            rows.fields.FloatField: "FloatField",
            rows.fields.PercentField: "DecimalField",
            rows.fields.DateField: "DateField",
            rows.fields.DatetimeField: "DateTimeField",
            rows.fields.TextField: "TextField",
            rows.fields.DecimalField: "DecimalField",
            rows.fields.EmailField: "EmailField",
            rows.fields.JSONField: "JSONField",
        }
        table_name = "".join(word.capitalize() for word in table.name.split("_"))

        lines = ["from django.db import models"]
        if rows.fields.JSONField in [
            table.fields[field_name] for field_name in export_fields
        ]:
            lines.append("from django.contrib.postgres.fields import JSONField")
        lines.append("")

        lines.append("class {}(models.Model):".format(table_name))
        for field_name, field_type in table.fields.items():
            if field_name not in export_fields:
                continue

            if field_type is not rows.fields.JSONField:
                django_type = "models.{}()".format(django_fields[field_type])
            else:
                django_type = "JSONField()"
            lines.append("    {} = {}".format(field_name, django_type))

        result = "\n".join(lines) + "\n"
        return result
#

Load schema from file in any of the supported formats

def load_schema(filename, context=None):
#

The table must have at least the fields field_name and field_type. context is a dict with field_type as key pointing to field class, like: {“text”: rows.fields.TextField, “value”: MyCustomField}

TODO: load_schema must support Path objects

    table = import_from_uri(filename)
    field_names = table.field_names
    assert "field_name" in field_names
    assert "field_type" in field_names

    context = context or {
        key.replace("Field", "").lower(): getattr(rows.fields, key)
        for key in dir(rows.fields)
        if "Field" in key and key != "Field"
    }
    return OrderedDict([(row.field_name, context[row.field_type]) for row in table])
#

Shortcuts

csv2sqlite = csv_to_sqlite
sqlite2csv = sqlite_to_csv