2828
2929#ifndef AVM_NO_JIT
3030
31+ #include <stdatomic.h>
3132#include <stdlib.h>
3233#include <string.h>
3334
4344#include "sys.h"
4445#include "term.h"
4546
47+ // Stable module ID counter to prevent ABA aliasing
48+ #ifndef AVM_NO_SMP
49+ static _Atomic uint32_t jit_next_module_id = 1 ;
50+ #else
51+ static uint32_t jit_next_module_id = 1 ;
52+ #endif
53+
54+ // Cross-thread invalidation ring: when a module is released, its ID and entry
55+ // count are pushed here so other threads can lazily call removeFunction().
56+ #ifndef AVM_NO_SMP
57+ #define JIT_INVALIDATION_RING_SIZE 64
58+ struct JITInvalidation
59+ {
60+ uint32_t module_id ;
61+ uint32_t num_entries ;
62+ };
63+ static struct JITInvalidation jit_invalidation_ring [JIT_INVALIDATION_RING_SIZE ];
64+ static _Atomic uint32_t jit_invalidation_write_idx = 0 ;
65+ #endif
66+
4667static term nif_jit_stream_module (Context * ctx , int argc , term argv [])
4768{
4869 UNUSED (argc );
@@ -56,14 +77,72 @@ static const struct Nif jit_stream_module_nif = {
5677 .nif_ptr = nif_jit_stream_module
5778};
5879
80+ // Remove per-thread wasmTable entries for a released module, and post to the
81+ // invalidation ring so other threads clean up their own entries lazily.
82+ // ring_ptr/ring_size/write_idx_ptr are 0 in non-SMP builds.
83+ // clang-format off
84+ EM_JS (void , jit_release_module , (uint32_t module_id , int num_entries , int ring_ptr , int ring_size , int write_idx_ptr ), {
85+ // Remove from this thread's cache
86+ var cache = Module ._jitCache && Module ._jitCache .get (module_id );
87+ if (cache ) {
88+ for (var i = 0 ; i < num_entries ; i ++ ) {
89+ if (cache [i ]) {
90+ removeFunction (cache [i ]);
91+ }
92+ }
93+ Module ._jitCache .delete (module_id );
94+ }
95+
96+ // Notify other threads via ring (SMP only).
97+ // Use a CAS loop so the write index is only advanced after the entry is written.
98+ if (ring_ptr ) {
99+ var int32View = new Int32Array (HEAPU8 .buffer );
100+ var oldIdx ;
101+ do {
102+ oldIdx = Atomics .load (int32View , write_idx_ptr >> 2 );
103+ var slot = oldIdx % ring_size ;
104+ var entryBase = (ring_ptr + slot * 8 ) >> 2 ;
105+ Atomics .store (int32View , entryBase , module_id );
106+ Atomics .store (int32View , entryBase + 1 , num_entries );
107+ } while (Atomics .compareExchange (int32View , write_idx_ptr >> 2 , oldIdx , oldIdx + 1 ) != = oldIdx );
108+ }
109+ })
110+ // clang-format on
111+
59112// Lazily compile the WASM module on this thread and return the addFunction
60- // index for the given label. Module._jitCache is thread-local.
61- EM_JS (int , jit_get_thread_func_ptr , (const uint8_t * wasm_binary , int wasm_size , int num_entries , int label ), {
113+ // index for the given label. Module._jitCache is thread-local, keyed by stable
114+ // module_id to prevent ABA aliasing when C memory is freed and reused.
115+ // ring_ptr/ring_size/write_idx_ptr are 0 in non-SMP builds.
116+ EM_JS (int , jit_get_thread_func_ptr , (uint32_t module_id , const uint8_t * wasm_binary , int wasm_size , int num_entries , int label , int ring_ptr , int ring_size , int write_idx_ptr ), {
62117 if (!Module ._jitCache ) {
63118 Module ._jitCache = new Map ();
119+ Module ._jitReadIdx = 0 ;
120+ }
121+
122+ // Drain invalidation ring: call removeFunction for any modules released
123+ // by other threads since we last checked (SMP only).
124+ if (ring_ptr ) {
125+ var int32View = new Int32Array (HEAPU8 .buffer );
126+ var writeIdx = Atomics .load (int32View , write_idx_ptr >> 2 );
127+ while (Module ._jitReadIdx < writeIdx ) {
128+ var slot = Module ._jitReadIdx % ring_size ;
129+ var entryBase = (ring_ptr + slot * 8 ) >> 2 ;
130+ var entryModuleId = Atomics .load (int32View , entryBase );
131+ var entryNumEntries = Atomics .load (int32View , entryBase + 1 );
132+ var staleCache = Module ._jitCache .get (entryModuleId );
133+ if (staleCache ) {
134+ for (var j = 0 ; j < entryNumEntries ; j ++ ) {
135+ if (staleCache [j ]) {
136+ removeFunction (staleCache [j ]);
137+ }
138+ }
139+ Module ._jitCache .delete (entryModuleId );
140+ }
141+ Module ._jitReadIdx ++ ;
142+ }
64143 }
65144
66- var cache = Module ._jitCache .get (wasm_binary );
145+ var cache = Module ._jitCache .get (module_id );
67146 if (!cache ) {
68147 try {
69148 var bytes = new Uint8Array (HEAPU8 .buffer , wasm_binary , wasm_size );
@@ -83,7 +162,7 @@ EM_JS(int, jit_get_thread_func_ptr, (const uint8_t *wasm_binary, int wasm_size,
83162 cache [i ] = addFunction (func , 'iiii' );
84163 }
85164 }
86- Module ._jitCache .set (wasm_binary , cache );
165+ Module ._jitCache .set (module_id , cache );
87166 } catch (e ) {
88167 err ("JIT per-thread WASM compilation failed: " + e .message );
89168 return 0 ;
@@ -103,6 +182,7 @@ struct JITWasmHeader
103182 uint8_t * wasm_binary ;
104183 uint32_t wasm_binary_size ;
105184 uint32_t num_entries ;
185+ uint32_t module_id ;
106186};
107187
108188/**
@@ -132,7 +212,13 @@ ModuleNativeEntryPoint jit_wasm_get_entry_point(const void *native_code, int lab
132212 if (IS_NULL_PTR (header ) || IS_NULL_PTR (header -> wasm_binary )) {
133213 return NULL ;
134214 }
135- int fp = jit_get_thread_func_ptr (header -> wasm_binary , header -> wasm_binary_size , header -> num_entries , label );
215+ #ifndef AVM_NO_SMP
216+ int fp = jit_get_thread_func_ptr (header -> module_id , header -> wasm_binary , header -> wasm_binary_size , header -> num_entries , label ,
217+ (int ) (uintptr_t ) jit_invalidation_ring , JIT_INVALIDATION_RING_SIZE ,
218+ (int ) (uintptr_t ) & jit_invalidation_write_idx );
219+ #else
220+ int fp = jit_get_thread_func_ptr (header -> module_id , header -> wasm_binary , header -> wasm_binary_size , header -> num_entries , label , 0 , 0 , 0 );
221+ #endif
136222 return (ModuleNativeEntryPoint ) (uintptr_t ) fp ;
137223}
138224
@@ -187,6 +273,11 @@ static ModuleNativeEntryPoint compile_wasm_stream(const uint8_t *data, size_t da
187273 memcpy (header -> wasm_binary , wasm_data , wasm_size );
188274 header -> wasm_binary_size = wasm_size ;
189275 header -> num_entries = num_entries ;
276+ #ifndef AVM_NO_SMP
277+ header -> module_id = atomic_fetch_add_explicit (& jit_next_module_id , 1 , memory_order_relaxed );
278+ #else
279+ header -> module_id = jit_next_module_id ++ ;
280+ #endif
190281
191282 header -> lines_metadata = NULL ;
192283 if (lines_offset > 0 && lines_offset < data_size ) {
@@ -228,6 +319,13 @@ void sys_release_native_code(ModuleNativeEntryPoint entry_point)
228319 ModuleNativeEntryPoint * func_table = (ModuleNativeEntryPoint * ) entry_point ;
229320 struct JITWasmHeader * header = (struct JITWasmHeader * ) (uintptr_t ) func_table [-1 ];
230321 if (!IS_NULL_PTR (header )) {
322+ #ifndef AVM_NO_SMP
323+ jit_release_module (header -> module_id , header -> num_entries ,
324+ (int ) (uintptr_t ) jit_invalidation_ring , JIT_INVALIDATION_RING_SIZE ,
325+ (int ) (uintptr_t ) & jit_invalidation_write_idx );
326+ #else
327+ jit_release_module (header -> module_id , header -> num_entries , 0 , 0 , 0 );
328+ #endif
231329 free (header -> wasm_binary );
232330 free (header -> lines_metadata );
233331 free (header );
0 commit comments