Skip to content

Commit e454ea4

Browse files
committed
feat: use axum for ohttp gateway middleware
1 parent 500b7b3 commit e454ea4

8 files changed

Lines changed: 293 additions & 323 deletions

File tree

Cargo-minimal.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,6 +2411,8 @@ dependencies = [
24112411
name = "ohttp-relay"
24122412
version = "0.0.11"
24132413
dependencies = [
2414+
"bhttp",
2415+
"bitcoin-ohttp",
24142416
"byteorder",
24152417
"bytes",
24162418
"futures",

Cargo-recent.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,6 +2411,8 @@ dependencies = [
24112411
name = "ohttp-relay"
24122412
version = "0.0.11"
24132413
dependencies = [
2414+
"bhttp",
2415+
"bitcoin-ohttp",
24142416
"byteorder",
24152417
"bytes",
24162418
"futures",

ohttp-relay/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ tokio-util = { version = "0.7.16", features = ["net", "codec"] }
4747
tower = "0.5"
4848
tracing = "0.1.41"
4949
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
50+
ohttp = { package = "bitcoin-ohttp", version = "0.6" }
51+
bhttp = { version = "0.6.1", features = ["http"] }
5052

5153
[dev-dependencies]
5254
mockito = "1.7.0"

ohttp-relay/src/gateway_helpers.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
use http_body_util::combinators::BoxBody;
2+
use http_body_util::{BodyExt, Full};
3+
use hyper::body::Bytes;
4+
use hyper::{Request, Response, Uri};
5+
6+
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
7+
8+
pub const CHACHA20_POLY1305_NONCE_LEN: usize = 32;
9+
pub const POLY1305_TAG_SIZE: usize = 16;
10+
pub const ENCAPSULATED_MESSAGE_BYTES: usize = 65536;
11+
pub const BHTTP_REQ_BYTES: usize =
12+
ENCAPSULATED_MESSAGE_BYTES - (CHACHA20_POLY1305_NONCE_LEN + POLY1305_TAG_SIZE);
13+
14+
#[derive(Debug)]
15+
pub enum GatewayError {
16+
BadRequest(String),
17+
OhttpKeyRejection(String),
18+
InternalServerError(String),
19+
}
20+
21+
impl std::fmt::Display for GatewayError {
22+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23+
match self {
24+
GatewayError::BadRequest(msg) => write!(f, "Bad request: {}", msg),
25+
GatewayError::OhttpKeyRejection(msg) => write!(f, "OHTTP key rejection: {}", msg),
26+
GatewayError::InternalServerError(msg) => write!(f, "Internal server error: {}", msg),
27+
}
28+
}
29+
}
30+
31+
impl std::error::Error for GatewayError {}
32+
33+
pub async fn decapsulate_ohttp_request<B>(
34+
req: Request<B>,
35+
ohttp_server: &ohttp::Server,
36+
) -> Result<(Request<BoxBody<Bytes, hyper::Error>>, ohttp::ServerResponse), GatewayError>
37+
where
38+
B: hyper::body::Body<Data = Bytes> + Send + 'static,
39+
B::Error: Into<BoxError>,
40+
{
41+
// Collect OHTTP body
42+
let ohttp_body = req
43+
.into_body()
44+
.collect()
45+
.await
46+
.map_err(|e| GatewayError::BadRequest(format!("Failed to read body: {}", e.into())))?
47+
.to_bytes();
48+
49+
// Decapsulate using OHTTP server
50+
let (bhttp_req, res_ctx) = ohttp_server.decapsulate(&ohttp_body).map_err(|e| {
51+
GatewayError::OhttpKeyRejection(format!("OHTTP decapsulation failed: {}", e))
52+
})?;
53+
54+
// Parse BHTTP message
55+
let mut cursor = std::io::Cursor::new(bhttp_req);
56+
let bhttp_msg = bhttp::Message::read_bhttp(&mut cursor)
57+
.map_err(|e| GatewayError::BadRequest(format!("Invalid BHTTP: {}", e)))?;
58+
59+
let uri = Uri::builder()
60+
.scheme(bhttp_msg.control().scheme().unwrap_or_default())
61+
.authority(bhttp_msg.control().authority().unwrap_or_default())
62+
.path_and_query(bhttp_msg.control().path().unwrap_or_default())
63+
.build()
64+
.map_err(|e| GatewayError::BadRequest(format!("Invalid URI: {}", e)))?;
65+
66+
let body = bhttp_msg.content().to_vec();
67+
let mut http_req =
68+
Request::builder().uri(uri).method(bhttp_msg.control().method().unwrap_or_default());
69+
70+
for header in bhttp_msg.header().fields() {
71+
http_req = http_req.header(header.name(), header.value());
72+
}
73+
74+
let request = http_req.body(full(body)).map_err(|e| {
75+
GatewayError::InternalServerError(format!("Failed to build request: {}", e))
76+
})?;
77+
78+
Ok((request, res_ctx))
79+
}
80+
81+
pub async fn encapsulate_ohttp_response(
82+
response: Response<BoxBody<Bytes, hyper::Error>>,
83+
res_ctx: ohttp::ServerResponse,
84+
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, GatewayError> {
85+
let (parts, body) = response.into_parts();
86+
87+
let mut bhttp_res =
88+
bhttp::Message::response(bhttp::StatusCode::try_from(parts.status.as_u16()).map_err(
89+
|e| GatewayError::InternalServerError(format!("Invalid status code: {}", e)),
90+
)?);
91+
92+
for (name, value) in parts.headers.iter() {
93+
bhttp_res.put_header(name.as_str(), value.to_str().unwrap_or_default());
94+
}
95+
96+
let full_body = body
97+
.collect()
98+
.await
99+
.map_err(|e| GatewayError::InternalServerError(format!("Failed to collect body: {}", e)))?
100+
.to_bytes();
101+
bhttp_res.write_content(&full_body);
102+
103+
let mut bhttp_bytes = Vec::new();
104+
bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).map_err(|e| {
105+
GatewayError::InternalServerError(format!("BHTTP serialization failed: {}", e))
106+
})?;
107+
108+
bhttp_bytes.resize(BHTTP_REQ_BYTES, 0);
109+
110+
let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).map_err(|e| {
111+
GatewayError::InternalServerError(format!("OHTTP encapsulation failed: {}", e))
112+
})?;
113+
114+
assert!(
115+
ohttp_res.len() == ENCAPSULATED_MESSAGE_BYTES,
116+
"Unexpected OHTTP response size: {} != {}",
117+
ohttp_res.len(),
118+
ENCAPSULATED_MESSAGE_BYTES
119+
);
120+
121+
Ok(Response::new(full(ohttp_res)))
122+
}
123+
124+
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
125+
Full::new(chunk.into()).map_err(|never| match never {}).boxed()
126+
}

ohttp-relay/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ pub mod gateway_prober;
3636
mod gateway_uri;
3737
pub mod sentinel;
3838
pub use sentinel::SentinelTag;
39+
pub mod gateway_helpers;
40+
41+
pub use gateway_helpers::{
42+
decapsulate_ohttp_request, encapsulate_ohttp_response, BHTTP_REQ_BYTES,
43+
ENCAPSULATED_MESSAGE_BYTES,
44+
};
3945

4046
use crate::error::{BoxError, Error};
4147

payjoin-service/src/lib.rs

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,29 @@ use config::Config;
66
use ohttp_relay::SentinelTag;
77
use rand::Rng;
88
use tokio_listener::{Listener, SystemOptions, UserOptions};
9-
use tower::Service;
9+
use tower::{Service, ServiceExt};
1010
use tracing::info;
11-
pub mod ohttp;
12-
13-
use http_body_util::combinators::BoxBody;
14-
use hyper::body::Bytes;
15-
use hyper::{Request, StatusCode};
16-
use ohttp::{OhttpGatewayConfig, OhttpGatewayLayer};
17-
use tower::{ServiceBuilder, ServiceExt};
1811

1912
pub mod cli;
2013
pub mod config;
14+
pub mod ohttp;
15+
16+
use crate::ohttp::OhttpGatewayConfig;
2117

2218
#[derive(Clone)]
2319
struct Services {
2420
directory: payjoin_directory::Service<payjoin_directory::FilesDb>,
2521
relay: ohttp_relay::Service,
26-
sentinel_tag: SentinelTag,
22+
ohttp_config: OhttpGatewayConfig,
2723
}
2824

2925
pub async fn serve(config: Config) -> anyhow::Result<()> {
3026
let sentinel_tag = generate_sentinel_tag();
27+
let directory = init_directory(&config, sentinel_tag).await?;
28+
let ohttp_config = OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag);
3129

32-
let services = Services {
33-
directory: init_directory(&config, sentinel_tag).await?,
34-
relay: ohttp_relay::Service::new(sentinel_tag).await,
35-
sentinel_tag,
36-
};
30+
let services =
31+
Services { directory, relay: ohttp_relay::Service::new(sentinel_tag).await, ohttp_config };
3732
let app = Router::new().fallback(route_request).with_state(services);
3833

3934
let listener =
@@ -61,12 +56,15 @@ pub async fn serve_manual_tls(
6156
use std::net::SocketAddr;
6257

6358
let sentinel_tag = generate_sentinel_tag();
59+
let directory = init_directory(&config, sentinel_tag).await?;
60+
let ohttp_config = OhttpGatewayConfig::new(directory.ohttp.clone(), sentinel_tag);
6461

6562
let services = Services {
66-
directory: init_directory(&config, sentinel_tag).await?,
63+
directory,
6764
relay: ohttp_relay::Service::new_with_roots(root_store, sentinel_tag).await,
68-
sentinel_tag,
65+
ohttp_config,
6966
};
67+
7068
let app = Router::new().fallback(route_request).with_state(services);
7169

7270
let addr: SocketAddr = config
@@ -134,63 +132,49 @@ async fn route_request(State(services): State<Services>, req: axum::extract::Req
134132
let mut relay = services.relay.clone();
135133
match relay.call(req).await {
136134
Ok(res) => res.into_response(),
137-
Err(e) => (StatusCode::BAD_GATEWAY, e.to_string()).into_response(),
135+
Err(e) => (axum::http::StatusCode::BAD_GATEWAY, e.to_string()).into_response(),
138136
}
139137
} else {
140138
handle_directory_request(services, req).await
141139
}
142140
}
143141

144142
async fn handle_directory_request(services: Services, req: axum::extract::Request) -> Response {
145-
let ohttp_server = services.directory.ohttp.clone();
146-
147-
let ohttp_config = OhttpGatewayConfig::new(ohttp_server, services.sentinel_tag);
148-
149-
let (parts, body) = req.into_parts();
150-
151-
use http_body_util::BodyExt as _;
152-
153-
let body_bytes = body
154-
.collect()
155-
.await
156-
.map_err(|_| "Failed to collect body")
157-
.expect("Failed to collect body")
158-
.to_bytes();
159-
160-
let boxed_body = BoxBody::new(http_body_util::Full::new(body_bytes));
161-
162-
let hyper_req = Request::from_parts(parts, boxed_body);
163-
164-
let directory_service = tower::service_fn({
165-
let directory = services.directory.clone();
166-
move |req: Request<BoxBody<Bytes, hyper::Error>>| {
167-
let mut dir = directory.clone();
168-
async move {
169-
dir.call(req).await.map_err(|e| {
170-
Box::new(std::io::Error::other(e.to_string()))
171-
as Box<dyn std::error::Error + Send + Sync>
172-
})
173-
}
174-
}
175-
});
176-
177-
let mut service_with_ohttp = ServiceBuilder::new()
178-
.layer(OhttpGatewayLayer::new(ohttp_config))
179-
.service(directory_service)
180-
.boxed_clone();
181-
182-
match service_with_ohttp.ready().await {
183-
Ok(ready_service) => match ready_service.call(hyper_req).await {
184-
Ok(response) => {
185-
let (parts, body) = response.into_parts();
186-
let axum_body = axum::body::Body::new(body);
187-
Response::from_parts(parts, axum_body).into_response()
188-
}
143+
let is_ohttp_request = matches!(
144+
(req.method(), req.uri().path()),
145+
(&Method::POST, "/.well-known/ohttp-gateway") | (&Method::POST, "/")
146+
);
147+
148+
if is_ohttp_request {
149+
let app = Router::new()
150+
.fallback(directory_handler)
151+
.layer(axum::middleware::from_fn_with_state(
152+
services.ohttp_config.clone(),
153+
crate::ohttp::ohttp_gateway,
154+
))
155+
.with_state(services.directory.clone());
156+
157+
match app.oneshot(req).await {
158+
Ok(response) => response,
189159
Err(e) =>
190-
(StatusCode::INTERNAL_SERVER_ERROR, format!("Service error: {}", e)).into_response(),
191-
},
160+
(axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("Service error: {}", e))
161+
.into_response(),
162+
}
163+
} else {
164+
directory_handler(State(services.directory), req).await
165+
}
166+
}
167+
168+
async fn directory_handler(
169+
State(directory): State<payjoin_directory::Service<payjoin_directory::FilesDb>>,
170+
req: axum::extract::Request,
171+
) -> Response {
172+
let mut dir = directory.clone();
173+
match dir.call(req).await {
174+
Ok(response) => response.into_response(),
192175
Err(e) =>
193-
(StatusCode::INTERNAL_SERVER_ERROR, format!("Service not ready: {}", e)).into_response(),
176+
(axum::http::StatusCode::INTERNAL_SERVER_ERROR, format!("Directory error: {}", e))
177+
.into_response(),
194178
}
195179
}
196180

0 commit comments

Comments
 (0)