Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor extraction pipeline.2 #23

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions src/stratigraphy/benchmark/score.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Evaluate the predictions against the ground truth."""

import json
import logging
import os
from pathlib import Path
Expand All @@ -9,7 +8,6 @@
from dotenv import load_dotenv
from stratigraphy import DATAPATH
from stratigraphy.benchmark.ground_truth import GroundTruth
from stratigraphy.util.draw import draw_predictions
from stratigraphy.util.util import parse_text

load_dotenv()
Expand Down Expand Up @@ -56,34 +54,20 @@ def f1(precision: float, recall: float) -> float:
return 0


def evaluate_matching(
predictions_path: Path, ground_truth_path: Path, directory: Path, out_directory: Path, skip_draw_predictions: bool
) -> tuple[dict, pd.DataFrame]:
def evaluate_matching(predictions: dict, number_of_truth_values: dict) -> tuple[dict, pd.DataFrame]:
"""Calculate F1, precision and recall for the predictions.

Calculate F1, precision and recall for the individual documents as well as overall.
The individual document metrics are returned as a DataFrame.

Args:
predictions_path (Path): Path to the predictions.json file.
ground_truth_path (Path): Path to the ground truth annotated data.
directory (Path): Path to the directory containing the pdf files.
out_directory (Path): Path to the directory where the evaluation images should be saved.
skip_draw_predictions (bool): Whether to draw the predictions on the pdf pages.
predictions (dict): The predictions.
number_of_truth_values (dict): The number of ground truth values per file.

Returns:
tuple[dict, pd.DataFrame]: A tuple containing the overall F1, precision and recall as a dictionary and the
individual document metrics as a DataFrame.
"""
ground_truth = GroundTruth(ground_truth_path)
with open(predictions_path) as in_file:
predictions = json.load(in_file)

predictions, number_of_truth_values = _add_ground_truth_to_predictions(predictions, ground_truth)

if not skip_draw_predictions:
draw_predictions(predictions, directory, out_directory)

document_level_metrics = {
"document_name": [],
"F1": [],
Expand Down Expand Up @@ -137,16 +121,18 @@ def evaluate_matching(
}, pd.DataFrame(document_level_metrics)


def _add_ground_truth_to_predictions(predictions: dict, ground_truth: GroundTruth) -> (dict, dict):
def add_ground_truth_to_predictions(predictions: dict, ground_truth_path: Path) -> tuple[dict, dict]:
"""Add the ground truth to the predictions.

Args:
predictions (dict): The predictions.
ground_truth (GroundTruth): The ground truth.
ground_truth_path (Path): The path to the ground truth file.

Returns:
(dict, dict): The predictions with the ground truth added, and the number of ground truth values per file.
tuple[dict, dict]: The predictions with the ground truth added, and the number of ground truth values per file.
"""
ground_truth = GroundTruth(ground_truth_path)

number_of_truth_values = {}
for file, file_predictions in predictions.items():
ground_truth_for_file = ground_truth.for_file(file)
Expand Down
6 changes: 2 additions & 4 deletions src/stratigraphy/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import fitz

from stratigraphy.line_detection import extract_lines, line_detection_params
from stratigraphy.util import find_depth_columns
from stratigraphy.util.dataclasses import Line
from stratigraphy.util.depthcolumn import DepthColumn
Expand All @@ -25,13 +24,14 @@
logger = logging.getLogger(__name__)


def process_page(page: fitz.Page, **params: dict) -> list[dict]:
def process_page(page: fitz.Page, geometric_lines, **params: dict) -> list[dict]:
"""Process a single page of a pdf.

Finds all descriptions and depth intervals on the page and matches them.

Args:
page (fitz.Page): The page to process.
geometric_lines (list[Line]): The geometric lines of the page.
**params (dict): Additional parameters for the matching pipeline.

Returns:
Expand Down Expand Up @@ -97,8 +97,6 @@ def process_page(page: fitz.Page, **params: dict) -> list[dict]:
continue
filtered_pairs = [item for index, item in enumerate(pairs) if index not in to_delete]

geometric_lines = extract_lines(page, line_detection_params)

groups = [] # list of matched depth intervals and text blocks
# groups is of the form: ["depth_interval": BoundaryInterval, "block": TextBlock]
if len(filtered_pairs): # match depth column items with material description
Expand Down
24 changes: 8 additions & 16 deletions src/stratigraphy/line_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Script for line detection in pdf pages."""

import os
from pathlib import Path

import cv2
import fitz
Expand Down Expand Up @@ -88,26 +87,19 @@ def extract_lines(page: fitz.Page, line_detection_params: dict) -> list[Line]:
return lines


def draw_lines_on_pdfs(input_directory: Path, line_detection_params: dict):
def draw_lines_on_page(filename: str, page: fitz.Page, geometric_lines: list[Line]):
"""Draw lines on pdf pages and stores them as artifacts in mlflow.

Note: now the function draw_lines_on_pdfs may not even be needed any more.
redur marked this conversation as resolved.
Show resolved Hide resolved

Args:
input_directory (Path): The directory containing the pdf files.
line_detection_params (dict): The parameters for the line detection algorithm.
filename (str): The filename of the pdf.
page (fitz.Page): The page to draw lines on.
geometric_lines (list[Line]): The lines to draw on the pdf page.
"""
if not mlflow_tracking:
raise Warning("MLFlow tracking is not enabled. MLFLow is required to store the images.")
import mlflow

for root, _dirs, files in os.walk(input_directory):
output = {}
for filename in files:
if filename.endswith(".pdf"):
in_path = os.path.join(root, filename)
output[filename] = {}

with fitz.Document(in_path) as doc:
for page_index, page in enumerate(doc):
lines = extract_lines(page, line_detection_params)
img = plot_lines(page, lines, scale_factor=line_detection_params["pdf_scale_factor"])
mlflow.log_image(img, f"pages/{filename}_page_{page_index}_lines.png")
img = plot_lines(page, geometric_lines, scale_factor=line_detection_params["pdf_scale_factor"])
mlflow.log_image(img, f"pages/{filename}_page_{page.number + 1}_lines.png")
59 changes: 40 additions & 19 deletions src/stratigraphy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
import json
import logging
import os
import shutil
from pathlib import Path

import click
import fitz
from dotenv import load_dotenv

from stratigraphy import DATAPATH
from stratigraphy.benchmark.score import evaluate_matching
from stratigraphy.extract import perform_matching
from stratigraphy.line_detection import draw_lines_on_pdfs, line_detection_params
from stratigraphy.benchmark.score import add_ground_truth_to_predictions, evaluate_matching
from stratigraphy.extract import process_page
from stratigraphy.line_detection import draw_lines_on_page, extract_lines, line_detection_params
from stratigraphy.util.draw import draw_predictions
from stratigraphy.util.util import flatten, read_params

load_dotenv()
Expand Down Expand Up @@ -103,32 +104,52 @@ def start_pipeline(

# if a file is specified instead of an input directory, copy the file to a temporary directory and work with that.
if input_directory.is_file():
if (temp_directory / "single_file").is_dir():
shutil.rmtree(temp_directory / "single_file")
file_iterator = [(input_directory.parent, None, [input_directory.name])]
else:
file_iterator = os.walk(input_directory)
# process the individual pdf files
predictions = {}
for root, _dirs, files in file_iterator:
for filename in files:
if filename.endswith(".pdf"):
in_path = os.path.join(root, filename)
logger.info("Processing file: %s", in_path)
predictions[filename] = {}

with fitz.Document(in_path) as doc:
for page_index, page in enumerate(doc):
page_number = page_index + 1
logger.info("Processing page %s", page_number)

geometric_lines = extract_lines(page, line_detection_params)
layer_predictions, depths_materials_column_pairs = process_page(
page, geometric_lines, **matching_params
)

predictions[filename][f"page_{page_number}"] = {
"layers": layer_predictions,
"depths_materials_column_pairs": depths_materials_column_pairs,
}
if draw_lines:
logger.info("Drawing lines on pdf pages.")
draw_lines_on_page(filename, page, geometric_lines)

Path.mkdir(temp_directory / "single_file")
shutil.copy(input_directory, temp_directory / "single_file")
input_directory = temp_directory / "single_file"

# run the matching pipeline and save the result
predictions = perform_matching(input_directory, **matching_params)
redur marked this conversation as resolved.
Show resolved Hide resolved
with open(predictions_path, "w") as file:
file.write(json.dumps(predictions))

# evaluate the predictions
metrics, document_level_metrics = evaluate_matching(
predictions_path, ground_truth_path, input_directory, out_directory, skip_draw_predictions
)
predictions, number_of_truth_values = add_ground_truth_to_predictions(predictions, ground_truth_path)

if not skip_draw_predictions:
draw_predictions(predictions, input_directory, out_directory)

metrics, document_level_metrics = evaluate_matching(predictions, number_of_truth_values)
document_level_metrics.to_csv(temp_directory / "document_level_metrics.csv") # mlflow.log_artifact expects a file

if mlflow_tracking:
mlflow.log_metrics(metrics)
mlflow.log_artifact(temp_directory / "document_level_metrics.csv")

if draw_lines:
logger.info("Drawing lines on pdf pages.")
draw_lines_on_pdfs(input_directory, line_detection_params=line_detection_params)


if __name__ == "__main__":
start_pipeline()
5 changes: 3 additions & 2 deletions src/stratigraphy/util/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ def draw_predictions(predictions: dict, directory: Path, out_directory: Path) ->
- Assignments of material description text blocks to depth intervals (if available)

Args:
predictions (dict): Content of the predictions.json file..
predictions (dict): Content of the predictions.json file.
directory (Path): Path to the directory containing the pdf files.
out_directory (Path): Path to the output directory where the images are saved.
"""
if directory.is_file(): # deal with the case when we pass a file instead of a directory
directory = directory.parent
for file in predictions:
logger.info(f"Evaluating {file}.")
with fitz.Document(directory / file) as doc:
for page_index, page in enumerate(doc):
page_number = page_index + 1
Expand Down
Loading