Skip to content

Commit d0400d1

Browse files
committed
OHTTP keys should be rotated
This pr addresses #445. It implements OHTTP-key rotation to payjoin-mailroom Mailroom operators can now decide the time interval for keys to be rotated. Also if a key has expired, a 422 error is returned to clients. Clients can handle they key-rotation via the cach-control header returned by the directory.
1 parent 8569be8 commit d0400d1

5 files changed

Lines changed: 380 additions & 51 deletions

File tree

payjoin-mailroom/src/config.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ pub struct Config {
1212
pub storage_dir: PathBuf,
1313
#[serde(deserialize_with = "deserialize_duration_secs")]
1414
pub timeout: Duration,
15+
#[serde(deserialize_with = "deserialize_optional_duration_secs")]
16+
pub ohttp_keys_max_age: Option<Duration>,
1517
pub v1: Option<V1Config>,
1618
#[cfg(feature = "telemetry")]
1719
pub telemetry: Option<TelemetryConfig>,
@@ -85,6 +87,7 @@ impl Default for Config {
8587
listener: "[::]:8080".parse().expect("valid default listener address"),
8688
storage_dir: PathBuf::from("./data"),
8789
timeout: Duration::from_secs(30),
90+
ohttp_keys_max_age: Some(Duration::from_secs(7 * 24 * 60 * 60)),
8891
v1: None,
8992
#[cfg(feature = "telemetry")]
9093
telemetry: None,
@@ -104,17 +107,33 @@ where
104107
Ok(Duration::from_secs(secs))
105108
}
106109

110+
fn deserialize_optional_duration_secs<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
111+
where
112+
D: serde::Deserializer<'de>,
113+
{
114+
let secs: Option<u64> = Option::deserialize(deserializer)?;
115+
match secs {
116+
None => Ok(None),
117+
Some(0) => Err(<D::Error as serde::de::Error>::custom(
118+
"ohttp_keys_max_age must be greater than 0 seconds when set",
119+
)),
120+
Some(s) => Ok(Some(Duration::from_secs(s))),
121+
}
122+
}
123+
107124
impl Config {
108125
pub fn new(
109126
listener: ListenerAddress,
110127
storage_dir: PathBuf,
111128
timeout: Duration,
129+
ohttp_keys_max_age: Option<Duration>,
112130
v1: Option<V1Config>,
113131
) -> Self {
114132
Self {
115133
listener,
116134
storage_dir,
117135
timeout,
136+
ohttp_keys_max_age,
118137
v1,
119138
#[cfg(feature = "telemetry")]
120139
telemetry: None,

payjoin-mailroom/src/directory.rs

Lines changed: 174 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
use std::path::PathBuf;
12
use std::pin::Pin;
23
use std::str::FromStr;
4+
use std::sync::atomic::{AtomicU8, Ordering};
35
use std::sync::Arc;
46
use std::task::{Context, Poll};
7+
use std::time::{Duration, Instant};
58

69
use anyhow::Result;
710
use axum::body::{Body, Bytes};
8-
use axum::http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE};
11+
use axum::http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CACHE_CONTROL, CONTENT_TYPE};
912
use axum::http::{Method, Request, Response, StatusCode, Uri};
1013
use http_body_util::BodyExt;
1114
use payjoin::directory::{ShortId, ShortIdError, ENCAPSULATED_MESSAGE_BYTES};
15+
use tokio::sync::RwLock;
1216
use tracing::{debug, error, trace, warn};
1317

1418
use crate::db::{Db, Error as DbError, SendableError};
@@ -28,6 +32,79 @@ const V1_VERSION_UNSUPPORTED_RES_JSON: &str =
2832

2933
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
3034

35+
// Two-slot OHTTP key set supporting rotation overlap.
36+
//
37+
// Key IDs alternate between 0 and 1. Both slots are always populated.
38+
// The current key is served to new clients; both slots are accepted
39+
// for decapsulation so that clients with a cached previous key still
40+
// work during the grace window after a switch.
41+
#[derive(Debug)]
42+
pub(crate) struct KeySlot {
43+
pub(crate) server: ohttp::Server,
44+
pub(crate) valid_from: Instant,
45+
}
46+
47+
#[derive(Debug)]
48+
pub struct KeyRotatingServer {
49+
keys: [RwLock<KeySlot>; 2],
50+
current_key_id: AtomicU8,
51+
}
52+
53+
impl KeyRotatingServer {
54+
pub(crate) fn new(slot0: KeySlot, slot1: KeySlot, current_key_id: u8) -> Self {
55+
assert!(current_key_id <= 1, "key_id must be 0 or 1");
56+
Self {
57+
keys: [RwLock::new(slot0), RwLock::new(slot1)],
58+
current_key_id: AtomicU8::new(current_key_id),
59+
}
60+
}
61+
62+
pub fn current_key_id(&self) -> u8 { self.current_key_id.load(Ordering::Acquire) }
63+
64+
pub async fn current_valid_from(&self) -> Instant {
65+
let id = self.current_key_id();
66+
self.keys[id as usize].read().await.valid_from
67+
}
68+
69+
pub fn next_key_id(&self) -> u8 { 1 - self.current_key_id() }
70+
71+
// Look up the server matching the key_id in an OHTTP message and
72+
// decapsulate. The first byte of an OHTTP encapsulated request is the
73+
// key identifier (RFC 9458 Section 4.3).
74+
pub async fn decapsulate(
75+
&self,
76+
ohttp_body: &[u8],
77+
) -> std::result::Result<(Vec<u8>, ohttp::ServerResponse), ohttp::Error> {
78+
let key_id = ohttp_body.first().copied().ok_or(ohttp::Error::Truncated)?;
79+
match self.keys.get(key_id as usize) {
80+
Some(slot) => slot.read().await.server.decapsulate(ohttp_body),
81+
None => Err(ohttp::Error::KeyId),
82+
}
83+
}
84+
85+
// Encode the current key's config for serving to clients.
86+
pub async fn encode_current(&self) -> std::result::Result<Vec<u8>, ohttp::Error> {
87+
let id = self.current_key_id();
88+
self.keys[id as usize].read().await.server.config().encode()
89+
}
90+
91+
// Flip which key is advertised to new clients.
92+
pub fn switch(&self) {
93+
let old = self.current_key_id();
94+
self.current_key_id.store(1 - old, Ordering::Release);
95+
}
96+
97+
// Replace a slot with fresh key material.
98+
pub async fn overwrite(&self, key_id: u8, server: ohttp::Server) {
99+
assert!(key_id <= 1, "key_id must be 0 or 1");
100+
*self.keys[key_id as usize].write().await = KeySlot { server, valid_from: Instant::now() };
101+
}
102+
103+
pub async fn valid_from(&self, key_id: u8) -> Instant {
104+
self.keys[key_id as usize].read().await.valid_from
105+
}
106+
}
107+
31108
/// Opaque blocklist of Bitcoin addresses stored as script pubkeys.
32109
///
33110
/// Addresses are converted to `ScriptBuf` at parse time so that
@@ -91,7 +168,8 @@ fn parse_address_lines(text: &str) -> std::collections::HashSet<bitcoin::ScriptB
91168
#[derive(Clone)]
92169
pub struct Service<D: Db> {
93170
db: D,
94-
ohttp: ohttp::Server,
171+
ohttp: Arc<KeyRotatingServer>,
172+
ohttp_keys_max_age: Option<Duration>,
95173
sentinel_tag: SentinelTag,
96174
v1: Option<V1>,
97175
}
@@ -117,10 +195,18 @@ where
117195
}
118196

119197
impl<D: Db> Service<D> {
120-
pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, v1: Option<V1>) -> Self {
121-
Self { db, ohttp, sentinel_tag, v1 }
198+
pub fn new(
199+
db: D,
200+
ohttp: Arc<KeyRotatingServer>,
201+
ohttp_keys_max_age: Option<Duration>,
202+
sentinel_tag: SentinelTag,
203+
v1: Option<V1>,
204+
) -> Self {
205+
Self { db, ohttp, ohttp_keys_max_age, sentinel_tag, v1 }
122206
}
123207

208+
pub fn ohttp_key_set(&self) -> &Arc<KeyRotatingServer> { &self.ohttp }
209+
124210
async fn serve_request<B>(&self, req: Request<B>) -> Result<Response<Body>>
125211
where
126212
B: axum::body::HttpBody<Data = Bytes> + Send + 'static,
@@ -200,10 +286,10 @@ impl<D: Db> Service<D> {
200286
.map_err(|e| HandlerError::BadRequest(anyhow::anyhow!(e.into())))?
201287
.to_bytes();
202288

203-
// Decapsulate OHTTP request
204289
let (bhttp_req, res_ctx) = self
205290
.ohttp
206291
.decapsulate(&ohttp_body)
292+
.await
207293
.map_err(|e| HandlerError::OhttpKeyRejection(e.into()))?;
208294
let mut cursor = std::io::Cursor::new(bhttp_req);
209295
let req = bhttp::Message::read_bhttp(&mut cursor)
@@ -380,11 +466,22 @@ impl<D: Db> Service<D> {
380466
async fn get_ohttp_keys(&self) -> Result<Response<Body>, HandlerError> {
381467
let ohttp_keys = self
382468
.ohttp
383-
.config()
384-
.encode()
469+
.encode_current()
470+
.await
385471
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
386472
let mut res = Response::new(full(ohttp_keys));
387473
res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/ohttp-keys"));
474+
if let Some(max_age) = self.ohttp_keys_max_age {
475+
let remaining = max_age.saturating_sub(self.ohttp.current_valid_from().await.elapsed());
476+
res.headers_mut().insert(
477+
CACHE_CONTROL,
478+
HeaderValue::from_str(&format!(
479+
"public, s-maxage={}, immutable",
480+
remaining.as_secs()
481+
))
482+
.expect("valid header value"),
483+
);
484+
}
388485
Ok(res)
389486
}
390487

@@ -412,6 +509,56 @@ impl<D: Db> Service<D> {
412509
}
413510
}
414511

512+
// Grace period after a switch during which the old key is still
513+
// accepted for decapsulation.
514+
const ROTATION_GRACE: Duration = Duration::from_secs(30);
515+
516+
// Background task that rotates OHTTP keys on a fixed interval.
517+
//
518+
// Each cycle: sleep until the current key is about to expire, switch
519+
// to the other slot, wait out the grace period, then overwrite the
520+
// old slot with fresh key material so it is ready for the next cycle.
521+
pub fn spawn_key_rotation(keyset: Arc<KeyRotatingServer>, keys_dir: PathBuf, interval: Duration) {
522+
tokio::spawn(async move {
523+
loop {
524+
let switch_delay = {
525+
let valid_from = keyset.current_valid_from().await;
526+
let switch_at = valid_from + interval - ROTATION_GRACE / 2;
527+
switch_at.saturating_duration_since(Instant::now())
528+
};
529+
tokio::time::sleep(switch_delay).await;
530+
531+
let old_key_id = keyset.current_key_id();
532+
keyset.switch();
533+
let new_key_id = 1 - old_key_id;
534+
tracing::info!("Switched OHTTP serving: key_id {old_key_id} -> {new_key_id}");
535+
536+
let overwrite_delay = {
537+
let valid_from = keyset.valid_from(old_key_id).await;
538+
let overwrite_at = valid_from + interval + ROTATION_GRACE;
539+
overwrite_at.saturating_duration_since(Instant::now())
540+
};
541+
tokio::time::sleep(overwrite_delay).await;
542+
543+
let config = match crate::key_config::gen_ohttp_server_config_with_id(old_key_id) {
544+
Ok(c) => c,
545+
Err(e) => {
546+
tracing::error!("Failed to generate OHTTP key: {e}");
547+
continue;
548+
}
549+
};
550+
let _ = tokio::fs::remove_file(keys_dir.join(format!("{old_key_id}.ikm"))).await;
551+
if let Err(e) = crate::key_config::persist_key_config(&config, &keys_dir).await {
552+
tracing::error!("Failed to persist OHTTP key: {e}");
553+
continue;
554+
}
555+
556+
keyset.overwrite(old_key_id, config.into_server()).await;
557+
tracing::info!("Overwrote OHTTP key_id {old_key_id} with fresh material");
558+
}
559+
});
560+
}
561+
415562
fn handle_peek<E: SendableError>(
416563
result: Result<Arc<Vec<u8>>, DbError<E>>,
417564
timeout_response: Response<Body>,
@@ -485,8 +632,8 @@ impl HandlerError {
485632
}
486633
HandlerError::OhttpKeyRejection(e) => {
487634
const OHTTP_KEY_REJECTION_RES_JSON: &str = r#"{"type":"https://iana.org/assignments/http-problem-types#ohttp-key", "title": "key identifier unknown"}"#;
488-
warn!("Bad request: Key configuration rejected: {}", e);
489-
*res.status_mut() = StatusCode::BAD_REQUEST;
635+
warn!("Key configuration rejected: {}", e);
636+
*res.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
490637
res.headers_mut()
491638
.insert(CONTENT_TYPE, HeaderValue::from_static("application/problem+json"));
492639
*res.body_mut() = full(OHTTP_KEY_REJECTION_RES_JSON);
@@ -592,9 +739,15 @@ mod tests {
592739
async fn test_service(v1: Option<V1>) -> Service<FilesDb> {
593740
let dir = tempfile::tempdir().expect("tempdir");
594741
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
595-
let ohttp: ohttp::Server =
596-
crate::key_config::gen_ohttp_server_config().expect("ohttp config").into();
597-
Service::new(db, ohttp, SentinelTag::new([0u8; 32]), v1)
742+
let now = Instant::now();
743+
let c0 = crate::key_config::gen_ohttp_server_config_with_id(0).expect("ohttp config");
744+
let c1 = crate::key_config::gen_ohttp_server_config_with_id(1).expect("ohttp config");
745+
let keyset = Arc::new(KeyRotatingServer::new(
746+
KeySlot { server: c0.into_server(), valid_from: now },
747+
KeySlot { server: c1.into_server(), valid_from: now },
748+
0,
749+
));
750+
Service::new(db, keyset, None, SentinelTag::new([0u8; 32]), v1)
598751
}
599752

600753
/// A valid ShortId encoded as bech32 for use in URL paths.
@@ -826,9 +979,15 @@ mod tests {
826979
let dir = tempfile::tempdir().expect("tempdir");
827980
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
828981
let db = MetricsDb::new(db, metrics);
829-
let ohttp: ohttp::Server =
830-
crate::key_config::gen_ohttp_server_config().expect("ohttp config").into();
831-
let svc = Service::new(db, ohttp, SentinelTag::new([0u8; 32]), None);
982+
let now = Instant::now();
983+
let c0 = crate::key_config::gen_ohttp_server_config_with_id(0).expect("ohttp config");
984+
let c1 = crate::key_config::gen_ohttp_server_config_with_id(1).expect("ohttp config");
985+
let keyset = Arc::new(KeyRotatingServer::new(
986+
KeySlot { server: c0.into_server(), valid_from: now },
987+
KeySlot { server: c1.into_server(), valid_from: now },
988+
0,
989+
));
990+
let svc = Service::new(db, keyset, None, SentinelTag::new([0u8; 32]), None);
832991

833992
let id = valid_short_id_path();
834993
let res = svc

0 commit comments

Comments
 (0)