Skip to content

Commit 0767504

Browse files
committed
OHTTP keys should be rotated
This pr addresses #445. It implements OHTTP-key rotation to payjoin-mailroom Mailroom operators can now decide the time interval for keys to be rotated. Also if a key has expired, a 422 error is returned to clients. Clients can handle they key-rotation via the cach-control header returned by the directory. fix spawn rotation
1 parent 45d286f commit 0767504

5 files changed

Lines changed: 522 additions & 53 deletions

File tree

payjoin-mailroom/src/config.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pub struct Config {
1212
pub storage_dir: PathBuf,
1313
#[serde(deserialize_with = "deserialize_duration_secs")]
1414
pub timeout: Duration,
15+
#[serde(deserialize_with = "deserialize_optional_duration_secs")]
16+
pub ohttp_keys_max_age: Option<Duration>,
1517
pub v1: Option<V1Config>,
1618
#[cfg(feature = "telemetry")]
1719
pub telemetry: Option<TelemetryConfig>,
@@ -85,6 +87,7 @@ impl Default for Config {
8587
listener: "[::]:8080".parse().expect("valid default listener address"),
8688
storage_dir: PathBuf::from("./data"),
8789
timeout: Duration::from_secs(30),
90+
ohttp_keys_max_age: None, //Some(Duration::from_secs(30)),
8891
v1: None,
8992
#[cfg(feature = "telemetry")]
9093
telemetry: None,
@@ -104,17 +107,33 @@ where
104107
Ok(Duration::from_secs(secs))
105108
}
106109

110+
fn deserialize_optional_duration_secs<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
111+
where
112+
D: serde::Deserializer<'de>,
113+
{
114+
let secs: Option<u64> = Option::deserialize(deserializer)?;
115+
match secs {
116+
None => Ok(None),
117+
Some(0) => Err(<D::Error as serde::de::Error>::custom(
118+
"ohttp_keys_max_age must be greater than 0 seconds when set",
119+
)),
120+
Some(s) => Ok(Some(Duration::from_secs(s))),
121+
}
122+
}
123+
107124
impl Config {
108125
pub fn new(
109126
listener: ListenerAddress,
110127
storage_dir: PathBuf,
111128
timeout: Duration,
129+
ohttp_keys_max_age: Option<Duration>,
112130
v1: Option<V1Config>,
113131
) -> Self {
114132
Self {
115133
listener,
116134
storage_dir,
117135
timeout,
136+
ohttp_keys_max_age,
118137
v1,
119138
#[cfg(feature = "telemetry")]
120139
telemetry: None,

payjoin-mailroom/src/directory.rs

Lines changed: 200 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
use std::path::PathBuf;
12
use std::pin::Pin;
23
use std::str::FromStr;
34
use std::sync::Arc;
45
use std::task::{Context, Poll};
6+
use std::time::{Duration, Instant};
57

68
use anyhow::Result;
79
use axum::body::{Body, Bytes};
8-
use axum::http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE};
10+
use axum::http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL, CONTENT_TYPE};
911
use axum::http::{Method, Request, Response, StatusCode, Uri};
1012
use http_body_util::BodyExt;
1113
use payjoin::directory::{ShortId, ShortIdError, ENCAPSULATED_MESSAGE_BYTES};
14+
use tokio::sync::RwLock;
1215
use tracing::{debug, error, trace, warn};
1316

1417
use crate::db::{Db, Error as DbError, SendableError};
@@ -28,6 +31,83 @@ const V1_VERSION_UNSUPPORTED_RES_JSON: &str =
2831

2932
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
3033

34+
// Two-slot OHTTP key set supporting rotation overlap.
35+
//
36+
// Key IDs alternate between 0 and 1. Both slots are always populated.
37+
// The current key is served to new clients; both slots are accepted
38+
// for decapsulation so that clients with a cached previous key still
39+
// work during the grace window after a switch.
40+
#[derive(Debug)]
41+
pub(crate) struct KeySlot {
42+
pub(crate) server: ohttp::Server,
43+
}
44+
45+
#[derive(Debug)]
46+
struct ActiveKey {
47+
key_id: u8,
48+
valid_until: Instant,
49+
}
50+
51+
#[derive(Debug)]
52+
pub struct KeyRotatingServer {
53+
keys: [Box<RwLock<KeySlot>>; 2],
54+
current: RwLock<ActiveKey>,
55+
}
56+
57+
impl KeyRotatingServer {
58+
pub(crate) fn new(
59+
slot0: KeySlot,
60+
slot1: KeySlot,
61+
current_key_id: u8,
62+
valid_until: Instant,
63+
) -> Self {
64+
assert!(current_key_id <= 1, "key_id must be 0 or 1");
65+
Self {
66+
keys: [Box::new(RwLock::new(slot0)), Box::new(RwLock::new(slot1))],
67+
current: RwLock::new(ActiveKey { key_id: current_key_id, valid_until }),
68+
}
69+
}
70+
71+
pub async fn current_key_id(&self) -> u8 { self.current.read().await.key_id }
72+
73+
pub async fn valid_until(&self) -> Instant { self.current.read().await.valid_until }
74+
75+
// Look up the server matching the key_id in an OHTTP message and
76+
// decapsulate. The first byte of an OHTTP encapsulated request is the
77+
// key identifier (RFC 9458 Section 4.3).
78+
pub async fn decapsulate(
79+
&self,
80+
ohttp_body: &[u8],
81+
) -> std::result::Result<(Vec<u8>, ohttp::ServerResponse), ohttp::Error> {
82+
let key_id = ohttp_body.first().copied().ok_or(ohttp::Error::Truncated)?;
83+
match self.keys.get(key_id as usize) {
84+
Some(slot) => slot.read().await.server.decapsulate(ohttp_body),
85+
None => Err(ohttp::Error::KeyId),
86+
}
87+
}
88+
89+
// Encode the current key's config for serving to clients.
90+
pub async fn encode_current(&self) -> std::result::Result<Vec<u8>, ohttp::Error> {
91+
let id = self.current_key_id().await;
92+
self.keys[id as usize].read().await.server.config().encode()
93+
}
94+
95+
// Flip which key is advertised to new clients and stamp the new expiry.
96+
// Anchored to Instant::now() at the moment of the actual switch so that
97+
// the next rotation cycle is measured from when the key became active,
98+
pub async fn switch(&self, interval: Duration) {
99+
let mut current = self.current.write().await;
100+
current.key_id = 1 - current.key_id;
101+
current.valid_until = Instant::now() + interval;
102+
}
103+
104+
// Replace a slot with fresh key material.
105+
pub async fn overwrite(&self, key_id: u8, server: ohttp::Server) {
106+
assert!(key_id <= 1, "key_id must be 0 or 1");
107+
*self.keys[key_id as usize].write().await = KeySlot { server };
108+
}
109+
}
110+
31111
/// Opaque blocklist of Bitcoin addresses stored as script pubkeys.
32112
///
33113
/// Addresses are converted to `ScriptBuf` at parse time so that
@@ -91,7 +171,8 @@ fn parse_address_lines(text: &str) -> std::collections::HashSet<bitcoin::ScriptB
91171
#[derive(Clone)]
92172
pub struct Service<D: Db> {
93173
db: D,
94-
ohttp: ohttp::Server,
174+
ohttp: Arc<KeyRotatingServer>,
175+
ohttp_keys_max_age: Option<Duration>,
95176
sentinel_tag: SentinelTag,
96177
v1: Option<V1>,
97178
}
@@ -117,10 +198,18 @@ where
117198
}
118199

119200
impl<D: Db> Service<D> {
120-
pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, v1: Option<V1>) -> Self {
121-
Self { db, ohttp, sentinel_tag, v1 }
201+
pub fn new(
202+
db: D,
203+
ohttp: Arc<KeyRotatingServer>,
204+
ohttp_keys_max_age: Option<Duration>,
205+
sentinel_tag: SentinelTag,
206+
v1: Option<V1>,
207+
) -> Self {
208+
Self { db, ohttp, ohttp_keys_max_age, sentinel_tag, v1 }
122209
}
123210

211+
pub fn ohttp_key_set(&self) -> &Arc<KeyRotatingServer> { &self.ohttp }
212+
124213
async fn serve_request<B>(&self, req: Request<B>) -> Result<Response<Body>>
125214
where
126215
B: axum::body::HttpBody<Data = Bytes> + Send + 'static,
@@ -200,10 +289,10 @@ impl<D: Db> Service<D> {
200289
.map_err(|e| HandlerError::BadRequest(anyhow::anyhow!(e.into())))?
201290
.to_bytes();
202291

203-
// Decapsulate OHTTP request
204292
let (bhttp_req, res_ctx) = self
205293
.ohttp
206294
.decapsulate(&ohttp_body)
295+
.await
207296
.map_err(|e| HandlerError::OhttpKeyRejection(e.into()))?;
208297
let mut cursor = std::io::Cursor::new(bhttp_req);
209298
let req = bhttp::Message::read_bhttp(&mut cursor)
@@ -380,11 +469,31 @@ impl<D: Db> Service<D> {
380469
async fn get_ohttp_keys(&self) -> Result<Response<Body>, HandlerError> {
381470
let ohttp_keys = self
382471
.ohttp
383-
.config()
384-
.encode()
472+
.encode_current()
473+
.await
385474
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
386475
let mut res = Response::new(full(ohttp_keys));
387476
res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/ohttp-keys"));
477+
if let Some(max_age) = self.ohttp_keys_max_age {
478+
// Subtract ROTATION_GRACE / 3 so clients refresh their cached key
479+
// slightly before the rotation boundary, staying well within the
480+
// grace window where the old key is still accepted.
481+
let remaining = self
482+
.ohttp
483+
.valid_until()
484+
.await
485+
.saturating_duration_since(Instant::now())
486+
.min(max_age)
487+
.saturating_add(ROTATION_GRACE / 3);
488+
res.headers_mut().insert(
489+
CACHE_CONTROL,
490+
HeaderValue::from_str(&format!(
491+
"public, s-maxage={}, immutable",
492+
remaining.as_secs()
493+
))
494+
.expect("valid header value"),
495+
);
496+
}
388497
Ok(res)
389498
}
390499

@@ -412,6 +521,66 @@ impl<D: Db> Service<D> {
412521
}
413522
}
414523

524+
// Grace period after a switch during which the old key is still
525+
// accepted for decapsulation.
526+
const ROTATION_GRACE: Duration = Duration::from_secs(30);
527+
528+
// Background task that rotates OHTTP keys on a fixed interval.
529+
//
530+
// 1. Sleep until the current key is about to expire (valid_until - ROTATION_GRACE/2).
531+
// 2. Switch to the standby slot; stamp valid_until = now + interval.
532+
// 3. Sleep until the old key's grace window has elapsed.
533+
// 4. Overwrite the old slot with fresh key material for the next cycle.
534+
pub fn spawn_key_rotation(keyset: Arc<KeyRotatingServer>, keys_dir: PathBuf, interval: Duration) {
535+
tokio::spawn(async move {
536+
loop {
537+
// Sleep until just before the current key expires.
538+
let valid_until = keyset.valid_until().await;
539+
tracing::info!("Sleeping until {:?}", valid_until);
540+
//let switch_at = valid_until.checked_sub(ROTATION_GRACE / 2).unwrap_or(valid_until);
541+
tokio::time::sleep_until(valid_until.into()).await;
542+
543+
// Capture old key id before switching, then switch.
544+
let old_key_id = keyset.current_key_id().await;
545+
let new_key_id = 1 - old_key_id;
546+
547+
tracing::info!(
548+
"---------------------------------------------------------------------------"
549+
);
550+
551+
// Touch the new active key file *after* overwriting the old slot so
552+
// its mtime is newest on disk. On restart,
553+
// and derives valid_until from its age.
554+
let active_path = keys_dir.join(format!("{new_key_id}.ikm"));
555+
let times = std::fs::FileTimes::new().set_modified(std::time::SystemTime::now());
556+
match std::fs::File::open(&active_path).and_then(|f| f.set_times(times)) {
557+
Ok(()) => {}
558+
Err(e) => tracing::warn!("Failed to change mtime {}: {e}", active_path.display()),
559+
}
560+
561+
// `switch` stamps valid_until = Instant::now() + interval, anchored
562+
// to the actual moment the new key goes live.
563+
keyset.switch(interval).await;
564+
565+
tracing::info!("Switched OHTTP serving: From key_id {old_key_id} -> TO {new_key_id}");
566+
567+
// Wait until the old key's grace window has fully elapsed before
568+
// overwriting it, so in-flight clients using the old key still succeed.
569+
tokio::time::sleep(ROTATION_GRACE).await;
570+
571+
let config = crate::key_config::gen_ohttp_server_config_with_id(old_key_id)
572+
.expect("OHTTP key generation must not fail");
573+
let _ = tokio::fs::remove_file(keys_dir.join(format!("{old_key_id}.ikm"))).await;
574+
crate::key_config::persist_key_config(&config, &keys_dir)
575+
.await
576+
.expect("OHTTP key persistence must not fail");
577+
578+
keyset.overwrite(old_key_id, config.into_server()).await;
579+
tracing::info!("Overwrote OHTTP key_id {old_key_id} with fresh material");
580+
}
581+
});
582+
}
583+
415584
fn handle_peek<E: SendableError>(
416585
result: Result<Arc<Vec<u8>>, DbError<E>>,
417586
timeout_response: Response<Body>,
@@ -485,8 +654,8 @@ impl HandlerError {
485654
}
486655
HandlerError::OhttpKeyRejection(e) => {
487656
const OHTTP_KEY_REJECTION_RES_JSON: &str = r#"{"type":"https://iana.org/assignments/http-problem-types#ohttp-key", "title": "key identifier unknown"}"#;
488-
warn!("Bad request: Key configuration rejected: {}", e);
489-
*res.status_mut() = StatusCode::BAD_REQUEST;
657+
warn!("Key configuration rejected: {}", e);
658+
*res.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
490659
res.headers_mut()
491660
.insert(CONTENT_TYPE, HeaderValue::from_static("application/problem+json"));
492661
*res.body_mut() = full(OHTTP_KEY_REJECTION_RES_JSON);
@@ -592,9 +761,17 @@ mod tests {
592761
async fn test_service(v1: Option<V1>) -> Service<FilesDb> {
593762
let dir = tempfile::tempdir().expect("tempdir");
594763
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
595-
let ohttp: ohttp::Server =
596-
crate::key_config::gen_ohttp_server_config().expect("ohttp config").into();
597-
Service::new(db, ohttp, SentinelTag::new([0u8; 32]), v1)
764+
let c0 = crate::key_config::gen_ohttp_server_config_with_id(0).expect("ohttp config");
765+
let c1 = crate::key_config::gen_ohttp_server_config_with_id(1).expect("ohttp config");
766+
// valid_until = now + a generous test interval so nothing rotates during tests
767+
let valid_until = Instant::now() + Duration::from_secs(3600);
768+
let keyset = Arc::new(KeyRotatingServer::new(
769+
KeySlot { server: c0.into_server() },
770+
KeySlot { server: c1.into_server() },
771+
0,
772+
valid_until,
773+
));
774+
Service::new(db, keyset, None, SentinelTag::new([0u8; 32]), v1)
598775
}
599776

600777
/// A valid ShortId encoded as bech32 for use in URL paths.
@@ -826,9 +1003,16 @@ mod tests {
8261003
let dir = tempfile::tempdir().expect("tempdir");
8271004
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
8281005
let db = MetricsDb::new(db, metrics);
829-
let ohttp: ohttp::Server =
830-
crate::key_config::gen_ohttp_server_config().expect("ohttp config").into();
831-
let svc = Service::new(db, ohttp, SentinelTag::new([0u8; 32]), None);
1006+
let c0 = crate::key_config::gen_ohttp_server_config_with_id(0).expect("ohttp config");
1007+
let c1 = crate::key_config::gen_ohttp_server_config_with_id(1).expect("ohttp config");
1008+
let valid_until = Instant::now() + Duration::from_secs(3600);
1009+
let keyset = Arc::new(KeyRotatingServer::new(
1010+
KeySlot { server: c0.into_server() },
1011+
KeySlot { server: c1.into_server() },
1012+
0,
1013+
valid_until,
1014+
));
1015+
let svc = Service::new(db, keyset, None, SentinelTag::new([0u8; 32]), None);
8321016

8331017
let id = valid_short_id_path();
8341018
let res = svc
@@ -849,7 +1033,7 @@ mod tests {
8491033
use opentelemetry::KeyValue;
8501034
use opentelemetry_sdk::metrics::data::{AggregatedMetrics, MetricData};
8511035

852-
// This checks that counter value is 1 as post_mailbox was called once
1036+
// This checks that counter value is 1 as post_mailbox was called once
8531037
// Also confirms the v2 label is recorded
8541038
match db_metric.data() {
8551039
AggregatedMetrics::U64(MetricData::Sum(sum)) => {

0 commit comments

Comments
 (0)