Skip to content

Commit

Permalink
Add a function which marks start_index and end_index on a tree for it…
Browse files Browse the repository at this point in the history
…s spans
  • Loading branch information
AngledLuffa committed Jan 15, 2025
1 parent 94c3151 commit 2934b7f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
16 changes: 16 additions & 0 deletions stanza/models/constituency/parse_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,19 @@ def write_treebank(trees, out_file, fmt="{}"):
for tree in trees:
fout.write(fmt.format(tree))
fout.write("\n")

def mark_spans(self):
self._mark_spans(0)

def _mark_spans(self, start_index):
self.start_index = start_index

if len(self.children) == 0:
self.end_index = start_index + 1
return

for child in self.children:
child._mark_spans(start_index)
start_index = child.end_index

self.end_index = start_index
14 changes: 14 additions & 0 deletions stanza/tests/constituency/test_parse_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,17 @@ def test_reverse():
assert len(trees) == 1
reversed_tree = trees[0].reverse()
assert str(reversed_tree) == "(ROOT (S (VP (S (VP (VP (NP (NNS antennae) (NP (POS 's) (NNP Jennifer))) (VB lick)) (TO to))) (VBP want)) (NP (PRP I))))"

def test_mark_spans():
text = "(ROOT (S (NP (PRP I)) (VP (VBP want) (S (VP (TO to) (VP (VB lick) (NP (NP (NNP Jennifer) (POS 's)) (NNS antennae))))))))"
trees = tree_reader.read_trees(text)
assert len(trees) == 1
tree = trees[0]

tree.mark_spans()

assert tree.start_index == 0
assert tree.end_index == 7
for idx, pt in enumerate(tree.yield_preterminals()):
assert pt.start_index == idx
assert pt.end_index == idx + 1

0 comments on commit 2934b7f

Please sign in to comment.