Skip to content

Commit d0b20a6

Browse files
Merge pull request #171 from SimonRichardson/errors-is-checker
error.Is testing checker
2 parents 9af10d6 + 448af9c commit d0b20a6

2 files changed

Lines changed: 121 additions & 0 deletions

File tree

checkers/error.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright 2023 Canonical Ltd.
2+
// Licensed under the LGPLv3, see LICENCE file for details.
3+
4+
package checkers
5+
6+
import (
7+
"errors"
8+
"fmt"
9+
"reflect"
10+
11+
gc "gopkg.in/check.v1"
12+
)
13+
14+
type errorIsChecker struct {
15+
*gc.CheckerInfo
16+
}
17+
18+
// ErrorIs checks whether a value is an error that matches the other
19+
// argument.
20+
var ErrorIs gc.Checker = &errorIsChecker{
21+
CheckerInfo: &gc.CheckerInfo{
22+
Name: "ErrorIs",
23+
Params: []string{"obtained", "error"},
24+
},
25+
}
26+
27+
var (
28+
errType = reflect.TypeOf((*error)(nil)).Elem()
29+
)
30+
31+
func (checker *errorIsChecker) Check(params []interface{}, names []string) (result bool, err string) {
32+
if params[1] == nil || params[0] == nil {
33+
return params[1] == params[0], ""
34+
}
35+
36+
f := reflect.ValueOf(params[1])
37+
ft := f.Type()
38+
if !ft.Implements(errType) {
39+
return false, fmt.Sprintf("wrong error target type, got: %s", ft)
40+
}
41+
42+
v := reflect.ValueOf(params[0])
43+
vt := v.Type()
44+
if !v.IsValid() {
45+
return false, fmt.Sprintf("wrong argument type %s for %s", vt, ft)
46+
}
47+
if !vt.Implements(errType) {
48+
return false, fmt.Sprintf("wrong argument type %s for %s", vt, ft)
49+
}
50+
51+
return errors.Is(v.Interface().(error), f.Interface().(error)), ""
52+
}

checkers/error_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package checkers_test
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/juju/errors"
7+
jc "github.com/juju/testing/checkers"
8+
gc "gopkg.in/check.v1"
9+
)
10+
11+
type ErrorSuite struct{}
12+
13+
var _ = gc.Suite(&ErrorSuite{})
14+
15+
var errorIsTests = []struct {
16+
arg interface{}
17+
target interface{}
18+
result bool
19+
msg string
20+
}{{
21+
arg: fmt.Errorf("bar"),
22+
target: nil,
23+
result: false,
24+
}, {
25+
arg: nil,
26+
target: fmt.Errorf("bar"),
27+
result: false,
28+
}, {
29+
arg: nil,
30+
target: nil,
31+
result: true,
32+
}, {
33+
arg: fmt.Errorf("bar"),
34+
target: fmt.Errorf("foo"),
35+
result: false,
36+
}, {
37+
arg: errors.ConstError("bar"),
38+
target: errors.ConstError("foo"),
39+
result: false,
40+
}, {
41+
arg: errors.ConstError("foo"),
42+
target: errors.ConstError("foo"),
43+
result: true,
44+
}, {
45+
arg: errors.Trace(errors.ConstError("foo")),
46+
target: errors.ConstError("foo"),
47+
result: true,
48+
}, {
49+
arg: errors.ConstError("foo"),
50+
target: "blah",
51+
msg: "wrong error target type, got: string",
52+
}, {
53+
arg: "blah",
54+
target: errors.ConstError("foo"),
55+
msg: "wrong argument type string for errors.ConstError",
56+
}, {
57+
arg: (*error)(nil),
58+
target: errors.ConstError("foo"),
59+
msg: "wrong argument type *error for errors.ConstError",
60+
}}
61+
62+
func (s *ErrorSuite) TestErrorIs(c *gc.C) {
63+
for i, test := range errorIsTests {
64+
c.Logf("test %d. %T %T", i, test.arg, test.target)
65+
result, msg := jc.ErrorIs.Check([]interface{}{test.arg, test.target}, nil)
66+
c.Check(result, gc.Equals, test.result)
67+
c.Check(msg, gc.Equals, test.msg)
68+
}
69+
}

0 commit comments

Comments
 (0)