Skip to content

Commit 381a2d1

Browse files
committed
allow node.descendants(add_mwt=True)
Add the `add_mwt` option to `node.descendants()`, `node.children()` etc. The MWT objects do not have the same methods&properties as Node objects, so one has to be careful when using the resulting list with a mixure of both, but there are usecases when it is useful to have such a mixed list (with MWTs preceding their words, similarly to the CoNLL-U format).
1 parent 717af29 commit 381a2d1

1 file changed

Lines changed: 17 additions & 4 deletions

File tree

udapi/core/node.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,7 @@ class ListOfNodes(list):
11511151
nodes = node.children
11521152
nodes = node.children()
11531153
nodes = node.children(add_self=True, following_only=True)
1154+
nodes = node.descendants(add_self=True, add_mwt=True)
11541155
"""
11551156
__slots__ = ('origin',)
11561157

@@ -1164,16 +1165,28 @@ def __init__(self, iterable, origin):
11641165
super().__init__(iterable)
11651166
self.origin = origin
11661167

1167-
def __call__(self, add_self=False, following_only=False, preceding_only=False):
1168+
def __call__(self, add_self=False, following_only=False, preceding_only=False, add_mwt=False):
11681169
"""Returns a subset of nodes contained in this list as specified by the args."""
11691170
if add_self:
11701171
self.append(self.origin)
11711172
self.sort()
1173+
result = self
11721174
if preceding_only:
1173-
return [x for x in self if x._ord <= self.origin._ord]
1175+
result = [x for x in result if x._ord <= self.origin._ord]
11741176
if following_only:
1175-
return [x for x in self if x._ord >= self.origin._ord]
1176-
return self
1177+
result = [x for x in result if x._ord >= self.origin._ord]
1178+
if add_mwt:
1179+
new = []
1180+
last_mwt_id = -1
1181+
for node in result:
1182+
mwt = node.multiword_token
1183+
if mwt:
1184+
if node.ord > last_mwt_id:
1185+
last_mwt_id = mwt.words[-1].ord
1186+
new.append(mwt)
1187+
new.append(node)
1188+
result = new
1189+
return result
11771190

11781191

11791192
def find_minimal_common_treelet(*args):

0 commit comments

Comments
 (0)