Skip to content
Merged
22 changes: 22 additions & 0 deletions common/testing/protoassert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"go.temporal.io/api/temporalproto"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/testing/protocmp"
)

Expand All @@ -34,6 +35,18 @@ func ProtoEqual(t assert.TestingT, a proto.Message, b proto.Message) bool {
return true
}

// ProtoEqualIgnoreFields compares two proto messages for equality, ignoring the specified fields
// on the given message type. Fields are specified by their proto name (snake_case).
func ProtoEqualIgnoreFields(t assert.TestingT, a proto.Message, b proto.Message, msgType proto.Message, fields ...protoreflect.Name) bool {
Comment thread
fretz12 marked this conversation as resolved.
Outdated
if th, ok := t.(helper); ok {
th.Helper()
}
if diff := cmp.Diff(a, b, protocmp.Transform(), protocmp.IgnoreFields(msgType, fields...)); diff != "" {
return assert.Fail(t, fmt.Sprintf("Proto mismatch (-want +got):\n%v", diff))
}
return true
}

func NotProtoEqual(t assert.TestingT, a proto.Message, b proto.Message) bool {
if th, ok := t.(helper); ok {
th.Helper()
Expand Down Expand Up @@ -120,3 +133,12 @@ func (x ProtoAssertions) ProtoElementsMatch(a any, b any) bool {

return ProtoElementsMatch(x.t, a, b)
}

// ProtoEqualIgnoreFields compares two proto messages for equality, ignoring the specified fields
// on the given message type. Fields are specified by their proto name (snake_case).
func (x ProtoAssertions) ProtoEqualIgnoreFields(a proto.Message, b proto.Message, msgType proto.Message, fields ...protoreflect.Name) bool {
if th, ok := x.t.(helper); ok {
th.Helper()
}
return ProtoEqualIgnoreFields(x.t, a, b, msgType, fields...)
}
55 changes: 55 additions & 0 deletions common/testing/protoassert/assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,61 @@ type canHazProto struct {
B *commonpb.WorkflowExecution
}

func TestProtoEqualIgnoreFields(t *testing.T) {
a := &workflowpb.WorkflowExecutionInfo{
Execution: &commonpb.WorkflowExecution{
WorkflowId: "wf-1",
RunId: myUUID,
},
Status: 1,
}
b := &workflowpb.WorkflowExecutionInfo{
Execution: &commonpb.WorkflowExecution{
WorkflowId: "wf-1",
RunId: myUUID,
},
Status: 2, // different — will be ignored
}

result := protoassert.ProtoEqualIgnoreFields(
t,
a,
b,
&workflowpb.WorkflowExecutionInfo{},
"status",
)
if !result {
t.Error("expected equality when ignoring 'status' field")
}
}

func TestProtoEqualIgnoreFields_FailsWhenNonIgnoredFieldDiffers(t *testing.T) {
a := &workflowpb.WorkflowExecutionInfo{
Execution: &commonpb.WorkflowExecution{
WorkflowId: "wf-1",
RunId: myUUID,
},
}
b := &workflowpb.WorkflowExecutionInfo{
Execution: &commonpb.WorkflowExecution{
WorkflowId: "wf-DIFFERENT",
RunId: myUUID,
},
}

mockT := &testing.T{}
result := protoassert.ProtoEqualIgnoreFields(
mockT,
a,
b,
&workflowpb.WorkflowExecutionInfo{},
"status",
)
if result {
t.Error("expected failure when non-ignored field differs")
}
}

func TestProtoElementsMatch(t *testing.T) {
assert := protoassert.New(t)
for _, tc := range []struct {
Expand Down
21 changes: 21 additions & 0 deletions common/testing/protorequire/require.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/stretchr/testify/require"
"go.temporal.io/server/common/testing/protoassert"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

type helper interface {
Expand All @@ -27,6 +28,17 @@ func ProtoEqual(t require.TestingT, a proto.Message, b proto.Message) {
}
}

// ProtoEqualIgnoreFields compares two proto messages for equality, ignoring the specified fields
// on the given message type. Calls FailNow on mismatch.
func ProtoEqualIgnoreFields(t require.TestingT, a proto.Message, b proto.Message, msgType proto.Message, fields ...protoreflect.Name) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Food for thought.
There's some other interesting options we can potentially expose:

protocmp.SortRepeatedFields
protocmp.IgnoreEnums
cmpopts.EquateApproxTime
Also another approach is to make this a really generic function:
func ProtoEqualWithOptions(t assert.TestingT, a proto.Message, b proto.Message, opts ...cmp.Option) bool

and it can be used as such:
protorequire.ProtoEqualWithOptions(t, expected, actual, protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{}, "state_transition_count"), cmpopts.EquateApproxTime(5*time.Second), )

However, that leaks proto/cmp abstractions making the helper harder to use. I decided to make the helper focused on ignore fields, as that's our majority use case and the interface is very clean. I propose for other useful options, like EquateApproxTime, we create separate helpers. But open to thoughts

if th, ok := t.(helper); ok {
th.Helper()
}
if !protoassert.ProtoEqualIgnoreFields(t, a, b, msgType, fields...) {
t.FailNow()
}
}

func NotProtoEqual(t require.TestingT, a proto.Message, b proto.Message) {
if th, ok := t.(helper); ok {
th.Helper()
Expand Down Expand Up @@ -83,3 +95,12 @@ func (x ProtoAssertions) ProtoElementsMatch(a any, b any) bool {

return protoassert.ProtoElementsMatch(x.t, a, b)
}

func (x ProtoAssertions) ProtoEqualIgnoreFields(a proto.Message, b proto.Message, msgType proto.Message, fields ...protoreflect.Name) {
if th, ok := x.t.(helper); ok {
th.Helper()
}
if !protoassert.ProtoEqualIgnoreFields(x.t, a, b, msgType, fields...) {
x.t.FailNow()
}
}
49 changes: 19 additions & 30 deletions tests/standalone_activity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
activitypb "go.temporal.io/api/activity/v1"
Expand All @@ -29,7 +28,6 @@ import (
"go.temporal.io/server/common/testing/testvars"
"go.temporal.io/server/tests/testcore"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
)
Expand Down Expand Up @@ -2866,16 +2864,13 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_NoWait() {
UserMetadata: defaultUserMetadata,
}

diff := cmp.Diff(expected, respInfo,
protocmp.Transform(),
// Ignore non-deterministic fields. Validated separately.
protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{},
"execution_duration",
"schedule_time",
"state_transition_count",
),
// Ignore non-deterministic fields. Validated separately.
protorequire.ProtoEqualIgnoreFields(t, expected, respInfo,
&activitypb.ActivityExecutionInfo{},
"execution_duration",
"schedule_time",
"state_transition_count",
)
require.Empty(t, diff)
require.Equal(t, respInfo.GetExecutionDuration().AsDuration(), time.Duration(0)) // Never completed, so expect 0
require.Nil(t, describeResp.GetInfo().GetCloseTime())
require.Positive(t, respInfo.GetScheduleTime().AsTime().Unix())
Expand Down Expand Up @@ -2927,16 +2922,13 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_WaitAnyState
Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING,
TaskQueue: taskQueue.Name,
}
diff := cmp.Diff(expected, firstDescribeResp.GetInfo(),
protocmp.Transform(),
// Ignore non-deterministic fields. Validated separately.
protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{},
"execution_duration",
"schedule_time",
"state_transition_count",
),
// Ignore non-deterministic fields. Validated separately.
protorequire.ProtoEqualIgnoreFields(t, expected, firstDescribeResp.GetInfo(),
&activitypb.ActivityExecutionInfo{},
"execution_duration",
"schedule_time",
"state_transition_count",
)
require.Empty(t, diff)

taskQueuePollErr := make(chan error, 1)
activityPollDone := make(chan struct{})
Expand Down Expand Up @@ -2985,17 +2977,14 @@ func (s *standaloneActivityTestSuite) TestDescribeActivityExecution_WaitAnyState
Status: enumspb.ACTIVITY_EXECUTION_STATUS_RUNNING,
TaskQueue: taskQueue.Name,
}
diff := cmp.Diff(expected, describeResp.GetInfo(),
protocmp.Transform(),
// Ignore non-deterministic fields. Validated separately.
protocmp.IgnoreFields(&activitypb.ActivityExecutionInfo{},
"execution_duration",
"last_started_time",
"schedule_time",
"state_transition_count",
),
// Ignore non-deterministic fields. Validated separately.
protorequire.ProtoEqualIgnoreFields(t, expected, describeResp.GetInfo(),
&activitypb.ActivityExecutionInfo{},
"execution_duration",
"last_started_time",
"schedule_time",
"state_transition_count",
)
require.Empty(t, diff)

protorequire.ProtoEqual(t, defaultInput, describeResp.Input)

Expand Down
Loading