Skip to content

Commit

Permalink
Merge pull request #87 from swisstopo/LGVISIUM-80-Refactor-the-ground…
Browse files Browse the repository at this point in the history
…water-object-and-its-evaluation

Close #LGVISIUM-80: Refactor the groundwater object and its evaluation
  • Loading branch information
dcleres authored Nov 7, 2024
2 parents a5d0947 + 4e3d50e commit b2998f0
Show file tree
Hide file tree
Showing 30 changed files with 1,242 additions and 653 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pipeline_run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.11'
- name: Create Environment and run pipeline
shell: bash
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: 3.10.14
python-version: '3.11'
- uses: pre-commit/[email protected]
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.11'
- name: Create Environment and run tests
shell: bash
run: |
Expand Down
56 changes: 56 additions & 0 deletions example/example_gw_groundtruth.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"example_borehole_profile.pdf": {
"groundwater": [
{
"date": "2016-04-18",
"depth": 2.22,
"elevation": 448.07
},
{
"date": "2016-04-20",
"depth": 3.22,
"elevation": 447.07
}
],
"layers": [],
"metadata": {
"coordinates": {
"E": 615790,
"N": 157500
},
"drilling_date": "1995-09-03",
"drilling_methods": null,
"original_name": "",
"project_name": "",
"reference_elevation": 788.6,
"total_depth": null
}
},
"example_borehole_profile_2.pdf": {
"groundwater": [
{
"date": "2016-04-18",
"depth": 2.22,
"elevation": 448.07
},
{
"date": "2016-04-20",
"depth": 3.22,
"elevation": 447.07
}
],
"layers": [],
"metadata": {
"coordinates": {
"E": 615790,
"N": 157500
},
"drilling_date": "1995-09-03",
"drilling_methods": null,
"original_name": "",
"project_name": "",
"reference_elevation": 788.6,
"total_depth": null
}
}
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "swissgeol-boreholes-dataextraction"
version = "0.0.1-dev"
description = "Python project to analyse borehole profiles."
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
dependencies = [
"boto3",
"pandas",
Expand Down
21 changes: 10 additions & 11 deletions src/stratigraphy/annotations/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import fitz
import pandas as pd
from dotenv import load_dotenv
from stratigraphy.data_extractor.data_extractor import FeatureOnPage
from stratigraphy.depthcolumn.depthcolumn import DepthColumn
from stratigraphy.depths_materials_column_pairs.depths_materials_column_pairs import DepthsMaterialsColumnPairs
from stratigraphy.groundwater.groundwater_extraction import GroundwaterInformationOnPage
from stratigraphy.layer.layer import LayerPrediction
from stratigraphy.groundwater.groundwater_extraction import Groundwater
from stratigraphy.layer.layer import Layer
from stratigraphy.metadata.coordinate_extraction import Coordinate
from stratigraphy.metadata.elevation_extraction import Elevation
from stratigraphy.text.textblock import TextBlock
Expand Down Expand Up @@ -90,7 +91,7 @@ def draw_predictions(
draw_coordinates(shape, coordinates)
if elevation is not None and page_number == elevation.page:
draw_elevation(shape, elevation)
for groundwater_entry in file_prediction.groundwater_entries:
for groundwater_entry in file_prediction.groundwater.groundwater:
if page_number == groundwater_entry.page:
draw_groundwater(shape, groundwater_entry)
draw_depth_columns_and_material_rect(
Expand All @@ -103,7 +104,7 @@ def draw_predictions(
page.derotation_matrix,
[
layer
for layer in file_prediction.layers
for layer in file_prediction.layers.get_all_layers()
if layer.material_description.page_number == page_number
],
)
Expand Down Expand Up @@ -197,19 +198,19 @@ def draw_coordinates(shape: fitz.Shape, coordinates: Coordinate) -> None:
shape.finish(color=fitz.utils.getColor("purple"))


def draw_groundwater(shape: fitz.Shape, groundwater_entry: GroundwaterInformationOnPage) -> None:
"""Draw a bounding box around the area of the page where the coordinates were extracted from.
def draw_groundwater(shape: fitz.Shape, groundwater_entry: FeatureOnPage[Groundwater]) -> None:
"""Draw a bounding box around the area of the page where the groundwater information was extracted from.
Args:
shape (fitz.Shape): The shape object for drawing.
groundwater_entry (GroundwaterInformationOnPage): The groundwater information to draw.
groundwater_entry (FeatureOnPage[Groundwater]): The groundwater information to draw.
"""
shape.draw_rect(groundwater_entry.rect)
shape.finish(color=fitz.utils.getColor("pink"))


def draw_elevation(shape: fitz.Shape, elevation: Elevation) -> None:
"""Draw a bounding box around the area of the page where the coordinates were extracted from.
"""Draw a bounding box around the area of the page where the elevation were extracted from.
Args:
shape (fitz.Shape): The shape object for drawing.
Expand All @@ -219,9 +220,7 @@ def draw_elevation(shape: fitz.Shape, elevation: Elevation) -> None:
shape.finish(color=fitz.utils.getColor("blue"))


def draw_material_descriptions(
shape: fitz.Shape, derotation_matrix: fitz.Matrix, layers: list[LayerPrediction]
) -> None:
def draw_material_descriptions(shape: fitz.Shape, derotation_matrix: fitz.Matrix, layers: list[Layer]) -> None:
"""Draw information about material descriptions on a pdf page.
In particular, this function:
Expand Down
2 changes: 1 addition & 1 deletion src/stratigraphy/benchmark/ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class GroundTruth:
"""Ground truth data for the stratigraphy benchmark."""

def __init__(self, path: Path):
def __init__(self, path: Path) -> None:
self.ground_truth = defaultdict(dict)

# Load the ground truth data
Expand Down
71 changes: 39 additions & 32 deletions src/stratigraphy/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stratigraphy.evaluation.evaluation_dataclasses import Metrics


class DatasetMetrics:
class OverallMetrics:
"""Keeps track of a particular metrics for all documents in a dataset."""

# TODO: Currently, some methods for averaging metrics are in the Metrics class.
Expand Down Expand Up @@ -58,23 +58,32 @@ def get_metrics_list(self) -> list[Metrics]:
return list(self.metrics.values())


class DatasetMetricsCatalog:
class OverallMetricsCatalog:
"""Keeps track of all different relevant metrics that are computed for a dataset."""

def __init__(self):
self.metrics: dict[str, DatasetMetrics] = {}
def __init__(self, languages: set[str]):
self.layer_metrics = OverallMetrics()
self.depth_interval_metrics = OverallMetrics()
self.groundwater_metrics = OverallMetrics()
self.groundwater_depth_metrics = OverallMetrics()
self.languages = languages

# Initialize language-specific metrics
for lang in languages:
setattr(self, f"{lang}_layer_metrics", OverallMetrics())
setattr(self, f"{lang}_depth_interval_metrics", OverallMetrics())

def document_level_metrics_df(self) -> pd.DataFrame:
"""Return a DataFrame with all the document level metrics."""
all_series = [
self.metrics["layer"].to_dataframe("F1", lambda metric: metric.f1),
self.metrics["layer"].to_dataframe("precision", lambda metric: metric.precision),
self.metrics["layer"].to_dataframe("recall", lambda metric: metric.recall),
self.metrics["depth_interval"].to_dataframe("Depth_interval_accuracy", lambda metric: metric.precision),
self.metrics["layer"].to_dataframe("Number Elements", lambda metric: metric.tp + metric.fn),
self.metrics["layer"].to_dataframe("Number wrong elements", lambda metric: metric.fp + metric.fn),
self.metrics["groundwater"].to_dataframe("groundwater", lambda metric: metric.f1),
self.metrics["groundwater_depth"].to_dataframe("groundwater_depth", lambda metric: metric.f1),
self.layer_metrics.to_dataframe("F1", lambda metric: metric.f1),
self.layer_metrics.to_dataframe("precision", lambda metric: metric.precision),
self.layer_metrics.to_dataframe("recall", lambda metric: metric.recall),
self.depth_interval_metrics.to_dataframe("Depth_interval_accuracy", lambda metric: metric.precision),
self.layer_metrics.to_dataframe("Number Elements", lambda metric: metric.tp + metric.fn),
self.layer_metrics.to_dataframe("Number wrong elements", lambda metric: metric.fp + metric.fn),
self.groundwater_metrics.to_dataframe("groundwater", lambda metric: metric.f1),
self.groundwater_depth_metrics.to_dataframe("groundwater_depth", lambda metric: metric.f1),
]
document_level_metrics = pd.DataFrame()
for series in all_series:
Expand All @@ -86,21 +95,19 @@ def metrics_dict(self) -> dict[str, float]:
# Initialize a defaultdict to automatically return 0.0 for missing keys
result = defaultdict(lambda: None)

# Safely compute groundwater metrics using .get() to avoid KeyErrors
groundwater_metrics = Metrics.micro_average(
self.metrics.get("groundwater", DatasetMetrics()).get_metrics_list()
)
groundwater_depth_metrics = Metrics.micro_average(
self.metrics.get("groundwater_depth", DatasetMetrics()).get_metrics_list()
)
# Compute the micro-average metrics for the groundwater and groundwater depth metrics
groundwater_metrics = Metrics.micro_average(self.groundwater_metrics.metrics.values())
groundwater_depth_metrics = Metrics.micro_average(self.groundwater_depth_metrics.metrics.values())

# Populate the basic metrics
result.update(
{
"F1": self.metrics.get("layer", DatasetMetrics()).pseudo_macro_f1(),
"recall": self.metrics.get("layer", DatasetMetrics()).macro_recall(),
"precision": self.metrics.get("layer", DatasetMetrics()).macro_precision(),
"depth_interval_accuracy": self.metrics.get("depth_interval", DatasetMetrics()).macro_precision(),
"F1": self.layer_metrics.pseudo_macro_f1() if self.layer_metrics else None,
"recall": self.layer_metrics.macro_recall() if self.layer_metrics else None,
"precision": self.layer_metrics.macro_precision() if self.layer_metrics else None,
"depth_interval_accuracy": self.depth_interval_metrics.macro_precision()
if self.depth_interval_metrics
else None,
"groundwater_f1": groundwater_metrics.f1,
"groundwater_recall": groundwater_metrics.recall,
"groundwater_precision": groundwater_metrics.precision,
Expand All @@ -111,16 +118,16 @@ def metrics_dict(self) -> dict[str, float]:
)

# Add dynamic language-specific metrics only if they exist
for lang in ["de", "fr"]:
layer_key = f"{lang}_layer"
depth_key = f"{lang}_depth_interval"
for lang in self.languages:
layer_key = f"{lang}_layer_metrics"
depth_key = f"{lang}_depth_interval_metrics"

if layer_key in self.metrics:
result[f"{lang}_F1"] = self.metrics[layer_key].pseudo_macro_f1()
result[f"{lang}_recall"] = self.metrics[layer_key].macro_recall()
result[f"{lang}_precision"] = self.metrics[layer_key].macro_precision()
if getattr(self, layer_key) and getattr(self, layer_key).metrics:
result[f"{lang}_F1"] = getattr(self, layer_key).pseudo_macro_f1()
result[f"{lang}_recall"] = getattr(self, layer_key).macro_recall()
result[f"{lang}_precision"] = getattr(self, layer_key).macro_precision()

if depth_key in self.metrics:
result[f"{lang}_depth_interval_accuracy"] = self.metrics[depth_key].macro_precision()
if getattr(self, depth_key) and getattr(self, depth_key).metrics:
result[f"{lang}_depth_interval_accuracy"] = getattr(self, depth_key).macro_precision()

return dict(result) # Convert defaultdict back to a regular dict
Loading

1 comment on commit b2998f0

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/stratigraphy
   __init__.py8188%11
   extract.py1931930%3–485
   get_files.py19190%3–47
   main.py1131130%3–309
src/stratigraphy/benchmark
   metrics.py594229%22–25, 29–32, 36–39, 46–49, 53–54, 58, 65–74, 78–91, 96–133
src/stratigraphy/data_extractor
   data_extractor.py75593%30, 39, 52, 129, 174
src/stratigraphy/depthcolumn
   boundarydepthcolumnvalidator.py412051%47, 57, 60, 81–84, 110–128, 140–149
   depthcolumn.py2278264%28, 33, 65, 82, 87, 93, 109–117, 138, 146, 154–155, 171–173, 188, 195, 208–209, 226, 254–273, 311, 319–320, 336–338, 383, 402–410, 421, 426, 433, 464, 469–476, 491–492, 534–575
   depthcolumnentry.py361072%17, 21, 37, 52, 55, 72, 81, 98–100
   find_depth_columns.py1061982%42–43, 73, 86, 180–181, 225–245
src/stratigraphy/depths_materials_column_pairs
   depths_materials_column_pairs.py19763%23, 34, 55–60
src/stratigraphy/evaluation
   evaluation_dataclasses.py491178%52, 71–74, 90, 104, 125–131, 147
   groundwater_evaluator.py48198%78
   metadata_evaluator.py391464%50–69, 92–99
   utility.py342041%48–57, 72–93
src/stratigraphy/groundwater
   groundwater_extraction.py1569937%52, 94, 127–132, 140, 167–171, 186–206, 217–306, 322–354
   utility.py393315%10–17, 30–47, 59–73, 88–102
src/stratigraphy/layer
   layer.py582066%31, 41, 59–96, 108–109, 121, 139
   layer_identifier_column.py795530%16–17, 20, 28, 43, 52, 61, 69–73, 81, 89, 106–111, 123, 136, 149–150, 169–172, 195–205, 219–246
src/stratigraphy/lines
   geometric_line_utilities.py86298%81, 131
   line.py51492%25, 50, 60, 110
   linesquadtree.py46198%75
src/stratigraphy/metadata
   coordinate_extraction.py110595%30, 68, 98–99, 111
   elevation_extraction.py906033%34–39, 47, 55, 63, 79–87, 124–138, 150–153, 165–197, 212–220, 228–232
   language_detection.py181328%17–23, 37–45
   metadata.py662464%27, 83, 101–127, 146–155, 195–198, 206
src/stratigraphy/text
   description_block_splitter.py70297%24, 139
   extract_text.py29390%19, 53–54
   find_description.py642856%27–35, 50–63, 79–95, 172–175
   textblock.py80989%28, 56, 64, 89, 101, 124, 145, 154, 183
src/stratigraphy/util
   dataclasses.py32391%37–39
   interval.py1045547%29–32, 37–40, 46, 52, 56, 66–68, 107–153, 174, 180–196
   predictions.py1387943%78, 86–94, 135, 158–184, 215–244, 259–295, 318–343, 349–353, 364
   util.py341362%41, 69–76, 90–92, 116–117
TOTAL2443106556% 

Tests Skipped Failures Errors Time
100 0 💤 0 ❌ 0 🔥 7.922s ⏱️

Please sign in to comment.