@@ -27,6 +27,141 @@ use crate::sandbox::metrics::{
2727 METRIC_ACTIVE_WASM_SANDBOXES , METRIC_SANDBOX_LOADS , METRIC_TOTAL_WASM_SANDBOXES ,
2828} ;
2929
30+ // All the logic around when to restore is nicely encapsulated here,
31+ // so that it would be harder for a `WasmSandbox` to end up in an
32+ // un-restored state.
33+ mod backing_sandbox {
34+ use super :: * ;
35+ #[ derive( Debug ) ]
36+ pub ( super ) enum BackingSandbox {
37+ /// A sandbox which has a clean copy of the runtime in it
38+ Clean ( MultiUseSandbox ) ,
39+ /// A sandbox which has had a wasm component/module loaded into
40+ /// it, but has not yet run any code from that
41+ Loaded ( MultiUseSandbox ) ,
42+ /// A sandbox which came from a `LoadedWasmSandbox`, and
43+ /// therefore presumably has run user code
44+ Dirty ( MultiUseSandbox ) ,
45+ /// A non-existent sandbox, used as an internal implementation
46+ /// detail of a few methods.
47+ Missing ,
48+ }
49+ impl BackingSandbox {
50+ pub ( super ) fn clean ( & mut self , snapshot : Arc < Snapshot > ) -> Result < ( ) > {
51+ * self = match std:: mem:: replace ( self , BackingSandbox :: Missing ) {
52+ BackingSandbox :: Clean ( x) => BackingSandbox :: Clean ( x) ,
53+ BackingSandbox :: Loaded ( _) => {
54+ return Err ( new_error ! (
55+ "internal invariant violation: cleaning loaded backing sandbox"
56+ ) ) ;
57+ }
58+ BackingSandbox :: Dirty ( mut x) => {
59+ x. restore ( snapshot) ?;
60+ BackingSandbox :: Clean ( x)
61+ }
62+ BackingSandbox :: Missing => {
63+ return Err ( new_error ! (
64+ "internal invariant violation: cleaning missing backing sandbox"
65+ ) ) ;
66+ }
67+ } ;
68+ Ok ( ( ) )
69+ }
70+ pub ( super ) fn load_via_restore ( & mut self , snapshot : Arc < Snapshot > ) -> Result < ( ) > {
71+ * self = match std:: mem:: replace ( self , BackingSandbox :: Missing ) {
72+ BackingSandbox :: Clean ( mut x) | BackingSandbox :: Dirty ( mut x) => {
73+ x. restore ( snapshot) ?;
74+ BackingSandbox :: Loaded ( x)
75+ }
76+ BackingSandbox :: Loaded ( _) => {
77+ return Err ( new_error ! (
78+ "internal invariant violation: loading loaded backing sandbox"
79+ ) ) ;
80+ }
81+ BackingSandbox :: Missing => {
82+ return Err ( new_error ! (
83+ "internal invariant violation: loading missing backing sandbox"
84+ ) ) ;
85+ }
86+ } ;
87+ Ok ( ( ) )
88+ }
89+ pub ( super ) fn load_via_fn (
90+ & mut self ,
91+ load : impl FnOnce ( & mut MultiUseSandbox ) -> Result < ( ) > ,
92+ ) -> Result < ( ) > {
93+ * self = match std:: mem:: replace ( self , BackingSandbox :: Missing ) {
94+ BackingSandbox :: Clean ( mut x) => {
95+ load ( & mut x) ?;
96+ BackingSandbox :: Loaded ( x)
97+ }
98+ _ => {
99+ return Err ( new_error ! (
100+ "internal invariant violation: loading non-clean backing sandbox"
101+ ) ) ;
102+ }
103+ } ;
104+ Ok ( ( ) )
105+ }
106+ pub ( super ) fn get_loaded ( & mut self ) -> Result < MultiUseSandbox > {
107+ match std:: mem:: replace ( self , BackingSandbox :: Missing ) {
108+ BackingSandbox :: Loaded ( x) => Ok ( x) ,
109+ _ => Err ( new_error ! (
110+ "internal invariant violation: encountered non-loaded backing sandbox"
111+ ) ) ,
112+ }
113+ }
114+ }
115+
116+ #[ cfg( test) ]
117+ mod tests {
118+ use super :: super :: tests:: * ;
119+ use super :: * ;
120+ #[ test]
121+ fn test_backing_sandbox_use_marks_dirty ( ) -> Result < ( ) > {
122+ let mut sb = SandboxBuilder :: new ( ) . build ( ) ?;
123+ sb. register (
124+ "GetTimeSinceBootMicrosecond" ,
125+ get_time_since_boot_microsecond,
126+ ) ?;
127+ let sb = sb. load_runtime ( ) ?;
128+ let lb = sb. load_module ( get_test_file_path ( "RunWasm.aot" ) ?) ?;
129+ let sb = lb. unload_module ( ) ?;
130+ assert ! ( matches!( sb. inner, super :: BackingSandbox :: Dirty ( _) ) ) ;
131+ Ok ( ( ) )
132+ }
133+
134+ #[ test]
135+ fn test_dirty_backing_sandbox_cannot_be_loaded_via_fn ( ) -> Result < ( ) > {
136+ let mut sb = SandboxBuilder :: new ( ) . build ( ) ?;
137+ sb. register (
138+ "GetTimeSinceBootMicrosecond" ,
139+ get_time_since_boot_microsecond,
140+ ) ?;
141+ let sb = sb. load_runtime ( ) ?;
142+ let lb = sb. load_module ( get_test_file_path ( "RunWasm.aot" ) ?) ?;
143+ let mut sb = lb. unload_module ( ) ?;
144+ assert ! ( sb. inner. load_via_fn( |_| Ok ( ( ) ) ) . is_err( ) ) ;
145+ Ok ( ( ) )
146+ }
147+
148+ #[ test]
149+ fn test_dirty_backing_sandbox_cannot_be_gotten_as_loaded ( ) -> Result < ( ) > {
150+ let mut sb = SandboxBuilder :: new ( ) . build ( ) ?;
151+ sb. register (
152+ "GetTimeSinceBootMicrosecond" ,
153+ get_time_since_boot_microsecond,
154+ ) ?;
155+ let sb = sb. load_runtime ( ) ?;
156+ let lb = sb. load_module ( get_test_file_path ( "RunWasm.aot" ) ?) ?;
157+ let mut sb = lb. unload_module ( ) ?;
158+ assert ! ( sb. inner. get_loaded( ) . is_err( ) ) ;
159+ Ok ( ( ) )
160+ }
161+ }
162+ }
163+ use backing_sandbox:: * ;
164+
30165/// A sandbox with just the Wasm engine loaded into memory. `WasmSandbox`es
31166/// are not yet ready to execute guest functions.
32167///
@@ -37,7 +172,7 @@ pub struct WasmSandbox {
37172 // inner is an Option<MultiUseSandbox> as we need to take ownership of it
38173 // We implement drop on the WasmSandbox to decrement the count of Sandboxes when it is dropped
39174 // because of this we cannot implement drop without making inner an Option (alternatively we could make MultiUseSandbox Copy but that would introduce other issues)
40- inner : Option < MultiUseSandbox > ,
175+ inner : BackingSandbox ,
41176 // Snapshot of state of an initial WasmSandbox (runtime loaded, but no guest module code loaded).
42177 // Used for LoadedWasmSandbox to be able restore state back to WasmSandbox
43178 snapshot : Option < Arc < Snapshot > > ,
@@ -54,7 +189,7 @@ impl WasmSandbox {
54189 metrics:: gauge!( METRIC_ACTIVE_WASM_SANDBOXES ) . increment ( 1 ) ;
55190 metrics:: counter!( METRIC_TOTAL_WASM_SANDBOXES ) . increment ( 1 ) ;
56191 Ok ( WasmSandbox {
57- inner : Some ( inner) ,
192+ inner : BackingSandbox :: Clean ( inner) ,
58193 snapshot : Some ( snapshot) ,
59194 } )
60195 }
@@ -64,35 +199,49 @@ impl WasmSandbox {
64199 /// the snapshot has already been created in that case.
65200 /// Expects a snapshot of the state where wasm runtime is loaded, but no guest module code is loaded.
66201 pub ( super ) fn new_from_loaded (
67- mut loaded : MultiUseSandbox ,
202+ loaded : MultiUseSandbox ,
68203 snapshot : Arc < Snapshot > ,
69204 ) -> Result < Self > {
70- loaded. restore ( snapshot. clone ( ) ) ?;
71205 metrics:: gauge!( METRIC_ACTIVE_WASM_SANDBOXES ) . increment ( 1 ) ;
72206 metrics:: counter!( METRIC_TOTAL_WASM_SANDBOXES ) . increment ( 1 ) ;
73207 Ok ( WasmSandbox {
74- inner : Some ( loaded) ,
208+ inner : BackingSandbox :: Dirty ( loaded) ,
75209 snapshot : Some ( snapshot) ,
76210 } )
77211 }
78212
213+ fn clean_inner ( & mut self ) -> Result < ( ) > {
214+ let snapshot = self . snapshot . as_ref ( ) . ok_or ( new_error ! (
215+ "internal invariant violation: Snapshot is missing"
216+ ) ) ?;
217+ self . inner . clean ( snapshot. clone ( ) )
218+ }
219+
79220 /// Load a Wasm module at the given path into the sandbox and return a `LoadedWasmSandbox`
80221 /// able to execute code in the loaded Wasm Module.
81222 ///
82223 /// Before you can call guest functions in the sandbox, you must call
83224 /// this function and use the returned value to call guest functions.
84225 pub fn load_module ( mut self , file : impl AsRef < Path > ) -> Result < LoadedWasmSandbox > {
85- let inner = self
86- . inner
87- . as_mut ( )
88- . ok_or_else ( || new_error ! ( "WasmSandbox is None" ) ) ?;
89-
90- if let Ok ( len) = inner. map_file_cow ( file. as_ref ( ) , MAPPED_BINARY_VA , None ) {
91- inner. call :: < ( ) > ( "LoadWasmModulePhys" , ( MAPPED_BINARY_VA , len) ) ?;
92- } else {
93- let wasm_bytes = std:: fs:: read ( file) ?;
94- load_wasm_module_from_bytes ( inner, wasm_bytes) ?;
95- }
226+ self . clean_inner ( ) ?;
227+
228+ self . inner . load_via_fn ( |inner| {
229+ if let Ok ( len) = inner. map_file_cow ( file. as_ref ( ) , MAPPED_BINARY_VA , None ) {
230+ inner. call :: < ( ) > ( "LoadWasmModulePhys" , ( MAPPED_BINARY_VA , len) ) ?;
231+ } else {
232+ let wasm_bytes = std:: fs:: read ( file) ?;
233+ load_wasm_module_from_bytes ( inner, wasm_bytes) ?;
234+ }
235+ Ok ( ( ) )
236+ } ) ?;
237+
238+ self . finalize_module_load ( )
239+ }
240+
241+ /// Load a Wasm module by restoring a Hyperlight snapshot taken
242+ /// from a `LoadedWasmSandbox`.
243+ pub fn load_from_snapshot ( mut self , snapshot : Arc < Snapshot > ) -> Result < LoadedWasmSandbox > {
244+ self . inner . load_via_restore ( snapshot) ?;
96245
97246 self . finalize_module_load ( )
98247 }
@@ -114,24 +263,25 @@ impl WasmSandbox {
114263 base : * mut libc:: c_void ,
115264 len : usize ,
116265 ) -> Result < LoadedWasmSandbox > {
117- let inner = self
118- . inner
119- . as_mut ( )
120- . ok_or_else ( || new_error ! ( "WasmSandbox is None" ) ) ?;
121-
122- let guest_base: usize = MAPPED_BINARY_VA as usize ;
123- let rgn = MemoryRegion {
124- host_region : base as usize ..base. wrapping_add ( len) as usize ,
125- guest_region : guest_base..guest_base + len,
126- flags : MemoryRegionFlags :: READ | MemoryRegionFlags :: EXECUTE ,
127- region_type : MemoryRegionType :: Heap ,
128- } ;
129- if let Ok ( ( ) ) = unsafe { inner. map_region ( & rgn) } {
130- inner. call :: < ( ) > ( "LoadWasmModulePhys" , ( MAPPED_BINARY_VA , len as u64 ) ) ?;
131- } else {
132- let wasm_bytes = unsafe { std:: slice:: from_raw_parts ( base as * const u8 , len) . to_vec ( ) } ;
133- load_wasm_module_from_bytes ( inner, wasm_bytes) ?;
134- }
266+ self . clean_inner ( ) ?;
267+
268+ self . inner . load_via_fn ( |inner| {
269+ let guest_base: usize = MAPPED_BINARY_VA as usize ;
270+ let rgn = MemoryRegion {
271+ host_region : base as usize ..base. wrapping_add ( len) as usize ,
272+ guest_region : guest_base..guest_base + len,
273+ flags : MemoryRegionFlags :: READ | MemoryRegionFlags :: EXECUTE ,
274+ region_type : MemoryRegionType :: Heap ,
275+ } ;
276+ if let Ok ( ( ) ) = unsafe { inner. map_region ( & rgn) } {
277+ inner. call :: < ( ) > ( "LoadWasmModulePhys" , ( MAPPED_BINARY_VA , len as u64 ) ) ?;
278+ } else {
279+ let wasm_bytes =
280+ unsafe { std:: slice:: from_raw_parts ( base as * const u8 , len) . to_vec ( ) } ;
281+ load_wasm_module_from_bytes ( inner, wasm_bytes) ?;
282+ }
283+ Ok ( ( ) )
284+ } ) ?;
135285
136286 self . finalize_module_load ( )
137287 }
@@ -142,26 +292,26 @@ impl WasmSandbox {
142292 /// Before you can call guest functions in the sandbox, you must call
143293 /// this function and use the returned value to call guest functions.
144294 pub fn load_module_from_buffer ( mut self , buffer : & [ u8 ] ) -> Result < LoadedWasmSandbox > {
145- let inner = self
146- . inner
147- . as_mut ( )
148- . ok_or_else ( || new_error ! ( "WasmSandbox is None" ) ) ?;
295+ self . clean_inner ( ) ?;
149296
150297 // TODO: get rid of this clone
151- load_wasm_module_from_bytes ( inner, buffer. to_vec ( ) ) ?;
298+ self . inner
299+ . load_via_fn ( |inner| load_wasm_module_from_bytes ( inner, buffer. to_vec ( ) ) ) ?;
152300
153301 self . finalize_module_load ( )
154302 }
155303
156304 /// Helper function to finalize module loading and create LoadedWasmSandbox
157305 fn finalize_module_load ( mut self ) -> Result < LoadedWasmSandbox > {
158306 metrics:: counter!( METRIC_SANDBOX_LOADS ) . increment ( 1 ) ;
159- match ( self . inner . take ( ) , self . snapshot . take ( ) ) {
160- ( Some ( sandbox) , Some ( snapshot) ) => LoadedWasmSandbox :: new ( sandbox, snapshot) ,
161- _ => Err ( new_error ! (
162- "WasmSandbox/snapshot is None, cannot load module"
163- ) ) ,
164- }
307+
308+ let sandbox = self . inner . get_loaded ( ) ?;
309+
310+ let snapshot = self . snapshot . take ( ) . ok_or ( new_error ! (
311+ "internal invariant violation: Snapshot is missing"
312+ ) ) ?;
313+
314+ LoadedWasmSandbox :: new ( sandbox, snapshot)
165315 }
166316}
167317
@@ -199,15 +349,15 @@ mod tests {
199349 use hyperlight_host:: { HyperlightError , is_hypervisor_present} ;
200350
201351 use super :: * ;
202- use crate :: sandbox:: sandbox_builder:: SandboxBuilder ;
352+ pub ( super ) use crate :: sandbox:: sandbox_builder:: SandboxBuilder ;
203353
204354 #[ test]
205355 fn test_new_sandbox ( ) -> Result < ( ) > {
206356 let _sandbox = SandboxBuilder :: new ( ) . build ( ) ?;
207357 Ok ( ( ) )
208358 }
209359
210- fn get_time_since_boot_microsecond ( ) -> Result < i64 > {
360+ pub ( super ) fn get_time_since_boot_microsecond ( ) -> Result < i64 > {
211361 let res = std:: time:: SystemTime :: now ( )
212362 . duration_since ( std:: time:: SystemTime :: UNIX_EPOCH ) ?
213363 . as_micros ( ) ;
@@ -473,6 +623,46 @@ mod tests {
473623 }
474624 }
475625
626+ #[ test]
627+ fn test_load_from_snapshot ( ) {
628+ let mut sandbox = SandboxBuilder :: new ( ) . build ( ) . unwrap ( ) ;
629+ sandbox
630+ . register (
631+ "GetTimeSinceBootMicrosecond" ,
632+ get_time_since_boot_microsecond,
633+ )
634+ . unwrap ( ) ;
635+ let sb = sandbox. load_runtime ( ) . unwrap ( ) ;
636+
637+ let helloworld_wasm = get_test_file_path ( "HelloWorld.aot" ) . unwrap ( ) ;
638+ let runwasm_wasm = get_test_file_path ( "RunWasm.aot" ) . unwrap ( ) ;
639+
640+ // load one module, and make sure that a function in it
641+ // can be called
642+ let mut lb1 = sb. load_module ( helloworld_wasm) . unwrap ( ) ;
643+ let result: i32 = lb1
644+ . call_guest_function ( "HelloWorld" , "Message from Rust Test" . to_string ( ) )
645+ . unwrap ( ) ;
646+ assert_eq ! ( result, 0 ) ;
647+ let snapshot = lb1. snapshot ( ) . unwrap ( ) ;
648+
649+ // load another module, and make sure that a function in
650+ // it can be called
651+ let sb = lb1. unload_module ( ) . unwrap ( ) ;
652+ let mut lb2 = sb. load_module ( runwasm_wasm) . unwrap ( ) ;
653+ let result: i32 = lb2. call_guest_function ( "CalcFib" , 10i32 ) . unwrap ( ) ;
654+ assert_eq ! ( result, 55 ) ;
655+
656+ // reload the first module via snapshot, and make sure the
657+ // original function can be called again
658+ let sb = lb2. unload_module ( ) . unwrap ( ) ;
659+ let mut lb3 = sb. load_from_snapshot ( snapshot) . unwrap ( ) ;
660+ let result: i32 = lb3
661+ . call_guest_function ( "HelloWorld" , "Message from Rust Test" . to_string ( ) )
662+ . unwrap ( ) ;
663+ assert_eq ! ( result, 0 ) ;
664+ }
665+
476666 #[ test]
477667 fn test_load_module_buffer ( ) {
478668 let sandboxes = get_test_wasm_sandboxes ( ) . unwrap ( ) ;
@@ -496,7 +686,7 @@ mod tests {
496686 }
497687 }
498688
499- fn get_test_file_path ( filename : & str ) -> Result < String > {
689+ pub ( super ) fn get_test_file_path ( filename : & str ) -> Result < String > {
500690 #[ cfg( debug_assertions) ]
501691 let config = "debug" ;
502692 #[ cfg( not( debug_assertions) ) ]
0 commit comments