-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4574922
commit b6970a4
Showing
3 changed files
with
137 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import argparse | ||
import json | ||
import logging | ||
from pathlib import Path | ||
|
||
import arrow | ||
|
||
from news_signals.signals import AylienSignal | ||
from news_signals.signals_dataset import generate_dataset, reduce_aylien_story | ||
from news_signals.dataset_transformations import get_dataset_transform | ||
from news_signals.log import create_logger | ||
|
||
|
||
logger = create_logger(__name__, level=logging.INFO) | ||
|
||
|
||
def main(args): | ||
with open(args.config) as f: | ||
config = json.load(f) | ||
|
||
output_dataset_path = Path(config["output_dataset_dir"]) | ||
logger.info(f"Beginning dataset generation with config {config}") | ||
if config.get('signal_configs', None) is not None and config.get('input', None) is not None: | ||
raise AssertionError("Cannot specify both signal_configs and input file path in dataset generation config") | ||
if config.get('signal_configs', None) is not None: | ||
aylien_signals = [AylienSignal(**signal_config) for signal_config in config["signal_configs"]] | ||
input = aylien_signals | ||
else: | ||
input = Path(config["input"]) | ||
dataset = generate_dataset( | ||
input=input, | ||
output_dataset_dir=output_dataset_path, | ||
gcs_bucket=config.get("gcs_bucket", None), | ||
start=arrow.get(config["start"]).datetime, | ||
end=arrow.get(config["end"]).datetime, | ||
stories_per_day=config["stories_per_day"], | ||
name_field=config.get("name_field", None), | ||
id_field=config.get("id_field", None), | ||
surface_form_field=config.get("surface_form_field", None), | ||
overwrite=args.overwrite, | ||
delete_tmp_files=True, | ||
compress=True, | ||
post_process_story=reduce_aylien_story | ||
) | ||
|
||
if config.get("transformations"): | ||
for t in config["transformations"]: | ||
logger.info(f"Applying transformation to dataset: {t['transform']}") | ||
transform = get_dataset_transform(t['transform']) | ||
transform(dataset, **t['params']) | ||
|
||
dataset.save(output_dataset_path, overwrite=True, compress=True) | ||
logger.info(f"Finished dataset generation, dataset saved here: {output_dataset_path}.tar.gz") | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--config', | ||
required=True, | ||
help="config json file containing all settings to create new dataset" | ||
) | ||
parser.add_argument( | ||
'--overwrite', | ||
action="store_true", | ||
help="whether to overwrite previous dataset if present at output_dataset_dir" | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
if __name__ == '__main__': | ||
main(parse_args()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import argparse | ||
import json | ||
import logging | ||
from pathlib import Path | ||
|
||
from news_signals.signals_dataset import SignalsDataset | ||
from news_signals.dataset_transformations import get_dataset_transform | ||
from news_signals.log import create_logger | ||
|
||
|
||
logger = create_logger(__name__, level=logging.INFO) | ||
|
||
|
||
def main(args): | ||
if args.output_dataset_path is None: | ||
output_dataset_path = Path(args.input_dataset_path) | ||
else: | ||
output_dataset_path = Path(args.output_dataset_path) | ||
|
||
with open(args.config) as f: | ||
config = json.load(f) | ||
|
||
if (args.input_dataset_path == args.output_dataset_path) or output_dataset_path.exists(): | ||
confirm = input( | ||
f"Are you sure you want to modify the dataset in {output_dataset_path}? " | ||
"Alternatively, you can set output_dataset_path to a new directory. (y|n) " | ||
) | ||
if confirm != "y": | ||
logger.info("Aborting") | ||
return | ||
|
||
dataset = SignalsDataset.load(args.input_dataset_path) | ||
|
||
# config is a list of transformations | ||
for t in config: | ||
logger.info(f"Applying transformation to dataset: {t['transform']}") | ||
transform = get_dataset_transform(t['transform']) | ||
transform(dataset, **t['params']) | ||
|
||
if str(output_dataset_path).endswith('.tar.gz'): | ||
dataset.save(output_dataset_path, overwrite=True, compress=True) | ||
else: | ||
dataset.save(output_dataset_path, overwrite=True) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--input-dataset-path', | ||
required=True, | ||
) | ||
parser.add_argument( | ||
'--output-dataset-path', | ||
) | ||
parser.add_argument( | ||
'--config', | ||
help="JSON string with config" | ||
), | ||
return parser.parse_args() | ||
|
||
|
||
if __name__ == '__main__': | ||
main(parse_args()) |