Skip to content

Commit

Permalink
Merge branch 'main' of github.com:AYLIEN/news-signals-datasets into main
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishokamp committed Mar 2, 2023
2 parents 7bbcfd4 + aecd165 commit b677f5f
Show file tree
Hide file tree
Showing 5 changed files with 8,454 additions and 20 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@

Check out this [colab notebook](https://drive.google.com/file/d/1iTjjeSt1S5WF0jJItH31DRe2C3IkZvz5/view?usp=sharing) to see some of the things you can do with the news-signals library.

## Generating a new Dataset

```shell

python bin/generate_dataset.py \
--start 2022/01/01 \
--end 2022/02/01 \
--input-csv resources/test/nasdaq100.small.csv \
--id-field "Wikidata ID" \
--name-field "Wikidata Label" \
--output-dataset-dir sample_dataset_output

```


#### Install news-signals in a new environment

Expand Down
62 changes: 62 additions & 0 deletions bin/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import argparse
from pathlib import Path

import arrow

from news_signals.signals_dataset import generate_dataset


def main(args):
generate_dataset(
input=Path(args.input_csv),
output_dataset_dir=Path(args.output_dataset_dir),
start=arrow.get(args.start).datetime,
end=arrow.get(args.end).datetime,
id_field=args.id_field,
name_field=args.name_field,
overwrite=args.overwrite,
delete_tmp_files=True,
)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--start',
required=True,
help="start date, e.g. 2020-1-1"
)
parser.add_argument(
'--end',
required=True,
help="end date, e.g. 2020-1-1"
)
parser.add_argument(
'--input-csv',
required=True,
help="csv file with entities"
)
parser.add_argument(
'--id-field',
default="Wikidata ID",
help="column in csv which indicates Wikidata id"
)
parser.add_argument(
'--name-field',
default="Wikidata Label",
help="column in csv which indicates Wikidata id"
)
parser.add_argument(
'--output-dataset-dir',
required=True,
help="dir where dataset is stored"
)
parser.add_argument(
'--overwrite',
action="store_true",
)
return parser.parse_args()


if __name__ == '__main__':
main(parse_args())
33 changes: 23 additions & 10 deletions news_signals/signals_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def ask_rmdir(dirpath, msg, yes="y"):
shutil.rmtree(dirpath)


def make_query(params, start, end, period="+1DAY"):
def make_aylien_newsapi_query(params, start, end, period="+1DAY"):
_start = arrow_to_aylien_date(arrow.get(start))
_end = arrow_to_aylien_date(arrow.get(end))
aql = params_to_aql(params)
Expand All @@ -210,9 +210,9 @@ def make_query(params, start, end, period="+1DAY"):
return new_params


def reduce_story(s):
def reduce_aylien_story(s):
body = " ".join(s["body"].split()[:MAX_BODY_TOKENS])
smart_cats = extract_smart_tagger_categories(s)
smart_cats = extract_aylien_smart_tagger_categories(s)
reduced = {
"title": s["title"],
"body": body,
Expand All @@ -225,7 +225,7 @@ def reduce_story(s):
return reduced


def extract_smart_tagger_categories(s):
def extract_aylien_smart_tagger_categories(s):
category_items = []
for c in s["categories"]:
if c["taxonomy"] == "aylien":
Expand Down Expand Up @@ -259,7 +259,8 @@ def retrieve_and_write_stories(
ts: List,
output_path: Path,
num_stories: int = 20,
stories_endpoint=newsapi.retrieve_stories
stories_endpoint=newsapi.retrieve_stories,
post_process_story=None,
):
time_to_volume = dict(
(arrow.get(x["published_at"]).datetime, x["count"]) for x in ts
Expand All @@ -284,9 +285,10 @@ def retrieve_and_write_stories(

vol = time_to_volume[start]
if vol > 0:
params = make_query(params_template, start, end)
params = make_aylien_newsapi_query(params_template, start, end)
stories = stories_endpoint(params)
stories = [reduce_story(s) for s in stories]
if post_process_story is not None:
stories = [post_process_story(s) for s in stories]
else:
stories = []
output_item = {
Expand All @@ -305,7 +307,7 @@ def retrieve_and_write_timeseries(
ts_endpoint=newsapi.retrieve_timeseries
) -> List:
if not output_path.exists():
params = make_query(params, start, end)
params = make_aylien_newsapi_query(params, start, end)
ts = ts_endpoint(params)
write_json(ts, output_path)
else:
Expand Down Expand Up @@ -343,6 +345,7 @@ def generate_dataset(
delete_tmp_files: bool = False,
stories_endpoint=newsapi.retrieve_stories,
ts_endpoint=newsapi.retrieve_timeseries,
post_process_story=None,
):

"""
Expand Down Expand Up @@ -375,6 +378,16 @@ def generate_dataset(
)
output_dataset_dir.mkdir(parents=True, exist_ok=True)

# optional, e.g. for reducing story fields
if post_process_story is not None and type(post_process_story) == str:
try:
post_process_story = globals()[post_process_story]
except:
raise NotImplementedError(
f"Unknown function for processing stories: {post_process_story}"
)


for signal in tqdm.tqdm(signals_):
if signal_exists(signal, output_dataset_dir):
logger.info("signal exists already, skipping to next")
Expand All @@ -401,7 +414,8 @@ def generate_dataset(
ts,
stories_path,
num_stories=stories_per_day,
stories_endpoint=stories_endpoint
stories_endpoint=stories_endpoint,
post_process_story=post_process_story
)

# now this signal is completely realized
Expand All @@ -417,4 +431,3 @@ def generate_dataset(
stories_path.unlink()

return SignalsDataset.load(output_dataset_dir)

20 changes: 10 additions & 10 deletions news_signals/test_signals_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,11 @@ def __call__(self, payload):

class MockStoriesEndPoint:

def __init__(self):
self.sample_stories = json.loads((resources / "sample_stories.json").read_text())

def __call__(self, payload):
return [{
'id': 'test-id',
'title': 'title',
'title': 'link',
'body': 'body',
'categories': [{'taxonomy': 'aylien', 'id': 'ay.test.cat', 'score': 0.8}],
'published_at': datetime_to_aylien_str(datetime.datetime(2023, 1, 1)),
'links': {'permalink': 'test-link'},
'language': 'en',
}]
return self.sample_stories


class MockWikidataClient:
Expand Down Expand Up @@ -125,6 +119,12 @@ def test_generate_dataset(self):
self.assertIsInstance(signal.feeds_df, pd.DataFrame)
for col in ["stories"]:
self.assertIn(col, signal.feeds_df)

# we know the stories should come from the mock endpoint
assert all(
len(tick) == len(self.stories_endpoint.sample_stories)
for tick in signal.feeds_df['stories']
)

assert signal.params is not None
assert signal.name is not None
Expand Down
Loading

0 comments on commit b677f5f

Please sign in to comment.