2828
2929#ifndef AVM_NO_JIT
3030
31+ #include <stdatomic.h>
3132#include <stdlib.h>
3233#include <string.h>
3334
4344#include "sys.h"
4445#include "term.h"
4546
47+ // Stable module ID counter to prevent ABA aliasing
48+ #ifndef AVM_NO_SMP
49+ static _Atomic uint32_t jit_next_module_id = 1 ;
50+ #else
51+ static uint32_t jit_next_module_id = 1 ;
52+ #endif
53+
54+ // Cross-thread invalidation ring: when a module is released, its ID and entry
55+ // count are pushed here so other threads can lazily call removeFunction().
56+ #ifndef AVM_NO_SMP
57+ #define JIT_INVALIDATION_RING_SIZE 64
58+ struct JITInvalidation
59+ {
60+ uint32_t module_id ;
61+ uint32_t num_entries ;
62+ };
63+ static struct JITInvalidation jit_invalidation_ring [JIT_INVALIDATION_RING_SIZE ];
64+ static _Atomic uint32_t jit_invalidation_write_idx = 0 ;
65+ #endif
66+
4667static term nif_jit_stream_module (Context * ctx , int argc , term argv [])
4768{
4869 UNUSED (argc );
@@ -56,14 +77,74 @@ static const struct Nif jit_stream_module_nif = {
5677 .nif_ptr = nif_jit_stream_module
5778};
5879
80+ // Remove per-thread wasmTable entries for a released module, and post to the
81+ // invalidation ring so other threads clean up their own entries lazily.
82+ // ring_ptr/ring_size/write_idx_ptr are 0 in non-SMP builds.
83+ EM_JS_DEPS (jit_release_module , "$removeFunction" )
84+ // clang-format off
85+ EM_JS (void , jit_release_module , (uint32_t module_id , int num_entries , int ring_ptr , int ring_size , int write_idx_ptr ), {
86+ // Remove from this thread's cache
87+ var cache = Module ._jitCache && Module ._jitCache .get (module_id );
88+ if (cache ) {
89+ for (var i = 0 ; i < num_entries ; i ++ ) {
90+ if (cache [i ]) {
91+ removeFunction (cache [i ]);
92+ }
93+ }
94+ Module ._jitCache .delete (module_id );
95+ }
96+
97+ // Notify other threads via ring (SMP only).
98+ // Use a CAS loop so the write index is only advanced after the entry is written.
99+ if (ring_ptr ) {
100+ var int32View = new Int32Array (HEAPU8 .buffer );
101+ var oldIdx ;
102+ do {
103+ oldIdx = Atomics .load (int32View , write_idx_ptr >> 2 );
104+ var slot = oldIdx % ring_size ;
105+ var entryBase = (ring_ptr + slot * 8 ) >> 2 ;
106+ Atomics .store (int32View , entryBase , module_id );
107+ Atomics .store (int32View , entryBase + 1 , num_entries );
108+ } while (Atomics .compareExchange (int32View , write_idx_ptr >> 2 , oldIdx , oldIdx + 1 ) != = oldIdx );
109+ }
110+ })
111+ // clang-format on
112+
59113// Lazily compile the WASM module on this thread and return the addFunction
60- // index for the given label. Module._jitCache is thread-local.
61- EM_JS (int , jit_get_thread_func_ptr , (const uint8_t * wasm_binary , int wasm_size , int num_entries , int label ), {
114+ // index for the given label. Module._jitCache is thread-local, keyed by stable
115+ // module_id to prevent ABA aliasing when C memory is freed and reused.
116+ // ring_ptr/ring_size/write_idx_ptr are 0 in non-SMP builds.
117+ EM_JS_DEPS (jit_get_thread_func_ptr , "$addFunction,$removeFunction" )
118+ EM_JS (int , jit_get_thread_func_ptr , (uint32_t module_id , const uint8_t * wasm_binary , int wasm_size , int num_entries , int label , int ring_ptr , int ring_size , int write_idx_ptr ), {
62119 if (!Module ._jitCache ) {
63120 Module ._jitCache = new Map ();
121+ Module ._jitReadIdx = 0 ;
122+ }
123+
124+ // Drain invalidation ring: call removeFunction for any modules released
125+ // by other threads since we last checked (SMP only).
126+ if (ring_ptr ) {
127+ var int32View = new Int32Array (HEAPU8 .buffer );
128+ var writeIdx = Atomics .load (int32View , write_idx_ptr >> 2 );
129+ while (Module ._jitReadIdx < writeIdx ) {
130+ var slot = Module ._jitReadIdx % ring_size ;
131+ var entryBase = (ring_ptr + slot * 8 ) >> 2 ;
132+ var entryModuleId = Atomics .load (int32View , entryBase );
133+ var entryNumEntries = Atomics .load (int32View , entryBase + 1 );
134+ var staleCache = Module ._jitCache .get (entryModuleId );
135+ if (staleCache ) {
136+ for (var j = 0 ; j < entryNumEntries ; j ++ ) {
137+ if (staleCache [j ]) {
138+ removeFunction (staleCache [j ]);
139+ }
140+ }
141+ Module ._jitCache .delete (entryModuleId );
142+ }
143+ Module ._jitReadIdx ++ ;
144+ }
64145 }
65146
66- var cache = Module ._jitCache .get (wasm_binary );
147+ var cache = Module ._jitCache .get (module_id );
67148 if (!cache ) {
68149 try {
69150 var bytes = new Uint8Array (HEAPU8 .buffer , wasm_binary , wasm_size );
@@ -83,7 +164,7 @@ EM_JS(int, jit_get_thread_func_ptr, (const uint8_t *wasm_binary, int wasm_size,
83164 cache [i ] = addFunction (func , 'iiii' );
84165 }
85166 }
86- Module ._jitCache .set (wasm_binary , cache );
167+ Module ._jitCache .set (module_id , cache );
87168 } catch (e ) {
88169 err ("JIT per-thread WASM compilation failed: " + e .message );
89170 return 0 ;
@@ -103,6 +184,7 @@ struct JITWasmHeader
103184 uint8_t * wasm_binary ;
104185 uint32_t wasm_binary_size ;
105186 uint32_t num_entries ;
187+ uint32_t module_id ;
106188};
107189
108190/**
@@ -132,7 +214,13 @@ ModuleNativeEntryPoint jit_wasm_get_entry_point(const void *native_code, int lab
132214 if (IS_NULL_PTR (header ) || IS_NULL_PTR (header -> wasm_binary )) {
133215 return NULL ;
134216 }
135- int fp = jit_get_thread_func_ptr (header -> wasm_binary , header -> wasm_binary_size , header -> num_entries , label );
217+ #ifndef AVM_NO_SMP
218+ int fp = jit_get_thread_func_ptr (header -> module_id , header -> wasm_binary , header -> wasm_binary_size , header -> num_entries , label ,
219+ (int ) (uintptr_t ) jit_invalidation_ring , JIT_INVALIDATION_RING_SIZE ,
220+ (int ) (uintptr_t ) & jit_invalidation_write_idx );
221+ #else
222+ int fp = jit_get_thread_func_ptr (header -> module_id , header -> wasm_binary , header -> wasm_binary_size , header -> num_entries , label , 0 , 0 , 0 );
223+ #endif
136224 return (ModuleNativeEntryPoint ) (uintptr_t ) fp ;
137225}
138226
@@ -187,6 +275,11 @@ static ModuleNativeEntryPoint compile_wasm_stream(const uint8_t *data, size_t da
187275 memcpy (header -> wasm_binary , wasm_data , wasm_size );
188276 header -> wasm_binary_size = wasm_size ;
189277 header -> num_entries = num_entries ;
278+ #ifndef AVM_NO_SMP
279+ header -> module_id = atomic_fetch_add_explicit (& jit_next_module_id , 1 , memory_order_relaxed );
280+ #else
281+ header -> module_id = jit_next_module_id ++ ;
282+ #endif
190283
191284 header -> lines_metadata = NULL ;
192285 if (lines_offset > 0 && lines_offset < data_size ) {
@@ -228,6 +321,13 @@ void sys_release_native_code(ModuleNativeEntryPoint entry_point)
228321 ModuleNativeEntryPoint * func_table = (ModuleNativeEntryPoint * ) entry_point ;
229322 struct JITWasmHeader * header = (struct JITWasmHeader * ) (uintptr_t ) func_table [-1 ];
230323 if (!IS_NULL_PTR (header )) {
324+ #ifndef AVM_NO_SMP
325+ jit_release_module (header -> module_id , header -> num_entries ,
326+ (int ) (uintptr_t ) jit_invalidation_ring , JIT_INVALIDATION_RING_SIZE ,
327+ (int ) (uintptr_t ) & jit_invalidation_write_idx );
328+ #else
329+ jit_release_module (header -> module_id , header -> num_entries , 0 , 0 , 0 );
330+ #endif
231331 free (header -> wasm_binary );
232332 free (header -> lines_metadata );
233333 free (header );
0 commit comments