Skip to content

Commit

Permalink
Merge pull request #84 from swisstopo/LGVISIUM-82-Refactor-the-FilePr…
Browse files Browse the repository at this point in the history
…ediction-Object-and-its-evaluation

Close #LGVISIUM-82 & Close #LGVISIUM-97: Refactor the FilePrediction object and fix entrypoint for score.py
  • Loading branch information
dcleres authored Oct 22, 2024
2 parents ed163b0 + a8b6bbd commit 94afd15
Show file tree
Hide file tree
Showing 18 changed files with 916 additions and 349 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/pipeline_run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ jobs:
source env/bin/activate
pip install -e .
echo "Running pipeline"
boreholes-extract-all -l -i example/example_borehole_profile.pdf -o example/ -p example/predictions.json -m example/metadata.json -g example/example_groundtruth.json -pa all
boreholes-extract-all -l -i example/example_borehole_profile.pdf -o example/ -p example/predictions.json -m example/metadata.json -g example/example_groundtruth.json -pa all
echo "Running scoring script"
boreholes-score --ground-truth-path example/example_groundtruth.json --predictions-path example/predictions.json --no-mlflow-tracking
10 changes: 10 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
"justMyCode": true,
"python": "${workspaceFolder}/swisstopo/bin/python3",
},
{
"name": "Python: Run scoring",
"type": "debugpy",
"request": "launch",
"module": "src.stratigraphy.benchmark.score",
"args": [],
"cwd": "${workspaceFolder}",
"justMyCode": true,
"python": "./swisstopo/bin/python3",
},
{
"name": "Python: Run label studio to GT",
"type": "debugpy",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ all = ["swissgeol-boreholes-dataextraction[test, lint, experiment-tracking, visu
boreholes-extract-all = "stratigraphy.main:click_pipeline"
boreholes-extract-metadata = "stratigraphy.main:click_pipeline_metadata"
boreholes-download-profiles = "stratigraphy.get_files:download_directory_froms3"
boreholes-score = "stratigraphy.benchmark.score:main"

[tool.ruff.lint]
select = [
Expand Down
126 changes: 70 additions & 56 deletions src/stratigraphy/annotations/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import fitz
import pandas as pd
from dotenv import load_dotenv
from stratigraphy.depthcolumn.depthcolumn import DepthColumn
from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs
from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformationOnPage
from stratigraphy.layer.layer import LayerPrediction
from stratigraphy.metadata.coordinate_extraction import Coordinate
from stratigraphy.metadata.elevation_extraction import Elevation
from stratigraphy.text.textblock import TextBlock
from stratigraphy.util.interval import BoundaryInterval
from stratigraphy.util.predictions import FilePredictions
from stratigraphy.util.predictions import OverallFilePredictions

load_dotenv()

Expand All @@ -25,7 +27,7 @@


def draw_predictions(
predictions: dict[str, FilePredictions],
predictions: OverallFilePredictions,
directory: Path,
out_directory: Path,
document_level_metadata_metrics: pd.DataFrame,
Expand All @@ -51,66 +53,78 @@ def draw_predictions(
"""
if directory.is_file(): # deal with the case when we pass a file instead of a directory
directory = directory.parent
for file_name, file_prediction in predictions.items():
logger.info("Drawing predictions for file %s", file_name)
for file_prediction in predictions.file_predictions_list:
logger.info("Drawing predictions for file %s", file_prediction.file_name)

depths_materials_column_pairs = file_prediction.depths_materials_columns_pairs
coordinates = file_prediction.metadata.coordinates
elevation = file_prediction.metadata.elevation

# Assess the correctness of the metadata
is_coordinates_correct = document_level_metadata_metrics.loc[file_name].coordinate
is_elevation_correct = document_level_metadata_metrics.loc[file_name].elevation

with fitz.Document(directory / file_name) as doc:
for page_index, page in enumerate(doc):
page_number = page_index + 1
shape = page.new_shape() # Create a shape object for drawing
if page_number == 1:
draw_metadata(
if file_prediction.file_name in document_level_metadata_metrics.index:
is_coordinates_correct = document_level_metadata_metrics.loc[file_prediction.file_name].coordinate
is_elevation_correct = document_level_metadata_metrics.loc[file_prediction.file_name].elevation
else:
logger.warning(
"Metrics for file %s not found in document_level_metadata_metrics.", file_prediction.file_name
)
is_coordinates_correct = False
is_elevation_correct = False

try:
with fitz.Document(directory / file_prediction.file_name) as doc:
for page_index, page in enumerate(doc):
page_number = page_index + 1
shape = page.new_shape() # Create a shape object for drawing
if page_number == 1:
draw_metadata(
shape,
page.derotation_matrix,
page.rotation,
coordinates,
is_coordinates_correct,
elevation,
is_elevation_correct,
)
if coordinates is not None and page_number == coordinates.page:
draw_coordinates(shape, coordinates)
if elevation is not None and page_number == elevation.page:
draw_elevation(shape, elevation)
for groundwater_entry in file_prediction.groundwater_entries:
if page_number == groundwater_entry.page:
draw_groundwater(shape, groundwater_entry)
draw_depth_columns_and_material_rect(
shape,
page.derotation_matrix,
page.rotation,
coordinates,
is_coordinates_correct,
elevation,
is_elevation_correct,
[pair for pair in depths_materials_column_pairs if pair.page == page_number],
)
if coordinates is not None and page_number == coordinates.page:
draw_coordinates(shape, coordinates)
if elevation is not None and page_number == elevation.page:
draw_elevation(shape, elevation)
for groundwater_entry in file_prediction.groundwater_entries:
if page_number == groundwater_entry.page:
draw_groundwater(shape, groundwater_entry)
draw_depth_columns_and_material_rect(
shape,
page.derotation_matrix,
[pair for pair in depths_materials_column_pairs if pair["page"] == page_number],
)
draw_material_descriptions(
shape,
page.derotation_matrix,
[
layer
for layer in file_prediction.layers
if layer.material_description.page_number == page_number
],
)
shape.commit() # Commit all the drawing operations to the page
draw_material_descriptions(
shape,
page.derotation_matrix,
[
layer
for layer in file_prediction.layers
if layer.material_description.page_number == page_number
],
)
shape.commit() # Commit all the drawing operations to the page

tmp_file_path = out_directory / f"{file_prediction.file_name}_page{page_number}.png"
fitz.utils.get_pixmap(page, matrix=fitz.Matrix(2, 2), clip=page.rect).save(tmp_file_path)

tmp_file_path = out_directory / f"{file_name}_page{page_number}.png"
fitz.utils.get_pixmap(page, matrix=fitz.Matrix(2, 2), clip=page.rect).save(tmp_file_path)
if mlflow_tracking: # This is only executed if MLFlow tracking is enabled
try:
import mlflow

if mlflow_tracking: # This is only executed if MLFlow tracking is enabled
try:
import mlflow
mlflow.log_artifact(tmp_file_path, artifact_path="pages")
except NameError:
logger.warning("MLFlow could not be imported. Skipping logging of artifact.")

mlflow.log_artifact(tmp_file_path, artifact_path="pages")
except NameError:
logger.warning("MLFlow could not be imported. Skipping logging of artifact.")
except (FileNotFoundError, fitz.FileDataError) as e:
logger.error("Error opening file %s: %s", file_prediction.file_name, e)
continue

logger.info("Finished drawing predictions for file %s", file_name)
logger.info("Finished drawing predictions for file %s", file_prediction.file_name)


def draw_metadata(
Expand Down Expand Up @@ -238,7 +252,7 @@ def draw_material_descriptions(


def draw_depth_columns_and_material_rect(
shape: fitz.Shape, derotation_matrix: fitz.Matrix, depths_materials_column_pairs: list
shape: fitz.Shape, derotation_matrix: fitz.Matrix, depths_materials_column_pairs: list[DepthsMaterialsColumnPairs]
):
"""Draw depth columns as well as the material rects on a pdf page.
Expand All @@ -253,17 +267,17 @@ def draw_depth_columns_and_material_rect(
depths_materials_column_pairs (list): List of depth column entries.
"""
for pair in depths_materials_column_pairs:
depth_column = pair["depth_column"]
material_description_rect = pair["material_description_rect"]
depth_column: DepthColumn = pair.depth_column
material_description_rect = pair.material_description_rect

if depth_column is not None: # Draw rectangle for depth columns
if depth_column: # Draw rectangle for depth columns
shape.draw_rect(
fitz.Rect(depth_column["rect"]) * derotation_matrix,
fitz.Rect(depth_column.rect()) * derotation_matrix,
)
shape.finish(color=fitz.utils.getColor("green"))
for depth_column_entry in depth_column["entries"]: # Draw rectangle for depth column entries
for depth_column_entry in depth_column.entries: # Draw rectangle for depth column entries
shape.draw_rect(
fitz.Rect(depth_column_entry["rect"]) * derotation_matrix,
fitz.Rect(depth_column_entry.rect) * derotation_matrix,
)
shape.finish(color=fitz.utils.getColor("purple"))

Expand Down
69 changes: 46 additions & 23 deletions src/stratigraphy/benchmark/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for keeping track of metrics such as the F1-score, precision and recall."""

from collections import defaultdict
from collections.abc import Callable

import pandas as pd
Expand Down Expand Up @@ -52,6 +53,10 @@ def to_dataframe(self, name: str, fn: Callable[[Metrics], float]) -> pd.DataFram
series = pd.Series({filename: fn(metric) for filename, metric in self.metrics.items()})
return series.to_frame(name=name)

def get_metrics_list(self) -> list[Metrics]:
"""Return a list of all metrics."""
return list(self.metrics.values())


class DatasetMetricsCatalog:
"""Keeps track of all different relevant metrics that are computed for a dataset."""
Expand All @@ -78,26 +83,44 @@ def document_level_metrics_df(self) -> pd.DataFrame:

def metrics_dict(self) -> dict[str, float]:
"""Return a dictionary with the overall metrics."""
groundwater_metrics = Metrics.micro_average(self.metrics["groundwater"].metrics.values())
groundwater_depth_metrics = Metrics.micro_average(self.metrics["groundwater_depth"].metrics.values())

return {
"F1": self.metrics["layer"].pseudo_macro_f1(),
"recall": self.metrics["layer"].macro_recall(),
"precision": self.metrics["layer"].macro_precision(),
"depth_interval_accuracy": self.metrics["depth_interval"].macro_precision(),
"de_F1": self.metrics["de_layer"].pseudo_macro_f1(),
"de_recall": self.metrics["de_layer"].macro_recall(),
"de_precision": self.metrics["de_layer"].macro_precision(),
"de_depth_interval_accuracy": self.metrics["de_depth_interval"].macro_precision(),
"fr_F1": self.metrics["fr_layer"].pseudo_macro_f1(),
"fr_recall": self.metrics["fr_layer"].macro_recall(),
"fr_precision": self.metrics["fr_layer"].macro_precision(),
"fr_depth_interval_accuracy": self.metrics["fr_depth_interval"].macro_precision(),
"groundwater_f1": groundwater_metrics.f1,
"groundwater_recall": groundwater_metrics.recall,
"groundwater_precision": groundwater_metrics.precision,
"groundwater_depth_f1": groundwater_depth_metrics.f1,
"groundwater_depth_recall": groundwater_depth_metrics.recall,
"groundwater_depth_precision": groundwater_depth_metrics.precision,
}
# Initialize a defaultdict to automatically return 0.0 for missing keys
result = defaultdict(lambda: None)

# Safely compute groundwater metrics using .get() to avoid KeyErrors
groundwater_metrics = Metrics.micro_average(
self.metrics.get("groundwater", DatasetMetrics()).get_metrics_list()
)
groundwater_depth_metrics = Metrics.micro_average(
self.metrics.get("groundwater_depth", DatasetMetrics()).get_metrics_list()
)

# Populate the basic metrics
result.update(
{
"F1": self.metrics.get("layer", DatasetMetrics()).pseudo_macro_f1(),
"recall": self.metrics.get("layer", DatasetMetrics()).macro_recall(),
"precision": self.metrics.get("layer", DatasetMetrics()).macro_precision(),
"depth_interval_accuracy": self.metrics.get("depth_interval", DatasetMetrics()).macro_precision(),
"groundwater_f1": groundwater_metrics.f1,
"groundwater_recall": groundwater_metrics.recall,
"groundwater_precision": groundwater_metrics.precision,
"groundwater_depth_f1": groundwater_depth_metrics.f1,
"groundwater_depth_recall": groundwater_depth_metrics.recall,
"groundwater_depth_precision": groundwater_depth_metrics.precision,
}
)

# Add dynamic language-specific metrics only if they exist
for lang in ["de", "fr"]:
layer_key = f"{lang}_layer"
depth_key = f"{lang}_depth_interval"

if layer_key in self.metrics:
result[f"{lang}_F1"] = self.metrics[layer_key].pseudo_macro_f1()
result[f"{lang}_recall"] = self.metrics[layer_key].macro_recall()
result[f"{lang}_precision"] = self.metrics[layer_key].macro_precision()

if depth_key in self.metrics:
result[f"{lang}_depth_interval_accuracy"] = self.metrics[depth_key].macro_precision()

return dict(result) # Convert defaultdict back to a regular dict
Loading

1 comment on commit 94afd15

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/stratigraphy
   __init__.py8188%11
   extract.py1921920%3–477
   get_files.py19190%3–47
   main.py1191190%3–314
src/stratigraphy/benchmark
   ground_truth.py21195%47
src/stratigraphy/data_extractor
   data_extractor.py57395%33, 66, 103
src/stratigraphy/depthcolumn
   boundarydepthcolumnvalidator.py412051%47, 57, 60, 81–84, 110–128, 140–149
   depthcolumn.py2238064%25, 29, 50, 67, 72, 78, 86–92, 113, 116, 124–125, 141–143, 150, 157, 165–166, 176, 193–209, 247, 255–256, 272–274, 312, 331–339, 350, 355, 362, 393, 398–405, 420–421, 464–506
   depthcolumnentry.py361072%17, 21, 37, 52, 55, 72, 81, 98–100
   find_depth_columns.py1061982%42–43, 73, 86, 180–181, 225–245
src/stratigraphy/depths_materials_column_pairs
   depths_materials_column_pairs.py18667%23, 34, 55–59
src/stratigraphy/evaluation
   evaluation_dataclasses.py491667%24, 33, 42–44, 52, 71–74, 90, 104, 125–131, 137
   metadata_evaluator.py381463%52–71, 94–101
src/stratigraphy/groundwater
   groundwater_extraction.py1479833%44, 52, 83, 98, 106, 125, 152–156, 171–191, 202–291, 307–339
   utility.py393315%10–17, 30–47, 59–73, 88–102
src/stratigraphy/layer
   layer.py371851%29, 39, 57–94, 106–107
   layer_identifier_column.py745230%16–17, 20, 28, 43, 47, 51, 59–63, 66, 74, 91–96, 99, 112, 125–126, 148–158, 172–199
src/stratigraphy/lines
   geometric_line_utilities.py86298%81, 131
   line.py51492%25, 50, 60, 110
   linesquadtree.py46198%75
src/stratigraphy/metadata
   coordinate_extraction.py108595%30, 64, 94–95, 107
   elevation_extraction.py795234%34, 42, 50, 66–69, 106–120, 132–135, 147–179, 194–202, 210–214
   language_detection.py181328%17–23, 37–45
   metadata.py662464%27, 83, 101–127, 146–155, 195–198, 206
src/stratigraphy/text
   description_block_splitter.py70297%24, 139
   extract_text.py29390%19, 53–54
   find_description.py642856%27–35, 50–63, 79–95, 172–175
   textblock.py80989%28, 56, 64, 89, 101, 124, 145, 154, 183
src/stratigraphy/util
   dataclasses.py32391%37–39
   interval.py1045547%29–32, 37–40, 46, 52, 56, 66–68, 107–153, 174, 180–196
   predictions.py1054161%74–78, 86–94, 171–203, 242, 265–283
   util.py391756%41, 69–76, 90–92, 116–117, 129–133
TOTAL220796057% 

Tests Skipped Failures Errors Time
88 0 💤 0 ❌ 0 🔥 7.098s ⏱️

Please sign in to comment.