Skip to content

Commit 1908139

Browse files
authored
Gate V1 protocol behind runtime feature flag (#1336)
2 parents 945d6d8 + 96b9edc commit 1908139

6 files changed

Lines changed: 132 additions & 8 deletions

File tree

payjoin-directory/src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ pub struct Config {
1616
pub timeout: Duration,
1717
pub storage_dir: PathBuf,
1818
pub ohttp_keys: PathBuf, // TODO OhttpConfig struct with rotation params, etc
19+
#[serde(default)]
20+
pub enable_v1: bool,
1921
#[cfg(feature = "acme")]
2022
pub acme: Option<AcmeConfig>,
2123
}
@@ -54,6 +56,7 @@ impl Config {
5456
timeout: Duration::from_secs(built_config.get("timeout")?),
5557
storage_dir: built_config.get("storage_dir")?,
5658
ohttp_keys: built_config.get("ohttp_keys")?,
59+
enable_v1: built_config.get("enable_v1").unwrap_or(false),
5760
#[cfg(feature = "acme")]
5861
acme: if built_config.get_table("acme").is_ok() {
5962
Some(AcmeConfig {

payjoin-directory/src/lib.rs

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ const V1_MAX_BUFFER_SIZE: usize = 65536;
3737
const V1_REJECT_RES_JSON: &str =
3838
r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"#;
3939
const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"#;
40+
const V1_VERSION_UNSUPPORTED_RES_JSON: &str =
41+
r#"{"errorCode": "version-unsupported", "supported": [2], "message": "V1 is not supported"}"#;
4042

4143
pub(crate) mod db;
4244

@@ -68,6 +70,7 @@ pub struct Service<D: Db> {
6870
db: D,
6971
ohttp: ohttp::Server,
7072
sentinel_tag: SentinelTag,
73+
enable_v1: bool,
7174
}
7275

7376
impl<D: Db, B> tower::Service<Request<B>> for Service<D>
@@ -91,8 +94,8 @@ where
9194
}
9295

9396
impl<D: Db> Service<D> {
94-
pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag) -> Self {
95-
Self { db, ohttp, sentinel_tag }
97+
pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, enable_v1: bool) -> Self {
98+
Self { db, ohttp, sentinel_tag, enable_v1 }
9699
}
97100

98101
#[cfg(feature = "_manual-tls")]
@@ -214,7 +217,7 @@ impl<D: Db> Service<D> {
214217
self.handle_ohttp_gateway_get(&query).await,
215218
(Method::POST, ["", ""]) => self.handle_ohttp_gateway(body).await,
216219
(Method::GET, ["", "ohttp-keys"]) => self.get_ohttp_keys().await,
217-
(Method::POST, ["", id]) => self.post_fallback_v1(id, query, body).await,
220+
(Method::POST, ["", id]) => self.handle_post_v1(id, query, body).await,
218221
(Method::GET, ["", "health"]) => health_check().await,
219222
(Method::GET, ["", ""]) => handle_directory_home_path().await,
220223
_ => Ok(not_found()),
@@ -227,6 +230,28 @@ impl<D: Db> Service<D> {
227230
Ok(response)
228231
}
229232

233+
/// Route POST /{id}: forward to V1 fallback when enabled, otherwise reject.
234+
async fn handle_post_v1<B>(
235+
&self,
236+
id: &str,
237+
query: String,
238+
body: B,
239+
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError>
240+
where
241+
B: Body<Data = Bytes> + Send + 'static,
242+
B::Error: Into<BoxError>,
243+
{
244+
if self.enable_v1 {
245+
self.post_fallback_v1(id, query, body).await
246+
} else {
247+
let _ = (id, query, body);
248+
Ok(Response::builder()
249+
.status(StatusCode::BAD_REQUEST)
250+
.header(CONTENT_TYPE, "application/json")
251+
.body(full(V1_VERSION_UNSUPPORTED_RES_JSON))?)
252+
}
253+
}
254+
230255
/// Handle an encapsulated OHTTP request and return an encapsulated response
231256
async fn handle_ohttp_gateway<B>(
232257
&self,
@@ -304,7 +329,7 @@ impl<D: Db> Service<D> {
304329
match (parts.method, path_segments.as_slice()) {
305330
(Method::POST, &["", id]) => self.post_mailbox(id, body).await,
306331
(Method::GET, &["", id]) => self.get_mailbox(id).await,
307-
(Method::PUT, &["", id]) => self.put_payjoin_v1(id, body).await,
332+
(Method::PUT, &["", id]) if self.enable_v1 => self.put_payjoin_v1(id, body).await,
308333
_ => Ok(not_found()),
309334
}
310335
}
@@ -603,3 +628,87 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
603628
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
604629
Full::new(chunk.into()).map_err(|never| match never {}).boxed()
605630
}
631+
632+
#[cfg(test)]
633+
mod tests {
634+
use std::time::Duration;
635+
636+
use http_body_util::BodyExt;
637+
use hyper::body::Bytes;
638+
use hyper::{Method, Request, StatusCode};
639+
use ohttp_relay::SentinelTag;
640+
use payjoin::directory::ShortId;
641+
642+
use super::*;
643+
644+
async fn test_service(enable_v1: bool) -> Service<FilesDb> {
645+
let dir = tempfile::tempdir().expect("tempdir");
646+
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
647+
let ohttp: ohttp::Server =
648+
key_config::gen_ohttp_server_config().expect("ohttp config").into();
649+
Service::new(db, ohttp, SentinelTag::new([0u8; 32]), enable_v1)
650+
}
651+
652+
/// A valid ShortId encoded as bech32 for use in URL paths.
653+
fn valid_short_id_path() -> String {
654+
let id = ShortId([0u8; 8]);
655+
id.to_string()
656+
}
657+
658+
async fn collect_body(res: Response<BoxBody<Bytes, hyper::Error>>) -> (StatusCode, String) {
659+
let (parts, body) = res.into_parts();
660+
let bytes = body.collect().await.unwrap().to_bytes();
661+
(parts.status, String::from_utf8(bytes.to_vec()).unwrap())
662+
}
663+
664+
#[tokio::test]
665+
async fn post_v1_when_disabled_returns_version_unsupported() {
666+
let mut svc = test_service(false).await;
667+
let id = valid_short_id_path();
668+
let req = Request::builder()
669+
.method(Method::POST)
670+
.uri(format!("http://localhost/{id}"))
671+
.body(Full::new(Bytes::from("base64-psbt")))
672+
.unwrap();
673+
674+
let res = tower::Service::call(&mut svc, req).await.unwrap();
675+
let (status, body) = collect_body(res).await;
676+
677+
assert_eq!(status, StatusCode::BAD_REQUEST);
678+
assert_eq!(body, V1_VERSION_UNSUPPORTED_RES_JSON);
679+
}
680+
681+
#[tokio::test]
682+
async fn post_v1_with_invalid_body_returns_reject() {
683+
let mut svc = test_service(true).await;
684+
let id = valid_short_id_path();
685+
let req = Request::builder()
686+
.method(Method::POST)
687+
.uri(format!("http://localhost/{id}"))
688+
.body(Full::new(Bytes::from(vec![0xFF, 0xFE])))
689+
.unwrap();
690+
691+
let res = tower::Service::call(&mut svc, req).await.unwrap();
692+
let (status, body) = collect_body(res).await;
693+
694+
assert_eq!(status, StatusCode::BAD_REQUEST);
695+
assert_eq!(body, V1_REJECT_RES_JSON);
696+
}
697+
698+
#[tokio::test]
699+
async fn post_v1_with_no_receiver_returns_unavailable() {
700+
let mut svc = test_service(true).await;
701+
let id = valid_short_id_path();
702+
let req = Request::builder()
703+
.method(Method::POST)
704+
.uri(format!("http://localhost/{id}"))
705+
.body(Full::new(Bytes::from("base64-psbt")))
706+
.unwrap();
707+
708+
let res = tower::Service::call(&mut svc, req).await.unwrap();
709+
let (status, body) = collect_body(res).await;
710+
711+
assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
712+
assert_eq!(body, V1_UNAVAILABLE_RES_JSON);
713+
}
714+
}

payjoin-directory/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async fn main() -> Result<(), BoxError> {
2929
.await
3030
.expect("Failed to initialize persistent storage");
3131

32-
let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32]));
32+
let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32]), config.enable_v1);
3333

3434
let listener = TcpListener::bind(config.listen_addr).await?;
3535

payjoin-mailroom/src/config.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub struct Config {
1212
pub storage_dir: PathBuf,
1313
#[serde(deserialize_with = "deserialize_duration_secs")]
1414
pub timeout: Duration,
15+
pub enable_v1: bool,
1516
#[cfg(feature = "telemetry")]
1617
pub telemetry: Option<TelemetryConfig>,
1718
#[cfg(feature = "acme")]
@@ -58,6 +59,7 @@ impl Default for Config {
5859
listener: "[::]:8080".parse().expect("valid default listener address"),
5960
storage_dir: PathBuf::from("./data"),
6061
timeout: Duration::from_secs(30),
62+
enable_v1: false,
6163
#[cfg(feature = "telemetry")]
6264
telemetry: None,
6365
#[cfg(feature = "acme")]
@@ -75,11 +77,17 @@ where
7577
}
7678

7779
impl Config {
78-
pub fn new(listener: ListenerAddress, storage_dir: PathBuf, timeout: Duration) -> Self {
80+
pub fn new(
81+
listener: ListenerAddress,
82+
storage_dir: PathBuf,
83+
timeout: Duration,
84+
enable_v1: bool,
85+
) -> Self {
7986
Self {
8087
listener,
8188
storage_dir,
8289
timeout,
90+
enable_v1,
8391
#[cfg(feature = "telemetry")]
8492
telemetry: None,
8593
#[cfg(feature = "acme")]

payjoin-mailroom/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ async fn init_directory(
167167
let ohttp_keys_dir = config.storage_dir.join("ohttp-keys");
168168
let ohttp_config = init_ohttp_config(&ohttp_keys_dir)?;
169169

170-
Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag))
170+
Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag, config.enable_v1))
171171
}
172172

173173
fn init_ohttp_config(
@@ -260,6 +260,7 @@ mod tests {
260260
"[::]:0".parse().expect("valid listener address"),
261261
tempdir.path().to_path_buf(),
262262
Duration::from_secs(2),
263+
false,
263264
);
264265

265266
let mut root_store = RootCertStore::empty();
@@ -284,7 +285,7 @@ mod tests {
284285

285286
// Make a request through the relay that targets this same instance's directory.
286287
// The path format is /{gateway_url} where gateway_url points back to ourselves.
287-
let ohttp_req_url = format!("{}/{}", base_url, base_url);
288+
let ohttp_req_url = format!("{base_url}/{base_url}");
288289

289290
let response = client
290291
.post(&ohttp_req_url)
@@ -354,6 +355,7 @@ mod tests {
354355
"[::]:0".parse().expect("valid listener address"),
355356
tempdir.path().to_path_buf(),
356357
Duration::from_secs(2),
358+
false,
357359
);
358360

359361
let sentinel_tag = generate_sentinel_tag();

payjoin-test-utils/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ pub async fn init_directory(
121121
"[::]:0".parse().expect("valid listener address"),
122122
tempdir.path().to_path_buf(),
123123
Duration::from_secs(2),
124+
true,
124125
);
125126

126127
let tls_config = RustlsConfig::from_der(vec![local_cert_key.0], local_cert_key.1).await?;
@@ -148,6 +149,7 @@ async fn init_ohttp_relay(
148149
"[::]:0".parse().expect("valid listener address"),
149150
tempdir.path().to_path_buf(),
150151
Duration::from_secs(2),
152+
false,
151153
);
152154

153155
let (port, handle) = payjoin_mailroom::serve_manual_tls(config, None, root_store)

0 commit comments

Comments
 (0)