Skip to content

Commit b6f5d8f

Browse files
authored
feat: ✨ Add ProfileLoader.required_module_names method
1 parent 56affd9 commit b6f5d8f

2 files changed

Lines changed: 51 additions & 0 deletions

File tree

injection/loaders.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ class ProfileLoader:
149149
def __is_empty(self) -> bool:
150150
return not self.module_subsets
151151

152+
def required_module_names(self, name: str | None = None, /) -> frozenset[str]:
153+
names = {self.module.name}
154+
155+
if name is not None:
156+
names.add(name)
157+
158+
subsets = (self.__walk_subsets_for(name) for name in names)
159+
return frozenset(itertools.chain.from_iterable(subsets))
160+
152161
def init(self) -> Self:
153162
self.__init_subsets_for(self.module)
154163
return self
@@ -179,6 +188,12 @@ def __is_initialized(self, module: Module) -> bool:
179188
def __mark_initialized(self, module: Module) -> None:
180189
self.__initialized_modules.add(module.name)
181190

191+
def __walk_subsets_for(self, module_name: str) -> Iterator[str]:
192+
yield module_name
193+
194+
for name in self.module_subsets.get(module_name, ()):
195+
yield from self.__walk_subsets_for(name)
196+
182197

183198
@runtime_checkable
184199
class LoadedProfile(Protocol):

tests/loaders/test_profile_loader.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,42 @@
88

99

1010
class TestProfileLoader:
11+
def test_required_module_names_with_success_return_frozenset(self):
12+
loader = ProfileLoader(
13+
{
14+
mod().name: ["a", "b", "c"],
15+
"dev": ["m", "n", "o"],
16+
"b": ["x", "y", "z"],
17+
"c": ["i", "j"],
18+
}
19+
)
20+
assert loader.required_module_names() == {
21+
mod().name,
22+
"a",
23+
"b",
24+
"c",
25+
"x",
26+
"y",
27+
"z",
28+
"i",
29+
"j",
30+
}
31+
32+
def test_required_module_names_with_name_return_frozenset(self):
33+
loader = ProfileLoader(
34+
{
35+
mod().name: ["a"],
36+
"dev": ["z"],
37+
"test": ["i"],
38+
}
39+
)
40+
assert loader.required_module_names("dev") == {
41+
mod().name,
42+
"dev",
43+
"a",
44+
"z",
45+
}
46+
1147
def test_load_with_success(self):
1248
profile_name = "test"
1349
global_profile_name = uuid4().hex

0 commit comments

Comments
 (0)