Skip to content

Commit 28fd42a

Browse files
refactor(pipeline_builder): improved stages.
1 parent 4e5f538 commit 28fd42a

8 files changed

Lines changed: 1865 additions & 996 deletions

File tree

Lines changed: 157 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,195 @@
11
from __future__ import annotations
22

3-
from typing import Dict, Iterable
3+
from typing import Any, Dict, Iterable
44

5-
from .schema import Schema
65
from .match_planner import MatchPlanner
76

87

98
class LookupPlanner:
109
"""
11-
Pure planning: produces a lookup tree (python field names).
12-
Does NOT emit Mongo stages.
10+
Builds a lookup tree keyed by *field names* (not db_field), suitable for StageBuilder._walk_lookups.
11+
12+
Inputs:
13+
- select_related: mongoengine select_related spec
14+
- bucket_prefixes: iterable of db_field dotted prefixes produced by MatchPlanner.bucket()
15+
16+
Output:
17+
- tree: dict like {"items": {"parent": {}}, "parent": {"gp": {}}}
1318
"""
1419

1520
def plan_from_select_related(self, select_related) -> dict:
16-
return self.build_related_tree(select_related)
21+
return self._tree_from_select_related(select_related)
1722

1823
def plan(self, doc_cls, select_related, bucket_prefixes: Iterable[str]) -> dict:
19-
tree = {}
24+
tree: Dict[str, Any] = {}
25+
26+
# 1) bucket-prefix-derived tree FIRST (filter stages happen earlier)
27+
for prefix in bucket_prefixes or ():
28+
if not prefix:
29+
continue
30+
p_tree = self._tree_from_db_prefix(doc_cls, prefix)
31+
self._merge_tree(tree, p_tree)
32+
33+
# 2) select_related tree AFTER (hydrate after filtering)
2034
if select_related:
21-
tree = self.merge_trees(tree, self.build_related_tree(select_related))
22-
tree = self.merge_trees(tree, self.auto_tree_from_bucket_prefixes(doc_cls, bucket_prefixes))
23-
return tree
35+
sr_tree = self.plan_from_select_related(select_related)
36+
self._merge_tree(tree, sr_tree)
2437

25-
@staticmethod
26-
def build_related_tree(fields) -> dict:
27-
tree = {}
28-
for f in fields or []:
29-
parts = f.split("__")
30-
node = tree
31-
for p in parts:
32-
node = node.setdefault(p, {})
33-
node[""] = True
3438
return tree
3539

36-
@staticmethod
37-
def merge_trees(a: dict, b: dict) -> dict:
38-
if not a:
39-
return dict(b or {})
40-
if not b:
41-
return dict(a)
42-
out = dict(a)
43-
for k, v in b.items():
44-
if k not in out:
45-
out[k] = v
46-
else:
47-
if isinstance(out[k], dict) and isinstance(v, dict):
48-
out[k] = LookupPlanner.merge_trees(out[k], v)
49-
return out
40+
# ---------------- internals ----------------
5041

51-
def auto_tree_from_bucket_prefixes(self, root_doc_cls, bucket_prefixes: Iterable[str]) -> dict:
42+
def _tree_from_db_prefix(self, doc_cls, db_prefix: str) -> dict:
5243
"""
53-
Bucket prefixes are db_field dotted (e.g. "target.gp").
54-
We build a tree using python field names, with safe GenericRef traversal.
44+
Convert db_field dotted path like "target.gp" into a field-name tree like {"target": {"gp": {}}}.
45+
46+
Key behavior:
47+
- ReferenceField: if there are more segments, traverse into referenced document
48+
- GenericReferenceField: if there are more segments and next segment is a COMMON ReferenceField
49+
across choices, traverse into representative choice document so later segments can be planned.
5550
"""
56-
tree: Dict[str, dict] = {}
51+
from mongoengine.fields import (
52+
EmbeddedDocumentField,
53+
EmbeddedDocumentListField,
54+
ListField,
55+
ReferenceField,
56+
GenericReferenceField,
57+
MapField,
58+
DictField,
59+
)
60+
61+
parts = [p for p in db_prefix.split(".") if p]
62+
if not parts:
63+
return {}
64+
65+
cur_doc = doc_cls
66+
root: Dict[str, Any] = {}
67+
node = root
68+
69+
i = 0
70+
while i < len(parts):
71+
if cur_doc is None:
72+
break
73+
74+
db_part = parts[i]
75+
fld = self._get_field_by_db_part(cur_doc, db_part)
76+
if fld is None:
77+
break
78+
79+
field_name = fld.name
80+
node = node.setdefault(field_name, {})
81+
82+
is_last = (i == len(parts) - 1)
83+
84+
# unwrap list wrapper for leaf checks
85+
leaf = fld
86+
while isinstance(leaf, ListField):
87+
leaf = leaf.field
88+
89+
# ---- embedded boundary: descend schema
90+
if isinstance(fld, EmbeddedDocumentField):
91+
cur_doc = fld.document_type
92+
i += 1
93+
continue
5794

58-
for dotted_prefix in bucket_prefixes:
59-
if not dotted_prefix:
95+
if isinstance(fld, EmbeddedDocumentListField) or (
96+
isinstance(fld, ListField) and isinstance(getattr(fld, "field", None), EmbeddedDocumentField)
97+
):
98+
embedded_dt = getattr(fld, "document_type", None)
99+
if embedded_dt is None and isinstance(getattr(fld, "field", None), EmbeddedDocumentField):
100+
embedded_dt = fld.field.document_type
101+
cur_doc = embedded_dt
102+
i += 1
60103
continue
61104

62-
parts = dotted_prefix.split(".")
63-
node = tree
64-
cur = root_doc_cls
105+
# ---- MapField / DictField: lookup happens at this node; deeper handled by MatchPlanner $expr rewrites
106+
if isinstance(fld, (MapField, DictField)):
107+
break
65108

66-
for idx, db_part in enumerate(parts):
67-
if cur is None:
109+
# ---- ReferenceField: keep traversing if more segments remain
110+
if isinstance(leaf, ReferenceField):
111+
if is_last:
68112
break
113+
cur_doc = getattr(leaf, "document_type_obj", None) or getattr(leaf, "document_type", None)
114+
i += 1
115+
continue
69116

70-
field_name, fld = Schema.resolve_field_name(cur, db_part)
71-
if not fld:
117+
# ---- GenericReferenceField:
118+
# If next segment is a COMMON ReferenceField across choices, traverse into representative choice doc
119+
if isinstance(leaf, GenericReferenceField):
120+
if is_last:
72121
break
73122

74-
node = node.setdefault(field_name, {})
123+
next_part = parts[i + 1]
124+
common_ref_field, _common_target = MatchPlanner.generic_common_ref(leaf, next_part)
125+
if common_ref_field is None:
126+
# cannot safely traverse beyond generic
127+
break
75128

76-
from mongoengine.fields import ReferenceField, GenericReferenceField, EmbeddedDocumentField, \
77-
EmbeddedDocumentListField
129+
# Ensure the tree includes the common-ref child
130+
# (StageBuilder will use this to emit lookup on target.<common_ref>)
131+
node = node.setdefault(common_ref_field.name, {})
78132

79-
leaf = Schema.unwrap_list_leaf(fld)
133+
# Traverse schema as if we're in a representative choice class
134+
# so we can plan deeper segments (like ...gp.age... -> prefix target.gp)
135+
doc_classes = MatchPlanner._safe_resolve_generic_choices(leaf)
136+
cur_doc = doc_classes[0] if doc_classes else None
80137

81-
if isinstance(leaf, ReferenceField):
82-
cur = getattr(leaf, "document_type_obj", None) or getattr(leaf, "document_type", None)
83-
continue
138+
# We consumed "next_part" by inserting common_ref_field.name
139+
i += 2
140+
continue
84141

85-
if isinstance(leaf, GenericReferenceField):
86-
if idx < len(parts) - 1:
87-
next_part = parts[idx + 1]
88-
common_ref_field, _common_target = MatchPlanner.generic_common_ref(leaf, next_part)
89-
if common_ref_field is None:
90-
cur = None
91-
break
142+
# ---- scalar: can't traverse further
143+
break
92144

93-
from mongoengine.document import _DocumentRegistry
94-
ch0 = (leaf.choices or ())[0]
95-
cur = _DocumentRegistry.get(ch0 if isinstance(ch0, str) else ch0.__name__)
96-
continue
145+
return root
97146

98-
cur = None
99-
break
147+
@staticmethod
148+
def _merge_tree(dst: dict, src: dict) -> None:
149+
for k, v in (src or {}).items():
150+
if k not in dst:
151+
dst[k] = v if isinstance(v, dict) else {}
152+
else:
153+
if isinstance(dst[k], dict) and isinstance(v, dict):
154+
LookupPlanner._merge_tree(dst[k], v)
100155

101-
if isinstance(fld, (EmbeddedDocumentField, EmbeddedDocumentListField)) or getattr(leaf, "document_type",
102-
None):
103-
cur = getattr(leaf, "document_type", None) or getattr(leaf, "document_type_obj", None)
104-
continue
156+
@staticmethod
157+
def _get_field_by_db_part(doc_cls, db_part: str):
158+
if doc_cls is None:
159+
return None
105160

106-
cur = None
161+
fld = doc_cls._fields.get(db_part)
162+
if fld is not None:
163+
return fld
107164

108-
node[""] = True
165+
for _name, f in doc_cls._fields.items():
166+
if getattr(f, "db_field", None) == db_part:
167+
return f
109168

110-
return tree
169+
return None
170+
171+
# ---- select_related converter (keep / adapt to your queryset format)
172+
def _tree_from_select_related(self, select_related) -> dict:
173+
if not select_related:
174+
return {}
175+
176+
if isinstance(select_related, (list, tuple, set)):
177+
paths = []
178+
for p in select_related:
179+
if isinstance(p, str) and p:
180+
paths.append(p.replace("__", "."))
181+
return self._tree_from_paths(paths)
182+
183+
return {}
184+
185+
@staticmethod
186+
def _tree_from_paths(paths: Iterable[str]) -> dict:
187+
root: Dict[str, Any] = {}
188+
for p in paths:
189+
if not p:
190+
continue
191+
parts = [x for x in p.split(".") if x]
192+
node = root
193+
for part in parts:
194+
node = node.setdefault(part, {})
195+
return root

0 commit comments

Comments
 (0)