Copyright 2014-2019 Álvaro Justen https://github.com/turicas/rows/
This program is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with this program. If not, see http://www.gnu.org/licenses/.
from __future__ import unicode_literals
import datetime
import sqlite3
import string
from pathlib import Path
import six
import rows.fields as fields
from rows.plugins.utils import (
create_table,
ipartition,
make_unique_name,
prepare_to_export,
)
from rows.utils import Source
SQL_TABLE_NAMES = 'SELECT name FROM sqlite_master WHERE type="table"'
SQL_CREATE_TABLE = 'CREATE TABLE IF NOT EXISTS "{table_name}" ({field_types})'
SQL_SELECT_ALL = 'SELECT * FROM "{table_name}"'
SQL_INSERT = 'INSERT INTO "{table_name}" ({field_names}) VALUES ({placeholders})'
SQLITE_TYPES = {
fields.BinaryField: "BLOB",
fields.BoolField: "INTEGER",
fields.DateField: "TEXT",
fields.DatetimeField: "TEXT",
fields.DecimalField: "REAL",
fields.FloatField: "REAL",
fields.IntegerField: "INTEGER",
fields.PercentField: "REAL",
fields.TextField: "TEXT",
}
DEFAULT_TYPE = "BLOB"
def _python_to_sqlite(field_types):
def convert_value(field_type, value):
if field_type in (
fields.BinaryField,
fields.BoolField,
fields.FloatField,
fields.IntegerField,
fields.TextField,
):
return value
elif field_type in (fields.DateField, fields.DatetimeField):
if value is None:
return None
elif isinstance(value, (datetime.date, datetime.datetime)):
return value.isoformat()
elif isinstance(value, (six.binary_type, six.text_type)):
return value
else:
raise ValueError("Cannot serialize date value: {}".format(repr(value)))
elif field_type in (fields.DecimalField, fields.PercentField):
return float(value) if not fields.is_null(value) else None
else: # don't know this field
return field_type.serialize(value)
def convert_row(row):
return [
convert_value(field_type, value)
for field_type, value in zip(field_types, row)
]
return convert_row
def get_source(filename_or_connection):
if isinstance(filename_or_connection, (six.binary_type, six.text_type, Path)):
connection = sqlite3.connect(filename_or_connection)
uri = filename_or_connection
input_is_uri = should_close = True
else: # already a connection
connection = filename_or_connection
input_is_uri = should_close = False
uri = None
Try to get filename inspecting the database
cursor = connection.cursor()
for _, name, filename in cursor.execute("PRAGMA database_list"):
if name == "main" and filename is not None:
uri = filename
cursor.close()
TODO: may improve Source for non-fobj cases (when open() is not needed)
source = Source.from_file(
connection,
plugin_name="sqlite",
mode=None,
is_file=bool(uri),
local=bool(uri),
should_close=should_close,
)
source.uri = uri
return source
Verify if a given table name is valid for rows
.
def _valid_table_name(name):
Rules: - Should start with a letter or ‘_’ - Letters can be capitalized or not - Acceps letters, numbers and _
if name[0] not in "_" + string.ascii_letters or not set(name).issubset(
"_" + string.ascii_letters + string.digits
):
return False
else:
return True
Return a rows.Table with data from SQLite database.
def import_from_sqlite(
filename_or_connection,
table_name="table1",
query=None,
query_args=None,
*args,
**kwargs
):
source = get_source(filename_or_connection)
connection = source.fobj
cursor = connection.cursor()
if query is None:
if not _valid_table_name(table_name):
raise ValueError("Invalid table name: {}".format(table_name))
query = SQL_SELECT_ALL.format(table_name=table_name)
if query_args is None:
query_args = tuple()
table_rows = list(cursor.execute(query, query_args)) # TODO: may be lazy
header = [six.text_type(info[0]) for info in cursor.description]
cursor.close()
TODO: should close connection also?
meta = {"imported_from": "sqlite", "source": source}
return create_table([header] + table_rows, meta=meta, *args, **kwargs)
def export_to_sqlite(
table,
filename_or_connection,
table_name=None,
table_name_format="table{index}",
batch_size=100,
callback=None,
*args,
**kwargs
):
TODO: should add transaction support?
prepared_table = prepare_to_export(table, *args, **kwargs)
source = get_source(filename_or_connection)
connection = source.fobj
cursor = connection.cursor()
if table_name is None:
table_names = [item[0] for item in cursor.execute(SQL_TABLE_NAMES)]
table_name = make_unique_name(
table_name_format.format(index=1),
existing_names=table_names,
name_format=table_name_format,
start=1,
)
elif not _valid_table_name(table_name):
raise ValueError("Invalid table name: {}".format(table_name))
field_names = next(prepared_table)
field_types = list(map(table.fields.get, field_names))
columns = [
"{} {}".format(field_name, SQLITE_TYPES.get(field_type, DEFAULT_TYPE))
for field_name, field_type in zip(field_names, field_types)
]
cursor.execute(
SQL_CREATE_TABLE.format(table_name=table_name, field_types=", ".join(columns))
)
insert_sql = SQL_INSERT.format(
table_name=table_name,
field_names=", ".join(field_names),
placeholders=", ".join("?" for _ in field_names),
)
_convert_row = _python_to_sqlite(field_types)
if callback is None:
for batch in ipartition(prepared_table, batch_size):
cursor.executemany(insert_sql, map(_convert_row, batch))
else:
total_written = 0
for batch in ipartition(prepared_table, batch_size):
cursor.executemany(insert_sql, map(_convert_row, batch))
written = len(batch)
total_written += written
callback(written, total_written)
connection.commit()
return connection