Skip to content

Commit 7b2eaeb

Browse files
committed
OHTTP keys should be rotated
This pr addresses #445. It implements OHTTP-key rotation to payjoin-mailroom Mailroom operators can now decide the time interval for keys to be rotated. Also if a key has expired, a 422 error is returned to clients. Clients can handle they key-rotation via the cach-control header returned by the directory.
1 parent 8569be8 commit 7b2eaeb

5 files changed

Lines changed: 386 additions & 52 deletions

File tree

payjoin-mailroom/src/config.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pub struct Config {
1212
pub storage_dir: PathBuf,
1313
#[serde(deserialize_with = "deserialize_duration_secs")]
1414
pub timeout: Duration,
15+
#[serde(deserialize_with = "deserialize_optional_duration_secs")]
16+
pub ohttp_keys_max_age: Option<Duration>,
1517
pub v1: Option<V1Config>,
1618
#[cfg(feature = "telemetry")]
1719
pub telemetry: Option<TelemetryConfig>,
@@ -85,6 +87,7 @@ impl Default for Config {
8587
listener: "[::]:8080".parse().expect("valid default listener address"),
8688
storage_dir: PathBuf::from("./data"),
8789
timeout: Duration::from_secs(30),
90+
ohttp_keys_max_age: Some(Duration::from_secs(7 * 24 * 60 * 60)),
8891
v1: None,
8992
#[cfg(feature = "telemetry")]
9093
telemetry: None,
@@ -104,17 +107,33 @@ where
104107
Ok(Duration::from_secs(secs))
105108
}
106109

110+
fn deserialize_optional_duration_secs<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
111+
where
112+
D: serde::Deserializer<'de>,
113+
{
114+
let secs: Option<u64> = Option::deserialize(deserializer)?;
115+
match secs {
116+
None => Ok(None),
117+
Some(0) => Err(<D::Error as serde::de::Error>::custom(
118+
"ohttp_keys_max_age must be greater than 0 seconds when set",
119+
)),
120+
Some(s) => Ok(Some(Duration::from_secs(s))),
121+
}
122+
}
123+
107124
impl Config {
108125
pub fn new(
109126
listener: ListenerAddress,
110127
storage_dir: PathBuf,
111128
timeout: Duration,
129+
ohttp_keys_max_age: Option<Duration>,
112130
v1: Option<V1Config>,
113131
) -> Self {
114132
Self {
115133
listener,
116134
storage_dir,
117135
timeout,
136+
ohttp_keys_max_age,
118137
v1,
119138
#[cfg(feature = "telemetry")]
120139
telemetry: None,

payjoin-mailroom/src/directory.rs

Lines changed: 182 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
use std::path::PathBuf;
12
use std::pin::Pin;
23
use std::str::FromStr;
34
use std::sync::Arc;
45
use std::task::{Context, Poll};
6+
use std::time::{Duration, Instant};
57

68
use anyhow::Result;
79
use axum::body::{Body, Bytes};
8-
use axum::http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE};
10+
use axum::http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL, CONTENT_TYPE};
911
use axum::http::{Method, Request, Response, StatusCode, Uri};
1012
use http_body_util::BodyExt;
1113
use payjoin::directory::{ShortId, ShortIdError, ENCAPSULATED_MESSAGE_BYTES};
14+
use tokio::sync::RwLock;
1215
use tracing::{debug, error, trace, warn};
1316

1417
use crate::db::{Db, Error as DbError, SendableError};
@@ -28,6 +31,77 @@ const V1_VERSION_UNSUPPORTED_RES_JSON: &str =
2831

2932
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
3033

34+
// Two-slot OHTTP key set supporting rotation overlap.
35+
//
36+
// Key IDs alternate between 0 and 1. Both slots are always populated.
37+
// The current key is served to new clients; both slots are accepted
38+
// for decapsulation so that clients with a cached previous key still
39+
// work during the grace window after a switch.
40+
#[derive(Debug)]
41+
pub(crate) struct KeySlot {
42+
pub(crate) server: ohttp::Server,
43+
}
44+
45+
#[derive(Debug)]
46+
pub struct KeyRotatingServer {
47+
keys: [Box<RwLock<KeySlot>>; 2],
48+
current: RwLock<(u8, Instant)>,
49+
}
50+
51+
impl KeyRotatingServer {
52+
pub(crate) fn new(
53+
slot0: KeySlot,
54+
slot1: KeySlot,
55+
current_key_id: u8,
56+
valid_until: Instant,
57+
) -> Self {
58+
assert!(current_key_id <= 1, "key_id must be 0 or 1");
59+
Self {
60+
keys: [Box::new(RwLock::new(slot0)), Box::new(RwLock::new(slot1))],
61+
current: RwLock::new((current_key_id, valid_until)),
62+
}
63+
}
64+
65+
pub async fn current_key_id(&self) -> u8 { self.current.read().await.0 }
66+
67+
pub async fn valid_until(&self) -> Instant { self.current.read().await.1 }
68+
69+
// Look up the server matching the key_id in an OHTTP message and
70+
// decapsulate. The first byte of an OHTTP encapsulated request is the
71+
// key identifier (RFC 9458 Section 4.3).
72+
pub async fn decapsulate(
73+
&self,
74+
ohttp_body: &[u8],
75+
) -> std::result::Result<(Vec<u8>, ohttp::ServerResponse), ohttp::Error> {
76+
let key_id = ohttp_body.first().copied().ok_or(ohttp::Error::Truncated)?;
77+
match self.keys.get(key_id as usize) {
78+
Some(slot) => slot.read().await.server.decapsulate(ohttp_body),
79+
None => Err(ohttp::Error::KeyId),
80+
}
81+
}
82+
83+
// Encode the current key's config for serving to clients.
84+
pub async fn encode_current(&self) -> std::result::Result<Vec<u8>, ohttp::Error> {
85+
let id = self.current_key_id().await;
86+
self.keys[id as usize].read().await.server.config().encode()
87+
}
88+
89+
// Flip which key is advertised to new clients and stamp the new expiry.
90+
// Anchored to Instant::now() at the moment of the actual switch so that
91+
// the next rotation cycle is measured from when the key became active,
92+
pub async fn switch(&self, interval: Duration) {
93+
let mut current = self.current.write().await;
94+
current.0 = 1 - current.0;
95+
current.1 = Instant::now() + interval;
96+
}
97+
98+
// Replace a slot with fresh key material.
99+
pub async fn overwrite(&self, key_id: u8, server: ohttp::Server) {
100+
assert!(key_id <= 1, "key_id must be 0 or 1");
101+
*self.keys[key_id as usize].write().await = KeySlot { server };
102+
}
103+
}
104+
31105
/// Opaque blocklist of Bitcoin addresses stored as script pubkeys.
32106
///
33107
/// Addresses are converted to `ScriptBuf` at parse time so that
@@ -91,7 +165,8 @@ fn parse_address_lines(text: &str) -> std::collections::HashSet<bitcoin::ScriptB
91165
#[derive(Clone)]
92166
pub struct Service<D: Db> {
93167
db: D,
94-
ohttp: ohttp::Server,
168+
ohttp: Arc<KeyRotatingServer>,
169+
ohttp_keys_max_age: Option<Duration>,
95170
sentinel_tag: SentinelTag,
96171
v1: Option<V1>,
97172
}
@@ -117,10 +192,18 @@ where
117192
}
118193

119194
impl<D: Db> Service<D> {
120-
pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, v1: Option<V1>) -> Self {
121-
Self { db, ohttp, sentinel_tag, v1 }
195+
pub fn new(
196+
db: D,
197+
ohttp: Arc<KeyRotatingServer>,
198+
ohttp_keys_max_age: Option<Duration>,
199+
sentinel_tag: SentinelTag,
200+
v1: Option<V1>,
201+
) -> Self {
202+
Self { db, ohttp, ohttp_keys_max_age, sentinel_tag, v1 }
122203
}
123204

205+
pub fn ohttp_key_set(&self) -> &Arc<KeyRotatingServer> { &self.ohttp }
206+
124207
async fn serve_request<B>(&self, req: Request<B>) -> Result<Response<Body>>
125208
where
126209
B: axum::body::HttpBody<Data = Bytes> + Send + 'static,
@@ -200,10 +283,10 @@ impl<D: Db> Service<D> {
200283
.map_err(|e| HandlerError::BadRequest(anyhow::anyhow!(e.into())))?
201284
.to_bytes();
202285

203-
// Decapsulate OHTTP request
204286
let (bhttp_req, res_ctx) = self
205287
.ohttp
206288
.decapsulate(&ohttp_body)
289+
.await
207290
.map_err(|e| HandlerError::OhttpKeyRejection(e.into()))?;
208291
let mut cursor = std::io::Cursor::new(bhttp_req);
209292
let req = bhttp::Message::read_bhttp(&mut cursor)
@@ -380,11 +463,31 @@ impl<D: Db> Service<D> {
380463
async fn get_ohttp_keys(&self) -> Result<Response<Body>, HandlerError> {
381464
let ohttp_keys = self
382465
.ohttp
383-
.config()
384-
.encode()
466+
.encode_current()
467+
.await
385468
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
386469
let mut res = Response::new(full(ohttp_keys));
387470
res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/ohttp-keys"));
471+
if let Some(max_age) = self.ohttp_keys_max_age {
472+
// Subtract ROTATION_GRACE / 3 so clients refresh their cached key
473+
// slightly before the rotation boundary, staying well within the
474+
// grace window where the old key is still accepted.
475+
let remaining = self
476+
.ohttp
477+
.valid_until()
478+
.await
479+
.saturating_duration_since(Instant::now())
480+
.min(max_age)
481+
.saturating_sub(ROTATION_GRACE / 3);
482+
res.headers_mut().insert(
483+
CACHE_CONTROL,
484+
HeaderValue::from_str(&format!(
485+
"public, s-maxage={}, immutable",
486+
remaining.as_secs()
487+
))
488+
.expect("valid header value"),
489+
);
490+
}
388491
Ok(res)
389492
}
390493

@@ -412,6 +515,54 @@ impl<D: Db> Service<D> {
412515
}
413516
}
414517

518+
// Grace period after a switch during which the old key is still
519+
// accepted for decapsulation.
520+
const ROTATION_GRACE: Duration = Duration::from_secs(30);
521+
522+
// Background task that rotates OHTTP keys on a fixed interval.
523+
//
524+
// 1. Sleep until the current key is about to expire (valid_until - ROTATION_GRACE/2).
525+
// 2. Switch to the standby slot; stamp valid_until = now + interval.
526+
// 3. Sleep until the old key's grace window has elapsed.
527+
// 4. Overwrite the old slot with fresh key material for the next cycle.
528+
pub fn spawn_key_rotation(keyset: Arc<KeyRotatingServer>, keys_dir: PathBuf, interval: Duration) {
529+
tokio::spawn(async move {
530+
loop {
531+
// Sleep until just before the current key expires.
532+
let valid_until = keyset.valid_until().await;
533+
let switch_at = valid_until.checked_sub(ROTATION_GRACE / 2).unwrap_or(valid_until);
534+
tokio::time::sleep_until(switch_at.into()).await;
535+
536+
// Capture old key id before switching, then switch.
537+
// `switch` stamps valid_until = Instant::now() + interval, anchored
538+
// to the actual moment the new key goes live.
539+
let old_key_id = keyset.current_key_id().await;
540+
keyset.switch(interval).await;
541+
let new_key_id = 1 - old_key_id;
542+
tracing::info!("Switched OHTTP serving: key_id {old_key_id} -> {new_key_id}");
543+
544+
// Wait until the old key's grace window has fully elapsed before
545+
// overwriting it, so in-flight clients using the old key still succeed.
546+
// The old key was valid until (new valid_until - interval), so its
547+
// grace window ends at (new valid_until - interval + ROTATION_GRACE).
548+
let valid_until = keyset.valid_until().await;
549+
let overwrite_at =
550+
valid_until.checked_sub(interval).unwrap_or(valid_until) + ROTATION_GRACE;
551+
tokio::time::sleep_until(overwrite_at.into()).await;
552+
553+
let config = crate::key_config::gen_ohttp_server_config_with_id(old_key_id)
554+
.expect("OHTTP key generation must not fail");
555+
let _ = tokio::fs::remove_file(keys_dir.join(format!("{old_key_id}.ikm"))).await;
556+
crate::key_config::persist_key_config(&config, &keys_dir)
557+
.await
558+
.expect("OHTTP key persistence must not fail");
559+
560+
keyset.overwrite(old_key_id, config.into_server()).await;
561+
tracing::info!("Overwrote OHTTP key_id {old_key_id} with fresh material");
562+
}
563+
});
564+
}
565+
415566
fn handle_peek<E: SendableError>(
416567
result: Result<Arc<Vec<u8>>, DbError<E>>,
417568
timeout_response: Response<Body>,
@@ -485,8 +636,8 @@ impl HandlerError {
485636
}
486637
HandlerError::OhttpKeyRejection(e) => {
487638
const OHTTP_KEY_REJECTION_RES_JSON: &str = r#"{"type":"https://iana.org/assignments/http-problem-types#ohttp-key", "title": "key identifier unknown"}"#;
488-
warn!("Bad request: Key configuration rejected: {}", e);
489-
*res.status_mut() = StatusCode::BAD_REQUEST;
639+
warn!("Key configuration rejected: {}", e);
640+
*res.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
490641
res.headers_mut()
491642
.insert(CONTENT_TYPE, HeaderValue::from_static("application/problem+json"));
492643
*res.body_mut() = full(OHTTP_KEY_REJECTION_RES_JSON);
@@ -592,9 +743,17 @@ mod tests {
592743
async fn test_service(v1: Option<V1>) -> Service<FilesDb> {
593744
let dir = tempfile::tempdir().expect("tempdir");
594745
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
595-
let ohttp: ohttp::Server =
596-
crate::key_config::gen_ohttp_server_config().expect("ohttp config").into();
597-
Service::new(db, ohttp, SentinelTag::new([0u8; 32]), v1)
746+
let c0 = crate::key_config::gen_ohttp_server_config_with_id(0).expect("ohttp config");
747+
let c1 = crate::key_config::gen_ohttp_server_config_with_id(1).expect("ohttp config");
748+
// valid_until = now + a generous test interval so nothing rotates during tests
749+
let valid_until = Instant::now() + Duration::from_secs(3600);
750+
let keyset = Arc::new(KeyRotatingServer::new(
751+
KeySlot { server: c0.into_server() },
752+
KeySlot { server: c1.into_server() },
753+
0,
754+
valid_until,
755+
));
756+
Service::new(db, keyset, None, SentinelTag::new([0u8; 32]), v1)
598757
}
599758

600759
/// A valid ShortId encoded as bech32 for use in URL paths.
@@ -826,9 +985,16 @@ mod tests {
826985
let dir = tempfile::tempdir().expect("tempdir");
827986
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
828987
let db = MetricsDb::new(db, metrics);
829-
let ohttp: ohttp::Server =
830-
crate::key_config::gen_ohttp_server_config().expect("ohttp config").into();
831-
let svc = Service::new(db, ohttp, SentinelTag::new([0u8; 32]), None);
988+
let c0 = crate::key_config::gen_ohttp_server_config_with_id(0).expect("ohttp config");
989+
let c1 = crate::key_config::gen_ohttp_server_config_with_id(1).expect("ohttp config");
990+
let valid_until = Instant::now() + Duration::from_secs(3600);
991+
let keyset = Arc::new(KeyRotatingServer::new(
992+
KeySlot { server: c0.into_server() },
993+
KeySlot { server: c1.into_server() },
994+
0,
995+
valid_until,
996+
));
997+
let svc = Service::new(db, keyset, None, SentinelTag::new([0u8; 32]), None);
832998

833999
let id = valid_short_id_path();
8341000
let res = svc
@@ -849,7 +1015,7 @@ mod tests {
8491015
use opentelemetry::KeyValue;
8501016
use opentelemetry_sdk::metrics::data::{AggregatedMetrics, MetricData};
8511017

852-
// This checks that counter value is 1 as post_mailbox was called once
1018+
// This checks that counter value is 1 as post_mailbox was called once
8531019
// Also confirms the v2 label is recorded
8541020
match db_metric.data() {
8551021
AggregatedMetrics::U64(MetricData::Sum(sum)) => {

0 commit comments

Comments
 (0)