1+ use std:: path:: PathBuf ;
12use std:: pin:: Pin ;
23use std:: str:: FromStr ;
34use std:: sync:: Arc ;
45use std:: task:: { Context , Poll } ;
6+ use std:: time:: { Duration , Instant } ;
57
68use anyhow:: Result ;
79use axum:: body:: { Body , Bytes } ;
8- use axum:: http:: header:: { HeaderValue , ACCESS_CONTROL_ALLOW_ORIGIN , CONTENT_TYPE } ;
10+ use axum:: http:: header:: { HeaderValue , ACCESS_CONTROL_ALLOW_ORIGIN , CACHE_CONTROL , CONTENT_TYPE } ;
911use axum:: http:: { Method , Request , Response , StatusCode , Uri } ;
1012use http_body_util:: BodyExt ;
1113use payjoin:: directory:: { ShortId , ShortIdError , ENCAPSULATED_MESSAGE_BYTES } ;
14+ use tokio:: sync:: RwLock ;
1215use tracing:: { debug, error, trace, warn} ;
1316
1417use crate :: db:: { Db , Error as DbError , SendableError } ;
@@ -28,6 +31,83 @@ const V1_VERSION_UNSUPPORTED_RES_JSON: &str =
2831
2932pub type BoxError = Box < dyn std:: error:: Error + Send + Sync > ;
3033
34+ // Two-slot OHTTP key set supporting rotation overlap.
35+ //
36+ // Key IDs alternate between 0 and 1. Both slots are always populated.
37+ // The current key is served to new clients; both slots are accepted
38+ // for decapsulation so that clients with a cached previous key still
39+ // work during the grace window after a switch.
40+ #[ derive( Debug ) ]
41+ pub ( crate ) struct KeySlot {
42+ pub ( crate ) server : ohttp:: Server ,
43+ }
44+
45+ #[ derive( Debug ) ]
46+ struct ActiveKey {
47+ key_id : u8 ,
48+ valid_until : Instant ,
49+ }
50+
51+ #[ derive( Debug ) ]
52+ pub struct KeyRotatingServer {
53+ keys : [ Box < RwLock < KeySlot > > ; 2 ] ,
54+ current : RwLock < ActiveKey > ,
55+ }
56+
57+ impl KeyRotatingServer {
58+ pub ( crate ) fn new (
59+ slot0 : KeySlot ,
60+ slot1 : KeySlot ,
61+ current_key_id : u8 ,
62+ valid_until : Instant ,
63+ ) -> Self {
64+ assert ! ( current_key_id <= 1 , "key_id must be 0 or 1" ) ;
65+ Self {
66+ keys : [ Box :: new ( RwLock :: new ( slot0) ) , Box :: new ( RwLock :: new ( slot1) ) ] ,
67+ current : RwLock :: new ( ActiveKey { key_id : current_key_id, valid_until } ) ,
68+ }
69+ }
70+
71+ pub async fn current_key_id ( & self ) -> u8 { self . current . read ( ) . await . key_id }
72+
73+ pub async fn valid_until ( & self ) -> Instant { self . current . read ( ) . await . valid_until }
74+
75+ // Look up the server matching the key_id in an OHTTP message and
76+ // decapsulate. The first byte of an OHTTP encapsulated request is the
77+ // key identifier (RFC 9458 Section 4.3).
78+ pub async fn decapsulate (
79+ & self ,
80+ ohttp_body : & [ u8 ] ,
81+ ) -> std:: result:: Result < ( Vec < u8 > , ohttp:: ServerResponse ) , ohttp:: Error > {
82+ let key_id = ohttp_body. first ( ) . copied ( ) . ok_or ( ohttp:: Error :: Truncated ) ?;
83+ match self . keys . get ( key_id as usize ) {
84+ Some ( slot) => slot. read ( ) . await . server . decapsulate ( ohttp_body) ,
85+ None => Err ( ohttp:: Error :: KeyId ) ,
86+ }
87+ }
88+
89+ // Encode the current key's config for serving to clients.
90+ pub async fn encode_current ( & self ) -> std:: result:: Result < Vec < u8 > , ohttp:: Error > {
91+ let id = self . current_key_id ( ) . await ;
92+ self . keys [ id as usize ] . read ( ) . await . server . config ( ) . encode ( )
93+ }
94+
95+ // Flip which key is advertised to new clients and stamp the new expiry.
96+ // Anchored to Instant::now() at the moment of the actual switch so that
97+ // the next rotation cycle is measured from when the key became active,
98+ pub async fn switch ( & self , interval : Duration ) {
99+ let mut current = self . current . write ( ) . await ;
100+ current. key_id = 1 - current. key_id ;
101+ current. valid_until = Instant :: now ( ) + interval;
102+ }
103+
104+ // Replace a slot with fresh key material.
105+ pub async fn overwrite ( & self , key_id : u8 , server : ohttp:: Server ) {
106+ assert ! ( key_id <= 1 , "key_id must be 0 or 1" ) ;
107+ * self . keys [ key_id as usize ] . write ( ) . await = KeySlot { server } ;
108+ }
109+ }
110+
31111/// Opaque blocklist of Bitcoin addresses stored as script pubkeys.
32112///
33113/// Addresses are converted to `ScriptBuf` at parse time so that
@@ -91,7 +171,8 @@ fn parse_address_lines(text: &str) -> std::collections::HashSet<bitcoin::ScriptB
91171#[ derive( Clone ) ]
92172pub struct Service < D : Db > {
93173 db : D ,
94- ohttp : ohttp:: Server ,
174+ ohttp : Arc < KeyRotatingServer > ,
175+ ohttp_keys_max_age : Option < Duration > ,
95176 sentinel_tag : SentinelTag ,
96177 v1 : Option < V1 > ,
97178}
@@ -117,10 +198,18 @@ where
117198}
118199
119200impl < D : Db > Service < D > {
120- pub fn new ( db : D , ohttp : ohttp:: Server , sentinel_tag : SentinelTag , v1 : Option < V1 > ) -> Self {
121- Self { db, ohttp, sentinel_tag, v1 }
201+ pub fn new (
202+ db : D ,
203+ ohttp : Arc < KeyRotatingServer > ,
204+ ohttp_keys_max_age : Option < Duration > ,
205+ sentinel_tag : SentinelTag ,
206+ v1 : Option < V1 > ,
207+ ) -> Self {
208+ Self { db, ohttp, ohttp_keys_max_age, sentinel_tag, v1 }
122209 }
123210
211+ pub fn ohttp_key_set ( & self ) -> & Arc < KeyRotatingServer > { & self . ohttp }
212+
124213 async fn serve_request < B > ( & self , req : Request < B > ) -> Result < Response < Body > >
125214 where
126215 B : axum:: body:: HttpBody < Data = Bytes > + Send + ' static ,
@@ -200,10 +289,10 @@ impl<D: Db> Service<D> {
200289 . map_err ( |e| HandlerError :: BadRequest ( anyhow:: anyhow!( e. into( ) ) ) ) ?
201290 . to_bytes ( ) ;
202291
203- // Decapsulate OHTTP request
204292 let ( bhttp_req, res_ctx) = self
205293 . ohttp
206294 . decapsulate ( & ohttp_body)
295+ . await
207296 . map_err ( |e| HandlerError :: OhttpKeyRejection ( e. into ( ) ) ) ?;
208297 let mut cursor = std:: io:: Cursor :: new ( bhttp_req) ;
209298 let req = bhttp:: Message :: read_bhttp ( & mut cursor)
@@ -380,11 +469,31 @@ impl<D: Db> Service<D> {
380469 async fn get_ohttp_keys ( & self ) -> Result < Response < Body > , HandlerError > {
381470 let ohttp_keys = self
382471 . ohttp
383- . config ( )
384- . encode ( )
472+ . encode_current ( )
473+ . await
385474 . map_err ( |e| HandlerError :: InternalServerError ( e. into ( ) ) ) ?;
386475 let mut res = Response :: new ( full ( ohttp_keys) ) ;
387476 res. headers_mut ( ) . insert ( CONTENT_TYPE , HeaderValue :: from_static ( "application/ohttp-keys" ) ) ;
477+ if let Some ( max_age) = self . ohttp_keys_max_age {
478+ // Subtract ROTATION_GRACE / 3 so clients refresh their cached key
479+ // slightly before the rotation boundary, staying well within the
480+ // grace window where the old key is still accepted.
481+ let remaining = self
482+ . ohttp
483+ . valid_until ( )
484+ . await
485+ . saturating_duration_since ( Instant :: now ( ) )
486+ . min ( max_age)
487+ . saturating_add ( ROTATION_GRACE / 3 ) ;
488+ res. headers_mut ( ) . insert (
489+ CACHE_CONTROL ,
490+ HeaderValue :: from_str ( & format ! (
491+ "public, s-maxage={}, immutable" ,
492+ remaining. as_secs( )
493+ ) )
494+ . expect ( "valid header value" ) ,
495+ ) ;
496+ }
388497 Ok ( res)
389498 }
390499
@@ -412,6 +521,66 @@ impl<D: Db> Service<D> {
412521 }
413522}
414523
524+ // Grace period after a switch during which the old key is still
525+ // accepted for decapsulation.
526+ const ROTATION_GRACE : Duration = Duration :: from_secs ( 30 ) ;
527+
528+ // Background task that rotates OHTTP keys on a fixed interval.
529+ //
530+ // 1. Sleep until the current key is about to expire (valid_until - ROTATION_GRACE/2).
531+ // 2. Switch to the standby slot; stamp valid_until = now + interval.
532+ // 3. Sleep until the old key's grace window has elapsed.
533+ // 4. Overwrite the old slot with fresh key material for the next cycle.
534+ pub fn spawn_key_rotation ( keyset : Arc < KeyRotatingServer > , keys_dir : PathBuf , interval : Duration ) {
535+ tokio:: spawn ( async move {
536+ loop {
537+ // Sleep until just before the current key expires.
538+ let valid_until = keyset. valid_until ( ) . await ;
539+ tracing:: info!( "Sleeping until {:?}" , valid_until) ;
540+ //let switch_at = valid_until.checked_sub(ROTATION_GRACE / 2).unwrap_or(valid_until);
541+ tokio:: time:: sleep_until ( valid_until. into ( ) ) . await ;
542+
543+ // Capture old key id before switching, then switch.
544+ let old_key_id = keyset. current_key_id ( ) . await ;
545+ let new_key_id = 1 - old_key_id;
546+
547+ tracing:: info!(
548+ "---------------------------------------------------------------------------"
549+ ) ;
550+
551+ // Touch the new active key file *after* overwriting the old slot so
552+ // its mtime is newest on disk. On restart,
553+ // and derives valid_until from its age.
554+ let active_path = keys_dir. join ( format ! ( "{new_key_id}.ikm" ) ) ;
555+ let times = std:: fs:: FileTimes :: new ( ) . set_modified ( std:: time:: SystemTime :: now ( ) ) ;
556+ match std:: fs:: File :: open ( & active_path) . and_then ( |f| f. set_times ( times) ) {
557+ Ok ( ( ) ) => { }
558+ Err ( e) => tracing:: warn!( "Failed to change mtime {}: {e}" , active_path. display( ) ) ,
559+ }
560+
561+ // `switch` stamps valid_until = Instant::now() + interval, anchored
562+ // to the actual moment the new key goes live.
563+ keyset. switch ( interval) . await ;
564+
565+ tracing:: info!( "Switched OHTTP serving: From key_id {old_key_id} -> TO {new_key_id}" ) ;
566+
567+ // Wait until the old key's grace window has fully elapsed before
568+ // overwriting it, so in-flight clients using the old key still succeed.
569+ tokio:: time:: sleep ( ROTATION_GRACE ) . await ;
570+
571+ let config = crate :: key_config:: gen_ohttp_server_config_with_id ( old_key_id)
572+ . expect ( "OHTTP key generation must not fail" ) ;
573+ let _ = tokio:: fs:: remove_file ( keys_dir. join ( format ! ( "{old_key_id}.ikm" ) ) ) . await ;
574+ crate :: key_config:: persist_key_config ( & config, & keys_dir)
575+ . await
576+ . expect ( "OHTTP key persistence must not fail" ) ;
577+
578+ keyset. overwrite ( old_key_id, config. into_server ( ) ) . await ;
579+ tracing:: info!( "Overwrote OHTTP key_id {old_key_id} with fresh material" ) ;
580+ }
581+ } ) ;
582+ }
583+
415584fn handle_peek < E : SendableError > (
416585 result : Result < Arc < Vec < u8 > > , DbError < E > > ,
417586 timeout_response : Response < Body > ,
@@ -485,8 +654,8 @@ impl HandlerError {
485654 }
486655 HandlerError :: OhttpKeyRejection ( e) => {
487656 const OHTTP_KEY_REJECTION_RES_JSON : & str = r#"{"type":"https://iana.org/assignments/http-problem-types#ohttp-key", "title": "key identifier unknown"}"# ;
488- warn ! ( "Bad request: Key configuration rejected: {}" , e) ;
489- * res. status_mut ( ) = StatusCode :: BAD_REQUEST ;
657+ warn ! ( "Key configuration rejected: {}" , e) ;
658+ * res. status_mut ( ) = StatusCode :: UNPROCESSABLE_ENTITY ;
490659 res. headers_mut ( )
491660 . insert ( CONTENT_TYPE , HeaderValue :: from_static ( "application/problem+json" ) ) ;
492661 * res. body_mut ( ) = full ( OHTTP_KEY_REJECTION_RES_JSON ) ;
@@ -592,9 +761,17 @@ mod tests {
592761 async fn test_service ( v1 : Option < V1 > ) -> Service < FilesDb > {
593762 let dir = tempfile:: tempdir ( ) . expect ( "tempdir" ) ;
594763 let db = FilesDb :: init ( Duration :: from_millis ( 100 ) , dir. keep ( ) ) . await . expect ( "db init" ) ;
595- let ohttp: ohttp:: Server =
596- crate :: key_config:: gen_ohttp_server_config ( ) . expect ( "ohttp config" ) . into ( ) ;
597- Service :: new ( db, ohttp, SentinelTag :: new ( [ 0u8 ; 32 ] ) , v1)
764+ let c0 = crate :: key_config:: gen_ohttp_server_config_with_id ( 0 ) . expect ( "ohttp config" ) ;
765+ let c1 = crate :: key_config:: gen_ohttp_server_config_with_id ( 1 ) . expect ( "ohttp config" ) ;
766+ // valid_until = now + a generous test interval so nothing rotates during tests
767+ let valid_until = Instant :: now ( ) + Duration :: from_secs ( 3600 ) ;
768+ let keyset = Arc :: new ( KeyRotatingServer :: new (
769+ KeySlot { server : c0. into_server ( ) } ,
770+ KeySlot { server : c1. into_server ( ) } ,
771+ 0 ,
772+ valid_until,
773+ ) ) ;
774+ Service :: new ( db, keyset, None , SentinelTag :: new ( [ 0u8 ; 32 ] ) , v1)
598775 }
599776
600777 /// A valid ShortId encoded as bech32 for use in URL paths.
@@ -826,9 +1003,16 @@ mod tests {
8261003 let dir = tempfile:: tempdir ( ) . expect ( "tempdir" ) ;
8271004 let db = FilesDb :: init ( Duration :: from_millis ( 100 ) , dir. keep ( ) ) . await . expect ( "db init" ) ;
8281005 let db = MetricsDb :: new ( db, metrics) ;
829- let ohttp: ohttp:: Server =
830- crate :: key_config:: gen_ohttp_server_config ( ) . expect ( "ohttp config" ) . into ( ) ;
831- let svc = Service :: new ( db, ohttp, SentinelTag :: new ( [ 0u8 ; 32 ] ) , None ) ;
1006+ let c0 = crate :: key_config:: gen_ohttp_server_config_with_id ( 0 ) . expect ( "ohttp config" ) ;
1007+ let c1 = crate :: key_config:: gen_ohttp_server_config_with_id ( 1 ) . expect ( "ohttp config" ) ;
1008+ let valid_until = Instant :: now ( ) + Duration :: from_secs ( 3600 ) ;
1009+ let keyset = Arc :: new ( KeyRotatingServer :: new (
1010+ KeySlot { server : c0. into_server ( ) } ,
1011+ KeySlot { server : c1. into_server ( ) } ,
1012+ 0 ,
1013+ valid_until,
1014+ ) ) ;
1015+ let svc = Service :: new ( db, keyset, None , SentinelTag :: new ( [ 0u8 ; 32 ] ) , None ) ;
8321016
8331017 let id = valid_short_id_path ( ) ;
8341018 let res = svc
@@ -849,7 +1033,7 @@ mod tests {
8491033 use opentelemetry:: KeyValue ;
8501034 use opentelemetry_sdk:: metrics:: data:: { AggregatedMetrics , MetricData } ;
8511035
852- // This checks that counter value is 1 as post_mailbox was called once
1036+ // This checks that counter value is 1 as post_mailbox was called once
8531037 // Also confirms the v2 label is recorded
8541038 match db_metric. data ( ) {
8551039 AggregatedMetrics :: U64 ( MetricData :: Sum ( sum) ) => {
0 commit comments