Skip to content

Commit 2b19f47

Browse files
committed
Add eliasfano and rrrvector to c++ and rust
1 parent 8b26069 commit 2b19f47

10 files changed

Lines changed: 1824 additions & 48 deletions
Lines changed: 209 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,219 @@
11
<#@ template language="C#" #>
22
<#@ import namespace="Genbox.FastData.Generator.Enums" #>
3+
<#@ import namespace="Genbox.FastData.Generator.Helpers" #>
4+
<#@ import namespace="Genbox.FastData.Generator.Extensions" #>
35

46
<#@ parameter type="Genbox.FastData.Generator.CPlusPlus.TemplateModel" name="Model" #>
57
<#@ parameter type="Genbox.FastData.Generator.Template.CommonDataModel" name="Common" #>
8+
<#@ parameter type="Genbox.FastData.Generators.Contexts.EliasFanoContext" name="Context" #>
9+
10+
static constexpr int32_t lower_bit_count = <#= Context.LowerBitCount #>;
11+
<#= Model.GetFieldModifier(true) #>std::array<uint64_t, <#= Context.UpperBits.Length.ToStringInvariant() #>> upper_bits = {
12+
<#= FormatHelper.FormatColumns(Context.UpperBits, x => Model.ToValueLabel(x)) #>
13+
};
14+
<#
15+
if (Context.LowerBitCount != 0)
16+
{
17+
#>
18+
19+
<#= Model.GetFieldModifier(true) #>std::array<uint64_t, <#= Context.LowerBits.Length.ToStringInvariant() #>> lower_bits = {
20+
<#= FormatHelper.FormatColumns(Context.LowerBits, x => Model.ToValueLabel(x)) #>
21+
};
22+
23+
static constexpr uint64_t lower_mask = <#= Model.ToValueLabel(Context.LowerMask) #>;
24+
<#
25+
}
26+
#>
27+
28+
static constexpr int32_t sample_rate_shift = <#= Context.SampleRateShift #>;
29+
<#= Model.GetFieldModifier(true) #>std::array<int32_t, <#= Context.SamplePositions.Length.ToStringInvariant() #>> sample_positions = {
30+
<#= FormatHelper.FormatColumns(Context.SamplePositions, x => Model.ToValueLabel(x)) #>
31+
};
32+
33+
static constexpr int popcount(uint64_t value) noexcept {
34+
int count = 0;
35+
while (value != 0) {
36+
count += static_cast<int>(value & 1ULL);
37+
value >>= 1;
38+
}
39+
return count;
40+
}
41+
42+
static constexpr int trailing_zero_count(uint64_t value) noexcept {
43+
if (value == 0) {
44+
return 64;
45+
}
46+
47+
int count = 0;
48+
while ((value & 1ULL) == 0) {
49+
value >>= 1;
50+
count++;
51+
}
52+
return count;
53+
}
54+
55+
static constexpr int select_bit_in_word(uint64_t word, int rank) noexcept {
56+
if (static_cast<uint32_t>(rank) >= 64u)
57+
return -1;
58+
59+
int remaining = rank;
60+
uint64_t value = word;
61+
62+
while (remaining > 0) {
63+
if (value == 0)
64+
return -1;
65+
66+
value &= value - 1;
67+
remaining--;
68+
}
69+
70+
if (value == 0)
71+
return -1;
72+
73+
return trailing_zero_count(value);
74+
}
75+
76+
static constexpr int64_t select_zero(int64_t rank) noexcept {
77+
if (rank < 0)
78+
return -1;
79+
80+
const size_t sample_index = static_cast<size_t>(rank >> sample_rate_shift);
81+
if (sample_index >= sample_positions.size())
82+
return -1;
83+
84+
int64_t zero_rank = static_cast<int64_t>(sample_index) << sample_rate_shift;
85+
int64_t start_position = sample_positions[sample_index];
86+
size_t word_index = static_cast<size_t>(start_position >> 6);
87+
int start_bit = static_cast<int>(start_position & 63);
88+
89+
for (; word_index < upper_bits.size(); word_index++) {
90+
const int valid_bits = word_index == upper_bits.size() - 1 ? <#= (Context.UpperBitLength & 63) #> : 64;
91+
const uint64_t valid_mask = valid_bits == 64 ? std::numeric_limits<uint64_t>::max() : ((1ULL << valid_bits) - 1);
92+
uint64_t zeros = ~upper_bits[word_index] & valid_mask;
93+
94+
if (start_bit > 0) {
95+
zeros &= ~((1ULL << start_bit) - 1);
96+
start_bit = 0;
97+
}
98+
99+
const int zero_count = popcount(zeros);
100+
if (zero_count == 0)
101+
continue;
102+
103+
if (zero_rank + zero_count > rank) {
104+
const int rank_in_word = static_cast<int>(rank - zero_rank);
105+
const int bit_in_word = select_bit_in_word(zeros, rank_in_word);
106+
return (static_cast<int64_t>(word_index) << 6) + bit_in_word;
107+
}
108+
109+
zero_rank += zero_count;
110+
}
111+
112+
return -1;
113+
}
6114

7115
public:
8116
<#= Model.MethodAttribute #>
9117
<#= Model.GetMethodModifier(true) #>bool contains(const <#= Model.KeyTypeName #> <#= Common.InputKeyName #>)<#= Model.PostMethodModifier #> {
10118
<#= Model.GetMethodHeader(MethodType.Contains) #>
11-
return false;
12-
}
119+
120+
const int64_t value = static_cast<int64_t>(<#= Common.LookupKeyName #>);
121+
const int64_t high = value >> lower_bit_count;
122+
123+
int64_t position = high == 0 ? 0 : select_zero(high - 1) + 1;
124+
if (position < 0)
125+
return false;
126+
127+
int64_t rank = position - high;
128+
if (static_cast<uint64_t>(rank) >= item_count)
129+
return false;
130+
131+
size_t curr_word = static_cast<size_t>(position >> 6);
132+
133+
if (curr_word >= upper_bits.size())
134+
return false;
135+
136+
uint64_t window = upper_bits[curr_word] & (std::numeric_limits<uint64_t>::max() << static_cast<uint32_t>(position & 63));
137+
<#
138+
if (Context.LowerBitCount != 0)
139+
{
140+
#> const uint64_t target_low = static_cast<uint64_t>(value) & lower_mask;
141+
int64_t lower_bits_offset = rank * lower_bit_count;
142+
143+
while (true) {
144+
while (window == 0) {
145+
curr_word++;
146+
if (curr_word >= upper_bits.size())
147+
return false;
148+
149+
window = upper_bits[curr_word];
150+
}
151+
152+
const int trailing = trailing_zero_count(window);
153+
const int64_t one_position = (static_cast<int64_t>(curr_word) << 6) + trailing;
154+
const int64_t current_high = one_position - rank;
155+
156+
if (current_high >= high) {
157+
if (current_high > high)
158+
return false;
159+
160+
const size_t word_index = static_cast<size_t>(lower_bits_offset >> 6);
161+
const int start_bit = static_cast<int>(lower_bits_offset & 63);
162+
163+
uint64_t current_low = 0;
164+
if (start_bit + lower_bit_count <= 64)
165+
current_low = (lower_bits[word_index] >> start_bit) & lower_mask;
166+
else {
167+
const uint64_t lower = lower_bits[word_index] >> start_bit;
168+
const uint64_t upper = lower_bits[word_index + 1] << (64 - start_bit);
169+
current_low = (lower | upper) & lower_mask;
170+
}
171+
172+
if (current_low == target_low)
173+
return true;
174+
175+
if (current_low > target_low)
176+
return false;
177+
}
178+
179+
window &= window - 1;
180+
rank++;
181+
182+
if (static_cast<uint64_t>(rank) >= item_count)
183+
return false;
184+
185+
lower_bits_offset += lower_bit_count;
186+
}
187+
<#
188+
}
189+
else
190+
{
191+
#> while (true) {
192+
while (window == 0) {
193+
curr_word++;
194+
if (curr_word >= upper_bits.size())
195+
return false;
196+
197+
window = upper_bits[curr_word];
198+
}
199+
200+
const int trailing = trailing_zero_count(window);
201+
const int64_t one_position = (static_cast<int64_t>(curr_word) << 6) + trailing;
202+
const int64_t current_high = one_position - rank;
203+
204+
if (current_high >= high) {
205+
if (current_high > high)
206+
return false;
207+
208+
return true;
209+
}
210+
211+
window &= window - 1;
212+
rank++;
213+
214+
if (static_cast<uint64_t>(rank) >= item_count)
215+
return false;
216+
}
217+
<#
218+
}
219+
#>}
Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,100 @@
11
<#@ template language="C#" #>
22
<#@ import namespace="Genbox.FastData.Generator.Enums" #>
3+
<#@ import namespace="Genbox.FastData.Enums" #>
4+
<#@ import namespace="Genbox.FastData.Generator.Helpers" #>
5+
<#@ import namespace="System" #>
6+
<#@ import namespace="Genbox.FastData.Generator.Extensions" #>
37

48
<#@ parameter type="Genbox.FastData.Generator.CPlusPlus.TemplateModel" name="Model" #>
59
<#@ parameter type="Genbox.FastData.Generator.Template.CommonDataModel" name="Common" #>
10+
<#@ parameter type="Genbox.FastData.Generators.Contexts.RrrBitVectorContext" name="Context" #>
11+
12+
<#
13+
string mapSource = Model.KeyType switch
14+
{
15+
KeyType.Char => $"static_cast<uint64_t>({Common.LookupKeyName})",
16+
KeyType.Byte => $"static_cast<uint64_t>({Common.LookupKeyName})",
17+
KeyType.UInt16 => $"static_cast<uint64_t>({Common.LookupKeyName})",
18+
KeyType.UInt32 => $"static_cast<uint64_t>({Common.LookupKeyName})",
19+
KeyType.UInt64 => Common.LookupKeyName,
20+
KeyType.SByte => $"static_cast<uint64_t>(static_cast<uint8_t>({Common.LookupKeyName} ^ std::numeric_limits<int8_t>::lowest()))",
21+
KeyType.Int16 => $"static_cast<uint64_t>(static_cast<uint16_t>({Common.LookupKeyName} ^ std::numeric_limits<int16_t>::lowest()))",
22+
KeyType.Int32 => $"static_cast<uint64_t>(static_cast<uint32_t>({Common.LookupKeyName} ^ std::numeric_limits<int32_t>::lowest()))",
23+
KeyType.Int64 => $"static_cast<uint64_t>({Common.LookupKeyName} ^ std::numeric_limits<int64_t>::lowest())",
24+
_ => throw new InvalidOperationException("RRR bitvector only supports integral key types.")
25+
};
26+
#>
27+
static constexpr uint64_t rrr_min_value = <#= Model.ToValueLabel(Context.MinValue) #>;
28+
static constexpr uint64_t rrr_max_value = <#= Model.ToValueLabel(Context.MaxValue) #>;
29+
static constexpr int32_t rrr_block_size = <#= Context.BlockSize #>;
30+
<#= Model.GetFieldModifier(true) #>std::array<uint8_t, <#= Context.Classes.Length.ToStringInvariant() #>> rrr_classes = {
31+
<#= FormatHelper.FormatColumns(Context.Classes, x => Model.ToValueLabel(x)) #>
32+
};
33+
<#= Model.GetFieldModifier(true) #>std::array<uint32_t, <#= Context.Offsets.Length.ToStringInvariant() #>> rrr_offsets = {
34+
<#= FormatHelper.FormatColumns(Context.Offsets, x => Model.ToValueLabel(x)) #>
35+
};
36+
37+
static constexpr int32_t binomial(int32_t n, int32_t k) noexcept {
38+
if (k < 0 || k > n)
39+
return 0;
40+
41+
if (k == 0 || k == n)
42+
return 1;
43+
44+
if (k > n - k)
45+
k = n - k;
46+
47+
int32_t result = 1;
48+
49+
for (int32_t i = 1; i <= k; i++)
50+
result = result * (n - (k - i)) / i;
51+
52+
return result;
53+
}
54+
55+
static constexpr bool decode_bit(int32_t class_value, uint32_t rank, int32_t target_bit) noexcept {
56+
int32_t remaining = class_value;
57+
58+
for (int32_t bit = rrr_block_size - 1; bit >= 0; bit--) {
59+
if (remaining == 0)
60+
return false;
61+
62+
const int32_t comb = binomial(bit, remaining);
63+
bool is_set = false;
64+
65+
if (rank >= static_cast<uint32_t>(comb)) {
66+
rank -= static_cast<uint32_t>(comb);
67+
remaining--;
68+
is_set = true;
69+
}
70+
else
71+
is_set = false;
72+
73+
if (bit == target_bit)
74+
return is_set;
75+
}
76+
77+
return false;
78+
}
679

780
public:
881
<#= Model.MethodAttribute #>
982
<#= Model.GetMethodModifier(true) #>bool contains(const <#= Model.KeyTypeName #> <#= Common.InputKeyName #>)<#= Model.PostMethodModifier #> {
1083
<#= Model.GetMethodHeader(MethodType.Contains) #>
11-
return false;
12-
}
84+
85+
const uint64_t mapped = <#= mapSource #>;
86+
87+
if (mapped < rrr_min_value || mapped > rrr_max_value)
88+
return false;
89+
90+
const uint64_t normalized = mapped - rrr_min_value;
91+
const int32_t block_index = static_cast<int32_t>(normalized / static_cast<uint64_t>(rrr_block_size));
92+
const int32_t bit_in_block = static_cast<int32_t>(normalized % static_cast<uint64_t>(rrr_block_size));
93+
const int32_t class_value = rrr_classes[static_cast<size_t>(block_index)];
94+
95+
if (class_value == 0)
96+
return false;
97+
98+
const uint32_t rank = rrr_offsets[static_cast<size_t>(block_index)];
99+
return decode_bit(class_value, rank, bit_in_block);
100+
}

0 commit comments

Comments
 (0)