Skip to content

Commit a4daf74

Browse files
authored
Support visiting repeated Payload (#216)
* Support repeated Payload, add godoc comment regarding not supporting Payload directly * Forgot to commit a change, remove slices import
1 parent 8b109c8 commit a4daf74

3 files changed

Lines changed: 67 additions & 8 deletions

File tree

cmd/proxygenerator/interceptor.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ type VisitPayloadsOptions struct {
8080
}
8181
8282
// VisitPayloads calls the options.Visitor function for every Payload proto within msg.
83+
//
84+
// Note: Directly visiting *common.Payload is not supported. Payloads must be passed through
85+
// a parent proto.
8386
func VisitPayloads(ctx context.Context, msg proto.Message, options VisitPayloadsOptions) error {
8487
visitCtx := VisitPayloadsContext{Context: ctx, Parent: msg}
8588
@@ -315,6 +318,14 @@ func visitPayloads(
315318
return err
316319
}
317320
}
321+
case []*common.Payload:
322+
for ix, x := range o {
323+
if nx, err := visitPayload(ctx, options, parent, x); err != nil {
324+
return err
325+
} else {
326+
o[ix] = nx
327+
}
328+
}
318329
case *anypb.Any:
319330
if o == nil {
320331
continue

proxy/interceptor.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

proxy/interceptor_test.go

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import (
2929
"fmt"
3030
"log"
3131
"net"
32-
"slices"
3332
"strings"
3433
"testing"
3534
"time"
@@ -189,6 +188,25 @@ func TestVisitPayloads_NestedParent(t *testing.T) {
189188
require.IsType(t, &command.StartChildWorkflowExecutionCommandAttributes{}, inputParent)
190189
}
191190

191+
func TestVisitPayloads_RepeatedPayload(t *testing.T) {
192+
root := &workflowservice.CountWorkflowExecutionsResponse_AggregationGroup{GroupValues: []*common.Payload{{Data: []byte("orig-val")}}}
193+
194+
var count int
195+
err := VisitPayloads(context.Background(), root, VisitPayloadsOptions{
196+
Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) {
197+
count += 1
198+
// Only mutate if the payloads has orig-val
199+
if len(p) == 1 && string(p[0].Data) == "orig-val" {
200+
return []*common.Payload{{Data: []byte("new-val")}}, nil
201+
}
202+
return p, nil
203+
},
204+
})
205+
require.NoError(t, err)
206+
require.Equal(t, 1, count)
207+
require.Equal(t, []*common.Payload{{Data: []byte("new-val")}}, root.GroupValues)
208+
}
209+
192210
func TestVisitPayloads_Any(t *testing.T) {
193211
// Due to us not visiting protos inside Any, this test used to fail
194212
msg1, err := anypb.New(&update.Request{Input: &update.Input{Args: &common.Payloads{
@@ -945,6 +963,29 @@ func TestVisitPayloads_MapCount(t *testing.T) {
945963
require.Equal(44, totalCount)
946964
}
947965

966+
func TestVisitPayloads_CountWorkflowExecutionsResponse(t *testing.T) {
967+
require := require.New(t)
968+
969+
var messageType protoreflect.MessageType
970+
var totalCount, count int
971+
972+
protoregistry.GlobalTypes.RangeMessages(func(mt protoreflect.MessageType) bool {
973+
if string(mt.Descriptor().FullName()) == "temporal.api.workflowservice.v1.CountWorkflowExecutionsResponse" {
974+
messageType = mt
975+
}
976+
return true
977+
})
978+
979+
// Create empty instance and populate with test values
980+
msg1 := messageType.New().Interface().(proto.Message)
981+
totalCount = 0
982+
count = 0
983+
populatePayload(&msg1, msg1, require, &totalCount, &count)
984+
985+
require.Equal(0, count)
986+
require.Equal(1, totalCount)
987+
}
988+
948989
func TestVisitPayloads_ResponseCount(t *testing.T) {
949990
require := require.New(t)
950991

@@ -989,14 +1030,10 @@ func TestVisitPayloads_Everything(t *testing.T) {
9891030
require := require.New(t)
9901031

9911032
var messageType []protoreflect.MessageType
992-
skipList := []string{
993-
"temporal.api.common.v1.Payload",
994-
// TODO: https://github.com/temporalio/sdk-go/issues/1865
995-
"temporal.api.workflowservice.v1.CountWorkflowExecutionsResponse",
996-
"temporal.api.workflowservice.v1.CountWorkflowExecutionsResponse.AggregationGroup",
997-
}
9981033
protoregistry.GlobalTypes.RangeMessages(func(mt protoreflect.MessageType) bool {
999-
if strings.HasPrefix(string(mt.Descriptor().FullName()), "temporal.api.") && !slices.Contains(skipList, string(mt.Descriptor().FullName())) {
1034+
// The base case of passing Payload into the visitor is not supported.
1035+
// See godoc for VisitPayloads
1036+
if strings.HasPrefix(string(mt.Descriptor().FullName()), "temporal.api.") && string(mt.Descriptor().FullName()) != "temporal.api.common.v1.Payload" {
10001037
messageType = append(messageType, mt)
10011038
}
10021039
return true

0 commit comments

Comments
 (0)