diff --git a/README.md b/README.md index 16c2a28d..5ebddf4f 100644 --- a/README.md +++ b/README.md @@ -43,9 +43,9 @@ To execute the data extraction pipeline, follow these steps: The main script for the extraction pipeline is located at `src/stratigraphy/main.py`. A cli command is created to run this script. - Run `boreholes-extract-layers` to run the main extraction script. With the default options, the command will source all PDFs from the `data/Benchmark` directory and create PNG files in the `data/Benchmark/extract` directory. + Run `boreholes-extract-all` to run the main extraction script. With the default options, the command will source all PDFs from the `data/Benchmark` directory and create PNG files in the `data/Benchmark/extract` directory. - Use `boreholes-extract-layers --help` to see all options for the extraction script. + Use `boreholes-extract-all --help` to see all options for the extraction script. 4. **Check the results** @@ -154,9 +154,9 @@ The project structure and the most important files are as follows: - `util/` : Utility scripts and modules. - `benchmark/` : Scripts to evaluate the data extraction. - `data/` : The data used by the project. - - `Benchmark/` : The directory containing the PDF files to be analyzed. - - `extract/` : The directory where the PNG files are saved. - - `predictions.json` : The output file of the project, containing the results of the data extraction process. + - `output/` : + - `draw/` : The directory where the PNG files are saved. + - `predictions.json` : The output file of the project, containing the results of the data extraction process. - `tests/` : The tests for the project. - `README.md` : The README file for the project. diff --git a/pyproject.toml b/pyproject.toml index 317463ef..75e3a4f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,11 @@ experiment-tracking = [ visualize = [ "matplotlib==3.8.0" ] +devtools = [ + "tqdm" +] -all = ["swissgeol-boreholes-dataextraction[test, lint, experiment-tracking, visualize]"] +all = ["swissgeol-boreholes-dataextraction[test, lint, experiment-tracking, visualize, devtools]"] [project.scripts] boreholes-extract-all = "stratigraphy.main:click_pipeline" diff --git a/src/stratigraphy/get_files.py b/src/stratigraphy/get_files.py index e980abff..87df4dca 100644 --- a/src/stratigraphy/get_files.py +++ b/src/stratigraphy/get_files.py @@ -1,6 +1,5 @@ """Script to download the borehole profiles from the S3 bucket.""" -import os from pathlib import Path import boto3 @@ -14,15 +13,15 @@ @click.option("--bucket-name", default="stijnvermeeren-boreholes-data", help="The name of the bucket.") @click.option( "--remote-directory-name", - default="data_v2/validation", + default="", help="The name of the directory in the bucket to be downloaded.", ) @click.option( "--output-path", default=DATAPATH, type=click.Path(path_type=Path), help="The path to save the downloaded files." ) def download_directory_froms3( - bucket_name: str = "stijnvermeeren-boreholes-data", - remote_directory_name: str = "data_v2/validation", + bucket_name: str, + remote_directory_name: str, output_path: Path = DATAPATH, ): """Download a directory from S3 bucket. @@ -31,17 +30,17 @@ def download_directory_froms3( \f Args: - bucketName (str): The name of the bucket. - remoteDirectoryName (str): The name of the directory in the bucket to be downloaded. + bucket_name (str): The name of the bucket. + remote_directory_name (str): The name of the directory in the bucket to be downloaded. + output_path (Path): Where to store the files locally """ # noqa: D301 s3_resource = boto3.resource("s3") bucket = s3_resource.Bucket(bucket_name) total_files = sum(1 for _ in bucket.objects.filter(Prefix=remote_directory_name)) # this is fast for obj in tqdm(bucket.objects.filter(Prefix=remote_directory_name), total=total_files): - Path(output_path / obj.key).parent.mkdir(parents=True, exist_ok=True) - if not os.path.exists(os.path.dirname(obj.key)): - os.makedirs(os.path.dirname(obj.key)) - bucket.download_file(obj.key, output_path / obj.key) # save to same path + if obj.key: + Path(output_path / obj.key).parent.mkdir(parents=True, exist_ok=True) + bucket.download_file(obj.key, output_path / obj.key) # save to same path if __name__ == "__main__": diff --git a/src/stratigraphy/main.py b/src/stratigraphy/main.py index 54a1b38e..f0411159 100644 --- a/src/stratigraphy/main.py +++ b/src/stratigraphy/main.py @@ -35,28 +35,26 @@ "-i", "--input-directory", type=click.Path(exists=True, path_type=Path), - default=DATAPATH / "Benchmark", help="Path to the input directory, or path to a single pdf file.", ) @click.option( "-g", "--ground-truth-path", type=click.Path(exists=False, path_type=Path), - default=DATAPATH / "Benchmark" / "ground_truth.json", help="Path to the ground truth file.", ) @click.option( "-o", "--out-directory", type=click.Path(path_type=Path), - default=DATAPATH / "Benchmark" / "evaluation", + default=DATAPATH / "output", help="Path to the output directory.", ) @click.option( "-p", "--predictions-path", type=click.Path(path_type=Path), - default=DATAPATH / "Benchmark" / "extract" / "predictions.json", + default=DATAPATH / "output" / "predictions.json", help="Path to the predictions file.", ) @click.option( @@ -144,8 +142,9 @@ def start_pipeline( temp_directory = DATAPATH / "_temp" # temporary directory to dump files for mlflow artifact logging - # check if directories exist and create them when neccessary - out_directory.mkdir(parents=True, exist_ok=True) + # check if directories exist and create them when necessary + draw_directory = out_directory / "draw" + draw_directory.mkdir(parents=True, exist_ok=True) temp_directory.mkdir(parents=True, exist_ok=True) # if a file is specified instead of an input directory, copy the file to a temporary directory and work with that. @@ -216,7 +215,7 @@ def start_pipeline( predictions, number_of_truth_values = create_predictions_objects(predictions, ground_truth_path) if not skip_draw_predictions: - draw_predictions(predictions, input_directory, out_directory) + draw_predictions(predictions, input_directory, draw_directory) if number_of_truth_values: # only evaluate if ground truth is available metrics, document_level_metrics = evaluate_borehole_extraction(predictions, number_of_truth_values)