Skip to content

Commit 54cf49c

Browse files
committed
Change sandbox restore to be a bit more judicious about when to restore
When we are loading from a snapshot, we can get away with just a restore and never going through a wasmtime load. Signed-off-by: Lucy Menon <168595099+syntactically@users.noreply.github.com>
1 parent d8cc07f commit 54cf49c

2 files changed

Lines changed: 245 additions & 52 deletions

File tree

src/hyperlight_wasm/src/sandbox/loaded_wasm_sandbox.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,14 @@ impl LoadedWasmSandbox {
112112
}
113113
}
114114

115-
/// Unload the wasm module and return a `WasmSandbox` that can be used to load another module.
115+
/// Unload the wasm module and return a `WasmSandbox` that can be
116+
/// used to load another module.
116117
///
117-
/// This method internally calls [`restore()`](Self::restore) to reset the sandbox to its
118-
/// pre-module state, which also clears any poisoned state. This means `unload_module()`
119-
/// can be called on a poisoned sandbox to recover it.
118+
/// This method defers calling [`restore()`](Self::restore) to
119+
/// reset the sandbox to its pre-module state until a new module
120+
/// is loaded. However, the sandbox will always be restored when a
121+
/// new module is loaded, so a poisoned sandbox can be recovered
122+
/// by unloading and reloading a module.
120123
pub fn unload_module(mut self) -> Result<WasmSandbox> {
121124
let sandbox = self
122125
.inner

src/hyperlight_wasm/src/sandbox/wasm_sandbox.rs

Lines changed: 238 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,141 @@ use crate::sandbox::metrics::{
2727
METRIC_ACTIVE_WASM_SANDBOXES, METRIC_SANDBOX_LOADS, METRIC_TOTAL_WASM_SANDBOXES,
2828
};
2929

30+
// All the logic around when to restore is nicely encapsulated here,
31+
// so that it would be harder for a `WasmSandbox` to end up in an
32+
// un-restored state.
33+
mod backing_sandbox {
34+
use super::*;
35+
#[derive(Debug)]
36+
pub(super) enum BackingSandbox {
37+
/// A sandbox which has a clean copy of the runtime in it
38+
Clean(MultiUseSandbox),
39+
/// A sandbox which has had a wasm component/module loaded into
40+
/// it, but has not yet run any code from that
41+
Loaded(MultiUseSandbox),
42+
/// A sandbox which came from a `LoadedWasmSandbox`, and
43+
/// therefore presumably has run user code
44+
Dirty(MultiUseSandbox),
45+
/// A non-existent sandbox, used as an internal implementation
46+
/// detail of a few methods.
47+
Missing,
48+
}
49+
impl BackingSandbox {
50+
pub(super) fn clean(&mut self, snapshot: Arc<Snapshot>) -> Result<()> {
51+
*self = match std::mem::replace(self, BackingSandbox::Missing) {
52+
BackingSandbox::Clean(x) => BackingSandbox::Clean(x),
53+
BackingSandbox::Loaded(_) => {
54+
return Err(new_error!(
55+
"internal invariant violation: cleaning loaded backing sandbox"
56+
));
57+
}
58+
BackingSandbox::Dirty(mut x) => {
59+
x.restore(snapshot)?;
60+
BackingSandbox::Clean(x)
61+
}
62+
BackingSandbox::Missing => {
63+
return Err(new_error!(
64+
"internal invariant violation: cleaning missing backing sandbox"
65+
));
66+
}
67+
};
68+
Ok(())
69+
}
70+
pub(super) fn load_via_restore(&mut self, snapshot: Arc<Snapshot>) -> Result<()> {
71+
*self = match std::mem::replace(self, BackingSandbox::Missing) {
72+
BackingSandbox::Clean(mut x) | BackingSandbox::Dirty(mut x) => {
73+
x.restore(snapshot)?;
74+
BackingSandbox::Loaded(x)
75+
}
76+
BackingSandbox::Loaded(_) => {
77+
return Err(new_error!(
78+
"internal invariant violation: loading loaded backing sandbox"
79+
));
80+
}
81+
BackingSandbox::Missing => {
82+
return Err(new_error!(
83+
"internal invariant violation: loading missing backing sandbox"
84+
));
85+
}
86+
};
87+
Ok(())
88+
}
89+
pub(super) fn load_via_fn(
90+
&mut self,
91+
load: impl FnOnce(&mut MultiUseSandbox) -> Result<()>,
92+
) -> Result<()> {
93+
*self = match std::mem::replace(self, BackingSandbox::Missing) {
94+
BackingSandbox::Clean(mut x) => {
95+
load(&mut x)?;
96+
BackingSandbox::Loaded(x)
97+
}
98+
_ => {
99+
return Err(new_error!(
100+
"internal invariant violation: loading non-clean backing sandbox"
101+
));
102+
}
103+
};
104+
Ok(())
105+
}
106+
pub(super) fn get_loaded(&mut self) -> Result<MultiUseSandbox> {
107+
match std::mem::replace(self, BackingSandbox::Missing) {
108+
BackingSandbox::Loaded(x) => Ok(x),
109+
_ => Err(new_error!(
110+
"internal invariant violation: encountered non-loaded backing sandbox"
111+
)),
112+
}
113+
}
114+
}
115+
116+
#[cfg(test)]
117+
mod tests {
118+
use super::super::tests::*;
119+
use super::*;
120+
#[test]
121+
fn test_backing_sandbox_use_marks_dirty() -> Result<()> {
122+
let mut sb = SandboxBuilder::new().build()?;
123+
sb.register(
124+
"GetTimeSinceBootMicrosecond",
125+
get_time_since_boot_microsecond,
126+
)?;
127+
let sb = sb.load_runtime()?;
128+
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
129+
let sb = lb.unload_module()?;
130+
assert!(matches!(sb.inner, super::BackingSandbox::Dirty(_)));
131+
Ok(())
132+
}
133+
134+
#[test]
135+
fn test_dirty_backing_sandbox_cannot_be_loaded_via_fn() -> Result<()> {
136+
let mut sb = SandboxBuilder::new().build()?;
137+
sb.register(
138+
"GetTimeSinceBootMicrosecond",
139+
get_time_since_boot_microsecond,
140+
)?;
141+
let sb = sb.load_runtime()?;
142+
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
143+
let mut sb = lb.unload_module()?;
144+
assert!(sb.inner.load_via_fn(|_| Ok(())).is_err());
145+
Ok(())
146+
}
147+
148+
#[test]
149+
fn test_dirty_backing_sandbox_cannot_be_gotten_as_loaded() -> Result<()> {
150+
let mut sb = SandboxBuilder::new().build()?;
151+
sb.register(
152+
"GetTimeSinceBootMicrosecond",
153+
get_time_since_boot_microsecond,
154+
)?;
155+
let sb = sb.load_runtime()?;
156+
let lb = sb.load_module(get_test_file_path("RunWasm.aot")?)?;
157+
let mut sb = lb.unload_module()?;
158+
assert!(sb.inner.get_loaded().is_err());
159+
Ok(())
160+
}
161+
}
162+
}
163+
use backing_sandbox::*;
164+
30165
/// A sandbox with just the Wasm engine loaded into memory. `WasmSandbox`es
31166
/// are not yet ready to execute guest functions.
32167
///
@@ -37,7 +172,7 @@ pub struct WasmSandbox {
37172
// inner is an Option<MultiUseSandbox> as we need to take ownership of it
38173
// We implement drop on the WasmSandbox to decrement the count of Sandboxes when it is dropped
39174
// because of this we cannot implement drop without making inner an Option (alternatively we could make MultiUseSandbox Copy but that would introduce other issues)
40-
inner: Option<MultiUseSandbox>,
175+
inner: BackingSandbox,
41176
// Snapshot of state of an initial WasmSandbox (runtime loaded, but no guest module code loaded).
42177
// Used for LoadedWasmSandbox to be able restore state back to WasmSandbox
43178
snapshot: Option<Arc<Snapshot>>,
@@ -54,7 +189,7 @@ impl WasmSandbox {
54189
metrics::gauge!(METRIC_ACTIVE_WASM_SANDBOXES).increment(1);
55190
metrics::counter!(METRIC_TOTAL_WASM_SANDBOXES).increment(1);
56191
Ok(WasmSandbox {
57-
inner: Some(inner),
192+
inner: BackingSandbox::Clean(inner),
58193
snapshot: Some(snapshot),
59194
})
60195
}
@@ -64,35 +199,49 @@ impl WasmSandbox {
64199
/// the snapshot has already been created in that case.
65200
/// Expects a snapshot of the state where wasm runtime is loaded, but no guest module code is loaded.
66201
pub(super) fn new_from_loaded(
67-
mut loaded: MultiUseSandbox,
202+
loaded: MultiUseSandbox,
68203
snapshot: Arc<Snapshot>,
69204
) -> Result<Self> {
70-
loaded.restore(snapshot.clone())?;
71205
metrics::gauge!(METRIC_ACTIVE_WASM_SANDBOXES).increment(1);
72206
metrics::counter!(METRIC_TOTAL_WASM_SANDBOXES).increment(1);
73207
Ok(WasmSandbox {
74-
inner: Some(loaded),
208+
inner: BackingSandbox::Dirty(loaded),
75209
snapshot: Some(snapshot),
76210
})
77211
}
78212

213+
fn clean_inner(&mut self) -> Result<()> {
214+
let snapshot = self.snapshot.as_ref().ok_or(new_error!(
215+
"internal invariant violation: Snapshot is missing"
216+
))?;
217+
self.inner.clean(snapshot.clone())
218+
}
219+
79220
/// Load a Wasm module at the given path into the sandbox and return a `LoadedWasmSandbox`
80221
/// able to execute code in the loaded Wasm Module.
81222
///
82223
/// Before you can call guest functions in the sandbox, you must call
83224
/// this function and use the returned value to call guest functions.
84225
pub fn load_module(mut self, file: impl AsRef<Path>) -> Result<LoadedWasmSandbox> {
85-
let inner = self
86-
.inner
87-
.as_mut()
88-
.ok_or_else(|| new_error!("WasmSandbox is None"))?;
89-
90-
if let Ok(len) = inner.map_file_cow(file.as_ref(), MAPPED_BINARY_VA, None) {
91-
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len))?;
92-
} else {
93-
let wasm_bytes = std::fs::read(file)?;
94-
load_wasm_module_from_bytes(inner, wasm_bytes)?;
95-
}
226+
self.clean_inner()?;
227+
228+
self.inner.load_via_fn(|inner| {
229+
if let Ok(len) = inner.map_file_cow(file.as_ref(), MAPPED_BINARY_VA, None) {
230+
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len))?;
231+
} else {
232+
let wasm_bytes = std::fs::read(file)?;
233+
load_wasm_module_from_bytes(inner, wasm_bytes)?;
234+
}
235+
Ok(())
236+
})?;
237+
238+
self.finalize_module_load()
239+
}
240+
241+
/// Load a Wasm module by restoring a Hyperlight snapshot taken
242+
/// from a `LoadedWasmSandbox`.
243+
pub fn load_from_snapshot(mut self, snapshot: Arc<Snapshot>) -> Result<LoadedWasmSandbox> {
244+
self.inner.load_via_restore(snapshot)?;
96245

97246
self.finalize_module_load()
98247
}
@@ -114,24 +263,25 @@ impl WasmSandbox {
114263
base: *mut libc::c_void,
115264
len: usize,
116265
) -> Result<LoadedWasmSandbox> {
117-
let inner = self
118-
.inner
119-
.as_mut()
120-
.ok_or_else(|| new_error!("WasmSandbox is None"))?;
121-
122-
let guest_base: usize = MAPPED_BINARY_VA as usize;
123-
let rgn = MemoryRegion {
124-
host_region: base as usize..base.wrapping_add(len) as usize,
125-
guest_region: guest_base..guest_base + len,
126-
flags: MemoryRegionFlags::READ | MemoryRegionFlags::EXECUTE,
127-
region_type: MemoryRegionType::Heap,
128-
};
129-
if let Ok(()) = unsafe { inner.map_region(&rgn) } {
130-
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len as u64))?;
131-
} else {
132-
let wasm_bytes = unsafe { std::slice::from_raw_parts(base as *const u8, len).to_vec() };
133-
load_wasm_module_from_bytes(inner, wasm_bytes)?;
134-
}
266+
self.clean_inner()?;
267+
268+
self.inner.load_via_fn(|inner| {
269+
let guest_base: usize = MAPPED_BINARY_VA as usize;
270+
let rgn = MemoryRegion {
271+
host_region: base as usize..base.wrapping_add(len) as usize,
272+
guest_region: guest_base..guest_base + len,
273+
flags: MemoryRegionFlags::READ | MemoryRegionFlags::EXECUTE,
274+
region_type: MemoryRegionType::Heap,
275+
};
276+
if let Ok(()) = unsafe { inner.map_region(&rgn) } {
277+
inner.call::<()>("LoadWasmModulePhys", (MAPPED_BINARY_VA, len as u64))?;
278+
} else {
279+
let wasm_bytes =
280+
unsafe { std::slice::from_raw_parts(base as *const u8, len).to_vec() };
281+
load_wasm_module_from_bytes(inner, wasm_bytes)?;
282+
}
283+
Ok(())
284+
})?;
135285

136286
self.finalize_module_load()
137287
}
@@ -142,26 +292,26 @@ impl WasmSandbox {
142292
/// Before you can call guest functions in the sandbox, you must call
143293
/// this function and use the returned value to call guest functions.
144294
pub fn load_module_from_buffer(mut self, buffer: &[u8]) -> Result<LoadedWasmSandbox> {
145-
let inner = self
146-
.inner
147-
.as_mut()
148-
.ok_or_else(|| new_error!("WasmSandbox is None"))?;
295+
self.clean_inner()?;
149296

150297
// TODO: get rid of this clone
151-
load_wasm_module_from_bytes(inner, buffer.to_vec())?;
298+
self.inner
299+
.load_via_fn(|inner| load_wasm_module_from_bytes(inner, buffer.to_vec()))?;
152300

153301
self.finalize_module_load()
154302
}
155303

156304
/// Helper function to finalize module loading and create LoadedWasmSandbox
157305
fn finalize_module_load(mut self) -> Result<LoadedWasmSandbox> {
158306
metrics::counter!(METRIC_SANDBOX_LOADS).increment(1);
159-
match (self.inner.take(), self.snapshot.take()) {
160-
(Some(sandbox), Some(snapshot)) => LoadedWasmSandbox::new(sandbox, snapshot),
161-
_ => Err(new_error!(
162-
"WasmSandbox/snapshot is None, cannot load module"
163-
)),
164-
}
307+
308+
let sandbox = self.inner.get_loaded()?;
309+
310+
let snapshot = self.snapshot.take().ok_or(new_error!(
311+
"internal invariant violation: Snapshot is missing"
312+
))?;
313+
314+
LoadedWasmSandbox::new(sandbox, snapshot)
165315
}
166316
}
167317

@@ -199,15 +349,15 @@ mod tests {
199349
use hyperlight_host::{HyperlightError, is_hypervisor_present};
200350

201351
use super::*;
202-
use crate::sandbox::sandbox_builder::SandboxBuilder;
352+
pub(super) use crate::sandbox::sandbox_builder::SandboxBuilder;
203353

204354
#[test]
205355
fn test_new_sandbox() -> Result<()> {
206356
let _sandbox = SandboxBuilder::new().build()?;
207357
Ok(())
208358
}
209359

210-
fn get_time_since_boot_microsecond() -> Result<i64> {
360+
pub(super) fn get_time_since_boot_microsecond() -> Result<i64> {
211361
let res = std::time::SystemTime::now()
212362
.duration_since(std::time::SystemTime::UNIX_EPOCH)?
213363
.as_micros();
@@ -473,6 +623,46 @@ mod tests {
473623
}
474624
}
475625

626+
#[test]
627+
fn test_load_from_snapshot() {
628+
let mut sandbox = SandboxBuilder::new().build().unwrap();
629+
sandbox
630+
.register(
631+
"GetTimeSinceBootMicrosecond",
632+
get_time_since_boot_microsecond,
633+
)
634+
.unwrap();
635+
let sb = sandbox.load_runtime().unwrap();
636+
637+
let helloworld_wasm = get_test_file_path("HelloWorld.aot").unwrap();
638+
let runwasm_wasm = get_test_file_path("RunWasm.aot").unwrap();
639+
640+
// load one module, and make sure that a function in it
641+
// can be called
642+
let mut lb1 = sb.load_module(helloworld_wasm).unwrap();
643+
let result: i32 = lb1
644+
.call_guest_function("HelloWorld", "Message from Rust Test".to_string())
645+
.unwrap();
646+
assert_eq!(result, 0);
647+
let snapshot = lb1.snapshot().unwrap();
648+
649+
// load another module, and make sure that a function in
650+
// it can be called
651+
let sb = lb1.unload_module().unwrap();
652+
let mut lb2 = sb.load_module(runwasm_wasm).unwrap();
653+
let result: i32 = lb2.call_guest_function("CalcFib", 10i32).unwrap();
654+
assert_eq!(result, 55);
655+
656+
// reload the first module via snapshot, and make sure the
657+
// original function can be called again
658+
let sb = lb2.unload_module().unwrap();
659+
let mut lb3 = sb.load_from_snapshot(snapshot).unwrap();
660+
let result: i32 = lb3
661+
.call_guest_function("HelloWorld", "Message from Rust Test".to_string())
662+
.unwrap();
663+
assert_eq!(result, 0);
664+
}
665+
476666
#[test]
477667
fn test_load_module_buffer() {
478668
let sandboxes = get_test_wasm_sandboxes().unwrap();
@@ -496,7 +686,7 @@ mod tests {
496686
}
497687
}
498688

499-
fn get_test_file_path(filename: &str) -> Result<String> {
689+
pub(super) fn get_test_file_path(filename: &str) -> Result<String> {
500690
#[cfg(debug_assertions)]
501691
let config = "debug";
502692
#[cfg(not(debug_assertions))]

0 commit comments

Comments
 (0)