Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,6 @@ workspace = true

[dev-dependencies]
futures-util = { version = "0.3", default-features = false }
rstest = "0.26"
tempfile = "3.21"
tokio = { version = "1.47", features = ["macros", "rt"] }
217 changes: 214 additions & 3 deletions crates/common/src/data_converters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ impl PayloadConverter {
pub fn serde_json() -> Self {
Self::Serde(Arc::new(SerdeJsonPayloadConverter))
}
// TODO [rust-sdk-branch]: Proto binary, other standard built-ins
}

impl Default for PayloadConverter {
Expand Down Expand Up @@ -618,17 +617,20 @@ pub trait ErasedSerdePayloadConverter: Send + Sync {

/// Wrapper for protobuf messages that implements [`TemporalSerializable`]/[`TemporalDeserializable`]
/// using `binary/protobuf` encoding.
///
/// Use this when you want compact binary protobuf wire format for your proto types.
pub struct ProstSerializable<T: prost::Message>(pub T);
impl<T> TemporalSerializable for ProstSerializable<T>
where
T: prost::Message + Default + 'static,
T: prost::Message + prost::Name + Default + 'static,
{
fn to_payload(&self, _: &SerializationContext<'_>) -> Result<Payload, PayloadConversionError> {
let as_proto = prost::Message::encode_to_vec(&self.0);
Ok(Payload {
metadata: {
let mut hm = HashMap::new();
hm.insert("encoding".to_string(), b"binary/protobuf".to_vec());
hm.insert("messageType".to_string(), T::full_name().into_bytes());
hm
},
data: as_proto,
Expand All @@ -638,7 +640,7 @@ where
}
impl<T> TemporalDeserializable for ProstSerializable<T>
where
T: prost::Message + Default + 'static,
T: prost::Message + prost::Name + Default + 'static,
{
fn from_payload(
_: &SerializationContext<'_>,
Expand All @@ -656,6 +658,74 @@ where
.map_err(|e| PayloadConversionError::EncodingError(Box::new(e)))
}
}
impl<T: prost::Message> From<T> for ProstSerializable<T> {
fn from(value: T) -> Self {
Self(value)
}
}
impl<T: prost::Message> ProstSerializable<T> {
/// Consumes the wrapper, returning the inner protobuf message.
pub fn into_inner(self) -> T {
self.0
}
}

/// Wrapper for protobuf messages that implements [`TemporalSerializable`]/[`TemporalDeserializable`]
/// using `json/protobuf` encoding via pbjson-generated serde impls.
///
/// Use this when you want human-readable JSON protobuf format (proto3 JSON mapping with
/// camelCase field names and string enums).
pub struct ProstJsonSerializable<T: prost::Message>(pub T);
impl<T> TemporalSerializable for ProstJsonSerializable<T>
where
T: prost::Message + prost::Name + serde::Serialize + Default + 'static,
{
fn to_payload(&self, _: &SerializationContext<'_>) -> Result<Payload, PayloadConversionError> {
let as_json = serde_json::to_vec(&self.0)
.map_err(|e| PayloadConversionError::EncodingError(e.into()))?;
Comment on lines +684 to +685
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about how this will reliably use the proto JSON encoding vs. normal serde-derive based encodings. In build.rs we apply serde derive to basically everything, and then pbjson is used for

            ".temporal.api.failure",
            ".temporal.api.common",
            ".temporal.api.enums",

Currently, but not everything else?

Ok(Payload {
metadata: {
let mut hm = HashMap::new();
hm.insert("encoding".to_string(), b"json/protobuf".to_vec());
hm.insert("messageType".to_string(), T::full_name().into_bytes());
hm
},
data: as_json,
external_payloads: vec![],
})
}
}
impl<T> TemporalDeserializable for ProstJsonSerializable<T>
where
T: prost::Message + prost::Name + serde::de::DeserializeOwned + Default + 'static,
{
fn from_payload(
_: &SerializationContext<'_>,
p: Payload,
) -> Result<Self, PayloadConversionError>
where
Self: Sized,
{
let encoding = p.metadata.get("encoding").map(|v| v.as_slice());
if encoding != Some(b"json/protobuf".as_slice()) {
return Err(PayloadConversionError::WrongEncoding);
}
serde_json::from_slice(&p.data)
.map(ProstJsonSerializable)
.map_err(|e| PayloadConversionError::EncodingError(e.into()))
}
}
impl<T: prost::Message> From<T> for ProstJsonSerializable<T> {
fn from(value: T) -> Self {
Self(value)
}
}
impl<T: prost::Message> ProstJsonSerializable<T> {
/// Consumes the wrapper, returning the inner protobuf message.
pub fn into_inner(self) -> T {
self.0
}
}

/// A payload converter that delegates to an ordered list of inner converters.
#[derive(Clone)]
Expand Down Expand Up @@ -769,6 +839,9 @@ impl_multi_args!(MultiArgs6; 6; 0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
#[cfg(test)]
mod tests {
use super::*;
use crate::protos::temporal::api::common::v1::WorkflowExecution;
use prost::Name;
use rstest::rstest;

#[test]
fn test_empty_payloads_as_unit_type() {
Expand Down Expand Up @@ -866,4 +939,142 @@ mod tests {
let args: MultiArgs2<String, i32> = ("hello".to_string(), 42i32).into();
assert_eq!(args, MultiArgs2("hello".to_string(), 42));
}

#[derive(Clone, Copy)]
enum EncodingFormat {
Binary,
Json,
}

impl EncodingFormat {
fn expected_encoding(self) -> &'static [u8] {
match self {
Self::Binary => b"binary/protobuf",
Self::Json => b"json/protobuf",
}
}
}

fn test_wf_exec() -> WorkflowExecution {
WorkflowExecution {
workflow_id: "wf-123".into(),
run_id: "run-456".into(),
}
}

fn serialize_as(
format: EncodingFormat,
wf: &WorkflowExecution,
ctx: &SerializationContext<'_>,
) -> Payload {
match format {
EncodingFormat::Binary => ProstSerializable(wf.clone()).to_payload(ctx).unwrap(),
EncodingFormat::Json => ProstJsonSerializable(wf.clone()).to_payload(ctx).unwrap(),
}
}

fn deserialize_as(
format: EncodingFormat,
payload: Payload,
ctx: &SerializationContext<'_>,
) -> Result<WorkflowExecution, PayloadConversionError> {
match format {
EncodingFormat::Binary => {
ProstSerializable::from_payload(ctx, payload).map(|w: ProstSerializable<_>| w.0)
}
EncodingFormat::Json => ProstJsonSerializable::from_payload(ctx, payload)
.map(|w: ProstJsonSerializable<_>| w.0),
}
}

#[rstest]
#[case::binary(EncodingFormat::Binary)]
#[case::json(EncodingFormat::Json)]
fn prost_round_trip(#[case] format: EncodingFormat) {
let converter = PayloadConverter::default();
let ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter: &converter,
};
let wf = test_wf_exec();

let payload = serialize_as(format, &wf, &ctx);
assert_eq!(
payload.metadata.get("encoding").unwrap().as_slice(),
format.expected_encoding(),
);
assert_eq!(
payload.metadata.get("messageType").unwrap().as_slice(),
WorkflowExecution::full_name().as_bytes(),
);

let result = deserialize_as(format, payload, &ctx).unwrap();
assert_eq!(result, wf);
}

#[test]
fn prost_json_produces_proto3_json_format() {
let converter = PayloadConverter::default();
let ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter: &converter,
};

let payload = ProstJsonSerializable(test_wf_exec())
.to_payload(&ctx)
.unwrap();
let json_str = std::str::from_utf8(&payload.data).unwrap();
assert!(
json_str.contains("workflowId"),
"expected camelCase field names in proto3 JSON, got: {json_str}"
);
assert!(json_str.contains("runId"));
}

#[rstest]
#[case::binary_rejects_json(EncodingFormat::Binary, EncodingFormat::Json)]
#[case::json_rejects_binary(EncodingFormat::Json, EncodingFormat::Binary)]
fn prost_rejects_wrong_encoding(
#[case] decode_format: EncodingFormat,
#[case] encode_format: EncodingFormat,
) {
let converter = PayloadConverter::default();
let ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter: &converter,
};

let payload = serialize_as(encode_format, &test_wf_exec(), &ctx);
let result = deserialize_as(decode_format, payload, &ctx);
assert!(matches!(result, Err(PayloadConversionError::WrongEncoding)));
}

#[rstest]
#[case::binary(EncodingFormat::Binary)]
#[case::json(EncodingFormat::Json)]
fn prost_through_composite_converter(#[case] format: EncodingFormat) {
let converter = PayloadConverter::default();
let ctx = SerializationContext {
data: &SerializationContextData::Workflow,
converter: &converter,
};
let wf = test_wf_exec();

let payloads = match format {
EncodingFormat::Binary => converter
.to_payloads(&ctx, &ProstSerializable(wf.clone()))
.unwrap(),
EncodingFormat::Json => converter
.to_payloads(&ctx, &ProstJsonSerializable(wf.clone()))
.unwrap(),
};
assert_eq!(payloads.len(), 1);
assert_eq!(
payloads[0].metadata.get("encoding").unwrap().as_slice(),
format.expected_encoding(),
);

let result = deserialize_as(format, payloads.into_iter().next().unwrap(), &ctx).unwrap();
assert_eq!(result, wf);
}
}
Loading
Loading