Skip to content

Commit 9505890

Browse files
committed
JIT WASM32: call removeFunction on release of modules
Signed-off-by: Paul Guyot <pguyot@kallisys.net>
1 parent 57ff1e4 commit 9505890

1 file changed

Lines changed: 105 additions & 5 deletions

File tree

src/platforms/emscripten/src/lib/jit_stream_wasm.c

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
#ifndef AVM_NO_JIT
3030

31+
#include <stdatomic.h>
3132
#include <stdlib.h>
3233
#include <string.h>
3334

@@ -43,6 +44,26 @@
4344
#include "sys.h"
4445
#include "term.h"
4546

47+
// Stable module ID counter to prevent ABA aliasing
48+
#ifndef AVM_NO_SMP
49+
static _Atomic uint32_t jit_next_module_id = 1;
50+
#else
51+
static uint32_t jit_next_module_id = 1;
52+
#endif
53+
54+
// Cross-thread invalidation ring: when a module is released, its ID and entry
55+
// count are pushed here so other threads can lazily call removeFunction().
56+
#ifndef AVM_NO_SMP
57+
#define JIT_INVALIDATION_RING_SIZE 64
58+
struct JITInvalidation
59+
{
60+
uint32_t module_id;
61+
uint32_t num_entries;
62+
};
63+
static struct JITInvalidation jit_invalidation_ring[JIT_INVALIDATION_RING_SIZE];
64+
static _Atomic uint32_t jit_invalidation_write_idx = 0;
65+
#endif
66+
4667
static term nif_jit_stream_module(Context *ctx, int argc, term argv[])
4768
{
4869
UNUSED(argc);
@@ -56,14 +77,74 @@ static const struct Nif jit_stream_module_nif = {
5677
.nif_ptr = nif_jit_stream_module
5778
};
5879

80+
// Remove per-thread wasmTable entries for a released module, and post to the
81+
// invalidation ring so other threads clean up their own entries lazily.
82+
// ring_ptr/ring_size/write_idx_ptr are 0 in non-SMP builds.
83+
EM_JS_DEPS(jit_release_module, "$removeFunction")
84+
// clang-format off
85+
EM_JS(void, jit_release_module, (uint32_t module_id, int num_entries, int ring_ptr, int ring_size, int write_idx_ptr), {
86+
// Remove from this thread's cache
87+
var cache = Module._jitCache && Module._jitCache.get(module_id);
88+
if (cache) {
89+
for (var i = 0; i < num_entries; i++) {
90+
if (cache[i]) {
91+
removeFunction(cache[i]);
92+
}
93+
}
94+
Module._jitCache.delete(module_id);
95+
}
96+
97+
// Notify other threads via ring (SMP only).
98+
// Use a CAS loop so the write index is only advanced after the entry is written.
99+
if (ring_ptr) {
100+
var int32View = new Int32Array(HEAPU8.buffer);
101+
var oldIdx;
102+
do {
103+
oldIdx = Atomics.load(int32View, write_idx_ptr >> 2);
104+
var slot = oldIdx % ring_size;
105+
var entryBase = (ring_ptr + slot * 8) >> 2;
106+
Atomics.store(int32View, entryBase, module_id);
107+
Atomics.store(int32View, entryBase + 1, num_entries);
108+
} while (Atomics.compareExchange(int32View, write_idx_ptr >> 2, oldIdx, oldIdx + 1) !== oldIdx);
109+
}
110+
})
111+
// clang-format on
112+
59113
// Lazily compile the WASM module on this thread and return the addFunction
60-
// index for the given label. Module._jitCache is thread-local.
61-
EM_JS(int, jit_get_thread_func_ptr, (const uint8_t *wasm_binary, int wasm_size, int num_entries, int label), {
114+
// index for the given label. Module._jitCache is thread-local, keyed by stable
115+
// module_id to prevent ABA aliasing when C memory is freed and reused.
116+
// ring_ptr/ring_size/write_idx_ptr are 0 in non-SMP builds.
117+
EM_JS_DEPS(jit_get_thread_func_ptr, "$addFunction,$removeFunction")
118+
EM_JS(int, jit_get_thread_func_ptr, (uint32_t module_id, const uint8_t *wasm_binary, int wasm_size, int num_entries, int label, int ring_ptr, int ring_size, int write_idx_ptr), {
62119
if (!Module._jitCache) {
63120
Module._jitCache = new Map();
121+
Module._jitReadIdx = 0;
122+
}
123+
124+
// Drain invalidation ring: call removeFunction for any modules released
125+
// by other threads since we last checked (SMP only).
126+
if (ring_ptr) {
127+
var int32View = new Int32Array(HEAPU8.buffer);
128+
var writeIdx = Atomics.load(int32View, write_idx_ptr >> 2);
129+
while (Module._jitReadIdx < writeIdx) {
130+
var slot = Module._jitReadIdx % ring_size;
131+
var entryBase = (ring_ptr + slot * 8) >> 2;
132+
var entryModuleId = Atomics.load(int32View, entryBase);
133+
var entryNumEntries = Atomics.load(int32View, entryBase + 1);
134+
var staleCache = Module._jitCache.get(entryModuleId);
135+
if (staleCache) {
136+
for (var j = 0; j < entryNumEntries; j++) {
137+
if (staleCache[j]) {
138+
removeFunction(staleCache[j]);
139+
}
140+
}
141+
Module._jitCache.delete(entryModuleId);
142+
}
143+
Module._jitReadIdx++;
144+
}
64145
}
65146

66-
var cache = Module._jitCache.get(wasm_binary);
147+
var cache = Module._jitCache.get(module_id);
67148
if (!cache) {
68149
try {
69150
var bytes = new Uint8Array(HEAPU8.buffer, wasm_binary, wasm_size);
@@ -83,7 +164,7 @@ EM_JS(int, jit_get_thread_func_ptr, (const uint8_t *wasm_binary, int wasm_size,
83164
cache[i] = addFunction(func, 'iiii');
84165
}
85166
}
86-
Module._jitCache.set(wasm_binary, cache);
167+
Module._jitCache.set(module_id, cache);
87168
} catch (e) {
88169
err("JIT per-thread WASM compilation failed: " + e.message);
89170
return 0;
@@ -103,6 +184,7 @@ struct JITWasmHeader
103184
uint8_t *wasm_binary;
104185
uint32_t wasm_binary_size;
105186
uint32_t num_entries;
187+
uint32_t module_id;
106188
};
107189

108190
/**
@@ -132,7 +214,13 @@ ModuleNativeEntryPoint jit_wasm_get_entry_point(const void *native_code, int lab
132214
if (IS_NULL_PTR(header) || IS_NULL_PTR(header->wasm_binary)) {
133215
return NULL;
134216
}
135-
int fp = jit_get_thread_func_ptr(header->wasm_binary, header->wasm_binary_size, header->num_entries, label);
217+
#ifndef AVM_NO_SMP
218+
int fp = jit_get_thread_func_ptr(header->module_id, header->wasm_binary, header->wasm_binary_size, header->num_entries, label,
219+
(int) (uintptr_t) jit_invalidation_ring, JIT_INVALIDATION_RING_SIZE,
220+
(int) (uintptr_t) &jit_invalidation_write_idx);
221+
#else
222+
int fp = jit_get_thread_func_ptr(header->module_id, header->wasm_binary, header->wasm_binary_size, header->num_entries, label, 0, 0, 0);
223+
#endif
136224
return (ModuleNativeEntryPoint) (uintptr_t) fp;
137225
}
138226

@@ -187,6 +275,11 @@ static ModuleNativeEntryPoint compile_wasm_stream(const uint8_t *data, size_t da
187275
memcpy(header->wasm_binary, wasm_data, wasm_size);
188276
header->wasm_binary_size = wasm_size;
189277
header->num_entries = num_entries;
278+
#ifndef AVM_NO_SMP
279+
header->module_id = atomic_fetch_add_explicit(&jit_next_module_id, 1, memory_order_relaxed);
280+
#else
281+
header->module_id = jit_next_module_id++;
282+
#endif
190283

191284
header->lines_metadata = NULL;
192285
if (lines_offset > 0 && lines_offset < data_size) {
@@ -228,6 +321,13 @@ void sys_release_native_code(ModuleNativeEntryPoint entry_point)
228321
ModuleNativeEntryPoint *func_table = (ModuleNativeEntryPoint *) entry_point;
229322
struct JITWasmHeader *header = (struct JITWasmHeader *) (uintptr_t) func_table[-1];
230323
if (!IS_NULL_PTR(header)) {
324+
#ifndef AVM_NO_SMP
325+
jit_release_module(header->module_id, header->num_entries,
326+
(int) (uintptr_t) jit_invalidation_ring, JIT_INVALIDATION_RING_SIZE,
327+
(int) (uintptr_t) &jit_invalidation_write_idx);
328+
#else
329+
jit_release_module(header->module_id, header->num_entries, 0, 0, 0);
330+
#endif
231331
free(header->wasm_binary);
232332
free(header->lines_metadata);
233333
free(header);

0 commit comments

Comments
 (0)