@@ -37,6 +37,8 @@ const V1_MAX_BUFFER_SIZE: usize = 65536;
3737const V1_REJECT_RES_JSON : & str =
3838 r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"# ;
3939const V1_UNAVAILABLE_RES_JSON : & str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"# ;
40+ const V1_VERSION_UNSUPPORTED_RES_JSON : & str =
41+ r#"{"errorCode": "version-unsupported", "supported": [2], "message": "V1 is not supported"}"# ;
4042
4143pub ( crate ) mod db;
4244
@@ -68,6 +70,7 @@ pub struct Service<D: Db> {
6870 db : D ,
6971 ohttp : ohttp:: Server ,
7072 sentinel_tag : SentinelTag ,
73+ enable_v1 : bool ,
7174}
7275
7376impl < D : Db , B > tower:: Service < Request < B > > for Service < D >
9194}
9295
9396impl < D : Db > Service < D > {
94- pub fn new ( db : D , ohttp : ohttp:: Server , sentinel_tag : SentinelTag ) -> Self {
95- Self { db, ohttp, sentinel_tag }
97+ pub fn new ( db : D , ohttp : ohttp:: Server , sentinel_tag : SentinelTag , enable_v1 : bool ) -> Self {
98+ Self { db, ohttp, sentinel_tag, enable_v1 }
9699 }
97100
98101 #[ cfg( feature = "_manual-tls" ) ]
@@ -214,7 +217,7 @@ impl<D: Db> Service<D> {
214217 self . handle_ohttp_gateway_get ( & query) . await ,
215218 ( Method :: POST , [ "" , "" ] ) => self . handle_ohttp_gateway ( body) . await ,
216219 ( Method :: GET , [ "" , "ohttp-keys" ] ) => self . get_ohttp_keys ( ) . await ,
217- ( Method :: POST , [ "" , id] ) => self . post_fallback_v1 ( id, query, body) . await ,
220+ ( Method :: POST , [ "" , id] ) => self . handle_post_v1 ( id, query, body) . await ,
218221 ( Method :: GET , [ "" , "health" ] ) => health_check ( ) . await ,
219222 ( Method :: GET , [ "" , "" ] ) => handle_directory_home_path ( ) . await ,
220223 _ => Ok ( not_found ( ) ) ,
@@ -227,6 +230,28 @@ impl<D: Db> Service<D> {
227230 Ok ( response)
228231 }
229232
233+ /// Route POST /{id}: forward to V1 fallback when enabled, otherwise reject.
234+ async fn handle_post_v1 < B > (
235+ & self ,
236+ id : & str ,
237+ query : String ,
238+ body : B ,
239+ ) -> Result < Response < BoxBody < Bytes , hyper:: Error > > , HandlerError >
240+ where
241+ B : Body < Data = Bytes > + Send + ' static ,
242+ B :: Error : Into < BoxError > ,
243+ {
244+ if self . enable_v1 {
245+ self . post_fallback_v1 ( id, query, body) . await
246+ } else {
247+ let _ = ( id, query, body) ;
248+ Ok ( Response :: builder ( )
249+ . status ( StatusCode :: BAD_REQUEST )
250+ . header ( CONTENT_TYPE , "application/json" )
251+ . body ( full ( V1_VERSION_UNSUPPORTED_RES_JSON ) ) ?)
252+ }
253+ }
254+
230255 /// Handle an encapsulated OHTTP request and return an encapsulated response
231256 async fn handle_ohttp_gateway < B > (
232257 & self ,
@@ -304,7 +329,7 @@ impl<D: Db> Service<D> {
304329 match ( parts. method , path_segments. as_slice ( ) ) {
305330 ( Method :: POST , & [ "" , id] ) => self . post_mailbox ( id, body) . await ,
306331 ( Method :: GET , & [ "" , id] ) => self . get_mailbox ( id) . await ,
307- ( Method :: PUT , & [ "" , id] ) => self . put_payjoin_v1 ( id, body) . await ,
332+ ( Method :: PUT , & [ "" , id] ) if self . enable_v1 => self . put_payjoin_v1 ( id, body) . await ,
308333 _ => Ok ( not_found ( ) ) ,
309334 }
310335 }
@@ -603,3 +628,87 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
603628fn full < T : Into < Bytes > > ( chunk : T ) -> BoxBody < Bytes , hyper:: Error > {
604629 Full :: new ( chunk. into ( ) ) . map_err ( |never| match never { } ) . boxed ( )
605630}
631+
632+ #[ cfg( test) ]
633+ mod tests {
634+ use std:: time:: Duration ;
635+
636+ use http_body_util:: BodyExt ;
637+ use hyper:: body:: Bytes ;
638+ use hyper:: { Method , Request , StatusCode } ;
639+ use ohttp_relay:: SentinelTag ;
640+ use payjoin:: directory:: ShortId ;
641+
642+ use super :: * ;
643+
644+ async fn test_service ( enable_v1 : bool ) -> Service < FilesDb > {
645+ let dir = tempfile:: tempdir ( ) . expect ( "tempdir" ) ;
646+ let db = FilesDb :: init ( Duration :: from_millis ( 100 ) , dir. keep ( ) ) . await . expect ( "db init" ) ;
647+ let ohttp: ohttp:: Server =
648+ key_config:: gen_ohttp_server_config ( ) . expect ( "ohttp config" ) . into ( ) ;
649+ Service :: new ( db, ohttp, SentinelTag :: new ( [ 0u8 ; 32 ] ) , enable_v1)
650+ }
651+
652+ /// A valid ShortId encoded as bech32 for use in URL paths.
653+ fn valid_short_id_path ( ) -> String {
654+ let id = ShortId ( [ 0u8 ; 8 ] ) ;
655+ id. to_string ( )
656+ }
657+
658+ async fn collect_body ( res : Response < BoxBody < Bytes , hyper:: Error > > ) -> ( StatusCode , String ) {
659+ let ( parts, body) = res. into_parts ( ) ;
660+ let bytes = body. collect ( ) . await . unwrap ( ) . to_bytes ( ) ;
661+ ( parts. status , String :: from_utf8 ( bytes. to_vec ( ) ) . unwrap ( ) )
662+ }
663+
664+ #[ tokio:: test]
665+ async fn post_v1_when_disabled_returns_version_unsupported ( ) {
666+ let mut svc = test_service ( false ) . await ;
667+ let id = valid_short_id_path ( ) ;
668+ let req = Request :: builder ( )
669+ . method ( Method :: POST )
670+ . uri ( format ! ( "http://localhost/{id}" ) )
671+ . body ( Full :: new ( Bytes :: from ( "base64-psbt" ) ) )
672+ . unwrap ( ) ;
673+
674+ let res = tower:: Service :: call ( & mut svc, req) . await . unwrap ( ) ;
675+ let ( status, body) = collect_body ( res) . await ;
676+
677+ assert_eq ! ( status, StatusCode :: BAD_REQUEST ) ;
678+ assert_eq ! ( body, V1_VERSION_UNSUPPORTED_RES_JSON ) ;
679+ }
680+
681+ #[ tokio:: test]
682+ async fn post_v1_with_invalid_body_returns_reject ( ) {
683+ let mut svc = test_service ( true ) . await ;
684+ let id = valid_short_id_path ( ) ;
685+ let req = Request :: builder ( )
686+ . method ( Method :: POST )
687+ . uri ( format ! ( "http://localhost/{id}" ) )
688+ . body ( Full :: new ( Bytes :: from ( vec ! [ 0xFF , 0xFE ] ) ) )
689+ . unwrap ( ) ;
690+
691+ let res = tower:: Service :: call ( & mut svc, req) . await . unwrap ( ) ;
692+ let ( status, body) = collect_body ( res) . await ;
693+
694+ assert_eq ! ( status, StatusCode :: BAD_REQUEST ) ;
695+ assert_eq ! ( body, V1_REJECT_RES_JSON ) ;
696+ }
697+
698+ #[ tokio:: test]
699+ async fn post_v1_with_no_receiver_returns_unavailable ( ) {
700+ let mut svc = test_service ( true ) . await ;
701+ let id = valid_short_id_path ( ) ;
702+ let req = Request :: builder ( )
703+ . method ( Method :: POST )
704+ . uri ( format ! ( "http://localhost/{id}" ) )
705+ . body ( Full :: new ( Bytes :: from ( "base64-psbt" ) ) )
706+ . unwrap ( ) ;
707+
708+ let res = tower:: Service :: call ( & mut svc, req) . await . unwrap ( ) ;
709+ let ( status, body) = collect_body ( res) . await ;
710+
711+ assert_eq ! ( status, StatusCode :: SERVICE_UNAVAILABLE ) ;
712+ assert_eq ! ( body, V1_UNAVAILABLE_RES_JSON ) ;
713+ }
714+ }
0 commit comments