Skip to content

Commit

Permalink
minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
redur committed Apr 10, 2024
1 parent 3f37f57 commit 10e9c6e
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions src/stratigraphy/benchmark/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def evaluate_matching(predictions: dict, number_of_truth_values: dict) -> tuple[
"Number Elements": [],
"Number wrong elements": [],
}
# separate list to calculate the overall depth interval accuracy is required,
# as the depth interval accuracy is not calculated for documents with no correct
# material predictions.
depth_interval_accuracies = []
for filename, file_prediction in predictions.items():
hits = 0
depth_interval_hits = 0
Expand All @@ -93,7 +97,7 @@ def evaluate_matching(predictions: dict, number_of_truth_values: dict) -> tuple[
depth_interval_occurences += 1

if parse_text(layer.material_description.text) == "":
print("Empty string found in predictions")
logger.warning("Empty string found in predictions")
tp = hits
fp = len(file_prediction.layers) - tp
fn = number_of_truth_values[filename] - tp
Expand All @@ -110,22 +114,24 @@ def evaluate_matching(predictions: dict, number_of_truth_values: dict) -> tuple[
document_level_metrics["F1"].append(f1(precision, recall))
document_level_metrics["Number Elements"].append(number_of_truth_values[filename])
document_level_metrics["Number wrong elements"].append(fn + fp)
document_level_metrics["Depth_interval_accuracy"].append(depth_interval_hits / depth_interval_occurences)
try:
document_level_metrics["Depth_interval_accuracy"].append(depth_interval_hits / depth_interval_occurences)
depth_interval_accuracies.append(depth_interval_hits / depth_interval_occurences)
except ZeroDivisionError:
document_level_metrics["Depth_interval_accuracy"].append(None)

if len(document_level_metrics["precision"]):
overall_precision = sum(document_level_metrics["precision"]) / len(document_level_metrics["precision"])
overall_recall = sum(document_level_metrics["recall"]) / len(document_level_metrics["recall"])
overall_depth_interval_accuracy = sum(document_level_metrics["Depth_interval_accuracy"]) / len(
document_level_metrics["Depth_interval_accuracy"]
)
overall_depth_interval_accuracy = sum(depth_interval_accuracies) / len(depth_interval_accuracies)
else:
overall_precision = 0
overall_recall = 0

logging.info("Macro avg:")
logging.info(
f"F1: {f1(overall_precision, overall_recall):.1%},"
f"precision: {overall_precision:.1%}, recall: {overall_recall:.1%},"
f"F1: {f1(overall_precision, overall_recall):.1%}, "
f"precision: {overall_precision:.1%}, recall: {overall_recall:.1%}, "
f"depth_interval_accuracy: {overall_depth_interval_accuracy:.1%}"
)

Expand Down Expand Up @@ -191,7 +197,6 @@ def create_predictions_objects(predictions: dict, ground_truth_path: Path) -> tu

predictions_objects[file_name] = prediction_object
if ground_truth_is_present:
print(file_name)
predictions_objects[file_name].evaluate(ground_truth.for_file(file_name))
number_of_truth_values[file_name] = ground_truth.for_file(file_name).num_layers

Expand Down

0 comments on commit 10e9c6e

Please sign in to comment.