From d110de50824e78202d8831896f62939f6c86368b Mon Sep 17 00:00:00 2001 From: Chen Kai <281165273grape@gmail.com> Date: Mon, 22 Jun 2026 23:19:44 +0800 Subject: [PATCH] feat(xmss): deterministic fake prover for Shadow sims (ZEAM_SHADOW_XMSS_FAKE) Opt-in stub that replaces the real leanMultisig STARK prove/verify FFI with a deterministic SHA-256 counter-mode proof when ZEAM_SHADOW_XMSS_FAKE is set. Under Shadow the real prover runs as native compute charged to virtual time, an uncontrollable baseline on top of the injected --shadow-xmss-*-rate sleep; the stub removes that baseline so the injected rate is the sole prover-cost knob for cost-sweep experiments. - shadow_cost: one process-global fake toggle, env-only, off by default. - aggregation: aggregateType1 / mergeType1ToType2 / splitType2ByMessage / verifyType1 / verifyType2 emit the stub; the sim-cost sleeps still run on the virtual clock, only the real prove/verify is bypassed. - Stub bytes are seeded only from wire-bound inputs (message hash, slot, child-proof bytes, participant counts), never per-process pointer handles, so every node yields identical bytes for identical inputs (consensus-safe). Off by default: with no rate set and fake off, behavior is byte-identical to before. Verified on Zig 0.16.0: test-shadow-cost + zig build test (-Dprover=dummy) green, including 4 new fake-path tests. --- pkgs/xmss/src/aggregation.zig | 185 +++++++++++++++++++++++++++++++++- pkgs/xmss/src/shadow_cost.zig | 56 +++++++++- 2 files changed, 235 insertions(+), 6 deletions(-) diff --git a/pkgs/xmss/src/aggregation.zig b/pkgs/xmss/src/aggregation.zig index 2f5199422..b89ddf33c 100644 --- a/pkgs/xmss/src/aggregation.zig +++ b/pkgs/xmss/src/aggregation.zig @@ -128,6 +128,37 @@ fn appendAll(out: *ByteList512KiB, buf: []const u8) AggregationError!void { for (buf) |b| out.append(b) catch return AggregationError.ProofTooLarge; } +/// Fixed size of every fake (stub) proof. Well under MAX_AGGREGATE_PROOF_SIZE so a fake never +/// trips the ProofTooLarge guard the real wire can. +const FAKE_PROOF_SIZE: usize = 32 * 1024; + +/// Fill `out` (init'd, empty) with `len` deterministic bytes via SHA-256 counter mode over +/// seed_parts. Deterministic across nodes/runs: identical inputs -> identical bytes. +/// DETERMINISM CONTRACT: callers seed from EXACTLY what the real Type-1/Type-2 FFI binds for this +/// op — message_hash, slot, child-proof BYTES, participant counts — and NOTHING else; NEVER a +/// pointer/handle address (per-process, would break cross-node wire determinism and consensus). +fn fillFakeProof(out: *ByteList512KiB, len: usize, seed_parts: []const []const u8) AggregationError!void { + const Sha256 = std.crypto.hash.sha2.Sha256; + + var base: [Sha256.digest_length]u8 = undefined; + var seed_hasher = Sha256.init(.{}); + for (seed_parts) |part| seed_hasher.update(part); + seed_hasher.final(&base); + + var written: usize = 0; + var counter: u64 = 0; + while (written < len) : (counter += 1) { + var digest: [Sha256.digest_length]u8 = undefined; + var block_hasher = Sha256.init(.{}); + block_hasher.update(&base); + block_hasher.update(std.mem.asBytes(&counter)); + block_hasher.final(&digest); + const take = @min(digest.len, len - written); + for (digest[0..take]) |b| out.append(b) catch return AggregationError.ProofTooLarge; + written += take; + } +} + /// Aggregate raw XMSS signatures and child Type-1 proofs into a single Type-1 proof. /// /// - `raw_pks`/`raw_sigs`: parallel raw public-key + signature handles (must be equal length). @@ -151,9 +182,27 @@ pub fn aggregateType1( const allocator = std.heap.c_allocator; const num_children = children_pks.len; + // Aggregation flat-re-proves the raw signature set per att_data (children-folding is off by + // default), so the modeled cost scales with raw_pks.len. + if (shadow_cost.fakeXmss()) { + // No pubkey/sig handles in the seed: per-process addresses would break cross-node determinism. + const raw_count: usize = raw_pks.len; + const seed_parts = allocator.alloc([]const u8, num_children + 3) catch + return AggregationError.Type1AggregateFailed; + defer allocator.free(seed_parts); + seed_parts[0] = message_hash[0..]; + seed_parts[1] = std.mem.asBytes(&slot); + for (children_proofs, 0..) |child, i| seed_parts[2 + i] = child.constSlice(); + seed_parts[2 + num_children] = std.mem.asBytes(&raw_count); + try fillFakeProof(out, FAKE_PROOF_SIZE, seed_parts); + // Keep modeling aggregation cost on the virtual clock; only the real prove is bypassed. + const shadow_delay_ns = shadow_cost.aggregateDelayNs(raw_pks.len); + if (shadow_delay_ns != 0) zeam_utils.sleepNs(shadow_delay_ns); + return; + } + var total_child_pks: usize = 0; for (children_pks) |c| total_child_pks += c.len; - const child_all_pks = allocator.alloc(*const hashsig.HashSigPublicKey, total_child_pks) catch return AggregationError.Type1AggregateFailed; defer allocator.free(child_all_pks); @@ -213,6 +262,8 @@ pub fn aggregateType1( if (shadow_delay_ns != 0) zeam_utils.sleepNs(shadow_delay_ns); } +// Wall-time of the REAL FFI prove only; under --shadow-xmss-fake the fake branch returns before +// this, so the histogram reads ~0 by design. fn recordXmssProveDuration(start_ns: i128, num_raw: usize, num_children: usize) void { const end_ns = zeam_utils.monotonicTimestampNs(); const elapsed_ns: i128 = if (end_ns >= start_ns) end_ns - start_ns else 0; @@ -228,9 +279,11 @@ pub fn verifyType1( slot: u32, proof: *const ByteList512KiB, ) AggregationError!void { - const wire = proof.constSlice(); - const ok = xmss_verify_type_1(pks.ptr, pks.len, message_hash, slot, wire.ptr, wire.len); - if (!ok) return AggregationError.Type1VerifyFailed; + if (!shadow_cost.fakeXmss()) { + const wire = proof.constSlice(); + const ok = xmss_verify_type_1(pks.ptr, pks.len, message_hash, slot, wire.ptr, wire.len); + if (!ok) return AggregationError.Type1VerifyFailed; + } // Model verification cost on the virtual clock (no-op unless a rate is set). Runs on the worker. const shadow_delay_ns = shadow_cost.verifyDelayNs(pks.len); @@ -256,6 +309,20 @@ pub fn mergeType1ToType2( const allocator = std.heap.c_allocator; const num_parts = parts.len; + if (shadow_cost.fakeXmss()) { + // No pubkey handles in the seed: per-process addresses would break cross-node determinism. + const seed_parts = allocator.alloc([]const u8, num_parts + 1) catch + return AggregationError.Type2MergeFailed; + defer allocator.free(seed_parts); + for (parts, 0..) |part, i| seed_parts[i] = part.constSlice(); + seed_parts[num_parts] = std.mem.asBytes(&num_parts); + try fillFakeProof(out, FAKE_PROOF_SIZE, seed_parts); + // Model merge cost on the virtual clock; the sleep survives the early return. + const shadow_delay_ns = shadow_cost.mergeDelayNs(num_parts); + if (shadow_delay_ns != 0) zeam_utils.sleepNs(shadow_delay_ns); + return; + } + var total_pks: usize = 0; for (pks_per_part) |p| total_pks += p.len; @@ -323,6 +390,12 @@ pub fn splitType2ByMessage( const allocator = std.heap.c_allocator; const num_messages = pks_per_message.len; + if (shadow_cost.fakeXmss()) { + // No pubkey handles in the seed: per-process addresses would break cross-node determinism. + try fillFakeProof(out, FAKE_PROOF_SIZE, &.{ type_2_proof.constSlice(), target_message_hash[0..] }); + return; + } + var total_pks: usize = 0; for (pks_per_message) |p| total_pks += p.len; @@ -375,6 +448,8 @@ pub fn verifyType2( if (pks_per_message.len != messages.len) return AggregationError.Type2VerifyFailed; if (messages.len == 0) return AggregationError.Type2VerifyFailed; + if (shadow_cost.fakeXmss()) return; + const allocator = std.heap.c_allocator; const num_messages = messages.len; @@ -570,3 +645,105 @@ test "aggregateType1 rejects mismatched raw pk/sig lengths" { aggregateType1(&pks, &sigs, none_pks, none_proofs, &message_hash, 0, 2, &out), ); } + +// Fake path needs no FFI setup; defer-reset keeps fake from leaking into the real-FFI negative tests above. + +test "fake: aggregateType1 then verifyType1 ok; no FFI setup needed" { + shadow_cost.setFakeForTest(true); + defer shadow_cost.resetForTest(); + const allocator = std.testing.allocator; + + var raw_pks = [_]*const hashsig.HashSigPublicKey{undefined}; + var raw_sigs = [_]*const hashsig.HashSigSignature{undefined}; + const none_pks: []const []*const hashsig.HashSigPublicKey = &.{}; + const none_proofs: []const ByteList512KiB = &.{}; + const message_hash = [_]u8{42} ** 32; + const slot: u32 = 7; + + var proof = try ByteList512KiB.init(allocator); + defer proof.deinit(); + try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &message_hash, slot, 2, &proof); + + try std.testing.expectEqual(FAKE_PROOF_SIZE, proof.constSlice().len); + try std.testing.expect(proof.constSlice().len < MAX_AGGREGATE_PROOF_SIZE); + try verifyType1(&raw_pks, &message_hash, slot, &proof); +} + +test "fake: aggregateType1 is deterministic for identical inputs, differs on message/slot" { + shadow_cost.setFakeForTest(true); + defer shadow_cost.resetForTest(); + const allocator = std.testing.allocator; + + var raw_pks = [_]*const hashsig.HashSigPublicKey{undefined}; + var raw_sigs = [_]*const hashsig.HashSigSignature{undefined}; + const none_pks: []const []*const hashsig.HashSigPublicKey = &.{}; + const none_proofs: []const ByteList512KiB = &.{}; + const message_hash = [_]u8{1} ** 32; + const slot: u32 = 3; + + var p1 = try ByteList512KiB.init(allocator); + defer p1.deinit(); + var p2 = try ByteList512KiB.init(allocator); + defer p2.deinit(); + try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &message_hash, slot, 2, &p1); + try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &message_hash, slot, 2, &p2); + try std.testing.expectEqualSlices(u8, p1.constSlice(), p2.constSlice()); + + // Different message hash -> different bytes. + const other_hash = [_]u8{2} ** 32; + var p_msg = try ByteList512KiB.init(allocator); + defer p_msg.deinit(); + try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &other_hash, slot, 2, &p_msg); + try std.testing.expect(!std.mem.eql(u8, p1.constSlice(), p_msg.constSlice())); + + // Different slot -> different bytes. + var p_slot = try ByteList512KiB.init(allocator); + defer p_slot.deinit(); + try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &message_hash, slot + 1, 2, &p_slot); + try std.testing.expect(!std.mem.eql(u8, p1.constSlice(), p_slot.constSlice())); +} + +test "fake: merge->verifyType2 ok and split->verifyType1 ok" { + shadow_cost.setFakeForTest(true); + defer shadow_cost.resetForTest(); + const allocator = std.testing.allocator; + const slot: u32 = 5; + const msg_a = [_]u8{0xAA} ** 32; + const msg_b = [_]u8{0xBB} ** 32; + + var a_pks = [_]*const hashsig.HashSigPublicKey{undefined}; + var a_sigs = [_]*const hashsig.HashSigSignature{undefined}; + var b_pks = [_]*const hashsig.HashSigPublicKey{undefined}; + var b_sigs = [_]*const hashsig.HashSigSignature{undefined}; + const none_pks: []const []*const hashsig.HashSigPublicKey = &.{}; + const none_proofs: []const ByteList512KiB = &.{}; + + var t1a = try ByteList512KiB.init(allocator); + defer t1a.deinit(); + var t1b = try ByteList512KiB.init(allocator); + defer t1b.deinit(); + try aggregateType1(&a_pks, &a_sigs, none_pks, none_proofs, &msg_a, slot, 2, &t1a); + try aggregateType1(&b_pks, &b_sigs, none_pks, none_proofs, &msg_b, slot, 2, &t1b); + + const parts = [_]ByteList512KiB{ t1a, t1b }; + var pkpa = [_]*const hashsig.HashSigPublicKey{undefined}; + var pkpb = [_]*const hashsig.HashSigPublicKey{undefined}; + const pks_per_part = [_][]*const hashsig.HashSigPublicKey{ &pkpa, &pkpb }; + + var t2 = try ByteList512KiB.init(allocator); + defer t2.deinit(); + try mergeType1ToType2(&parts, &pks_per_part, 2, &t2); + try std.testing.expectEqual(FAKE_PROOF_SIZE, t2.constSlice().len); + + const bindings = [_]MessageBinding{ + .{ .hash = msg_a, .slot = slot }, + .{ .hash = msg_b, .slot = slot }, + }; + try verifyType2(&t2, &pks_per_part, &bindings); + + var recovered = try ByteList512KiB.init(allocator); + defer recovered.deinit(); + try splitType2ByMessage(&t2, &pks_per_part, &msg_a, 2, &recovered); + try std.testing.expectEqual(FAKE_PROOF_SIZE, recovered.constSlice().len); + try verifyType1(&a_pks, &msg_a, slot, &recovered); +} diff --git a/pkgs/xmss/src/shadow_cost.zig b/pkgs/xmss/src/shadow_cost.zig index 573d92a22..ae65e3f54 100644 --- a/pkgs/xmss/src/shadow_cost.zig +++ b/pkgs/xmss/src/shadow_cost.zig @@ -23,6 +23,24 @@ const ENV_AGG = "ZEAM_SHADOW_XMSS_AGGREGATE_SIGNATURES_RATE"; const ENV_VERIFY = "ZEAM_SHADOW_XMSS_VERIFY_AGGREGATED_SIGNATURES_RATE"; const ENV_MERGE = "ZEAM_SHADOW_XMSS_MERGE_RATE"; +// Fake-XMSS replaces real STARK prove/verify with a deterministic stub (see aggregation.zig). +// Set from the ZEAM_SHADOW_XMSS_FAKE env var in init() (no CLI flag — adding a CLI field overruns +// the zigcli parser's comptime quota). +// INVARIANT: ONE process-global bool, NOT a per-op flag — "fake" must be uniform across all 5 +// FFI wrappers in one process; a fake aggregate/split feeding a real merge/verify would mix +// incompatible byte formats and corrupt. Set once in init() before any worker thread reads it. +var fake_enabled: bool = false; +const ENV_FAKE = "ZEAM_SHADOW_XMSS_FAKE"; + +// Mirrors readEnvRate: a simple libc env read, no allocation. True iff the value is "1"/"true". +fn readEnvFake() bool { + const raw = std.c.getenv(ENV_FAKE) orelse return false; + const value = std.mem.span(raw); + if (std.mem.eql(u8, value, "1")) return true; + if (std.mem.eql(u8, value, "true")) return true; + return false; +} + // Zig 0.16 removed `std.process.getEnvVarOwned`; use libc `getenv` for a simple // process-env read (no allocation, no free). Returns null when unset/unparseable. fn readEnvRate(key: [*:0]const u8) ?f64 { @@ -30,12 +48,34 @@ fn readEnvRate(key: [*:0]const u8) ?f64 { return std.fmt.parseFloat(f64, std.mem.span(raw)) catch null; } -/// Resolve the shadow sim-cost rates. Precedence: CLI flag (non-null) > env var > off. -/// Call exactly once at node startup, before aggregation begins. +/// Resolve the shadow sim-cost rates and the fake-XMSS toggle. Rates: CLI flag (non-null) > +/// env var > off. Fake: env var only. Call exactly once at node startup, before aggregation begins. pub fn init(cli_agg: ?f64, cli_verify: ?f64, cli_merge: ?f64) void { agg_rate = cli_agg orelse readEnvRate(ENV_AGG); verify_rate = cli_verify orelse readEnvRate(ENV_VERIFY); merge_rate = cli_merge orelse readEnvRate(ENV_MERGE); + fake_enabled = readEnvFake(); +} + +/// Whether the XMSS prove/verify FFI is replaced by the deterministic stub. Read lock-free on +/// workers: set once in init() before any worker thread runs, read-only after. +pub fn fakeXmss() bool { + return fake_enabled; +} + +/// Test-only. Every test that flips a global MUST `defer shadow_cost.resetForTest()`. INVARIANT: +/// no real-crypto test may run with fake_enabled=true — an always-accept fake verify would green +/// a broken real verify. +pub fn resetForTest() void { + agg_rate = null; + verify_rate = null; + merge_rate = null; + fake_enabled = false; +} + +/// Test-only: enable/disable the fake path; pair with `defer resetForTest()`. +pub fn setFakeForTest(v: bool) void { + fake_enabled = v; } /// Nanoseconds to sleep to model aggregating `n` raw signatures. @@ -71,6 +111,7 @@ test "computeDelayNs: proportional to n / rate (qlean default rate)" { } test "mergeDelayNs: reads merge_rate set by init; proportional to n / rate" { + defer resetForTest(); // A non-null CLI arg wins over the env var, so this stays deterministic. init(null, null, 22.704); // 100 components / 22.704 per-sec = 4.40451... s ~= 4_404_510_000 ns @@ -80,6 +121,7 @@ test "mergeDelayNs: reads merge_rate set by init; proportional to n / rate" { } test "mergeDelayNs: off when merge_rate <= 0; zero when n == 0" { + defer resetForTest(); init(null, null, 0); try std.testing.expectEqual(@as(u64, 0), mergeDelayNs(100)); init(null, null, -5.0); @@ -87,3 +129,13 @@ test "mergeDelayNs: off when merge_rate <= 0; zero when n == 0" { init(null, null, 22.704); try std.testing.expectEqual(@as(u64, 0), mergeDelayNs(0)); } + +test "fakeXmss: off by default, setFakeForTest toggles" { + defer resetForTest(); + resetForTest(); + try std.testing.expect(!fakeXmss()); + setFakeForTest(true); + try std.testing.expect(fakeXmss()); + setFakeForTest(false); + try std.testing.expect(!fakeXmss()); +}