@@ -53,7 +53,9 @@ import (
5353 "fmt"
5454
5555 "google.golang.org/grpc"
56+ "google.golang.org/grpc/status"
5657 "google.golang.org/protobuf/proto"
58+ "google.golang.org/protobuf/protoadapt"
5759)
5860
5961// VisitPayloadsContext provides Payload context for visitor functions.
@@ -93,6 +95,9 @@ type PayloadVisitorInterceptorOptions struct {
9395}
9496
9597// NewPayloadVisitorInterceptor creates a new gRPC interceptor for workflowservice messages.
98+ //
99+ // Note: Failure converters should come before payload codec converts, to allow the
100+ // payloads generated by the failure convert to be intercepted by the payload codec converters.
96101func NewPayloadVisitorInterceptor(options PayloadVisitorInterceptorOptions) (grpc.UnaryClientInterceptor, error) {
97102 return func(ctx context.Context, method string, req, response interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
98103 if reqMsg, ok := req.(proto.Message); ok && options.Outbound != nil {
@@ -104,6 +109,30 @@ func NewPayloadVisitorInterceptor(options PayloadVisitorInterceptorOptions) (grp
104109
105110 err := invoker(ctx, method, req, response, cc, opts...)
106111 if err != nil {
112+ if options.Inbound != nil {
113+ stat, ok := status.FromError(err)
114+ if ok {
115+ // user provided payloads can sometimes end up in the status details of
116+ // gRPC errors, make sure to visit those as well
117+ newStatus := status.New(stat.Code(), stat.Message())
118+ var newDetails []protoadapt.MessageV1
119+ for _, detail := range stat.Details() {
120+ detailAny, ok := detail.(*anypb.Any)
121+ if ok && ({{ range $i, $name := .GrpcPayload }}{{ if $i }} || {{ end }}detailAny.MessageName() == "{{$name}}"{{ end }}) {
122+ err = VisitPayloads(ctx, detailAny, *options.Inbound)
123+ if err != nil {
124+ return err
125+ }
126+ }
127+ newDetails = append(newDetails, detailAny)
128+ }
129+ newStatus, err = newStatus.WithDetails(newDetails...)
130+ if err != nil {
131+ return err
132+ }
133+ return newStatus.Err()
134+ }
135+ }
107136 return err
108137 }
109138
@@ -148,6 +177,9 @@ type FailureVisitorInterceptorOptions struct {
148177}
149178
150179// NewFailureVisitorInterceptor creates a new gRPC interceptor for workflowservice messages.
180+ //
181+ // Note: Failure converters should come before payload codec converts, to allow the
182+ // payloads generated by the failure convert to be intercepted by the payload codec converters.
151183func NewFailureVisitorInterceptor(options FailureVisitorInterceptorOptions) (grpc.UnaryClientInterceptor, error) {
152184 return func(ctx context.Context, method string, req, response interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
153185 if reqMsg, ok := req.(proto.Message); ok && options.Outbound != nil {
@@ -159,6 +191,30 @@ func NewFailureVisitorInterceptor(options FailureVisitorInterceptorOptions) (grp
159191
160192 err := invoker(ctx, method, req, response, cc, opts...)
161193 if err != nil {
194+ if options.Inbound != nil {
195+ stat, ok := status.FromError(err)
196+ if ok {
197+ // user provided payloads can sometimes end up in the status details of
198+ // gRPC errors, make sure to visit those as well
199+ newStatus := status.New(stat.Code(), stat.Message())
200+ var newDetails []protoadapt.MessageV1
201+ for _, detail := range stat.Details() {
202+ detailAny, ok := detail.(*anypb.Any)
203+ if ok && ({{ range $i, $name := .GrpcFailure }}{{ if $i }} || {{ end }}detailAny.MessageName() == "{{$name}}"{{ end }}) {
204+ err = VisitFailures(ctx, detailAny, *options.Inbound)
205+ if err != nil {
206+ return err
207+ }
208+ }
209+ newDetails = append(newDetails, detailAny)
210+ }
211+ newStatus, err = newStatus.WithDetails(newDetails...)
212+ if err != nil {
213+ return err
214+ }
215+ return newStatus.Err()
216+ }
217+ }
162218 return err
163219 }
164220
@@ -369,6 +425,13 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj
369425
370426var interceptorTemplate = template .Must (template .New ("interceptor" ).Parse (InterceptorTemplateText ))
371427
428+ type TemplateInput struct {
429+ PayloadTypes map [string ]* TypeRecord
430+ FailureTypes map [string ]* TypeRecord
431+ GrpcPayload []string
432+ GrpcFailure []string
433+ }
434+
372435// TypeRecord holds the state for a type referred to by the workflow service
373436type TypeRecord struct {
374437 Methods []string // List of methods on this type that can eventually lead to Payload(s)
@@ -776,11 +839,31 @@ func generateInterceptor(cfg config) error {
776839 return err
777840 }
778841
842+ // gather a list of errordetails that can contain user payloads when included in
843+ // gRPC error messages
844+ var grpcPayload []string
845+ for _ , msg := range allPayloadContainingMessages {
846+ if strings .Contains (string (msg .FullName ()), "temporal.api.errordetails.v1." ) && strings .Count (string (msg .FullName ()), "." ) == 4 {
847+ grpcPayload = append (grpcPayload , string (msg .FullName ()))
848+ }
849+ }
850+
851+ var grpcFailure []string
852+ for _ , msg := range allFailureContainingMessages {
853+ if strings .Contains (string (msg .FullName ()), "temporal.api.errordetails.v1." ) && strings .Count (string (msg .FullName ()), "." ) == 4 {
854+ grpcFailure = append (grpcFailure , string (msg .FullName ()))
855+ }
856+ }
857+
779858 buf := & bytes.Buffer {}
780859 fmt .Fprint (buf , cfg .license )
781860
782- err = interceptorTemplate .Execute (buf , map [string ]map [string ]* TypeRecord {
783- "PayloadTypes" : payloadRecords , "FailureTypes" : failureRecords })
861+ err = interceptorTemplate .Execute (buf , TemplateInput {
862+ PayloadTypes : payloadRecords ,
863+ FailureTypes : failureRecords ,
864+ GrpcFailure : grpcFailure ,
865+ GrpcPayload : grpcPayload ,
866+ })
784867 if err != nil {
785868 return err
786869 }
0 commit comments