diff --git a/news_signals/cli.py b/news_signals/cli.py index 8052ae0..1007781 100644 --- a/news_signals/cli.py +++ b/news_signals/cli.py @@ -15,5 +15,5 @@ def generate_dataset(): logger.info("Generating dataset via cli command `generate-dataset`") - print(f"Args: {sys.argv[1:]}") - sys.exit(subprocess.call([sys.executable, str(path_to_file / '../bin/generate_dataset.py')] + sys.argv[1:])) + logger.info(f"Args: {sys.argv[1:]}") + sys.exit(subprocess.call([sys.executable, str(path_to_file / 'generate_dataset.py')] + sys.argv[1:])) diff --git a/news_signals/generate_dataset.py b/news_signals/generate_dataset.py new file mode 100644 index 0000000..3b3ef76 --- /dev/null +++ b/news_signals/generate_dataset.py @@ -0,0 +1,72 @@ +import argparse +import json +import logging +from pathlib import Path + +import arrow + +from news_signals.signals import AylienSignal +from news_signals.signals_dataset import generate_dataset, reduce_aylien_story +from news_signals.dataset_transformations import get_dataset_transform +from news_signals.log import create_logger + + +logger = create_logger(__name__, level=logging.INFO) + + +def main(args): + with open(args.config) as f: + config = json.load(f) + + output_dataset_path = Path(config["output_dataset_dir"]) + logger.info(f"Beginning dataset generation with config {config}") + if config.get('signal_configs', None) is not None and config.get('input', None) is not None: + raise AssertionError("Cannot specify both signal_configs and input file path in dataset generation config") + if config.get('signal_configs', None) is not None: + aylien_signals = [AylienSignal(**signal_config) for signal_config in config["signal_configs"]] + input = aylien_signals + else: + input = Path(config["input"]) + dataset = generate_dataset( + input=input, + output_dataset_dir=output_dataset_path, + gcs_bucket=config.get("gcs_bucket", None), + start=arrow.get(config["start"]).datetime, + end=arrow.get(config["end"]).datetime, + stories_per_day=config["stories_per_day"], + name_field=config.get("name_field", None), + id_field=config.get("id_field", None), + surface_form_field=config.get("surface_form_field", None), + overwrite=args.overwrite, + delete_tmp_files=True, + compress=True, + post_process_story=reduce_aylien_story + ) + + if config.get("transformations"): + for t in config["transformations"]: + logger.info(f"Applying transformation to dataset: {t['transform']}") + transform = get_dataset_transform(t['transform']) + transform(dataset, **t['params']) + + dataset.save(output_dataset_path, overwrite=True, compress=True) + logger.info(f"Finished dataset generation, dataset saved here: {output_dataset_path}.tar.gz") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--config', + required=True, + help="config json file containing all settings to create new dataset" + ) + parser.add_argument( + '--overwrite', + action="store_true", + help="whether to overwrite previous dataset if present at output_dataset_dir" + ) + return parser.parse_args() + + +if __name__ == '__main__': + main(parse_args()) diff --git a/news_signals/transform_dataset.py b/news_signals/transform_dataset.py new file mode 100644 index 0000000..b8b19ec --- /dev/null +++ b/news_signals/transform_dataset.py @@ -0,0 +1,63 @@ +import argparse +import json +import logging +from pathlib import Path + +from news_signals.signals_dataset import SignalsDataset +from news_signals.dataset_transformations import get_dataset_transform +from news_signals.log import create_logger + + +logger = create_logger(__name__, level=logging.INFO) + + +def main(args): + if args.output_dataset_path is None: + output_dataset_path = Path(args.input_dataset_path) + else: + output_dataset_path = Path(args.output_dataset_path) + + with open(args.config) as f: + config = json.load(f) + + if (args.input_dataset_path == args.output_dataset_path) or output_dataset_path.exists(): + confirm = input( + f"Are you sure you want to modify the dataset in {output_dataset_path}? " + "Alternatively, you can set output_dataset_path to a new directory. (y|n) " + ) + if confirm != "y": + logger.info("Aborting") + return + + dataset = SignalsDataset.load(args.input_dataset_path) + + # config is a list of transformations + for t in config: + logger.info(f"Applying transformation to dataset: {t['transform']}") + transform = get_dataset_transform(t['transform']) + transform(dataset, **t['params']) + + if str(output_dataset_path).endswith('.tar.gz'): + dataset.save(output_dataset_path, overwrite=True, compress=True) + else: + dataset.save(output_dataset_path, overwrite=True) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--input-dataset-path', + required=True, + ) + parser.add_argument( + '--output-dataset-path', + ) + parser.add_argument( + '--config', + help="JSON string with config" + ), + return parser.parse_args() + + +if __name__ == '__main__': + main(parse_args())