-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathunion.py
More file actions
135 lines (112 loc) · 4.81 KB
/
union.py
File metadata and controls
135 lines (112 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#
# This file is part of libdestruct (https://github.com/mrindeciso/libdestruct).
# Copyright (c) 2026 Roberto Alessandro Bertolini. All rights reserved.
# Licensed under the MIT license. See LICENSE file in the project root for details.
#
from __future__ import annotations
from typing import TYPE_CHECKING
from libdestruct.common.obj import obj
if TYPE_CHECKING: # pragma: no cover
from libdestruct.backing.resolver import Resolver
class union(obj):
"""A union value, supporting both tagged (single active variant) and plain (all variants overlaid) modes."""
_variant: obj | None
"""The single active variant (tagged union mode)."""
_variants: dict[str, obj]
"""Named variants (plain union mode)."""
_frozen_bytes: bytes | None
"""The frozen bytes of the full union region."""
def __init__(
self: union,
resolver: Resolver | None,
variant: obj | None,
max_size: int,
variants: dict[str, obj] | None = None,
) -> None:
"""Initialize the union.
Args:
resolver: The backing resolver.
variant: The single active variant (tagged union mode, None for plain unions).
max_size: The size of the union (max of all variant sizes).
variants: Named variants dict (plain union mode, None for tagged unions).
"""
super().__init__(resolver)
self._variant = variant
self._variants = variants or {}
self.size = max_size
self._frozen_bytes = None
@property
def variant(self: union) -> obj | None:
"""Return the active variant object (tagged union mode)."""
return self._variant
def get(self: union) -> object:
"""Return the value of the active variant."""
if self._variant is not None:
return self._variant.get()
if self._variants:
return {name: v.get() for name, v in self._variants.items()}
return None
def _set(self: union, value: object) -> None:
"""Set the value of the active variant."""
if self._variant is None:
raise RuntimeError("Cannot set the value of a union without an active variant.")
self._variant._set(value)
def to_dict(self: union) -> object:
"""Return a JSON-serializable representation of the union."""
if self._variant is not None:
return self._variant.to_dict()
if self._variants:
return {name: v.to_dict() for name, v in self._variants.items()}
return None
def to_bytes(self: union) -> bytes:
"""Return the full union-sized region as bytes."""
if self._frozen_bytes is not None:
return self._frozen_bytes
if self.resolver is None:
return b"\x00" * self.size
return self.resolver.resolve(self.size, 0)
def freeze(self: union) -> None:
"""Freeze the union and all its variants."""
if self.resolver is not None:
self._frozen_bytes = self.resolver.resolve(self.size, 0)
else:
self._frozen_bytes = b"\x00" * self.size
if self._variant is not None:
self._variant.freeze()
for v in self._variants.values():
v.freeze()
super().freeze()
def diff(self: union) -> tuple[object, object] | dict[str, tuple[object, object]]:
"""Return the difference between the frozen and current value."""
if self._variant is not None:
return self._variant.diff()
return {name: v.diff() for name, v in self._variants.items()}
def reset(self: union) -> None:
"""Reset the union to its frozen value by restoring the full frozen byte region."""
if self._frozen_bytes is None:
raise RuntimeError("Cannot reset a union that has not been frozen.")
if self.resolver is not None:
self.resolver.modify(self.size, 0, self._frozen_bytes)
def to_str(self: union, indent: int = 0) -> str:
"""Return a string representation of the union."""
if self._variant is not None:
return self._variant.to_str(indent)
if self._variants:
members = ", ".join(self._variants)
return f"union({members})"
return "union(empty)"
def __getattr__(self: union, name: str) -> object:
"""Delegate attribute access to named variants or the active variant."""
try:
variants = object.__getattribute__(self, "_variants")
if name in variants:
return variants[name]
except AttributeError:
pass
try:
variant = object.__getattribute__(self, "_variant")
if variant is not None:
return getattr(variant, name)
except AttributeError:
pass
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")