Skip to content

Commit

Permalink
updated with events extension
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhoper committed May 18, 2022
1 parent bddc597 commit 0992a31
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 88 deletions.
69 changes: 36 additions & 33 deletions dygie/spacy_interface/spacy_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,44 +13,47 @@
Span.set_extension("rels", default=[], force=True)
Doc.set_extension("span_ents", default=[], force=True)
Span.set_extension("label_", default=[], force=True)
Doc.set_extension("events", default=[], force=True)
Span.set_extension("events", default=[], force=True)


def prepare_spacy_doc(doc: Doc, prediction: Dict) -> Doc:
doc_rels = []
doc_evs = []
# store events as relations. include confidence scores in the relation tuple (TODO: add relation property)
if "predicted_events" in prediction:
for rels, ds in zip(prediction.get("predicted_events", []), doc.sents):
sent_rels = []
for rel in rels:
if len(rel)>=3:
trig = [r for r in rel if r[1]=="TRIGGER"]
arg0s = [r for r in rel if r[2]=="ARG0"]
#example arg0s: [[40, 43, 'ARG0', 12.1145, 1.0], [45, 45, 'ARG0', 11.3498, 1.0]]
arg1s = [r for r in rel if r[2]=="ARG1"]
e_trig = doc[trig[0][0]:trig[0][0]+1]
for arg0 in arg0s:
e_arg0 = doc[arg0[0] : arg0[1] + 1]
for arg1 in arg1s:
e_arg1 = doc[arg1[0] : arg1[1] + 1]
#here confidence is set as the minimum among {trigger,args}, as a conservative measure.
sent_rels.append({"ARG0":e_arg0,"ARG1":e_arg1,"RELATION_TRIGGER":e_trig,"CONF":min([arg0[4],arg1[4],trig[0][3]])})

doc_rels.append(sent_rels)
ds._.rels = sent_rels
doc._.rels = doc_rels
#TODO add doc._.span_ents too.
return doc
else:
for rels, ds in zip(prediction.get("predicted_relations", []), doc.sents):
sent_rels = []
for rel in rels:
e1 = doc[rel[0] : rel[1] + 1]
e2 = doc[rel[2] : rel[3] + 1]
tag = rel[4]
sent_rels.append((e1, e2, tag))
doc_rels.append(sent_rels)
ds._.rels = sent_rels
doc._.rels = doc_rels
for evs, ds in zip(prediction.get("predicted_events", []), doc.sents):
sent_evs = []
for ev in evs:
if len(ev)>=3:
trig = [r for r in ev if r[1]=="TRIGGER"]
arg0s = [r for r in ev if r[2]=="ARG0"]
#example arg0s: [[40, 43, 'ARG0', 12.1145, 1.0], [45, 45, 'ARG0', 11.3498, 1.0]]
arg1s = [r for r in ev if r[2]=="ARG1"]
e_trig = doc[trig[0][0]:trig[0][0]+1]
for arg0 in arg0s:
e_arg0 = doc[arg0[0] : arg0[1] + 1]
for arg1 in arg1s:
e_arg1 = doc[arg1[0] : arg1[1] + 1]
#here confidence is set as the minimum among {trigger,args}, as a conservative measure.
sent_evs.append({"ARG0":e_arg0,"ARG1":e_arg1,"RELATION_TRIGGER":e_trig,"CONF":min([arg0[4],arg1[4],trig[0][3]])})

doc_evs.append(sent_evs)
ds._.events = sent_evs
doc._.events = doc_evs
#TODO add doc._.span_ents too.

for rels, ds in zip(prediction.get("predicted_relations", []), doc.sents):
sent_rels = []
for rel in rels:
e1 = doc[rel[0] : rel[1] + 1]
e2 = doc[rel[2] : rel[3] + 1]
tag = rel[4]
sent_rels.append((e1, e2, tag))
doc_rels.append(sent_rels)
ds._.rels = sent_rels
doc._.rels = doc_rels
if "predicted_ner" not in prediction:
return doc
preds = [p for r in prediction.get("predicted_ner", []) for p in r]
# storing all span based entitis to doc._.span_ents
span_ents = []
Expand Down
Loading

0 comments on commit 0992a31

Please sign in to comment.