Copyright 2014-2019 Álvaro Justen https://github.com/turicas/rows/
This program is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with this program. If not, see http://www.gnu.org/licenses/.
from __future__ import unicode_literals
import io
import math
import six
from cached_property import cached_property
from rows.plugins.utils import create_table
from rows.utils import Source
try:
import fitz as pymupdf
pymupdf.TOOLS.mupdf_display_errors(False)
pymupdf_imported = True
except ImportError:
pymupdf_imported = False
try:
from pdfminer.converter import PDFPageAggregator, TextConverter
from pdfminer.layout import LAParams, LTTextBox, LTTextLine, LTChar, LTRect
from pdfminer.pdfdocument import PDFDocument
from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter, resolve1
from pdfminer.pdfpage import PDFPage
from pdfminer.pdfparser import PDFParser
import logging
logging.getLogger("pdfminer").setLevel(logging.ERROR)
PDFMINER_TEXT_TYPES = (LTTextBox, LTTextLine, LTChar)
PDFMINER_ALL_TYPES = (LTTextBox, LTTextLine, LTChar, LTRect)
pdfminer_imported = True
except ImportError:
pdfminer_imported = False
PDFMINER_TEXT_TYPES, PDFMINER_ALL_TYPES = None, None
extract_intervals(“1,2,3”) [1, 2, 3] extract_intervals(“1,2,5-10”) [1, 2, 5, 6, 7, 8, 9, 10] extract_intervals(“1,2,5-10,3”) [1, 2, 3, 5, 6, 7, 8, 9, 10] extract_intervals(“1,2,5-10,6,7”) [1, 2, 5, 6, 7, 8, 9, 10]
def extract_intervals(text, repeat=False, sort=True):
result = []
for value in text.split(","):
value = value.strip()
if "-" in value:
start_value, end_value = value.split("-")
start_value = int(start_value.strip())
end_value = int(end_value.strip())
result.extend(range(start_value, end_value + 1))
else:
result.append(int(value.strip()))
if not repeat:
result = list(set(result))
if sort:
result.sort()
return result
def default_backend():
if pymupdf_imported:
return "pymupdf"
elif pdfminer_imported:
return "pdfminer.six"
else:
raise ImportError(
"No PDF backend found. Did you install the dependencies (pymupdf or pdfminer.six)?"
)
def number_of_pages(filename_or_fobj, backend=None):
backend = backend or default_backend()
Backend = get_backend(backend)
pdf_doc = Backend(filename_or_fobj)
return pdf_doc.number_of_pages
def pdf_to_text(filename_or_fobj, page_numbers=None, backend=None):
if isinstance(page_numbers, six.text_type):
page_numbers = extract_intervals(page_numbers)
backend = backend or default_backend()
Backend = get_backend(backend)
pdf_doc = Backend(filename_or_fobj)
for page in pdf_doc.extract_text(page_numbers=page_numbers):
yield page
Base Backend class to parse PDF files
class PDFBackend(object):
x_order = 1
y_order = 1
def __init__(self, source):
self.source = Source.from_file(source, plugin_name="pdf", mode="rb")
Filter out objects outside table boundaries
@property
def number_of_pages(self):
"Number of pages in the document"
raise NotImplementedError()
def extract_text(self):
"Return a string for each page in the document (generator)"
raise NotImplementedError()
def objects(self):
"Return a list of objects for each page in the document (generator)"
raise NotImplementedError()
def text_objects(self):
"Return a list of text objects for each page in the document (generator)"
raise NotImplementedError()
@property
def text(self):
return "\n\n".join(self.extract_text())
def get_cell_text(self, cell):
if not cell:
return ""
if self.y_order == 1:
cell.sort(key=lambda obj: obj.y0)
else:
cell.sort(key=lambda obj: -obj.y0)
return "\n".join(obj.text.strip() for obj in cell)
def __del__(self):
source = self.source
if (
source.should_close
and hasattr(source.fobj, "closed")
and not source.fobj.closed
):
source.fobj.close()
class PDFMinerBackend(PDFBackend):
name = "pdfminer.six"
y_order = -1
@cached_property
def document(self):
parser = PDFParser(self.source.fobj)
doc = PDFDocument(parser)
parser.set_document(doc)
return doc
@cached_property
def number_of_pages(self):
return resolve1(self.document.catalog["Pages"])["Count"]
def extract_text(self, page_numbers=None):
for page_number, page in enumerate(
PDFPage.create_pages(self.document), start=1
):
if page_numbers is not None and page_number not in page_numbers:
continue
rsrcmgr = PDFResourceManager()
laparams = LAParams()
result = io.StringIO()
device = TextConverter(rsrcmgr, result, laparams=laparams)
interpreter = PDFPageInterpreter(rsrcmgr, device)
interpreter.process_page(page)
yield result.getvalue()
@staticmethod
def convert_object(obj):
if isinstance(obj, PDFMINER_TEXT_TYPES):
return TextObject(
x0=obj.x0, y0=obj.y0, x1=obj.x1, y1=obj.y1, text=obj.get_text()
)
elif isinstance(obj, LTRect):
return RectObject(x0=obj.x0, y0=obj.y0, x1=obj.x1, y1=obj.y1, fill=obj.fill)
def objects(
self,
page_numbers=None,
starts_after=None,
ends_before=None,
desired_types=PDFMINER_ALL_TYPES,
):
doc = self.document
rsrcmgr = PDFResourceManager()
laparams = LAParams()
device = PDFPageAggregator(rsrcmgr, laparams=laparams)
interpreter = PDFPageInterpreter(rsrcmgr, device)
started, finished = False, False
if starts_after is None:
started = True
else:
starts_after = get_delimiter_function(starts_after)
if ends_before is not None:
ends_before = get_delimiter_function(ends_before)
for page_number, page in enumerate(PDFPage.create_pages(doc), start=1):
if page_numbers is not None and page_number not in page_numbers:
continue
interpreter.process_page(page)
layout = device.get_result()
objs = [
PDFMinerBackend.convert_object(obj)
for obj in layout
if isinstance(obj, desired_types)
]
objs.sort(key=lambda obj: -obj.y0)
objects_in_page = []
for obj in objs:
if not started and starts_after is not None and starts_after(obj):
started = True
if started and ends_before is not None and ends_before(obj):
finished = True
break
if started:
objects_in_page.append(obj)
yield objects_in_page
if finished:
break
def text_objects(self, page_numbers=None, starts_after=None, ends_before=None):
return self.objects(
page_numbers=page_numbers,
starts_after=starts_after,
ends_before=ends_before,
desired_types=PDFMINER_TEXT_TYPES,
)
class PyMuPDFBackend(PDFBackend):
name = "pymupdf"
@cached_property
def document(self):
if self.source.uri:
doc = pymupdf.open(filename=self.source.uri, filetype="pdf")
else:
data = self.source.fobj.read() # TODO: may use a lot of memory
doc = pymupdf.open(stream=data, filetype="pdf")
return doc
@cached_property
def number_of_pages(self):
return self.document.pageCount
def extract_text(self, page_numbers=None):
doc = self.document
for page_number, page_index in enumerate(range(doc.pageCount), start=1):
if page_numbers is not None and page_number not in page_numbers:
continue
page = doc.loadPage(page_index)
page_text = "\n".join(block[4] for block in page.getTextBlocks())
yield page_text
def objects(self, page_numbers=None, starts_after=None, ends_before=None):
doc = self.document
started, finished = False, False
if starts_after is None:
started = True
else:
starts_after = get_delimiter_function(starts_after)
if ends_before is not None:
ends_before = get_delimiter_function(ends_before)
for page_number, page_index in enumerate(range(doc.pageCount), start=1):
if page_numbers is not None and page_number not in page_numbers:
continue
page = doc.loadPage(page_index)
text_objs = []
for block in page.getText("dict")["blocks"]:
if block["type"] != 0:
continue
for line in block["lines"]:
line_text = " ".join(span["text"] for span in line["spans"])
text_objs.append(list(line["bbox"]) + [line_text])
objs = [
TextObject(x0=obj[0], y0=obj[1], x1=obj[2], y1=obj[3], text=obj[4])
for obj in text_objs
]
objs.sort(key=lambda obj: (obj.y0, obj.x0))
objects_in_page = []
for obj in objs:
if not started and starts_after is not None and starts_after(obj):
started = True
if started and ends_before is not None and ends_before(obj):
finished = True
break
if started:
objects_in_page.append(obj)
yield objects_in_page
if finished:
break
text_objects = objects
def get_delimiter_function(value):
if isinstance(value, str): # regular string, match exactly
return lambda obj: (isinstance(obj, TextObject) and obj.text.strip() == value)
elif hasattr(value, "search"): # regular expression
return lambda obj: bool(
isinstance(obj, TextObject) and value.search(obj.text.strip())
)
elif callable(value): # function
return lambda obj: bool(value(obj))
class TextObject(object):
def __init__(self, x0, y0, x1, y1, text):
self.x0 = x0
self.x1 = x1
self.y0 = y0
self.y1 = y1
self.text = text
@property
def bbox(self):
return (self.x0, self.y0, self.x1, self.y1)
def __repr__(self):
text = repr(self.text)
if len(text) > 50:
text = repr(self.text[:45] + "[...]")
bbox = ", ".join("{:.3f}".format(value) for value in self.bbox)
return "<TextObject ({}) {}>".format(bbox, text)
class RectObject(object):
def __init__(self, x0, y0, x1, y1, fill):
self.x0 = x0
self.x1 = x1
self.y0 = y0
self.y1 = y1
self.fill = fill
@property
def bbox(self):
return (self.x0, self.y0, self.x1, self.y1)
def __repr__(self):
bbox = ", ".join("{:.3f}".format(value) for value in self.bbox)
return "<RectObject ({}) fill={}>".format(bbox, self.fill)
class Group(object):
"Helper class to group objects based on its positions and sizes"
def __init__(self, minimum=float("inf"), maximum=float("-inf"), threshold=0):
self.minimum = minimum
self.maximum = maximum
self.threshold = threshold
self.objects = []
@property
def min(self):
return self.minimum - self.threshold
@property
def max(self):
return self.maximum + self.threshold
def contains(self, obj):
d0 = getattr(obj, self.dimension_0)
d1 = getattr(obj, self.dimension_1)
middle = d0 + (d1 - d0) / 2.0
return self.min <= middle <= self.max
def add(self, obj):
self.objects.append(obj)
d0 = getattr(obj, self.dimension_0)
d1 = getattr(obj, self.dimension_1)
if d0 < self.minimum:
self.minimum = d0
if d1 > self.maximum:
self.maximum = d1
class HorizontalGroup(Group):
dimension_0 = "y0"
dimension_1 = "y1"
class VerticalGroup(Group):
dimension_0 = "x0"
dimension_1 = "x1"
def group_objects(objs, threshold, axis):
if axis == "x":
GroupClass = VerticalGroup
elif axis == "y":
GroupClass = HorizontalGroup
groups = []
for obj in objs:
found = False
for group in groups:
if group.contains(obj):
group.add(obj)
found = True
break
if not found:
group = GroupClass(threshold=threshold)
group.add(obj)
groups.append(group)
return {group.minimum: group.objects for group in groups}
def contains_or_overlap(a, b):
x1min, y1min, x1max, y1max = a
x2min, y2min, x2max, y2max = b
contains = x2min >= x1min and x2max <= x1max and y2min >= y1min and y2max <= y1max
overlaps = (
(x1min <= x2min <= x1max and y1min <= y2min <= y1max)
or (x1min <= x2min <= x1max and y1min <= y2max <= y1max)
or (x1min <= x2max <= x1max and y1min <= y2min <= y1max)
or (x1min <= x2max <= x1max and y1min <= y2max <= y1max)
)
return contains or overlaps
def distance(a, b):
return math.sqrt((a.x0 - b.x0) ** 2 + (a.y0 - b.y0) ** 2)
def closest_from_text(objs, text, strip=True):
if strip:
text = text.strip()
desired_obj = [obj for obj in objs if obj.text.strip() == text][0]
else:
desired_obj = [obj for obj in objs if obj.text == text][0]
for obj in sorted(objs, key=lambda row: distance(desired_obj, row)):
if obj.text.strip() != text:
return obj
def closest_same_line(objs, text):
for _, line_objs in group_objects(objs, 0.1, "y").items():
desired_y0 = None
for obj in line_objs:
if obj.text == text:
desired_y0 = obj.y0
break
if desired_y0 is not None:
return sorted(line_objs, key=lambda row: -row.x0)[0]
return None # Not found
def same_column(objs, text):
object_groups = {
key: list(value) for key, value in group_objects(objs, 0.1, "x").items()
}
desired_x0 = None
for x0, column_objs in object_groups.items():
for obj in column_objs:
if obj.text.strip() == text:
desired_x0 = x0
break
if desired_x0 is None: # Text not found
return []
return sorted(object_groups[desired_x0], key=lambda row: -row.y0)
class ExtractionAlgorithm(object):
def __init__(
self, objects, text_objects, x_threshold, y_threshold, x_order, y_order
):
self.objects = objects
self.text_objects = text_objects
self.x_threshold = x_threshold
self.y_threshold = y_threshold
self.x_order = x_order
self.y_order = y_order
@property
def table_bbox(self):
raise NotImplementedError
@property
def x_intervals(self):
raise NotImplementedError
@property
def y_intervals(self):
raise NotImplementedError
@cached_property
def selected_objects(self):
return [
obj
for obj in self.text_objects
if contains_or_overlap(self.table_bbox, obj.bbox)
]
def get_lines(self):
x_intervals = list(self.x_intervals)
if self.x_order == -1:
x_intervals = list(reversed(x_intervals))
y_intervals = list(self.y_intervals)
if self.y_order == -1:
y_intervals = list(reversed(y_intervals))
objs = list(self.selected_objects)
matrix = []
for y0, y1 in y_intervals:
line = []
for x0, x1 in x_intervals:
cell = [
obj for obj in objs if x0 <= obj.x0 <= x1 and y0 <= obj.y0 <= y1
]
if not cell:
line.append(None)
else:
line.append(cell)
for obj in cell:
objs.remove(obj)
matrix.append(line)
return matrix
Extraction algorithm based on objects’ y values
class YGroupsAlgorithm(ExtractionAlgorithm):
name = "y-groups"
TODO: filter out objects with empty text before grouping by y0 (but consider them if inside table’s bbox) TODO: get y0 groups bbox and merge overlapping ones (overlapping only on y, not on x). ex: imgs-33281.pdf/06.png should not remove bigger cells
@cached_property
def table_bbox(self):
groups = group_objects(self.text_objects, self.y_threshold, "y")
desired_objs = []
for group_objs in groups.values():
if len(group_objs) < 2: # Ignore floating text objects
continue
desired_objs.extend(group_objs)
if not desired_objs:
return (0, 0, 0, 0)
x_min = min(obj.x0 for obj in desired_objs)
x_max = max(obj.x1 for obj in desired_objs)
y_min = min(obj.y0 for obj in desired_objs)
y_max = max(obj.y1 for obj in desired_objs)
return (x_min, y_min, x_max, y_max)
@staticmethod
def _define_intervals(objs, min_attr, max_attr, threshold, axis):
groups = group_objects(objs, threshold, axis)
intervals = [
(key, max_attr(max(value, key=max_attr))) for key, value in groups.items()
]
intervals.sort()
if not intervals:
return []
Merge overlapping intervals
result = [intervals[0]]
for current in intervals[1:]:
previous = result.pop()
if current[0] <= previous[1] or current[1] <= previous[1]:
result.append((previous[0], max((previous[1], current[1]))))
else:
result.extend((previous, current))
return result
@cached_property
def x_intervals(self):
objects = self.selected_objects
objects.sort(key=lambda obj: obj.x0)
return self._define_intervals(
objects,
min_attr=lambda obj: obj.x0,
max_attr=lambda obj: obj.x1,
threshold=self.x_threshold,
axis="x",
)
@cached_property
def y_intervals(self):
objects = self.selected_objects
objects.sort(key=lambda obj: -obj.y1)
return self._define_intervals(
objects,
min_attr=lambda obj: obj.y0,
max_attr=lambda obj: obj.y1,
threshold=self.y_threshold,
axis="y",
)
class HeaderPositionAlgorithm(YGroupsAlgorithm):
name = "header-position"
@property
def x_intervals(self):
raise NotImplementedError
def get_lines(self):
objects = self.selected_objects
objects.sort(key=lambda obj: obj.x0)
y_intervals = list(self.y_intervals)
if self.y_order == -1:
y_intervals = list(reversed(y_intervals))
used, lines = [], []
header_interval = y_intervals[0]
header_objs = [
obj for obj in objects if header_interval[0] <= obj.y0 <= header_interval[1]
]
used.extend(header_objs)
lines.append([[obj] for obj in header_objs])
def x_intersects(a, b):
return a.x0 < b.x1 and a.x1 > b.x0
for y0, y1 in y_intervals[1:]:
line_objs = [
obj for obj in objects if obj not in used and y0 <= obj.y0 <= y1
]
line = []
for column in header_objs:
y_objs = [
obj
for obj in line_objs
if obj not in used and x_intersects(column, obj)
]
used.extend(y_objs)
line.append(y_objs)
lines.append(line)
TODO: may check if one of objects in line_objs is not in used and raise an exception
return lines
Extraction algorithm based on rectangles present in the page
class RectsBoundariesAlgorithm(ExtractionAlgorithm):
name = "rects-boundaries"
def __init__(self, *args, **kwargs):
super(RectsBoundariesAlgorithm, self).__init__(*args, **kwargs)
self.rects = [
obj for obj in self.objects if isinstance(obj, RectObject) and obj.fill
]
@cached_property
def table_bbox(self):
y0 = min(obj.y0 for obj in self.rects)
y1 = max(obj.y1 for obj in self.rects)
x0 = min(obj.x0 for obj in self.rects)
x1 = max(obj.x1 for obj in self.rects)
return (x0, y0, x1, y1)
@staticmethod
def _clean_intersections(lines):
def other_line_contains(all_lines, search_line):
for line2 in all_lines:
if search_line == line2:
continue
elif search_line[0] >= line2[0] and search_line[1] <= line2[1]:
return True
return False
final = []
for line in lines:
if not other_line_contains(lines, line):
final.append(line)
return final
@cached_property
def x_intervals(self):
x_intervals = set((obj.x0, obj.x1) for obj in self.rects)
return sorted(self._clean_intersections(x_intervals))
@cached_property
def y_intervals(self):
y_intervals = set((obj.y0, obj.y1) for obj in self.rects)
return sorted(self._clean_intersections(y_intervals))
def subclasses(cls):
children = cls.__subclasses__()
return set(children).union(
set(grandchild for child in children for grandchild in subclasses(child))
)
def algorithms():
return {Class.name: Class for Class in subclasses(ExtractionAlgorithm)}
def get_algorithm(algorithm):
available_algorithms = algorithms()
if isinstance(algorithm, six.text_type):
if algorithm not in available_algorithms:
raise ValueError(
'Unknown algorithm "{}" (options are: {})'.format(
algorithm, ", ".join(available_algorithms.keys())
)
)
return available_algorithms[algorithm]
elif issubclass(algorithm, ExtractionAlgorithm):
return algorithm
else:
raise ValueError(
'Unknown algorithm "{}" (options are: {})'.format(
algorithm, ", ".join(available_algorithms.keys())
)
)
def backends():
return {Class.name: Class for Class in subclasses(PDFBackend)}
def get_backend(backend):
available_backends = backends()
if isinstance(backend, six.text_type):
if backend not in available_backends:
raise ValueError(
'Unknown PDF backend "{}" (options are: {})'.format(
backend, ", ".join(available_backends.keys())
)
)
return available_backends[backend]
elif issubclass(backend, PDFBackend):
return backend
else:
raise ValueError(
'Unknown PDF backend "{}" (options are: {})'.format(
backend, ", ".join(available_backends.keys())
)
)
def pdf_table_lines(
source,
page_numbers=None,
algorithm="y-groups",
starts_after=None,
ends_before=None,
x_threshold=0.5,
y_threshold=0.5,
backend=None,
):
if isinstance(page_numbers, six.text_type):
page_numbers = extract_intervals(page_numbers)
backend = backend or default_backend()
TODO: check if both backends accepts filename or fobj
Backend = get_backend(backend)
Algorithm = get_algorithm(algorithm)
pdf_doc = Backend(source)
pages = pdf_doc.objects(
page_numbers=page_numbers, starts_after=starts_after, ends_before=ends_before
)
header = line_size = None
for page_index, page in enumerate(pages):
objs = list(page)
text_objs = [obj for obj in objs if isinstance(obj, TextObject)]
extractor = Algorithm(
objs, text_objs, x_threshold, y_threshold, pdf_doc.x_order, pdf_doc.y_order
)
lines = [
[pdf_doc.get_cell_text(cell) for cell in row]
for row in extractor.get_lines()
]
for line_index, line in enumerate(lines):
if line_index == 0:
if page_index == 0:
header = line
elif page_index > 0 and line == header: # skip header repetition
continue
yield line
def import_from_pdf(
filename_or_fobj,
page_numbers=None,
starts_after=None,
ends_before=None,
backend=None,
algorithm="y-groups",
x_threshold=0.5,
y_threshold=0.5,
*args,
**kwargs
):
if isinstance(page_numbers, six.text_type):
page_numbers = extract_intervals(page_numbers)
backend = backend or default_backend()
source = Source.from_file(filename_or_fobj, plugin_name="pdf", mode="rb")
meta = {"imported_from": "pdf", "source": source}
table_rows = pdf_table_lines(
source,
page_numbers,
starts_after=starts_after,
ends_before=ends_before,
algorithm=algorithm,
x_threshold=x_threshold,
y_threshold=y_threshold,
backend=backend,
)
return create_table(table_rows, meta=meta, *args, **kwargs)
Call the function so it’ll raise ImportError if no backend is available
default_backend()