|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import Dict, Iterable |
| 3 | +from typing import Any, Dict, Iterable |
4 | 4 |
|
5 | | -from .schema import Schema |
6 | 5 | from .match_planner import MatchPlanner |
7 | 6 |
|
8 | 7 |
|
9 | 8 | class LookupPlanner: |
10 | 9 | """ |
11 | | - Pure planning: produces a lookup tree (python field names). |
12 | | - Does NOT emit Mongo stages. |
| 10 | + Builds a lookup tree keyed by *field names* (not db_field), suitable for StageBuilder._walk_lookups. |
| 11 | +
|
| 12 | + Inputs: |
| 13 | + - select_related: mongoengine select_related spec |
| 14 | + - bucket_prefixes: iterable of db_field dotted prefixes produced by MatchPlanner.bucket() |
| 15 | +
|
| 16 | + Output: |
| 17 | + - tree: dict like {"items": {"parent": {}}, "parent": {"gp": {}}} |
13 | 18 | """ |
14 | 19 |
|
15 | 20 | def plan_from_select_related(self, select_related) -> dict: |
16 | | - return self.build_related_tree(select_related) |
| 21 | + return self._tree_from_select_related(select_related) |
17 | 22 |
|
18 | 23 | def plan(self, doc_cls, select_related, bucket_prefixes: Iterable[str]) -> dict: |
19 | | - tree = {} |
| 24 | + tree: Dict[str, Any] = {} |
| 25 | + |
| 26 | + # 1) bucket-prefix-derived tree FIRST (filter stages happen earlier) |
| 27 | + for prefix in bucket_prefixes or (): |
| 28 | + if not prefix: |
| 29 | + continue |
| 30 | + p_tree = self._tree_from_db_prefix(doc_cls, prefix) |
| 31 | + self._merge_tree(tree, p_tree) |
| 32 | + |
| 33 | + # 2) select_related tree AFTER (hydrate after filtering) |
20 | 34 | if select_related: |
21 | | - tree = self.merge_trees(tree, self.build_related_tree(select_related)) |
22 | | - tree = self.merge_trees(tree, self.auto_tree_from_bucket_prefixes(doc_cls, bucket_prefixes)) |
23 | | - return tree |
| 35 | + sr_tree = self.plan_from_select_related(select_related) |
| 36 | + self._merge_tree(tree, sr_tree) |
24 | 37 |
|
25 | | - @staticmethod |
26 | | - def build_related_tree(fields) -> dict: |
27 | | - tree = {} |
28 | | - for f in fields or []: |
29 | | - parts = f.split("__") |
30 | | - node = tree |
31 | | - for p in parts: |
32 | | - node = node.setdefault(p, {}) |
33 | | - node[""] = True |
34 | 38 | return tree |
35 | 39 |
|
36 | | - @staticmethod |
37 | | - def merge_trees(a: dict, b: dict) -> dict: |
38 | | - if not a: |
39 | | - return dict(b or {}) |
40 | | - if not b: |
41 | | - return dict(a) |
42 | | - out = dict(a) |
43 | | - for k, v in b.items(): |
44 | | - if k not in out: |
45 | | - out[k] = v |
46 | | - else: |
47 | | - if isinstance(out[k], dict) and isinstance(v, dict): |
48 | | - out[k] = LookupPlanner.merge_trees(out[k], v) |
49 | | - return out |
| 40 | + # ---------------- internals ---------------- |
50 | 41 |
|
51 | | - def auto_tree_from_bucket_prefixes(self, root_doc_cls, bucket_prefixes: Iterable[str]) -> dict: |
| 42 | + def _tree_from_db_prefix(self, doc_cls, db_prefix: str) -> dict: |
52 | 43 | """ |
53 | | - Bucket prefixes are db_field dotted (e.g. "target.gp"). |
54 | | - We build a tree using python field names, with safe GenericRef traversal. |
| 44 | + Convert db_field dotted path like "target.gp" into a field-name tree like {"target": {"gp": {}}}. |
| 45 | +
|
| 46 | + Key behavior: |
| 47 | + - ReferenceField: if there are more segments, traverse into referenced document |
| 48 | + - GenericReferenceField: if there are more segments and next segment is a COMMON ReferenceField |
| 49 | + across choices, traverse into representative choice document so later segments can be planned. |
55 | 50 | """ |
56 | | - tree: Dict[str, dict] = {} |
| 51 | + from mongoengine.fields import ( |
| 52 | + EmbeddedDocumentField, |
| 53 | + EmbeddedDocumentListField, |
| 54 | + ListField, |
| 55 | + ReferenceField, |
| 56 | + GenericReferenceField, |
| 57 | + MapField, |
| 58 | + DictField, |
| 59 | + ) |
| 60 | + |
| 61 | + parts = [p for p in db_prefix.split(".") if p] |
| 62 | + if not parts: |
| 63 | + return {} |
| 64 | + |
| 65 | + cur_doc = doc_cls |
| 66 | + root: Dict[str, Any] = {} |
| 67 | + node = root |
| 68 | + |
| 69 | + i = 0 |
| 70 | + while i < len(parts): |
| 71 | + if cur_doc is None: |
| 72 | + break |
| 73 | + |
| 74 | + db_part = parts[i] |
| 75 | + fld = self._get_field_by_db_part(cur_doc, db_part) |
| 76 | + if fld is None: |
| 77 | + break |
| 78 | + |
| 79 | + field_name = fld.name |
| 80 | + node = node.setdefault(field_name, {}) |
| 81 | + |
| 82 | + is_last = (i == len(parts) - 1) |
| 83 | + |
| 84 | + # unwrap list wrapper for leaf checks |
| 85 | + leaf = fld |
| 86 | + while isinstance(leaf, ListField): |
| 87 | + leaf = leaf.field |
| 88 | + |
| 89 | + # ---- embedded boundary: descend schema |
| 90 | + if isinstance(fld, EmbeddedDocumentField): |
| 91 | + cur_doc = fld.document_type |
| 92 | + i += 1 |
| 93 | + continue |
57 | 94 |
|
58 | | - for dotted_prefix in bucket_prefixes: |
59 | | - if not dotted_prefix: |
| 95 | + if isinstance(fld, EmbeddedDocumentListField) or ( |
| 96 | + isinstance(fld, ListField) and isinstance(getattr(fld, "field", None), EmbeddedDocumentField) |
| 97 | + ): |
| 98 | + embedded_dt = getattr(fld, "document_type", None) |
| 99 | + if embedded_dt is None and isinstance(getattr(fld, "field", None), EmbeddedDocumentField): |
| 100 | + embedded_dt = fld.field.document_type |
| 101 | + cur_doc = embedded_dt |
| 102 | + i += 1 |
60 | 103 | continue |
61 | 104 |
|
62 | | - parts = dotted_prefix.split(".") |
63 | | - node = tree |
64 | | - cur = root_doc_cls |
| 105 | + # ---- MapField / DictField: lookup happens at this node; deeper handled by MatchPlanner $expr rewrites |
| 106 | + if isinstance(fld, (MapField, DictField)): |
| 107 | + break |
65 | 108 |
|
66 | | - for idx, db_part in enumerate(parts): |
67 | | - if cur is None: |
| 109 | + # ---- ReferenceField: keep traversing if more segments remain |
| 110 | + if isinstance(leaf, ReferenceField): |
| 111 | + if is_last: |
68 | 112 | break |
| 113 | + cur_doc = getattr(leaf, "document_type_obj", None) or getattr(leaf, "document_type", None) |
| 114 | + i += 1 |
| 115 | + continue |
69 | 116 |
|
70 | | - field_name, fld = Schema.resolve_field_name(cur, db_part) |
71 | | - if not fld: |
| 117 | + # ---- GenericReferenceField: |
| 118 | + # If next segment is a COMMON ReferenceField across choices, traverse into representative choice doc |
| 119 | + if isinstance(leaf, GenericReferenceField): |
| 120 | + if is_last: |
72 | 121 | break |
73 | 122 |
|
74 | | - node = node.setdefault(field_name, {}) |
| 123 | + next_part = parts[i + 1] |
| 124 | + common_ref_field, _common_target = MatchPlanner.generic_common_ref(leaf, next_part) |
| 125 | + if common_ref_field is None: |
| 126 | + # cannot safely traverse beyond generic |
| 127 | + break |
75 | 128 |
|
76 | | - from mongoengine.fields import ReferenceField, GenericReferenceField, EmbeddedDocumentField, \ |
77 | | - EmbeddedDocumentListField |
| 129 | + # Ensure the tree includes the common-ref child |
| 130 | + # (StageBuilder will use this to emit lookup on target.<common_ref>) |
| 131 | + node = node.setdefault(common_ref_field.name, {}) |
78 | 132 |
|
79 | | - leaf = Schema.unwrap_list_leaf(fld) |
| 133 | + # Traverse schema as if we're in a representative choice class |
| 134 | + # so we can plan deeper segments (like ...gp.age... -> prefix target.gp) |
| 135 | + doc_classes = MatchPlanner._safe_resolve_generic_choices(leaf) |
| 136 | + cur_doc = doc_classes[0] if doc_classes else None |
80 | 137 |
|
81 | | - if isinstance(leaf, ReferenceField): |
82 | | - cur = getattr(leaf, "document_type_obj", None) or getattr(leaf, "document_type", None) |
83 | | - continue |
| 138 | + # We consumed "next_part" by inserting common_ref_field.name |
| 139 | + i += 2 |
| 140 | + continue |
84 | 141 |
|
85 | | - if isinstance(leaf, GenericReferenceField): |
86 | | - if idx < len(parts) - 1: |
87 | | - next_part = parts[idx + 1] |
88 | | - common_ref_field, _common_target = MatchPlanner.generic_common_ref(leaf, next_part) |
89 | | - if common_ref_field is None: |
90 | | - cur = None |
91 | | - break |
| 142 | + # ---- scalar: can't traverse further |
| 143 | + break |
92 | 144 |
|
93 | | - from mongoengine.document import _DocumentRegistry |
94 | | - ch0 = (leaf.choices or ())[0] |
95 | | - cur = _DocumentRegistry.get(ch0 if isinstance(ch0, str) else ch0.__name__) |
96 | | - continue |
| 145 | + return root |
97 | 146 |
|
98 | | - cur = None |
99 | | - break |
| 147 | + @staticmethod |
| 148 | + def _merge_tree(dst: dict, src: dict) -> None: |
| 149 | + for k, v in (src or {}).items(): |
| 150 | + if k not in dst: |
| 151 | + dst[k] = v if isinstance(v, dict) else {} |
| 152 | + else: |
| 153 | + if isinstance(dst[k], dict) and isinstance(v, dict): |
| 154 | + LookupPlanner._merge_tree(dst[k], v) |
100 | 155 |
|
101 | | - if isinstance(fld, (EmbeddedDocumentField, EmbeddedDocumentListField)) or getattr(leaf, "document_type", |
102 | | - None): |
103 | | - cur = getattr(leaf, "document_type", None) or getattr(leaf, "document_type_obj", None) |
104 | | - continue |
| 156 | + @staticmethod |
| 157 | + def _get_field_by_db_part(doc_cls, db_part: str): |
| 158 | + if doc_cls is None: |
| 159 | + return None |
105 | 160 |
|
106 | | - cur = None |
| 161 | + fld = doc_cls._fields.get(db_part) |
| 162 | + if fld is not None: |
| 163 | + return fld |
107 | 164 |
|
108 | | - node[""] = True |
| 165 | + for _name, f in doc_cls._fields.items(): |
| 166 | + if getattr(f, "db_field", None) == db_part: |
| 167 | + return f |
109 | 168 |
|
110 | | - return tree |
| 169 | + return None |
| 170 | + |
| 171 | + # ---- select_related converter (keep / adapt to your queryset format) |
| 172 | + def _tree_from_select_related(self, select_related) -> dict: |
| 173 | + if not select_related: |
| 174 | + return {} |
| 175 | + |
| 176 | + if isinstance(select_related, (list, tuple, set)): |
| 177 | + paths = [] |
| 178 | + for p in select_related: |
| 179 | + if isinstance(p, str) and p: |
| 180 | + paths.append(p.replace("__", ".")) |
| 181 | + return self._tree_from_paths(paths) |
| 182 | + |
| 183 | + return {} |
| 184 | + |
| 185 | + @staticmethod |
| 186 | + def _tree_from_paths(paths: Iterable[str]) -> dict: |
| 187 | + root: Dict[str, Any] = {} |
| 188 | + for p in paths: |
| 189 | + if not p: |
| 190 | + continue |
| 191 | + parts = [x for x in p.split(".") if x] |
| 192 | + node = root |
| 193 | + for part in parts: |
| 194 | + node = node.setdefault(part, {}) |
| 195 | + return root |
0 commit comments