Skip to content

Commit c4f4c64

Browse files
committed
CABI: improve and add cooperative thread built-ins
1 parent 39ae5c2 commit c4f4c64

2 files changed

Lines changed: 243 additions & 48 deletions

File tree

design/mvp/canonical-abi/definitions.py

Lines changed: 98 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
class Trap(BaseException): pass
1818
class CoreWebAssemblyException(BaseException): pass
19+
class ThreadExit(BaseException): pass
1920

2021
def trap():
2122
raise Trap()
@@ -294,7 +295,7 @@ class ComponentInstance:
294295
handles: Table[ResourceHandle | Waitable | WaitableSet | ErrorContext]
295296
threads: Table[Thread]
296297
may_leave: bool
297-
may_block: bool
298+
sync_before_return: bool
298299
backpressure: int
299300
exclusive: Optional[Task]
300301
num_waiting_to_enter: int
@@ -306,11 +307,17 @@ def __init__(self, store, parent = None):
306307
self.handles = Table()
307308
self.threads = Table()
308309
self.may_leave = True
309-
self.may_block = True
310+
self.sync_before_return = False
310311
self.backpressure = 0
311312
self.exclusive = None
312313
self.num_waiting_to_enter = 0
313314

315+
def ready_threads(self) -> list[Thread]:
316+
return [t for t in self.threads.array if t and t.waiting() and t.ready()]
317+
318+
def may_block(self):
319+
return not self.sync_before_return or len(self.ready_threads()) > 0
320+
314321
def reflexive_ancestors(self) -> set[ComponentInstance]:
315322
s = set()
316323
inst = self
@@ -487,7 +494,10 @@ def ready(self):
487494
def __init__(self, task, thread_func):
488495
def cont_func(cancelled):
489496
assert(self.running() and not cancelled)
490-
thread_func()
497+
try:
498+
thread_func()
499+
except ThreadExit:
500+
pass
491501
return None
492502
self.cont = cont_new(cont_func)
493503
self.ready_func = None
@@ -497,7 +507,7 @@ def cont_func(cancelled):
497507
self.storage = [0,0]
498508
assert(self.suspended())
499509

500-
def resume_later(self):
510+
def unsuspend(self):
501511
assert(self.suspended())
502512
self.ready_func = lambda: True
503513
self.task.inst.store.waiting.append(self)
@@ -507,18 +517,25 @@ def resume(self, cancelled):
507517
assert(not self.running() and (self.cancellable or not cancelled))
508518
if self.waiting():
509519
assert(cancelled or self.ready())
510-
self.ready_func = None
511-
self.task.inst.store.waiting.remove(self)
520+
self.stop_waiting()
512521
thread = self
513522
while thread is not None:
514523
cont = thread.cont
515524
thread.cont = None
516525
(thread.cont, switch_to) = resume(cont, cancelled, thread)
526+
if switch_to is None and self.task.inst.sync_before_return:
527+
switch_to = random.choice(self.task.inst.ready_threads())
528+
switch_to.stop_waiting()
517529
thread = switch_to
518530
cancelled = Cancelled.FALSE
519531

532+
def stop_waiting(self):
533+
assert(self.waiting())
534+
self.ready_func = None
535+
self.task.inst.store.waiting.remove(self)
536+
520537
def suspend(self, cancellable) -> Cancelled:
521-
assert(self.running() and self.task.inst.may_block)
538+
assert(self.running() and self.task.inst.may_block())
522539
if self.task.deliver_pending_cancel(cancellable):
523540
return Cancelled.TRUE
524541
self.cancellable = cancellable
@@ -527,7 +544,7 @@ def suspend(self, cancellable) -> Cancelled:
527544
return cancelled
528545

529546
def wait_until(self, ready_func, cancellable = False) -> Cancelled:
530-
assert(self.running() and self.task.inst.may_block)
547+
assert(self.running() and self.task.inst.may_block())
531548
if self.task.deliver_pending_cancel(cancellable):
532549
return Cancelled.TRUE
533550
if ready_func() and not DETERMINISTIC_PROFILE and random.randint(0,1):
@@ -538,7 +555,7 @@ def wait_until(self, ready_func, cancellable = False) -> Cancelled:
538555

539556
def yield_until(self, ready_func, cancellable) -> Cancelled:
540557
assert(self.running())
541-
if self.task.inst.may_block:
558+
if self.task.inst.may_block():
542559
return self.wait_until(ready_func, cancellable)
543560
else:
544561
assert(ready_func())
@@ -547,7 +564,7 @@ def yield_until(self, ready_func, cancellable) -> Cancelled:
547564
def yield_(self, cancellable) -> Cancelled:
548565
return self.yield_until(lambda: True, cancellable)
549566

550-
def switch_to(self, cancellable, other: Thread) -> Cancelled:
567+
def suspend_to_suspended(self, cancellable, other: Thread) -> ResumeArg:
551568
assert(self.running() and other.suspended())
552569
if self.task.deliver_pending_cancel(cancellable):
553570
return Cancelled.TRUE
@@ -556,11 +573,31 @@ def switch_to(self, cancellable, other: Thread) -> Cancelled:
556573
assert(self.running() and (cancellable or not cancelled))
557574
return cancelled
558575

559-
def yield_to(self, cancellable, other: Thread) -> Cancelled:
576+
def yield_to_suspended(self, cancellable, other: Thread) -> ResumeArg:
560577
assert(self.running() and other.suspended())
561578
self.ready_func = lambda: True
562579
self.task.inst.store.waiting.append(self)
563-
return self.switch_to(cancellable, other)
580+
return self.suspend_to_suspended(cancellable, other)
581+
582+
def suspend_then_promote(self, cancellable, other: Thread) -> ResumeArg:
583+
assert(self.running())
584+
if other.waiting() and other.ready():
585+
other.stop_waiting()
586+
return self.suspend_to_suspended(cancellable, other)
587+
else:
588+
return self.suspend(cancellable)
589+
590+
def yield_then_promote(self, cancellable, other: Thread) -> ResumeArg:
591+
assert(self.running())
592+
if other.waiting() and other.ready():
593+
other.stop_waiting()
594+
return self.yield_to_suspended(cancellable, other)
595+
else:
596+
return self.yield_(cancellable)
597+
598+
def exit(self):
599+
assert(self.running() and self.task.inst.may_block())
600+
raise ThreadExit()
564601

565602
#### Waitable State
566603

@@ -701,8 +738,8 @@ def has_backpressure():
701738
assert(self.inst.exclusive is None)
702739
self.inst.exclusive = self
703740
else:
704-
assert(self.inst.may_block)
705-
self.inst.may_block = False
741+
assert(not self.inst.sync_before_return)
742+
self.inst.sync_before_return = True
706743
self.register_thread(thread)
707744
return True
708745

@@ -753,8 +790,8 @@ def return_(self, result):
753790
trap_if(self.state == Task.State.RESOLVED)
754791
trap_if(self.num_borrows > 0)
755792
if not self.ft.async_:
756-
assert(not self.inst.may_block)
757-
self.inst.may_block = True
793+
assert(self.inst.sync_before_return)
794+
self.inst.sync_before_return = False
758795
assert(result is not None)
759796
self.on_resolve(result)
760797
self.state = Task.State.RESOLVED
@@ -2096,7 +2133,7 @@ def thread_func():
20962133
else:
20972134
event = (EventCode.NONE, 0, 0)
20982135
case CallbackCode.WAIT:
2099-
trap_if(not inst.may_block)
2136+
trap_if(not inst.may_block())
21002137
wset = inst.handles.get(si)
21012138
trap_if(not isinstance(wset, WaitableSet))
21022139
event = wset.wait_until(lambda: not inst.exclusive, cancellable = True)
@@ -2140,7 +2177,7 @@ def call_and_trap_on_throw(callee, args):
21402177
def canon_lower(opts, ft, callee: FuncInst, flat_args):
21412178
thread = current_thread()
21422179
trap_if(not thread.task.inst.may_leave)
2143-
trap_if(not thread.task.inst.may_block and ft.async_ and not opts.async_)
2180+
trap_if(not thread.task.inst.may_block() and ft.async_ and not opts.async_)
21442181

21452182
subtask = Subtask()
21462183
cx = LiftLowerContext(opts, thread.task.inst, subtask)
@@ -2328,7 +2365,7 @@ def canon_waitable_set_new():
23282365
def canon_waitable_set_wait(cancellable, mem, si, ptr):
23292366
inst = current_thread().task.inst
23302367
trap_if(not inst.may_leave)
2331-
trap_if(not inst.may_block)
2368+
trap_if(not inst.may_block())
23322369
wset = inst.handles.get(si)
23332370
trap_if(not isinstance(wset, WaitableSet))
23342371
event = wset.wait(cancellable)
@@ -2383,7 +2420,7 @@ def canon_waitable_join(wi, si):
23832420
def canon_subtask_cancel(async_, i):
23842421
thread = current_thread()
23852422
trap_if(not thread.task.inst.may_leave)
2386-
trap_if(not thread.task.inst.may_block and not async_)
2423+
trap_if(not thread.task.inst.may_block() and not async_)
23872424
subtask = thread.task.inst.handles.get(i)
23882425
trap_if(not isinstance(subtask, Subtask))
23892426
trap_if(subtask.resolve_delivered())
@@ -2444,7 +2481,7 @@ def canon_stream_write(stream_t, opts, i, ptr, n):
24442481
def stream_copy(EndT, BufferT, event_code, stream_t, opts, i, ptr, n):
24452482
thread = current_thread()
24462483
trap_if(not thread.task.inst.may_leave)
2447-
trap_if(not thread.task.inst.may_block and not opts.async_)
2484+
trap_if(not thread.task.inst.may_block() and not opts.async_)
24482485

24492486
e = thread.task.inst.handles.get(i)
24502487
trap_if(not isinstance(e, EndT))
@@ -2499,7 +2536,7 @@ def canon_future_write(future_t, opts, i, ptr):
24992536
def future_copy(EndT, BufferT, event_code, future_t, opts, i, ptr):
25002537
thread = current_thread()
25012538
trap_if(not thread.task.inst.may_leave)
2502-
trap_if(not thread.task.inst.may_block and not opts.async_)
2539+
trap_if(not thread.task.inst.may_block() and not opts.async_)
25032540

25042541
e = thread.task.inst.handles.get(i)
25052542
trap_if(not isinstance(e, EndT))
@@ -2552,7 +2589,7 @@ def canon_future_cancel_write(future_t, async_, i):
25522589
def cancel_copy(EndT, event_code, stream_or_future_t, async_, i):
25532590
thread = current_thread()
25542591
trap_if(not thread.task.inst.may_leave)
2555-
trap_if(not thread.task.inst.may_block and not async_)
2592+
trap_if(not thread.task.inst.may_block() and not async_)
25562593
e = thread.task.inst.handles.get(i)
25572594
trap_if(not isinstance(e, EndT))
25582595
trap_if(e.shared.t != stream_or_future_t.t)
@@ -2620,22 +2657,22 @@ def thread_func():
26202657
task.register_thread(new_thread)
26212658
return [new_thread.index]
26222659

2623-
### 🧵 `canon thread.resume-later`
2660+
### 🧵 `canon thread.unsuspend`
26242661

2625-
def canon_thread_resume_later(i):
2662+
def canon_thread_unsuspend(i):
26262663
thread = current_thread()
26272664
trap_if(not thread.task.inst.may_leave)
26282665
other_thread = thread.task.inst.threads.get(i)
26292666
trap_if(not other_thread.suspended())
2630-
other_thread.resume_later()
2667+
other_thread.unsuspend()
26312668
return []
26322669

26332670
### 🧵 `canon thread.suspend`
26342671

26352672
def canon_thread_suspend(cancellable):
26362673
thread = current_thread()
26372674
trap_if(not thread.task.inst.may_leave)
2638-
trap_if(not thread.task.inst.may_block)
2675+
trap_if(not thread.task.inst.may_block())
26392676
cancelled = thread.suspend(cancellable)
26402677
return [cancelled]
26412678

@@ -2647,26 +2684,54 @@ def canon_thread_yield(cancellable):
26472684
cancelled = thread.yield_(cancellable)
26482685
return [cancelled]
26492686

2650-
### 🧵 `canon thread.switch-to`
2687+
### 🧵 `canon thread.suspend-to-suspended`
26512688

2652-
def canon_thread_switch_to(cancellable, i):
2689+
def canon_thread_suspend_to_suspended(cancellable, i):
26532690
thread = current_thread()
26542691
trap_if(not thread.task.inst.may_leave)
26552692
other_thread = thread.task.inst.threads.get(i)
26562693
trap_if(not other_thread.suspended())
2657-
cancelled = thread.switch_to(cancellable, other_thread)
2694+
cancelled = thread.suspend_to_suspended(cancellable, other_thread)
26582695
return [cancelled]
26592696

2660-
### 🧵 `canon thread.yield-to`
2697+
### 🧵 `canon thread.yield-to-suspended`
26612698

2662-
def canon_thread_yield_to(cancellable, i):
2699+
def canon_thread_yield_to_suspended(cancellable, i):
26632700
thread = current_thread()
26642701
trap_if(not thread.task.inst.may_leave)
26652702
other_thread = thread.task.inst.threads.get(i)
26662703
trap_if(not other_thread.suspended())
2667-
cancelled = thread.yield_to(cancellable, other_thread)
2704+
cancelled = thread.yield_to_suspended(cancellable, other_thread)
2705+
return [cancelled]
2706+
2707+
### 🧵 `canon thread.suspend-then-promote`
2708+
2709+
def canon_thread_suspend_then_promote(cancellable, i):
2710+
thread = current_thread()
2711+
trap_if(not thread.task.inst.may_leave)
2712+
trap_if(not thread.task.inst.may_block())
2713+
other_thread = thread.task.inst.threads.get(i)
2714+
cancelled = thread.suspend_then_promote(cancellable, other_thread)
26682715
return [cancelled]
26692716

2717+
### 🧵 `canon thread.yield-then-promote`
2718+
2719+
def canon_thread_yield_then_promote(cancellable, i):
2720+
thread = current_thread()
2721+
trap_if(not thread.task.inst.may_leave)
2722+
other_thread = thread.task.inst.threads.get(i)
2723+
cancelled = thread.yield_then_promote(cancellable, other_thread)
2724+
return [cancelled]
2725+
2726+
### 🧵 `canon thread.exit`
2727+
2728+
def canon_thread_exit():
2729+
thread = current_thread()
2730+
trap_if(not thread.task.inst.may_leave)
2731+
trap_if(not thread.task.inst.may_block())
2732+
thread.exit()
2733+
assert(False)
2734+
26702735
### 📝 `canon error-context.new`
26712736

26722737
@dataclass

0 commit comments

Comments
 (0)