diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index f88c5c8ab..bd128ccc1 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -105,5 +105,6 @@ workspace = true [dev-dependencies] futures-util = { version = "0.3", default-features = false } +rstest = "0.26" tempfile = "3.21" tokio = { version = "1.47", features = ["macros", "rt"] } diff --git a/crates/common/src/data_converters.rs b/crates/common/src/data_converters.rs index e58f37dad..1157a97df 100644 --- a/crates/common/src/data_converters.rs +++ b/crates/common/src/data_converters.rs @@ -158,7 +158,6 @@ impl PayloadConverter { pub fn serde_json() -> Self { Self::Serde(Arc::new(SerdeJsonPayloadConverter)) } - // TODO [rust-sdk-branch]: Proto binary, other standard built-ins } impl Default for PayloadConverter { @@ -618,10 +617,12 @@ pub trait ErasedSerdePayloadConverter: Send + Sync { /// Wrapper for protobuf messages that implements [`TemporalSerializable`]/[`TemporalDeserializable`] /// using `binary/protobuf` encoding. +/// +/// Use this when you want compact binary protobuf wire format for your proto types. pub struct ProstSerializable(pub T); impl TemporalSerializable for ProstSerializable where - T: prost::Message + Default + 'static, + T: prost::Message + prost::Name + Default + 'static, { fn to_payload(&self, _: &SerializationContext<'_>) -> Result { let as_proto = prost::Message::encode_to_vec(&self.0); @@ -629,6 +630,7 @@ where metadata: { let mut hm = HashMap::new(); hm.insert("encoding".to_string(), b"binary/protobuf".to_vec()); + hm.insert("messageType".to_string(), T::full_name().into_bytes()); hm }, data: as_proto, @@ -638,7 +640,7 @@ where } impl TemporalDeserializable for ProstSerializable where - T: prost::Message + Default + 'static, + T: prost::Message + prost::Name + Default + 'static, { fn from_payload( _: &SerializationContext<'_>, @@ -656,6 +658,74 @@ where .map_err(|e| PayloadConversionError::EncodingError(Box::new(e))) } } +impl From for ProstSerializable { + fn from(value: T) -> Self { + Self(value) + } +} +impl ProstSerializable { + /// Consumes the wrapper, returning the inner protobuf message. + pub fn into_inner(self) -> T { + self.0 + } +} + +/// Wrapper for protobuf messages that implements [`TemporalSerializable`]/[`TemporalDeserializable`] +/// using `json/protobuf` encoding via pbjson-generated serde impls. +/// +/// Use this when you want human-readable JSON protobuf format (proto3 JSON mapping with +/// camelCase field names and string enums). +pub struct ProstJsonSerializable(pub T); +impl TemporalSerializable for ProstJsonSerializable +where + T: prost::Message + prost::Name + serde::Serialize + Default + 'static, +{ + fn to_payload(&self, _: &SerializationContext<'_>) -> Result { + let as_json = serde_json::to_vec(&self.0) + .map_err(|e| PayloadConversionError::EncodingError(e.into()))?; + Ok(Payload { + metadata: { + let mut hm = HashMap::new(); + hm.insert("encoding".to_string(), b"json/protobuf".to_vec()); + hm.insert("messageType".to_string(), T::full_name().into_bytes()); + hm + }, + data: as_json, + external_payloads: vec![], + }) + } +} +impl TemporalDeserializable for ProstJsonSerializable +where + T: prost::Message + prost::Name + serde::de::DeserializeOwned + Default + 'static, +{ + fn from_payload( + _: &SerializationContext<'_>, + p: Payload, + ) -> Result + where + Self: Sized, + { + let encoding = p.metadata.get("encoding").map(|v| v.as_slice()); + if encoding != Some(b"json/protobuf".as_slice()) { + return Err(PayloadConversionError::WrongEncoding); + } + serde_json::from_slice(&p.data) + .map(ProstJsonSerializable) + .map_err(|e| PayloadConversionError::EncodingError(e.into())) + } +} +impl From for ProstJsonSerializable { + fn from(value: T) -> Self { + Self(value) + } +} +impl ProstJsonSerializable { + /// Consumes the wrapper, returning the inner protobuf message. + pub fn into_inner(self) -> T { + self.0 + } +} /// A payload converter that delegates to an ordered list of inner converters. #[derive(Clone)] @@ -769,6 +839,9 @@ impl_multi_args!(MultiArgs6; 6; 0: A, 1: B, 2: C, 3: D, 4: E, 5: F); #[cfg(test)] mod tests { use super::*; + use crate::protos::temporal::api::common::v1::WorkflowExecution; + use prost::Name; + use rstest::rstest; #[test] fn test_empty_payloads_as_unit_type() { @@ -866,4 +939,142 @@ mod tests { let args: MultiArgs2 = ("hello".to_string(), 42i32).into(); assert_eq!(args, MultiArgs2("hello".to_string(), 42)); } + + #[derive(Clone, Copy)] + enum EncodingFormat { + Binary, + Json, + } + + impl EncodingFormat { + fn expected_encoding(self) -> &'static [u8] { + match self { + Self::Binary => b"binary/protobuf", + Self::Json => b"json/protobuf", + } + } + } + + fn test_wf_exec() -> WorkflowExecution { + WorkflowExecution { + workflow_id: "wf-123".into(), + run_id: "run-456".into(), + } + } + + fn serialize_as( + format: EncodingFormat, + wf: &WorkflowExecution, + ctx: &SerializationContext<'_>, + ) -> Payload { + match format { + EncodingFormat::Binary => ProstSerializable(wf.clone()).to_payload(ctx).unwrap(), + EncodingFormat::Json => ProstJsonSerializable(wf.clone()).to_payload(ctx).unwrap(), + } + } + + fn deserialize_as( + format: EncodingFormat, + payload: Payload, + ctx: &SerializationContext<'_>, + ) -> Result { + match format { + EncodingFormat::Binary => { + ProstSerializable::from_payload(ctx, payload).map(|w: ProstSerializable<_>| w.0) + } + EncodingFormat::Json => ProstJsonSerializable::from_payload(ctx, payload) + .map(|w: ProstJsonSerializable<_>| w.0), + } + } + + #[rstest] + #[case::binary(EncodingFormat::Binary)] + #[case::json(EncodingFormat::Json)] + fn prost_round_trip(#[case] format: EncodingFormat) { + let converter = PayloadConverter::default(); + let ctx = SerializationContext { + data: &SerializationContextData::Workflow, + converter: &converter, + }; + let wf = test_wf_exec(); + + let payload = serialize_as(format, &wf, &ctx); + assert_eq!( + payload.metadata.get("encoding").unwrap().as_slice(), + format.expected_encoding(), + ); + assert_eq!( + payload.metadata.get("messageType").unwrap().as_slice(), + WorkflowExecution::full_name().as_bytes(), + ); + + let result = deserialize_as(format, payload, &ctx).unwrap(); + assert_eq!(result, wf); + } + + #[test] + fn prost_json_produces_proto3_json_format() { + let converter = PayloadConverter::default(); + let ctx = SerializationContext { + data: &SerializationContextData::Workflow, + converter: &converter, + }; + + let payload = ProstJsonSerializable(test_wf_exec()) + .to_payload(&ctx) + .unwrap(); + let json_str = std::str::from_utf8(&payload.data).unwrap(); + assert!( + json_str.contains("workflowId"), + "expected camelCase field names in proto3 JSON, got: {json_str}" + ); + assert!(json_str.contains("runId")); + } + + #[rstest] + #[case::binary_rejects_json(EncodingFormat::Binary, EncodingFormat::Json)] + #[case::json_rejects_binary(EncodingFormat::Json, EncodingFormat::Binary)] + fn prost_rejects_wrong_encoding( + #[case] decode_format: EncodingFormat, + #[case] encode_format: EncodingFormat, + ) { + let converter = PayloadConverter::default(); + let ctx = SerializationContext { + data: &SerializationContextData::Workflow, + converter: &converter, + }; + + let payload = serialize_as(encode_format, &test_wf_exec(), &ctx); + let result = deserialize_as(decode_format, payload, &ctx); + assert!(matches!(result, Err(PayloadConversionError::WrongEncoding))); + } + + #[rstest] + #[case::binary(EncodingFormat::Binary)] + #[case::json(EncodingFormat::Json)] + fn prost_through_composite_converter(#[case] format: EncodingFormat) { + let converter = PayloadConverter::default(); + let ctx = SerializationContext { + data: &SerializationContextData::Workflow, + converter: &converter, + }; + let wf = test_wf_exec(); + + let payloads = match format { + EncodingFormat::Binary => converter + .to_payloads(&ctx, &ProstSerializable(wf.clone())) + .unwrap(), + EncodingFormat::Json => converter + .to_payloads(&ctx, &ProstJsonSerializable(wf.clone())) + .unwrap(), + }; + assert_eq!(payloads.len(), 1); + assert_eq!( + payloads[0].metadata.get("encoding").unwrap().as_slice(), + format.expected_encoding(), + ); + + let result = deserialize_as(format, payloads.into_iter().next().unwrap(), &ctx).unwrap(); + assert_eq!(result, wf); + } } diff --git a/crates/sdk-core/tests/integ_tests/data_converter_tests.rs b/crates/sdk-core/tests/integ_tests/data_converter_tests.rs index 5b1a1442c..72684e753 100644 --- a/crates/sdk-core/tests/integ_tests/data_converter_tests.rs +++ b/crates/sdk-core/tests/integ_tests/data_converter_tests.rs @@ -11,10 +11,13 @@ use temporalio_client::{Client, ClientOptions, UntypedWorkflow, WorkflowStartOpt use temporalio_common::{ data_converters::{ DataConverter, DefaultFailureConverter, MultiArgs2, PayloadCodec, PayloadConversionError, - PayloadConverter, SerializationContext, SerializationContextData, TemporalDeserializable, - TemporalSerializable, + PayloadConverter, ProstJsonSerializable, ProstSerializable, SerializationContext, + SerializationContextData, TemporalDeserializable, TemporalSerializable, + }, + protos::temporal::api::{ + common::v1::{Payload, WorkflowExecution}, + history::v1::history_event::Attributes, }, - protos::temporal::api::{common::v1::Payload, history::v1::history_event::Attributes}, worker::WorkerTaskTypes, }; use temporalio_macros::{activities, workflow, workflow_methods}; @@ -379,3 +382,233 @@ async fn codec_encodes_and_decodes_payloads() { "Codec should have decoded payloads, but decode_count was 0" ); } + +struct ProtoActivities; +#[activities] +impl ProtoActivities { + #[activity] + async fn echo_binary_proto( + _ctx: ActivityContext, + input: ProstSerializable, + ) -> Result, ActivityError> { + let mut wf = input.0; + wf.run_id = format!("activity-{}", wf.run_id); + Ok(ProstSerializable(wf)) + } + + #[activity] + async fn echo_json_proto( + _ctx: ActivityContext, + input: ProstJsonSerializable, + ) -> Result, ActivityError> { + let mut wf = input.0; + wf.run_id = format!("activity-{}", wf.run_id); + Ok(ProstJsonSerializable(wf)) + } +} + +#[workflow] +#[derive(Default)] +struct ProtoBinaryWorkflow; +#[workflow_methods] +impl ProtoBinaryWorkflow { + #[run] + async fn run( + ctx: &mut WorkflowContext, + input: ProstSerializable, + ) -> WorkflowResult> { + let output = ctx + .start_activity( + ProtoActivities::echo_binary_proto, + input, + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + ..Default::default() + }, + ) + .await + .map_err(|e| anyhow::anyhow!("{e}"))?; + Ok(output) + } +} + +#[tokio::test] +async fn prost_binary_round_trips_through_workflow() { + let wf_name = ProtoBinaryWorkflow::name(); + let mut starter = CoreWfStarter::new(wf_name); + starter.sdk_config.register_activities(ProtoActivities); + starter + .sdk_config + .register_workflow::(); + let mut worker = starter.worker().await; + + let input = WorkflowExecution { + workflow_id: "test-wf".into(), + run_id: "test-run".into(), + }; + + let task_queue = starter.get_task_queue().to_owned(); + let handle = worker + .submit_workflow( + ProtoBinaryWorkflow::run, + ProstSerializable(input), + WorkflowStartOptions::new(task_queue, wf_name.to_owned()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + let output = handle.get_result(Default::default()).await.unwrap().0; + assert_eq!( + output, + WorkflowExecution { + workflow_id: "test-wf".into(), + run_id: "activity-test-run".into(), + } + ); + + // Verify the history payloads use binary/protobuf encoding + let client = starter.get_client().await; + let events = client + .get_workflow_handle::(wf_name) + .fetch_history(Default::default()) + .await + .unwrap() + .into_events(); + + let started = events + .iter() + .find_map(|e| { + if let Attributes::WorkflowExecutionStartedEventAttributes(attrs) = + e.attributes.as_ref().unwrap() + { + Some(attrs) + } else { + None + } + }) + .expect("Should find WorkflowExecutionStarted event"); + + let input_payload = &started.input.as_ref().unwrap().payloads[0]; + assert_eq!( + input_payload.metadata.get("encoding").unwrap().as_slice(), + b"binary/protobuf", + "Workflow input should be encoded as binary/protobuf" + ); + assert_eq!( + input_payload + .metadata + .get("messageType") + .unwrap() + .as_slice(), + b"temporal.api.common.v1.WorkflowExecution" + ); + + // Decode the raw payload back into the proto type to verify the bytes are valid + let decoded = ::decode(input_payload.data.as_slice()) + .expect("History payload should decode as a valid WorkflowExecution"); + assert_eq!(decoded.workflow_id, "test-wf"); + assert_eq!(decoded.run_id, "test-run"); +} + +#[workflow] +#[derive(Default)] +struct ProtoJsonWorkflow; +#[workflow_methods] +impl ProtoJsonWorkflow { + #[run] + async fn run( + ctx: &mut WorkflowContext, + input: ProstJsonSerializable, + ) -> WorkflowResult> { + let output = ctx + .start_activity( + ProtoActivities::echo_json_proto, + input, + ActivityOptions { + start_to_close_timeout: Some(Duration::from_secs(5)), + ..Default::default() + }, + ) + .await + .map_err(|e| anyhow::anyhow!("{e}"))?; + Ok(output) + } +} + +#[tokio::test] +async fn prost_json_round_trips_through_workflow() { + let wf_name = ProtoJsonWorkflow::name(); + let mut starter = CoreWfStarter::new(wf_name); + starter.sdk_config.register_activities(ProtoActivities); + starter.sdk_config.register_workflow::(); + let mut worker = starter.worker().await; + + let input = ProstJsonSerializable(WorkflowExecution { + workflow_id: "test-wf".into(), + run_id: "test-run".into(), + }); + + let task_queue = starter.get_task_queue().to_owned(); + let handle = worker + .submit_workflow( + ProtoJsonWorkflow::run, + input, + WorkflowStartOptions::new(task_queue, wf_name.to_owned()).build(), + ) + .await + .unwrap(); + worker.run_until_done().await.unwrap(); + + let output = handle.get_result(Default::default()).await.unwrap().0; + assert_eq!(output.workflow_id, "test-wf"); + assert_eq!(output.run_id, "activity-test-run"); + + // Verify the history payloads use json/protobuf encoding + let client = starter.get_client().await; + let events = client + .get_workflow_handle::(wf_name) + .fetch_history(Default::default()) + .await + .unwrap() + .into_events(); + + let started = events + .iter() + .find_map(|e| { + if let Attributes::WorkflowExecutionStartedEventAttributes(attrs) = + e.attributes.as_ref().unwrap() + { + Some(attrs) + } else { + None + } + }) + .expect("Should find WorkflowExecutionStarted event"); + + let input_payload = &started.input.as_ref().unwrap().payloads[0]; + assert_eq!( + input_payload.metadata.get("encoding").unwrap().as_slice(), + b"json/protobuf", + "Workflow input should be encoded as json/protobuf" + ); + assert_eq!( + input_payload + .metadata + .get("messageType") + .unwrap() + .as_slice(), + b"temporal.api.common.v1.WorkflowExecution" + ); + + // Decode the raw payload back into the proto type to verify the JSON is valid + let decoded: WorkflowExecution = serde_json::from_slice(&input_payload.data) + .expect("History payload should decode as a valid WorkflowExecution"); + assert_eq!( + decoded, + WorkflowExecution { + workflow_id: "test-wf".into(), + run_id: "test-run".into(), + } + ); +}