Skip to content

Commit 5e825f1

Browse files
authored
change require.protoEquals to support compare options (#9937)
## What changed? Added ProtoEqualIgnoreFields to protoassert and protorequire packages. Refactored 3 call sites in standalone_activity_test.go to use the new helper. ## Why? Comparing proto messages while ignoring specific fields required a verbose pattern (cmp.Diff + protocmp.Transform() + protocmp.IgnoreFields() + require.Empty). The new option extends ProtoEqual with a clean call site that infers the message type automatically. The Option func type and config struct are designed to support additional comparison options in the future without changing the ProtoEqual signature. ## How did you test it? - [X] built - [X] run locally and tested manually - [X] covered by existing tests - [X] added new unit test(s) - [ ] added new functional test(s)
1 parent f854ea9 commit 5e825f1

3 files changed

Lines changed: 108 additions & 24 deletions

File tree

common/testing/protorequire/require.go

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,37 @@
11
package protorequire
22

33
import (
4+
"fmt"
5+
6+
"github.com/google/go-cmp/cmp"
47
"github.com/stretchr/testify/require"
58
"go.temporal.io/server/common/testing/protoassert"
69
"google.golang.org/protobuf/proto"
10+
"google.golang.org/protobuf/reflect/protoreflect"
11+
"google.golang.org/protobuf/testing/protocmp"
712
)
813

914
type helper interface {
1015
Helper()
1116
}
1217

18+
// config holds settings populated by Option functions. Add new fields here for custom comparison behaviors beyond cmp
19+
// options.
20+
type config struct {
21+
cmpOpts []cmp.Option
22+
}
23+
24+
// Option configures how proto comparison behaves.
25+
type Option func(msg proto.Message, cfg *config)
26+
27+
// IgnoreFields returns an Option that ignores the specified fields (by proto name) on the top-level message type
28+
// when comparing.
29+
func IgnoreFields(fields ...protoreflect.Name) Option {
30+
return func(msg proto.Message, cfg *config) {
31+
cfg.cmpOpts = append(cfg.cmpOpts, protocmp.IgnoreFields(msg, fields...))
32+
}
33+
}
34+
1335
type ProtoAssertions struct {
1436
t require.TestingT
1537
}
@@ -18,12 +40,19 @@ func New(t require.TestingT) ProtoAssertions {
1840
return ProtoAssertions{t}
1941
}
2042

21-
func ProtoEqual(t require.TestingT, a proto.Message, b proto.Message) {
43+
// ProtoEqual compares two proto messages for equality using proto semantics. Options can be passed to customize
44+
// comparison behavior, e.g. protorequire.IgnoreFields to exclude specific fields.
45+
func ProtoEqual(t require.TestingT, a proto.Message, b proto.Message, opts ...Option) {
2246
if th, ok := t.(helper); ok {
2347
th.Helper()
2448
}
25-
if !protoassert.ProtoEqual(t, a, b) {
26-
t.FailNow()
49+
cfg := &config{}
50+
for _, opt := range opts {
51+
opt(a, cfg)
52+
}
53+
cmpOpts := append([]cmp.Option{protocmp.Transform()}, cfg.cmpOpts...)
54+
if diff := cmp.Diff(a, b, cmpOpts...); diff != "" {
55+
require.Fail(t, fmt.Sprintf("Proto mismatch (-want +got):\n%v", diff))
2756
}
2857
}
2958

@@ -49,13 +78,11 @@ func ProtoSliceEqual[T proto.Message](t require.TestingT, a []T, b []T) {
4978
}
5079
}
5180

52-
func (x ProtoAssertions) ProtoEqual(a proto.Message, b proto.Message) {
81+
func (x ProtoAssertions) ProtoEqual(a proto.Message, b proto.Message, opts ...Option) {
5382
if th, ok := x.t.(helper); ok {
5483
th.Helper()
5584
}
56-
if !protoassert.ProtoEqual(x.t, a, b) {
57-
x.t.FailNow()
58-
}
85+
ProtoEqual(x.t, a, b, opts...)
5986
}
6087

6188
func (x ProtoAssertions) NotProtoEqual(a proto.Message, b proto.Message) {
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package protorequire_test
2+
3+
import (
4+
"testing"
5+
6+
commonpb "go.temporal.io/api/common/v1"
7+
workflowpb "go.temporal.io/api/workflow/v1"
8+
"go.temporal.io/server/common/testing/protorequire"
9+
)
10+
11+
const myUUID = "deb7b204-b384-4fde-85c6-e5a56c42336a"
12+
13+
type mockT struct {
14+
failed bool
15+
}
16+
17+
func (m *mockT) Errorf(string, ...any) {
18+
m.failed = true
19+
}
20+
21+
func (m *mockT) FailNow() {
22+
m.failed = true
23+
}
24+
25+
func TestProtoEqualIgnoreFields(t *testing.T) {
26+
a := &workflowpb.WorkflowExecutionInfo{
27+
Execution: &commonpb.WorkflowExecution{
28+
WorkflowId: "wf-1",
29+
RunId: myUUID,
30+
},
31+
Status: 1,
32+
TaskQueue: "queue-a",
33+
}
34+
b := &workflowpb.WorkflowExecutionInfo{
35+
Execution: &commonpb.WorkflowExecution{
36+
WorkflowId: "wf-1",
37+
RunId: myUUID,
38+
},
39+
Status: 2,
40+
TaskQueue: "queue-b",
41+
}
42+
43+
t.Run("all differing fields ignored", func(t *testing.T) {
44+
protorequire.ProtoEqual(t, a, b,
45+
protorequire.IgnoreFields(
46+
"status",
47+
"task_queue",
48+
),
49+
)
50+
})
51+
52+
t.Run("partial ignore still fails", func(t *testing.T) {
53+
mt := &mockT{}
54+
protorequire.ProtoEqual(mt, a, b,
55+
protorequire.IgnoreFields(
56+
"status",
57+
),
58+
)
59+
if !mt.failed {
60+
t.Fatal("expected comparison to fail when not all differing fields are ignored")
61+
}
62+
})
63+
}

tests/standalone_activity_test.go

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"testing"
1010
"time"
1111

12-
"github.com/google/go-cmp/cmp"
1312
"github.com/nexus-rpc/sdk-go/nexus"
1413
"github.com/stretchr/testify/require"
1514
"github.com/stretchr/testify/suite"
@@ -37,7 +36,6 @@ import (
3736
"go.temporal.io/server/components/callbacks"
3837
"go.temporal.io/server/tests/testcore"
3938
"google.golang.org/grpc/codes"
40-
"google.golang.org/protobuf/testing/protocmp"
4139
"google.golang.org/protobuf/types/known/durationpb"
4240
"google.golang.org/protobuf/types/known/timestamppb"
4341
)
@@ -3044,17 +3042,15 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_NoWait() {
30443042
UserMetadata: defaultUserMetadata,
30453043
}
30463044

3047-
diff := cmp.Diff(expected, respInfo,
3048-
protocmp.Transform(),
3049-
// Ignore non-deterministic fields. Validated separately.
3050-
protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{},
3045+
// Ignore non-deterministic fields. Validated separately.
3046+
protorequire.ProtoEqual(t, expected, respInfo,
3047+
protorequire.IgnoreFields(
30513048
"execution_duration",
30523049
"schedule_time",
30533050
"state_size_bytes",
30543051
"state_transition_count",
30553052
),
30563053
)
3057-
require.Empty(t, diff)
30583054
require.Equal(t, respInfo.GetExecutionDuration().AsDuration(), time.Duration(0)) // Never completed, so expect 0
30593055
require.Nil(t, describeResp.GetInfo().GetCloseTime())
30603056
require.Positive(t, respInfo.GetScheduleTime().AsTime().Unix())
@@ -3107,17 +3103,16 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_WaitAnyState
31073103
Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING,
31083104
TaskQueue: taskQueue.Name,
31093105
}
3110-
diff := cmp.Diff(expected, firstDescribeResp.GetInfo(),
3111-
protocmp.Transform(),
3112-
// Ignore non-deterministic fields. Validated separately.
3113-
protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{},
3106+
3107+
// Ignore non-deterministic fields. Validated separately.
3108+
protorequire.ProtoEqual(t, expected, firstDescribeResp.GetInfo(),
3109+
protorequire.IgnoreFields(
31143110
"execution_duration",
31153111
"schedule_time",
31163112
"state_size_bytes",
31173113
"state_transition_count",
31183114
),
31193115
)
3120-
require.Empty(t, diff)
31213116
require.Positive(t, firstDescribeResp.GetInfo().GetStateSizeBytes())
31223117

31233118
taskQueuePollErr := make(chan error, 1)
@@ -3167,18 +3162,17 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_WaitAnyState
31673162
Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING,
31683163
TaskQueue: taskQueue.Name,
31693164
}
3170-
diff := cmp.Diff(expected, describeResp.GetInfo(),
3171-
protocmp.Transform(),
3172-
// Ignore non-deterministic fields. Validated separately.
3173-
protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{},
3165+
3166+
// Ignore non-deterministic fields. Validated separately.
3167+
protorequire.ProtoEqual(t, expected, describeResp.GetInfo(),
3168+
protorequire.IgnoreFields(
31743169
"execution_duration",
31753170
"last_started_time",
31763171
"schedule_time",
31773172
"state_size_bytes",
31783173
"state_transition_count",
31793174
),
31803175
)
3181-
require.Empty(t, diff)
31823176
require.Positive(t, describeResp.GetInfo().GetStateSizeBytes())
31833177

31843178
protorequire.ProtoEqual(t, defaultInput, describeResp.Input)

0 commit comments

Comments
 (0)