Skip to content

Commit 3c68454

Browse files
committed
feat(virtq): add reset API
Signed-off-by: Tomasz Andrzejak <andreiltd@gmail.com>
1 parent 528a5c0 commit 3c68454

4 files changed

Lines changed: 327 additions & 0 deletions

File tree

src/hyperlight_common/src/virtq/consumer.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,12 @@ impl<M: MemOps + Clone, N: Notifier> VirtqConsumer<M, N> {
441441

442442
Ok(Bytes::from(buf))
443443
}
444+
445+
/// Reset ring and inflight state to initial values.
446+
pub fn reset(&mut self) {
447+
self.inner.reset();
448+
self.inflight.fill(None);
449+
}
444450
}
445451

446452
/// Parse a descriptor chain into entry/completion buffer elements.
@@ -630,4 +636,48 @@ mod tests {
630636
assert_eq!(data.as_ref(), b"abc");
631637
consumer.complete(completion).unwrap();
632638
}
639+
640+
#[test]
641+
fn test_virtq_consumer_reset() {
642+
let ring = make_ring(16);
643+
let (mut producer, mut consumer, _notifier) = make_test_producer(&ring);
644+
645+
// Submit and poll (but do not complete)
646+
let se = producer.chain().completion(16).build().unwrap();
647+
producer.submit(se).unwrap();
648+
649+
let (_entry, completion) = consumer.poll(1024).unwrap().unwrap();
650+
assert!(consumer.inflight.iter().any(|s| s.is_some()));
651+
652+
// Complete first so we do not leak
653+
consumer.complete(completion).unwrap();
654+
655+
consumer.reset();
656+
657+
assert!(consumer.inflight.iter().all(|s| s.is_none()));
658+
assert_eq!(consumer.inner.num_inflight(), 0);
659+
}
660+
661+
#[test]
662+
fn test_virtq_consumer_reset_clears_inflight() {
663+
let ring = make_ring(16);
664+
let (mut producer, mut consumer, _notifier) = make_test_producer(&ring);
665+
666+
// Submit two entries and poll both
667+
let se1 = producer.chain().completion(16).build().unwrap();
668+
producer.submit(se1).unwrap();
669+
let se2 = producer.chain().completion(16).build().unwrap();
670+
producer.submit(se2).unwrap();
671+
672+
let (_e1, c1) = consumer.poll(1024).unwrap().unwrap();
673+
let (_e2, c2) = consumer.poll(1024).unwrap().unwrap();
674+
// Complete both before reset
675+
consumer.complete(c1).unwrap();
676+
consumer.complete(c2).unwrap();
677+
678+
consumer.reset();
679+
680+
assert!(consumer.inflight.iter().all(|s| s.is_none()));
681+
assert_eq!(consumer.inner.num_inflight(), 0);
682+
}
633683
}

src/hyperlight_common/src/virtq/pool.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,12 @@ impl<const N: usize> Slab<N> {
516516
pub const fn slot_size() -> usize {
517517
N
518518
}
519+
520+
/// Reset the slab to initial state which is all slots free.
521+
pub fn reset(&mut self) {
522+
self.used_slots.clear();
523+
self.last_free_run = None;
524+
}
519525
}
520526

521527
#[inline]
@@ -547,6 +553,13 @@ impl<const L: usize, const U: usize> BufferPool<L, U> {
547553
inner: inner.into(),
548554
})
549555
}
556+
557+
/// Reset the pool to initial state
558+
pub fn reset(&self) {
559+
let mut inner = self.inner.borrow_mut();
560+
inner.lower.reset();
561+
inner.upper.reset();
562+
}
550563
}
551564

552565
#[cfg(all(test, loom))]
@@ -1171,6 +1184,87 @@ mod tests {
11711184
slab.dealloc(a3).unwrap();
11721185
slab.dealloc(a4).unwrap();
11731186
}
1187+
1188+
#[test]
1189+
fn test_slab_reset_returns_to_initial_state() {
1190+
let mut slab = make_slab::<256>(4096);
1191+
let initial_free = slab.free_bytes();
1192+
let initial_cap = slab.capacity();
1193+
1194+
// Allocate some slots
1195+
let _a1 = slab.alloc(256).unwrap();
1196+
let _a2 = slab.alloc(512).unwrap();
1197+
assert!(slab.free_bytes() < initial_free);
1198+
1199+
slab.reset();
1200+
1201+
assert_eq!(slab.free_bytes(), initial_free);
1202+
assert_eq!(slab.capacity(), initial_cap);
1203+
assert!(slab.last_free_run.is_none());
1204+
assert_eq!(slab.used_slots.count_ones(..), 0);
1205+
1206+
// Should be able to allocate the full capacity again
1207+
let a = slab.alloc(initial_cap).unwrap();
1208+
assert_eq!(a.len, initial_cap);
1209+
}
1210+
1211+
#[test]
1212+
fn test_slab_reset_matches_new() {
1213+
let base = align_up(0x10000, 256) as u64;
1214+
let region = 4096;
1215+
1216+
let fresh = Slab::<256>::new(base, region).unwrap();
1217+
1218+
let mut used = Slab::<256>::new(base, region).unwrap();
1219+
let _a = used.alloc(256).unwrap();
1220+
let _b = used.alloc(1024).unwrap();
1221+
used.reset();
1222+
1223+
assert_eq!(used.free_bytes(), fresh.free_bytes());
1224+
assert_eq!(used.capacity(), fresh.capacity());
1225+
assert_eq!(
1226+
used.used_slots.count_ones(..),
1227+
fresh.used_slots.count_ones(..)
1228+
);
1229+
assert!(used.last_free_run.is_none());
1230+
assert!(fresh.last_free_run.is_none());
1231+
}
1232+
1233+
#[test]
1234+
fn test_buffer_pool_reset_returns_to_initial_state() {
1235+
let pool = make_pool::<256, 4096>(0x20000);
1236+
1237+
// Allocate from both tiers
1238+
let a1 = pool.inner.borrow_mut().alloc(128).unwrap();
1239+
let a2 = pool.inner.borrow_mut().alloc(8192).unwrap();
1240+
assert!(a1.len > 0);
1241+
assert!(a2.len > 0);
1242+
1243+
pool.reset();
1244+
1245+
let inner = pool.inner.borrow();
1246+
assert_eq!(inner.lower.used_slots.count_ones(..), 0);
1247+
assert_eq!(inner.upper.used_slots.count_ones(..), 0);
1248+
assert!(inner.lower.last_free_run.is_none());
1249+
assert!(inner.upper.last_free_run.is_none());
1250+
}
1251+
1252+
#[test]
1253+
fn test_buffer_pool_reset_allows_reallocation() {
1254+
let pool = make_pool::<256, 4096>(0x20000);
1255+
1256+
// Fill up some allocations
1257+
let mut allocs = Vec::new();
1258+
for _ in 0..5 {
1259+
allocs.push(pool.inner.borrow_mut().alloc(256).unwrap());
1260+
}
1261+
1262+
pool.reset();
1263+
1264+
// Should be able to allocate as if fresh
1265+
let a = pool.inner.borrow_mut().alloc(256).unwrap();
1266+
assert!(a.len > 0);
1267+
}
11741268
}
11751269

11761270
#[cfg(test)]

src/hyperlight_common/src/virtq/producer.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,13 @@ where
329329
}
330330
Ok(())
331331
}
332+
333+
/// Reset ring and inflight state to initial values.
334+
/// Does not reset the buffer pool; call pool.reset() separately if needed.
335+
pub fn reset(&mut self) {
336+
self.inner.reset();
337+
self.inflight.fill(None);
338+
}
332339
}
333340

334341
/// Builder for configuring a descriptor chain's buffer layout.
@@ -787,4 +794,45 @@ mod tests {
787794
assert_eq!(cqe.token, token);
788795
assert_eq!(&cqe.data[..], b"response data");
789796
}
797+
798+
#[test]
799+
fn test_virtq_producer_reset() {
800+
let ring = make_ring(16);
801+
let (mut producer, mut consumer, _notifier) = make_test_producer(&ring);
802+
803+
// Submit and complete a round trip
804+
let mut se = producer.chain().entry(32).completion(64).build().unwrap();
805+
se.write_all(b"hello").unwrap();
806+
producer.submit(se).unwrap();
807+
808+
let (entry, completion) = consumer.poll(1024).unwrap().unwrap();
809+
assert_eq!(entry.data().as_ref(), b"hello");
810+
consumer.complete(completion).unwrap();
811+
let _ = producer.poll().unwrap().unwrap();
812+
813+
// Now reset
814+
producer.reset();
815+
816+
// All inflight slots should be None
817+
assert!(producer.inflight.iter().all(|s| s.is_none()));
818+
// Ring state should be back to initial
819+
assert_eq!(producer.inner.num_free(), producer.inner.len());
820+
}
821+
822+
#[test]
823+
fn test_virtq_producer_reset_clears_inflight() {
824+
let ring = make_ring(16);
825+
let (mut producer, _consumer, _notifier) = make_test_producer(&ring);
826+
827+
// Submit without completing
828+
let se = producer.chain().completion(64).build().unwrap();
829+
producer.submit(se).unwrap();
830+
831+
assert!(producer.inflight.iter().any(|s| s.is_some()));
832+
833+
producer.reset();
834+
835+
assert!(producer.inflight.iter().all(|s| s.is_none()));
836+
assert_eq!(producer.inner.num_free(), producer.inner.len());
837+
}
790838
}

src/hyperlight_common/src/virtq/ring.rs

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ impl RingCursor {
369369
pub fn wrap(&self) -> bool {
370370
self.wrap
371371
}
372+
373+
/// Reset cursor to initial state.
374+
pub fn reset(&mut self) {
375+
self.head = 0;
376+
self.wrap = true;
377+
}
372378
}
373379

374380
/// Producer (driver) side of a packed virtqueue.
@@ -817,6 +823,18 @@ impl<M: MemOps> RingProducer<M> {
817823

818824
Ok(should_notify(evt, self.len() as u16, old, new))
819825
}
826+
827+
/// Reset to initial state matching a freshly zeroed ring.
828+
pub fn reset(&mut self) {
829+
let size = self.desc_table.len();
830+
self.avail_cursor.reset();
831+
self.used_cursor.reset();
832+
self.num_free = size;
833+
self.id_free.clear();
834+
self.id_free.extend(0..size as u16);
835+
self.id_num.iter_mut().for_each(|n| *n = 0);
836+
self.event_flags_shadow = EventFlags::ENABLE;
837+
}
820838
}
821839

822840
/// Consumer (device) side of a packed virtqueue.
@@ -1164,6 +1182,16 @@ impl<M: MemOps> RingConsumer<M> {
11641182

11651183
Ok(should_notify(evt, self.desc_table.len() as u16, old, new))
11661184
}
1185+
1186+
/// Reset to initial state matching a freshly zeroed ring.
1187+
/// Does not reallocate internal buffers.
1188+
pub fn reset(&mut self) {
1189+
self.avail_cursor.reset();
1190+
self.used_cursor.reset();
1191+
self.id_num.iter_mut().for_each(|n| *n = 0);
1192+
self.num_inflight = 0;
1193+
self.event_flags_shadow = EventFlags::ENABLE;
1194+
}
11671195
}
11681196

11691197
/// Common packed-ring notification decision:
@@ -2974,6 +3002,113 @@ pub(crate) mod tests {
29743002
let (_, _) = consumer.poll_available().unwrap();
29753003
}
29763004
}
3005+
3006+
#[test]
3007+
fn test_ring_cursor_reset() {
3008+
let mut cursor = RingCursor::new(16);
3009+
cursor.advance_by(5);
3010+
assert_eq!(cursor.head(), 5);
3011+
3012+
cursor.reset();
3013+
assert_eq!(cursor, RingCursor::new(16));
3014+
assert_eq!(cursor.head(), 0);
3015+
assert!(cursor.wrap());
3016+
}
3017+
3018+
#[test]
3019+
fn test_ring_cursor_reset_after_wrap() {
3020+
let mut cursor = RingCursor::new(4);
3021+
// Advance past the wrap point
3022+
cursor.advance_by(5);
3023+
assert_eq!(cursor.head(), 1);
3024+
assert!(!cursor.wrap());
3025+
3026+
cursor.reset();
3027+
assert_eq!(cursor.head(), 0);
3028+
assert!(cursor.wrap());
3029+
}
3030+
3031+
#[test]
3032+
fn test_ring_producer_reset_matches_new() {
3033+
let ring = make_ring(8);
3034+
let fresh = make_producer(&ring);
3035+
3036+
let mut used = make_producer(&ring);
3037+
// Mutate state
3038+
used.submit_one(0x1000, 64, false).unwrap();
3039+
used.submit_one(0x2000, 128, true).unwrap();
3040+
3041+
used.reset();
3042+
3043+
assert_eq!(used.avail_cursor, fresh.avail_cursor);
3044+
assert_eq!(used.used_cursor, fresh.used_cursor);
3045+
assert_eq!(used.num_free, fresh.num_free);
3046+
assert_eq!(used.id_free.len(), fresh.id_free.len());
3047+
assert_eq!(used.id_num.as_slice(), fresh.id_num.as_slice());
3048+
assert_eq!(used.event_flags_shadow, fresh.event_flags_shadow);
3049+
}
3050+
3051+
#[test]
3052+
fn test_ring_producer_reset_id_free_complete() {
3053+
let ring = make_ring(8);
3054+
let mut producer = make_producer(&ring);
3055+
3056+
// Submit and consume several descriptors
3057+
for i in 0..4u64 {
3058+
producer.submit_one(0x1000 + i * 0x100, 64, false).unwrap();
3059+
}
3060+
assert_eq!(producer.num_free, 4);
3061+
3062+
producer.reset();
3063+
3064+
assert_eq!(producer.num_free, 8);
3065+
assert_eq!(producer.id_free.len(), 8);
3066+
// All IDs 0..8 should be present
3067+
for id in 0..8u16 {
3068+
assert!(producer.id_free.contains(&id));
3069+
}
3070+
}
3071+
3072+
#[test]
3073+
fn test_ring_consumer_reset_matches_new() {
3074+
let ring = make_ring(8);
3075+
let fresh = make_consumer(&ring);
3076+
3077+
let mut used = make_consumer(&ring);
3078+
3079+
// Submit from producer side so consumer has something to poll
3080+
let mut producer = make_producer(&ring);
3081+
producer.submit_one(0x1000, 64, false).unwrap();
3082+
3083+
// Consumer polls the available descriptor
3084+
let (id, _chain) = used.poll_available().unwrap();
3085+
used.submit_used(id, 64).unwrap();
3086+
3087+
used.reset();
3088+
3089+
assert_eq!(used.avail_cursor, fresh.avail_cursor);
3090+
assert_eq!(used.used_cursor, fresh.used_cursor);
3091+
assert_eq!(used.id_num.as_slice(), fresh.id_num.as_slice());
3092+
assert_eq!(used.num_inflight, fresh.num_inflight);
3093+
assert_eq!(used.event_flags_shadow, fresh.event_flags_shadow);
3094+
}
3095+
3096+
#[test]
3097+
fn test_ring_consumer_reset_clears_inflight() {
3098+
let ring = make_ring(8);
3099+
let mut producer = make_producer(&ring);
3100+
let mut consumer = make_consumer(&ring);
3101+
3102+
// Submit and poll two items (consume but do not complete)
3103+
producer.submit_one(0x1000, 64, false).unwrap();
3104+
producer.submit_one(0x2000, 64, false).unwrap();
3105+
let _ = consumer.poll_available().unwrap();
3106+
let _ = consumer.poll_available().unwrap();
3107+
assert_eq!(consumer.num_inflight, 2);
3108+
3109+
consumer.reset();
3110+
assert_eq!(consumer.num_inflight, 0);
3111+
}
29773112
}
29783113

29793114
#[cfg(test)]

0 commit comments

Comments
 (0)