Skip to content

Commit

Permalink
update cli
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishokamp committed Apr 15, 2024
1 parent 4574922 commit b6970a4
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 2 deletions.
4 changes: 2 additions & 2 deletions news_signals/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@

def generate_dataset():
logger.info("Generating dataset via cli command `generate-dataset`")
print(f"Args: {sys.argv[1:]}")
sys.exit(subprocess.call([sys.executable, str(path_to_file / '../bin/generate_dataset.py')] + sys.argv[1:]))
logger.info(f"Args: {sys.argv[1:]}")
sys.exit(subprocess.call([sys.executable, str(path_to_file / 'generate_dataset.py')] + sys.argv[1:]))
72 changes: 72 additions & 0 deletions news_signals/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
import json
import logging
from pathlib import Path

import arrow

from news_signals.signals import AylienSignal
from news_signals.signals_dataset import generate_dataset, reduce_aylien_story
from news_signals.dataset_transformations import get_dataset_transform
from news_signals.log import create_logger


logger = create_logger(__name__, level=logging.INFO)


def main(args):
with open(args.config) as f:
config = json.load(f)

output_dataset_path = Path(config["output_dataset_dir"])
logger.info(f"Beginning dataset generation with config {config}")
if config.get('signal_configs', None) is not None and config.get('input', None) is not None:
raise AssertionError("Cannot specify both signal_configs and input file path in dataset generation config")
if config.get('signal_configs', None) is not None:
aylien_signals = [AylienSignal(**signal_config) for signal_config in config["signal_configs"]]
input = aylien_signals
else:
input = Path(config["input"])
dataset = generate_dataset(
input=input,
output_dataset_dir=output_dataset_path,
gcs_bucket=config.get("gcs_bucket", None),
start=arrow.get(config["start"]).datetime,
end=arrow.get(config["end"]).datetime,
stories_per_day=config["stories_per_day"],
name_field=config.get("name_field", None),
id_field=config.get("id_field", None),
surface_form_field=config.get("surface_form_field", None),
overwrite=args.overwrite,
delete_tmp_files=True,
compress=True,
post_process_story=reduce_aylien_story
)

if config.get("transformations"):
for t in config["transformations"]:
logger.info(f"Applying transformation to dataset: {t['transform']}")
transform = get_dataset_transform(t['transform'])
transform(dataset, **t['params'])

dataset.save(output_dataset_path, overwrite=True, compress=True)
logger.info(f"Finished dataset generation, dataset saved here: {output_dataset_path}.tar.gz")


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--config',
required=True,
help="config json file containing all settings to create new dataset"
)
parser.add_argument(
'--overwrite',
action="store_true",
help="whether to overwrite previous dataset if present at output_dataset_dir"
)
return parser.parse_args()


if __name__ == '__main__':
main(parse_args())
63 changes: 63 additions & 0 deletions news_signals/transform_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import argparse
import json
import logging
from pathlib import Path

from news_signals.signals_dataset import SignalsDataset
from news_signals.dataset_transformations import get_dataset_transform
from news_signals.log import create_logger


logger = create_logger(__name__, level=logging.INFO)


def main(args):
if args.output_dataset_path is None:
output_dataset_path = Path(args.input_dataset_path)
else:
output_dataset_path = Path(args.output_dataset_path)

with open(args.config) as f:
config = json.load(f)

if (args.input_dataset_path == args.output_dataset_path) or output_dataset_path.exists():
confirm = input(
f"Are you sure you want to modify the dataset in {output_dataset_path}? "
"Alternatively, you can set output_dataset_path to a new directory. (y|n) "
)
if confirm != "y":
logger.info("Aborting")
return

dataset = SignalsDataset.load(args.input_dataset_path)

# config is a list of transformations
for t in config:
logger.info(f"Applying transformation to dataset: {t['transform']}")
transform = get_dataset_transform(t['transform'])
transform(dataset, **t['params'])

if str(output_dataset_path).endswith('.tar.gz'):
dataset.save(output_dataset_path, overwrite=True, compress=True)
else:
dataset.save(output_dataset_path, overwrite=True)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--input-dataset-path',
required=True,
)
parser.add_argument(
'--output-dataset-path',
)
parser.add_argument(
'--config',
help="JSON string with config"
),
return parser.parse_args()


if __name__ == '__main__':
main(parse_args())

0 comments on commit b6970a4

Please sign in to comment.