@@ -29,7 +29,6 @@ import (
2929 "fmt"
3030 "log"
3131 "net"
32- "slices"
3332 "strings"
3433 "testing"
3534 "time"
@@ -189,6 +188,25 @@ func TestVisitPayloads_NestedParent(t *testing.T) {
189188 require .IsType (t , & command.StartChildWorkflowExecutionCommandAttributes {}, inputParent )
190189}
191190
191+ func TestVisitPayloads_RepeatedPayload (t * testing.T ) {
192+ root := & workflowservice.CountWorkflowExecutionsResponse_AggregationGroup {GroupValues : []* common.Payload {{Data : []byte ("orig-val" )}}}
193+
194+ var count int
195+ err := VisitPayloads (context .Background (), root , VisitPayloadsOptions {
196+ Visitor : func (ctx * VisitPayloadsContext , p []* common.Payload ) ([]* common.Payload , error ) {
197+ count += 1
198+ // Only mutate if the payloads has orig-val
199+ if len (p ) == 1 && string (p [0 ].Data ) == "orig-val" {
200+ return []* common.Payload {{Data : []byte ("new-val" )}}, nil
201+ }
202+ return p , nil
203+ },
204+ })
205+ require .NoError (t , err )
206+ require .Equal (t , 1 , count )
207+ require .Equal (t , []* common.Payload {{Data : []byte ("new-val" )}}, root .GroupValues )
208+ }
209+
192210func TestVisitPayloads_Any (t * testing.T ) {
193211 // Due to us not visiting protos inside Any, this test used to fail
194212 msg1 , err := anypb .New (& update.Request {Input : & update.Input {Args : & common.Payloads {
@@ -945,6 +963,29 @@ func TestVisitPayloads_MapCount(t *testing.T) {
945963 require .Equal (44 , totalCount )
946964}
947965
966+ func TestVisitPayloads_CountWorkflowExecutionsResponse (t * testing.T ) {
967+ require := require .New (t )
968+
969+ var messageType protoreflect.MessageType
970+ var totalCount , count int
971+
972+ protoregistry .GlobalTypes .RangeMessages (func (mt protoreflect.MessageType ) bool {
973+ if string (mt .Descriptor ().FullName ()) == "temporal.api.workflowservice.v1.CountWorkflowExecutionsResponse" {
974+ messageType = mt
975+ }
976+ return true
977+ })
978+
979+ // Create empty instance and populate with test values
980+ msg1 := messageType .New ().Interface ().(proto.Message )
981+ totalCount = 0
982+ count = 0
983+ populatePayload (& msg1 , msg1 , require , & totalCount , & count )
984+
985+ require .Equal (0 , count )
986+ require .Equal (1 , totalCount )
987+ }
988+
948989func TestVisitPayloads_ResponseCount (t * testing.T ) {
949990 require := require .New (t )
950991
@@ -989,14 +1030,10 @@ func TestVisitPayloads_Everything(t *testing.T) {
9891030 require := require .New (t )
9901031
9911032 var messageType []protoreflect.MessageType
992- skipList := []string {
993- "temporal.api.common.v1.Payload" ,
994- // TODO: https://github.com/temporalio/sdk-go/issues/1865
995- "temporal.api.workflowservice.v1.CountWorkflowExecutionsResponse" ,
996- "temporal.api.workflowservice.v1.CountWorkflowExecutionsResponse.AggregationGroup" ,
997- }
9981033 protoregistry .GlobalTypes .RangeMessages (func (mt protoreflect.MessageType ) bool {
999- if strings .HasPrefix (string (mt .Descriptor ().FullName ()), "temporal.api." ) && ! slices .Contains (skipList , string (mt .Descriptor ().FullName ())) {
1034+ // The base case of passing Payload into the visitor is not supported.
1035+ // See godoc for VisitPayloads
1036+ if strings .HasPrefix (string (mt .Descriptor ().FullName ()), "temporal.api." ) && string (mt .Descriptor ().FullName ()) != "temporal.api.common.v1.Payload" {
10001037 messageType = append (messageType , mt )
10011038 }
10021039 return true
0 commit comments