Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 181 additions & 4 deletions pkgs/xmss/src/aggregation.zig
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,37 @@ fn appendAll(out: *ByteList512KiB, buf: []const u8) AggregationError!void {
for (buf) |b| out.append(b) catch return AggregationError.ProofTooLarge;
}

/// Fixed size of every fake (stub) proof. Well under MAX_AGGREGATE_PROOF_SIZE so a fake never
/// trips the ProofTooLarge guard the real wire can.
const FAKE_PROOF_SIZE: usize = 32 * 1024;

/// Fill `out` (init'd, empty) with `len` deterministic bytes via SHA-256 counter mode over
/// seed_parts. Deterministic across nodes/runs: identical inputs -> identical bytes.
/// DETERMINISM CONTRACT: callers seed from EXACTLY what the real Type-1/Type-2 FFI binds for this
/// op — message_hash, slot, child-proof BYTES, participant counts — and NOTHING else; NEVER a
/// pointer/handle address (per-process, would break cross-node wire determinism and consensus).
fn fillFakeProof(out: *ByteList512KiB, len: usize, seed_parts: []const []const u8) AggregationError!void {
const Sha256 = std.crypto.hash.sha2.Sha256;

var base: [Sha256.digest_length]u8 = undefined;
var seed_hasher = Sha256.init(.{});
for (seed_parts) |part| seed_hasher.update(part);
seed_hasher.final(&base);

var written: usize = 0;
var counter: u64 = 0;
while (written < len) : (counter += 1) {
var digest: [Sha256.digest_length]u8 = undefined;
var block_hasher = Sha256.init(.{});
block_hasher.update(&base);
block_hasher.update(std.mem.asBytes(&counter));
block_hasher.final(&digest);
const take = @min(digest.len, len - written);
for (digest[0..take]) |b| out.append(b) catch return AggregationError.ProofTooLarge;
written += take;
}
}

/// Aggregate raw XMSS signatures and child Type-1 proofs into a single Type-1 proof.
///
/// - `raw_pks`/`raw_sigs`: parallel raw public-key + signature handles (must be equal length).
Expand All @@ -151,9 +182,27 @@ pub fn aggregateType1(
const allocator = std.heap.c_allocator;
const num_children = children_pks.len;

// Aggregation flat-re-proves the raw signature set per att_data (children-folding is off by
// default), so the modeled cost scales with raw_pks.len.
if (shadow_cost.fakeXmss()) {
// No pubkey/sig handles in the seed: per-process addresses would break cross-node determinism.
const raw_count: usize = raw_pks.len;
const seed_parts = allocator.alloc([]const u8, num_children + 3) catch
return AggregationError.Type1AggregateFailed;
defer allocator.free(seed_parts);
seed_parts[0] = message_hash[0..];
seed_parts[1] = std.mem.asBytes(&slot);
for (children_proofs, 0..) |child, i| seed_parts[2 + i] = child.constSlice();
seed_parts[2 + num_children] = std.mem.asBytes(&raw_count);
try fillFakeProof(out, FAKE_PROOF_SIZE, seed_parts);
// Keep modeling aggregation cost on the virtual clock; only the real prove is bypassed.
const shadow_delay_ns = shadow_cost.aggregateDelayNs(raw_pks.len);
if (shadow_delay_ns != 0) zeam_utils.sleepNs(shadow_delay_ns);
return;
}

var total_child_pks: usize = 0;
for (children_pks) |c| total_child_pks += c.len;

const child_all_pks = allocator.alloc(*const hashsig.HashSigPublicKey, total_child_pks) catch
return AggregationError.Type1AggregateFailed;
defer allocator.free(child_all_pks);
Expand Down Expand Up @@ -213,6 +262,8 @@ pub fn aggregateType1(
if (shadow_delay_ns != 0) zeam_utils.sleepNs(shadow_delay_ns);
}

// Wall-time of the REAL FFI prove only; under --shadow-xmss-fake the fake branch returns before
// this, so the histogram reads ~0 by design.
fn recordXmssProveDuration(start_ns: i128, num_raw: usize, num_children: usize) void {
const end_ns = zeam_utils.monotonicTimestampNs();
const elapsed_ns: i128 = if (end_ns >= start_ns) end_ns - start_ns else 0;
Expand All @@ -228,9 +279,11 @@ pub fn verifyType1(
slot: u32,
proof: *const ByteList512KiB,
) AggregationError!void {
const wire = proof.constSlice();
const ok = xmss_verify_type_1(pks.ptr, pks.len, message_hash, slot, wire.ptr, wire.len);
if (!ok) return AggregationError.Type1VerifyFailed;
if (!shadow_cost.fakeXmss()) {
const wire = proof.constSlice();
const ok = xmss_verify_type_1(pks.ptr, pks.len, message_hash, slot, wire.ptr, wire.len);
if (!ok) return AggregationError.Type1VerifyFailed;
}

// Model verification cost on the virtual clock (no-op unless a rate is set). Runs on the worker.
const shadow_delay_ns = shadow_cost.verifyDelayNs(pks.len);
Expand All @@ -256,6 +309,20 @@ pub fn mergeType1ToType2(
const allocator = std.heap.c_allocator;
const num_parts = parts.len;

if (shadow_cost.fakeXmss()) {
// No pubkey handles in the seed: per-process addresses would break cross-node determinism.
const seed_parts = allocator.alloc([]const u8, num_parts + 1) catch
return AggregationError.Type2MergeFailed;
defer allocator.free(seed_parts);
for (parts, 0..) |part, i| seed_parts[i] = part.constSlice();
seed_parts[num_parts] = std.mem.asBytes(&num_parts);
try fillFakeProof(out, FAKE_PROOF_SIZE, seed_parts);
// Model merge cost on the virtual clock; the sleep survives the early return.
const shadow_delay_ns = shadow_cost.mergeDelayNs(num_parts);
if (shadow_delay_ns != 0) zeam_utils.sleepNs(shadow_delay_ns);
return;
}

var total_pks: usize = 0;
for (pks_per_part) |p| total_pks += p.len;

Expand Down Expand Up @@ -323,6 +390,12 @@ pub fn splitType2ByMessage(
const allocator = std.heap.c_allocator;
const num_messages = pks_per_message.len;

if (shadow_cost.fakeXmss()) {
// No pubkey handles in the seed: per-process addresses would break cross-node determinism.
try fillFakeProof(out, FAKE_PROOF_SIZE, &.{ type_2_proof.constSlice(), target_message_hash[0..] });
return;
}

var total_pks: usize = 0;
for (pks_per_message) |p| total_pks += p.len;

Expand Down Expand Up @@ -375,6 +448,8 @@ pub fn verifyType2(
if (pks_per_message.len != messages.len) return AggregationError.Type2VerifyFailed;
if (messages.len == 0) return AggregationError.Type2VerifyFailed;

if (shadow_cost.fakeXmss()) return;

const allocator = std.heap.c_allocator;
const num_messages = messages.len;

Expand Down Expand Up @@ -570,3 +645,105 @@ test "aggregateType1 rejects mismatched raw pk/sig lengths" {
aggregateType1(&pks, &sigs, none_pks, none_proofs, &message_hash, 0, 2, &out),
);
}

// Fake path needs no FFI setup; defer-reset keeps fake from leaking into the real-FFI negative tests above.

test "fake: aggregateType1 then verifyType1 ok; no FFI setup needed" {
shadow_cost.setFakeForTest(true);
defer shadow_cost.resetForTest();
const allocator = std.testing.allocator;

var raw_pks = [_]*const hashsig.HashSigPublicKey{undefined};
var raw_sigs = [_]*const hashsig.HashSigSignature{undefined};
const none_pks: []const []*const hashsig.HashSigPublicKey = &.{};
const none_proofs: []const ByteList512KiB = &.{};
const message_hash = [_]u8{42} ** 32;
const slot: u32 = 7;

var proof = try ByteList512KiB.init(allocator);
defer proof.deinit();
try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &message_hash, slot, 2, &proof);

try std.testing.expectEqual(FAKE_PROOF_SIZE, proof.constSlice().len);
try std.testing.expect(proof.constSlice().len < MAX_AGGREGATE_PROOF_SIZE);
try verifyType1(&raw_pks, &message_hash, slot, &proof);
}

test "fake: aggregateType1 is deterministic for identical inputs, differs on message/slot" {
shadow_cost.setFakeForTest(true);
defer shadow_cost.resetForTest();
const allocator = std.testing.allocator;

var raw_pks = [_]*const hashsig.HashSigPublicKey{undefined};
var raw_sigs = [_]*const hashsig.HashSigSignature{undefined};
const none_pks: []const []*const hashsig.HashSigPublicKey = &.{};
const none_proofs: []const ByteList512KiB = &.{};
const message_hash = [_]u8{1} ** 32;
const slot: u32 = 3;

var p1 = try ByteList512KiB.init(allocator);
defer p1.deinit();
var p2 = try ByteList512KiB.init(allocator);
defer p2.deinit();
try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &message_hash, slot, 2, &p1);
try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &message_hash, slot, 2, &p2);
try std.testing.expectEqualSlices(u8, p1.constSlice(), p2.constSlice());

// Different message hash -> different bytes.
const other_hash = [_]u8{2} ** 32;
var p_msg = try ByteList512KiB.init(allocator);
defer p_msg.deinit();
try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &other_hash, slot, 2, &p_msg);
try std.testing.expect(!std.mem.eql(u8, p1.constSlice(), p_msg.constSlice()));

// Different slot -> different bytes.
var p_slot = try ByteList512KiB.init(allocator);
defer p_slot.deinit();
try aggregateType1(&raw_pks, &raw_sigs, none_pks, none_proofs, &message_hash, slot + 1, 2, &p_slot);
try std.testing.expect(!std.mem.eql(u8, p1.constSlice(), p_slot.constSlice()));
}

test "fake: merge->verifyType2 ok and split->verifyType1 ok" {
shadow_cost.setFakeForTest(true);
defer shadow_cost.resetForTest();
const allocator = std.testing.allocator;
const slot: u32 = 5;
const msg_a = [_]u8{0xAA} ** 32;
const msg_b = [_]u8{0xBB} ** 32;

var a_pks = [_]*const hashsig.HashSigPublicKey{undefined};
var a_sigs = [_]*const hashsig.HashSigSignature{undefined};
var b_pks = [_]*const hashsig.HashSigPublicKey{undefined};
var b_sigs = [_]*const hashsig.HashSigSignature{undefined};
const none_pks: []const []*const hashsig.HashSigPublicKey = &.{};
const none_proofs: []const ByteList512KiB = &.{};

var t1a = try ByteList512KiB.init(allocator);
defer t1a.deinit();
var t1b = try ByteList512KiB.init(allocator);
defer t1b.deinit();
try aggregateType1(&a_pks, &a_sigs, none_pks, none_proofs, &msg_a, slot, 2, &t1a);
try aggregateType1(&b_pks, &b_sigs, none_pks, none_proofs, &msg_b, slot, 2, &t1b);

const parts = [_]ByteList512KiB{ t1a, t1b };
var pkpa = [_]*const hashsig.HashSigPublicKey{undefined};
var pkpb = [_]*const hashsig.HashSigPublicKey{undefined};
const pks_per_part = [_][]*const hashsig.HashSigPublicKey{ &pkpa, &pkpb };

var t2 = try ByteList512KiB.init(allocator);
defer t2.deinit();
try mergeType1ToType2(&parts, &pks_per_part, 2, &t2);
try std.testing.expectEqual(FAKE_PROOF_SIZE, t2.constSlice().len);

const bindings = [_]MessageBinding{
.{ .hash = msg_a, .slot = slot },
.{ .hash = msg_b, .slot = slot },
};
try verifyType2(&t2, &pks_per_part, &bindings);

var recovered = try ByteList512KiB.init(allocator);
defer recovered.deinit();
try splitType2ByMessage(&t2, &pks_per_part, &msg_a, 2, &recovered);
try std.testing.expectEqual(FAKE_PROOF_SIZE, recovered.constSlice().len);
try verifyType1(&a_pks, &msg_a, slot, &recovered);
}
56 changes: 54 additions & 2 deletions pkgs/xmss/src/shadow_cost.zig
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,59 @@ const ENV_AGG = "ZEAM_SHADOW_XMSS_AGGREGATE_SIGNATURES_RATE";
const ENV_VERIFY = "ZEAM_SHADOW_XMSS_VERIFY_AGGREGATED_SIGNATURES_RATE";
const ENV_MERGE = "ZEAM_SHADOW_XMSS_MERGE_RATE";

// Fake-XMSS replaces real STARK prove/verify with a deterministic stub (see aggregation.zig).
// Set from the ZEAM_SHADOW_XMSS_FAKE env var in init() (no CLI flag — adding a CLI field overruns
// the zigcli parser's comptime quota).
// INVARIANT: ONE process-global bool, NOT a per-op flag — "fake" must be uniform across all 5
// FFI wrappers in one process; a fake aggregate/split feeding a real merge/verify would mix
// incompatible byte formats and corrupt. Set once in init() before any worker thread reads it.
var fake_enabled: bool = false;
const ENV_FAKE = "ZEAM_SHADOW_XMSS_FAKE";

// Mirrors readEnvRate: a simple libc env read, no allocation. True iff the value is "1"/"true".
fn readEnvFake() bool {
const raw = std.c.getenv(ENV_FAKE) orelse return false;
const value = std.mem.span(raw);
if (std.mem.eql(u8, value, "1")) return true;
if (std.mem.eql(u8, value, "true")) return true;
return false;
}

// Zig 0.16 removed `std.process.getEnvVarOwned`; use libc `getenv` for a simple
// process-env read (no allocation, no free). Returns null when unset/unparseable.
fn readEnvRate(key: [*:0]const u8) ?f64 {
const raw = std.c.getenv(key) orelse return null;
return std.fmt.parseFloat(f64, std.mem.span(raw)) catch null;
}

/// Resolve the shadow sim-cost rates. Precedence: CLI flag (non-null) > env var > off.
/// Call exactly once at node startup, before aggregation begins.
/// Resolve the shadow sim-cost rates and the fake-XMSS toggle. Rates: CLI flag (non-null) >
/// env var > off. Fake: env var only. Call exactly once at node startup, before aggregation begins.
pub fn init(cli_agg: ?f64, cli_verify: ?f64, cli_merge: ?f64) void {
agg_rate = cli_agg orelse readEnvRate(ENV_AGG);
verify_rate = cli_verify orelse readEnvRate(ENV_VERIFY);
merge_rate = cli_merge orelse readEnvRate(ENV_MERGE);
fake_enabled = readEnvFake();
}

/// Whether the XMSS prove/verify FFI is replaced by the deterministic stub. Read lock-free on
/// workers: set once in init() before any worker thread runs, read-only after.
pub fn fakeXmss() bool {
return fake_enabled;
}

/// Test-only. Every test that flips a global MUST `defer shadow_cost.resetForTest()`. INVARIANT:
/// no real-crypto test may run with fake_enabled=true — an always-accept fake verify would green
/// a broken real verify.
pub fn resetForTest() void {
agg_rate = null;
verify_rate = null;
merge_rate = null;
fake_enabled = false;
}

/// Test-only: enable/disable the fake path; pair with `defer resetForTest()`.
pub fn setFakeForTest(v: bool) void {
fake_enabled = v;
}

/// Nanoseconds to sleep to model aggregating `n` raw signatures.
Expand Down Expand Up @@ -71,6 +111,7 @@ test "computeDelayNs: proportional to n / rate (qlean default rate)" {
}

test "mergeDelayNs: reads merge_rate set by init; proportional to n / rate" {
defer resetForTest();
// A non-null CLI arg wins over the env var, so this stays deterministic.
init(null, null, 22.704);
// 100 components / 22.704 per-sec = 4.40451... s ~= 4_404_510_000 ns
Expand All @@ -80,10 +121,21 @@ test "mergeDelayNs: reads merge_rate set by init; proportional to n / rate" {
}

test "mergeDelayNs: off when merge_rate <= 0; zero when n == 0" {
defer resetForTest();
init(null, null, 0);
try std.testing.expectEqual(@as(u64, 0), mergeDelayNs(100));
init(null, null, -5.0);
try std.testing.expectEqual(@as(u64, 0), mergeDelayNs(100));
init(null, null, 22.704);
try std.testing.expectEqual(@as(u64, 0), mergeDelayNs(0));
}

test "fakeXmss: off by default, setFakeForTest toggles" {
defer resetForTest();
resetForTest();
try std.testing.expect(!fakeXmss());
setFakeForTest(true);
try std.testing.expect(fakeXmss());
setFakeForTest(false);
try std.testing.expect(!fakeXmss());
}
Loading