Skip to content
This repository was archived by the owner on Mar 23, 2023. It is now read-only.

Commit d65582e

Browse files
corona10trotterdylan
authored andcommitted
Implement Complex.Add and Sub (#294)
1 parent 6841671 commit d65582e

3 files changed

Lines changed: 125 additions & 23 deletions

File tree

runtime/complex.go

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,43 @@ func (c *Complex) Value() complex128 {
4949
return c.value
5050
}
5151

52+
func complexAdd(f *Frame, v, w *Object) (*Object, *BaseException) {
53+
return complexArithmeticOp(f, "__add__", v, w, func(lhs, rhs complex128) complex128 {
54+
return lhs + rhs
55+
})
56+
}
57+
58+
func complexEq(f *Frame, v, w *Object) (*Object, *BaseException) {
59+
e, ok := complexCompare(toComplexUnsafe(v), w)
60+
if !ok {
61+
return NotImplemented, nil
62+
}
63+
return GetBool(e).ToObject(), nil
64+
}
65+
66+
func complexHash(f *Frame, o *Object) (*Object, *BaseException) {
67+
v := toComplexUnsafe(o).Value()
68+
hashCombined := hashFloat(real(v)) + 1000003*hashFloat(imag(v))
69+
if hashCombined == -1 {
70+
hashCombined = -2
71+
}
72+
return NewInt(hashCombined).ToObject(), nil
73+
}
74+
75+
func complexNE(f *Frame, v, w *Object) (*Object, *BaseException) {
76+
e, ok := complexCompare(toComplexUnsafe(v), w)
77+
if !ok {
78+
return NotImplemented, nil
79+
}
80+
return GetBool(!e).ToObject(), nil
81+
}
82+
83+
func complexRAdd(f *Frame, v, w *Object) (*Object, *BaseException) {
84+
return complexArithmeticOp(f, "__radd__", v, w, func(lhs, rhs complex128) complex128 {
85+
return lhs + rhs
86+
})
87+
}
88+
5289
func complexRepr(f *Frame, o *Object) (*Object, *BaseException) {
5390
c := toComplexUnsafe(o).Value()
5491
rs, is := "", ""
@@ -68,31 +105,31 @@ func complexRepr(f *Frame, o *Object) (*Object, *BaseException) {
68105
return NewStr(fmt.Sprintf("%s%s%s%sj%s", pre, rs, sign, is, post)).ToObject(), nil
69106
}
70107

108+
func complexRSub(f *Frame, v, w *Object) (*Object, *BaseException) {
109+
return complexArithmeticOp(f, "__rsub__", v, w, func(lhs, rhs complex128) complex128 {
110+
return rhs - lhs
111+
})
112+
}
113+
114+
func complexSub(f *Frame, v, w *Object) (*Object, *BaseException) {
115+
return complexArithmeticOp(f, "__rsub__", v, w, func(lhs, rhs complex128) complex128 {
116+
return lhs - rhs
117+
})
118+
}
119+
71120
func initComplexType(dict map[string]*Object) {
121+
ComplexType.slots.Add = &binaryOpSlot{complexAdd}
72122
ComplexType.slots.Eq = &binaryOpSlot{complexEq}
73123
ComplexType.slots.GE = &binaryOpSlot{complexCompareNotSupported}
74124
ComplexType.slots.GT = &binaryOpSlot{complexCompareNotSupported}
75125
ComplexType.slots.Hash = &unaryOpSlot{complexHash}
76126
ComplexType.slots.LE = &binaryOpSlot{complexCompareNotSupported}
77127
ComplexType.slots.LT = &binaryOpSlot{complexCompareNotSupported}
78128
ComplexType.slots.NE = &binaryOpSlot{complexNE}
129+
ComplexType.slots.RAdd = &binaryOpSlot{complexRAdd}
79130
ComplexType.slots.Repr = &unaryOpSlot{complexRepr}
80-
}
81-
82-
func complexEq(f *Frame, v, w *Object) (*Object, *BaseException) {
83-
e, ok := complexCompare(toComplexUnsafe(v), w)
84-
if !ok {
85-
return NotImplemented, nil
86-
}
87-
return GetBool(e).ToObject(), nil
88-
}
89-
90-
func complexNE(f *Frame, v, w *Object) (*Object, *BaseException) {
91-
e, ok := complexCompare(toComplexUnsafe(v), w)
92-
if !ok {
93-
return NotImplemented, nil
94-
}
95-
return GetBool(!e).ToObject(), nil
131+
ComplexType.slots.RSub = &binaryOpSlot{complexRSub}
132+
ComplexType.slots.Sub = &binaryOpSlot{complexSub}
96133
}
97134

98135
func complexCompare(v *Complex, w *Object) (bool, bool) {
@@ -132,11 +169,17 @@ func complexCoerce(o *Object) (complex128, bool) {
132169
return complex(floatO, 0.0), true
133170
}
134171

135-
func complexHash(f *Frame, o *Object) (*Object, *BaseException) {
136-
v := toComplexUnsafe(o).Value()
137-
hashCombined := hashFloat(real(v)) + 1000003*hashFloat(imag(v))
138-
if hashCombined == -1 {
139-
hashCombined = -2
172+
func complexArithmeticOp(f *Frame, method string, v, w *Object, fun func(v, w complex128) complex128) (*Object, *BaseException) {
173+
if w.isInstance(ComplexType) {
174+
return NewComplex(fun(toComplexUnsafe(v).Value(), toComplexUnsafe(w).Value())).ToObject(), nil
140175
}
141-
return NewInt(hashCombined).ToObject(), nil
176+
177+
floatW, ok := floatCoerce(w)
178+
if !ok {
179+
if math.IsInf(floatW, 0) {
180+
return nil, f.RaiseType(OverflowErrorType, "long int too large to convert to float")
181+
}
182+
return NotImplemented, nil
183+
}
184+
return NewComplex(fun(toComplexUnsafe(v).Value(), complex(floatW, 0))).ToObject(), nil
142185
}

runtime/complex_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ package grumpy
1616

1717
import (
1818
"math"
19+
"math/big"
20+
"math/cmplx"
1921
"testing"
2022
)
2123

@@ -42,6 +44,56 @@ func TestComplexEq(t *testing.T) {
4244
}
4345
}
4446

47+
func TestComplexBinaryOps(t *testing.T) {
48+
cases := []struct {
49+
fun func(f *Frame, v, w *Object) (*Object, *BaseException)
50+
v, w *Object
51+
want *Object
52+
wantExc *BaseException
53+
}{
54+
{Add, NewComplex(1 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(2 + 3i).ToObject(), nil},
55+
{Add, NewComplex(1 + 3i).ToObject(), NewFloat(-1).ToObject(), NewComplex(3i).ToObject(), nil},
56+
{Add, NewComplex(1 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(2 + 3i).ToObject(), nil},
57+
{Add, NewComplex(1 + 3i).ToObject(), NewComplex(-1 - 3i).ToObject(), NewComplex(0i).ToObject(), nil},
58+
{Add, NewFloat(math.Inf(1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(1), 3)).ToObject(), nil},
59+
{Add, NewFloat(math.Inf(-1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(-1), 3)).ToObject(), nil},
60+
{Add, NewFloat(math.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), 3)).ToObject(), nil},
61+
{Add, NewComplex(cmplx.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(cmplx.NaN()).ToObject(), nil},
62+
{Add, NewFloat(math.Inf(-1)).ToObject(), NewComplex(complex(math.Inf(+1), 3)).ToObject(), NewComplex(complex(math.NaN(), 3)).ToObject(), nil},
63+
{Add, NewComplex(1 + 3i).ToObject(), None, nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for +: 'complex' and 'NoneType'")},
64+
{Add, None, NewComplex(1 + 3i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for +: 'NoneType' and 'complex'")},
65+
{Add, NewInt(3).ToObject(), NewComplex(3i).ToObject(), NewComplex(3 + 3i).ToObject(), nil},
66+
{Add, NewLong(big.NewInt(9999999)).ToObject(), NewComplex(3i).ToObject(), NewComplex(9999999 + 3i).ToObject(), nil},
67+
{Add, NewFloat(3.5).ToObject(), NewComplex(3i).ToObject(), NewComplex(3.5 + 3i).ToObject(), nil},
68+
{Sub, NewComplex(1 + 3i).ToObject(), NewComplex(1 + 3i).ToObject(), NewComplex(0i).ToObject(), nil},
69+
{Sub, NewComplex(1 + 3i).ToObject(), NewComplex(3i).ToObject(), NewComplex(1).ToObject(), nil},
70+
{Sub, NewComplex(1 + 3i).ToObject(), NewFloat(1).ToObject(), NewComplex(3i).ToObject(), nil},
71+
{Sub, NewComplex(3i).ToObject(), NewFloat(1.2).ToObject(), NewComplex(-1.2 + 3i).ToObject(), nil},
72+
{Sub, NewComplex(1 + 3i).ToObject(), NewComplex(1 + 3i).ToObject(), NewComplex(0i).ToObject(), nil},
73+
{Sub, NewComplex(4 + 3i).ToObject(), NewInt(1).ToObject(), NewComplex(3 + 3i).ToObject(), nil},
74+
{Sub, NewComplex(4 + 3i).ToObject(), NewLong(big.NewInt(99994)).ToObject(), NewComplex(-99990 + 3i).ToObject(), nil},
75+
{Sub, NewFloat(math.Inf(1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(1), -3)).ToObject(), nil},
76+
{Sub, NewFloat(math.Inf(-1)).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.Inf(-1), -3)).ToObject(), nil},
77+
{Sub, NewComplex(1 + 3i).ToObject(), None, nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for -: 'complex' and 'NoneType'")},
78+
{Sub, None, NewComplex(1 + 3i).ToObject(), nil, mustCreateException(TypeErrorType, "unsupported operand type(s) for -: 'NoneType' and 'complex'")},
79+
{Sub, NewFloat(math.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(complex(math.NaN(), -3)).ToObject(), nil},
80+
{Sub, NewComplex(cmplx.NaN()).ToObject(), NewComplex(3i).ToObject(), NewComplex(cmplx.NaN()).ToObject(), nil},
81+
{Sub, NewFloat(math.Inf(-1)).ToObject(), NewComplex(complex(math.Inf(-1), 3)).ToObject(), NewComplex(complex(math.NaN(), -3)).ToObject(), nil},
82+
}
83+
84+
for _, cas := range cases {
85+
switch got, result := checkInvokeResult(wrapFuncForTest(cas.fun), []*Object{cas.v, cas.w}, cas.want, cas.wantExc); result {
86+
case checkInvokeResultExceptionMismatch:
87+
t.Errorf("%s(%v, %v) raised %v, want %v", getFuncName(cas.fun), cas.v, cas.w, got, cas.wantExc)
88+
case checkInvokeResultReturnValueMismatch:
89+
if got == nil || cas.want == nil || !got.isInstance(ComplexType) || !cas.want.isInstance(ComplexType) ||
90+
!complexesAreSame(toComplexUnsafe(got).Value(), toComplexUnsafe(cas.want).Value()) {
91+
t.Errorf("%s(%v, %v) = %v, want %v", getFuncName(cas.fun), cas.v, cas.w, got, cas.want)
92+
}
93+
}
94+
}
95+
}
96+
4597
func TestComplexCompareNotSupported(t *testing.T) {
4698
cases := []invokeTestCase{
4799
{args: wrapArgs(complex(1, 2), 1), wantExc: mustCreateException(TypeErrorType, "no ordering relation is defined for complex numbers")},
@@ -108,3 +160,11 @@ func TestComplexHash(t *testing.T) {
108160
}
109161
}
110162
}
163+
164+
func floatsAreSame(a, b float64) bool {
165+
return a == b || (math.IsNaN(a) && math.IsNaN(b))
166+
}
167+
168+
func complexesAreSame(a, b complex128) bool {
169+
return floatsAreSame(real(a), real(b)) && floatsAreSame(imag(a), imag(b))
170+
}

third_party/ouroboros/test/test_operator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,6 @@ def test_dunder_is_original(self):
472472
if dunder:
473473
self.assertIs(dunder, orig)
474474

475-
@unittest.expectedFailure
476475
def test_complex_operator(self):
477476
self.assertRaises(TypeError, operator.lt, 1j, 2j)
478477
self.assertRaises(TypeError, operator.le, 1j, 2j)

0 commit comments

Comments
 (0)