Skip to content

Commit 0c0164c

Browse files
committed
JIT wasm: call removeFunction on release of modules
Signed-off-by: Paul Guyot <pguyot@kallisys.net>
1 parent 77d6381 commit 0c0164c

1 file changed

Lines changed: 103 additions & 5 deletions

File tree

src/platforms/emscripten/src/lib/jit_stream_wasm.c

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#ifndef AVM_NO_JIT
3030

31+
#include <stdatomic.h>
3132
#include <stdlib.h>
3233
#include <string.h>
3334

@@ -43,6 +44,26 @@
4344
#include "sys.h"
4445
#include "term.h"
4546

47+
// Stable module ID counter to prevent ABA aliasing
48+
#ifndef AVM_NO_SMP
49+
static _Atomic uint32_t jit_next_module_id = 1;
50+
#else
51+
static uint32_t jit_next_module_id = 1;
52+
#endif
53+
54+
// Cross-thread invalidation ring: when a module is released, its ID and entry
55+
// count are pushed here so other threads can lazily call removeFunction().
56+
#ifndef AVM_NO_SMP
57+
#define JIT_INVALIDATION_RING_SIZE 64
58+
struct JITInvalidation
59+
{
60+
uint32_t module_id;
61+
uint32_t num_entries;
62+
};
63+
static struct JITInvalidation jit_invalidation_ring[JIT_INVALIDATION_RING_SIZE];
64+
static _Atomic uint32_t jit_invalidation_write_idx = 0;
65+
#endif
66+
4667
static term nif_jit_stream_module(Context *ctx, int argc, term argv[])
4768
{
4869
UNUSED(argc);
@@ -56,14 +77,72 @@ static const struct Nif jit_stream_module_nif = {
5677
.nif_ptr = nif_jit_stream_module
5778
};
5879

80+
// Remove per-thread wasmTable entries for a released module, and post to the
81+
// invalidation ring so other threads clean up their own entries lazily.
82+
// ring_ptr/ring_size/write_idx_ptr are 0 in non-SMP builds.
83+
// clang-format off
84+
EM_JS(void, jit_release_module, (uint32_t module_id, int num_entries, int ring_ptr, int ring_size, int write_idx_ptr), {
85+
// Remove from this thread's cache
86+
var cache = Module._jitCache && Module._jitCache.get(module_id);
87+
if (cache) {
88+
for (var i = 0; i < num_entries; i++) {
89+
if (cache[i]) {
90+
removeFunction(cache[i]);
91+
}
92+
}
93+
Module._jitCache.delete(module_id);
94+
}
95+
96+
// Notify other threads via ring (SMP only).
97+
// Use a CAS loop so the write index is only advanced after the entry is written.
98+
if (ring_ptr) {
99+
var int32View = new Int32Array(HEAPU8.buffer);
100+
var oldIdx;
101+
do {
102+
oldIdx = Atomics.load(int32View, write_idx_ptr >> 2);
103+
var slot = oldIdx % ring_size;
104+
var entryBase = (ring_ptr + slot * 8) >> 2;
105+
Atomics.store(int32View, entryBase, module_id);
106+
Atomics.store(int32View, entryBase + 1, num_entries);
107+
} while (Atomics.compareExchange(int32View, write_idx_ptr >> 2, oldIdx, oldIdx + 1) !== oldIdx);
108+
}
109+
})
110+
// clang-format on
111+
59112
// Lazily compile the WASM module on this thread and return the addFunction
60-
// index for the given label. Module._jitCache is thread-local.
61-
EM_JS(int, jit_get_thread_func_ptr, (const uint8_t *wasm_binary, int wasm_size, int num_entries, int label), {
113+
// index for the given label. Module._jitCache is thread-local, keyed by stable
114+
// module_id to prevent ABA aliasing when C memory is freed and reused.
115+
// ring_ptr/ring_size/write_idx_ptr are 0 in non-SMP builds.
116+
EM_JS(int, jit_get_thread_func_ptr, (uint32_t module_id, const uint8_t *wasm_binary, int wasm_size, int num_entries, int label, int ring_ptr, int ring_size, int write_idx_ptr), {
62117
if (!Module._jitCache) {
63118
Module._jitCache = new Map();
119+
Module._jitReadIdx = 0;
120+
}
121+
122+
// Drain invalidation ring: call removeFunction for any modules released
123+
// by other threads since we last checked (SMP only).
124+
if (ring_ptr) {
125+
var int32View = new Int32Array(HEAPU8.buffer);
126+
var writeIdx = Atomics.load(int32View, write_idx_ptr >> 2);
127+
while (Module._jitReadIdx < writeIdx) {
128+
var slot = Module._jitReadIdx % ring_size;
129+
var entryBase = (ring_ptr + slot * 8) >> 2;
130+
var entryModuleId = Atomics.load(int32View, entryBase);
131+
var entryNumEntries = Atomics.load(int32View, entryBase + 1);
132+
var staleCache = Module._jitCache.get(entryModuleId);
133+
if (staleCache) {
134+
for (var j = 0; j < entryNumEntries; j++) {
135+
if (staleCache[j]) {
136+
removeFunction(staleCache[j]);
137+
}
138+
}
139+
Module._jitCache.delete(entryModuleId);
140+
}
141+
Module._jitReadIdx++;
142+
}
64143
}
65144

66-
var cache = Module._jitCache.get(wasm_binary);
145+
var cache = Module._jitCache.get(module_id);
67146
if (!cache) {
68147
try {
69148
var bytes = new Uint8Array(HEAPU8.buffer, wasm_binary, wasm_size);
@@ -83,7 +162,7 @@ EM_JS(int, jit_get_thread_func_ptr, (const uint8_t *wasm_binary, int wasm_size,
83162
cache[i] = addFunction(func, 'iiii');
84163
}
85164
}
86-
Module._jitCache.set(wasm_binary, cache);
165+
Module._jitCache.set(module_id, cache);
87166
} catch (e) {
88167
err("JIT per-thread WASM compilation failed: " + e.message);
89168
return 0;
@@ -103,6 +182,7 @@ struct JITWasmHeader
103182
uint8_t *wasm_binary;
104183
uint32_t wasm_binary_size;
105184
uint32_t num_entries;
185+
uint32_t module_id;
106186
};
107187

108188
/**
@@ -132,7 +212,13 @@ ModuleNativeEntryPoint jit_wasm_get_entry_point(const void *native_code, int lab
132212
if (IS_NULL_PTR(header) || IS_NULL_PTR(header->wasm_binary)) {
133213
return NULL;
134214
}
135-
int fp = jit_get_thread_func_ptr(header->wasm_binary, header->wasm_binary_size, header->num_entries, label);
215+
#ifndef AVM_NO_SMP
216+
int fp = jit_get_thread_func_ptr(header->module_id, header->wasm_binary, header->wasm_binary_size, header->num_entries, label,
217+
(int) (uintptr_t) jit_invalidation_ring, JIT_INVALIDATION_RING_SIZE,
218+
(int) (uintptr_t) &jit_invalidation_write_idx);
219+
#else
220+
int fp = jit_get_thread_func_ptr(header->module_id, header->wasm_binary, header->wasm_binary_size, header->num_entries, label, 0, 0, 0);
221+
#endif
136222
return (ModuleNativeEntryPoint) (uintptr_t) fp;
137223
}
138224

@@ -187,6 +273,11 @@ static ModuleNativeEntryPoint compile_wasm_stream(const uint8_t *data, size_t da
187273
memcpy(header->wasm_binary, wasm_data, wasm_size);
188274
header->wasm_binary_size = wasm_size;
189275
header->num_entries = num_entries;
276+
#ifndef AVM_NO_SMP
277+
header->module_id = atomic_fetch_add_explicit(&jit_next_module_id, 1, memory_order_relaxed);
278+
#else
279+
header->module_id = jit_next_module_id++;
280+
#endif
190281

191282
header->lines_metadata = NULL;
192283
if (lines_offset > 0 && lines_offset < data_size) {
@@ -228,6 +319,13 @@ void sys_release_native_code(ModuleNativeEntryPoint entry_point)
228319
ModuleNativeEntryPoint *func_table = (ModuleNativeEntryPoint *) entry_point;
229320
struct JITWasmHeader *header = (struct JITWasmHeader *) (uintptr_t) func_table[-1];
230321
if (!IS_NULL_PTR(header)) {
322+
#ifndef AVM_NO_SMP
323+
jit_release_module(header->module_id, header->num_entries,
324+
(int) (uintptr_t) jit_invalidation_ring, JIT_INVALIDATION_RING_SIZE,
325+
(int) (uintptr_t) &jit_invalidation_write_idx);
326+
#else
327+
jit_release_module(header->module_id, header->num_entries, 0, 0, 0);
328+
#endif
231329
free(header->wasm_binary);
232330
free(header->lines_metadata);
233331
free(header);

0 commit comments

Comments
 (0)