Skip to content

Commit ab2fa2a

Browse files
gfmcknightslozier
authored andcommitted
Remove comparison operator fallbacks (#695)
* Remove comparison operator fallbacks Remove type comparison fallbacks for the '<', '>', '<=', and '>=' operators and throw a TypeError instead when no suitable comparator is found between two types. * Update test__weakref to an expected failure and make ToIDictionary return tuples. * Make tuples of the keys and values together.
1 parent 7848ca5 commit ab2fa2a

11 files changed

Lines changed: 65 additions & 104 deletions

File tree

Src/IronPython/Runtime/Binding/PythonProtocol.Operations.cs

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,12 +1286,28 @@ private static void MakeCompareTest(PythonOperationKind op, ConditionalBuilder/*
12861286
}
12871287

12881288
private static Expression/*!*/ MakeFallbackCompare(DynamicMetaObjectBinder/*!*/ binder, PythonOperationKind op, DynamicMetaObject[] types) {
1289-
return Ast.Call(
1290-
GetComparisonFallbackMethod(op),
1291-
PythonContext.GetCodeContext(binder),
1292-
AstUtils.Convert(types[0].Expression, typeof(object)),
1293-
AstUtils.Convert(types[1].Expression, typeof(object))
1294-
);
1289+
if (op == PythonOperationKind.Equal || op == PythonOperationKind.NotEqual ||
1290+
op == PythonOperationKind.Compare) {
1291+
return Ast.Call(
1292+
GetComparisonFallbackMethod(op),
1293+
PythonContext.GetCodeContext(binder),
1294+
AstUtils.Convert(types[0].Expression, typeof(object)),
1295+
AstUtils.Convert(types[1].Expression, typeof(object))
1296+
);
1297+
}
1298+
1299+
return MakeBinaryThrow(binder, op, types).Expression;
1300+
1301+
static MethodInfo/*!*/ GetComparisonFallbackMethod(PythonOperationKind op) {
1302+
string name;
1303+
switch (op) {
1304+
case PythonOperationKind.Equal: name = nameof(PythonOps.CompareTypesEqual); break;
1305+
case PythonOperationKind.NotEqual: name = nameof(PythonOps.CompareTypesNotEqual); break;
1306+
case PythonOperationKind.Compare: name = nameof(PythonOps.CompareTypes); break;
1307+
default: throw new InvalidOperationException();
1308+
}
1309+
return typeof(PythonOps).GetMethod(name);
1310+
}
12951311
}
12961312

12971313
private static Expression GetCompareTest(PythonOperationKind op, Expression expr, bool reverse) {
@@ -1876,21 +1892,6 @@ public static PythonOperationKind OperatorToReverseOperator(PythonOperationKind
18761892
return BindingHelpers.AddPythonBoxing(res);
18771893
}
18781894

1879-
private static MethodInfo/*!*/ GetComparisonFallbackMethod(PythonOperationKind op) {
1880-
string name;
1881-
switch (op) {
1882-
case PythonOperationKind.Equal: name = nameof(PythonOps.CompareTypesEqual); break;
1883-
case PythonOperationKind.NotEqual: name = nameof(PythonOps.CompareTypesNotEqual); break;
1884-
case PythonOperationKind.GreaterThan: name = nameof(PythonOps.CompareTypesGreaterThan); break;
1885-
case PythonOperationKind.LessThan: name = nameof(PythonOps.CompareTypesLessThan); break;
1886-
case PythonOperationKind.GreaterThanOrEqual: name = nameof(PythonOps.CompareTypesGreaterThanOrEqual); break;
1887-
case PythonOperationKind.LessThanOrEqual: name = nameof(PythonOps.CompareTypesLessThanOrEqual); break;
1888-
case PythonOperationKind.Compare: name = nameof(PythonOps.CompareTypes); break;
1889-
default: throw new InvalidOperationException();
1890-
}
1891-
return typeof(PythonOps).GetMethod(name);
1892-
}
1893-
18941895
internal static Expression/*!*/ CheckMissing(Expression/*!*/ toCheck) {
18951896
if (toCheck.Type == typeof(MissingParameter)) {
18961897
return AstUtils.Constant(null);
@@ -1990,7 +1991,7 @@ public static PythonOperationKind OperatorToReverseOperator(PythonOperationKind
19901991
action.Throw(
19911992
Ast.Call(
19921993
typeof(PythonOps).GetMethod(nameof(PythonOps.TypeErrorForBinaryOp)),
1993-
AstUtils.Constant(Symbols.OperatorToSymbol(op)),
1994+
AstUtils.Constant(GetOperatorDisplay(op)),
19941995
AstUtils.Convert(args[0].Expression, typeof(object)),
19951996
AstUtils.Convert(args[1].Expression, typeof(object))
19961997
),

Src/IronPython/Runtime/Bytes.cs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -735,31 +735,19 @@ private static Bytes MultiplyWorker(Bytes self, int count) {
735735
throw PythonOps.TypeErrorForUnIndexableObject(count);
736736
}
737737

738-
public static bool operator >(Bytes/*!*/ x, Bytes/*!*/ y) {
739-
if (y == null) {
740-
return true;
741-
}
738+
public static bool operator >([NotNull]Bytes/*!*/ x, [NotNull]Bytes/*!*/ y) {
742739
return x._bytes.Compare(y._bytes) > 0;
743740
}
744741

745-
public static bool operator <(Bytes/*!*/ x, Bytes/*!*/ y) {
746-
if (y == null) {
747-
return false;
748-
}
742+
public static bool operator <([NotNull]Bytes/*!*/ x, [NotNull]Bytes/*!*/ y) {
749743
return x._bytes.Compare(y._bytes) < 0;
750744
}
751745

752-
public static bool operator >=(Bytes/*!*/ x, Bytes/*!*/ y) {
753-
if (y == null) {
754-
return true;
755-
}
746+
public static bool operator >=([NotNull]Bytes/*!*/ x, [NotNull]Bytes/*!*/ y) {
756747
return x._bytes.Compare(y._bytes) >= 0;
757748
}
758749

759-
public static bool operator <=(Bytes/*!*/ x, Bytes/*!*/ y) {
760-
if (y == null) {
761-
return false;
762-
}
750+
public static bool operator <=([NotNull]Bytes/*!*/ x, [NotNull]Bytes/*!*/ y) {
763751
return x._bytes.Compare(y._bytes) <= 0;
764752
}
765753

Src/IronPython/Runtime/Operations/PythonOps.cs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -607,22 +607,6 @@ public static bool CompareTypesNotEqual(CodeContext/*!*/ context, object x, obje
607607
return PythonOps.CompareTypesWorker(context, false, x, y) != 0;
608608
}
609609

610-
public static bool CompareTypesGreaterThan(CodeContext/*!*/ context, object x, object y) {
611-
return PythonOps.CompareTypes(context, x, y) > 0;
612-
}
613-
614-
public static bool CompareTypesLessThan(CodeContext/*!*/ context, object x, object y) {
615-
return PythonOps.CompareTypes(context, x, y) < 0;
616-
}
617-
618-
public static bool CompareTypesGreaterThanOrEqual(CodeContext/*!*/ context, object x, object y) {
619-
return PythonOps.CompareTypes(context, x, y) >= 0;
620-
}
621-
622-
public static bool CompareTypesLessThanOrEqual(CodeContext/*!*/ context, object x, object y) {
623-
return PythonOps.CompareTypes(context, x, y) <= 0;
624-
}
625-
626610
public static int CompareTypesWorker(CodeContext/*!*/ context, bool shouldWarn, object x, object y) {
627611
if (x == null && y == null) return 0;
628612
if (x == null) return -1;

Src/IronPythonTest/Conversions.cs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
33
// See the LICENSE file in the project root for more information.
44

5+
using IronPython.Runtime;
56
using System.Collections;
67
using System.Collections.Generic;
78

@@ -69,11 +70,7 @@ public class DictConversion {
6970
public static IList<object> ToIDictionary(IDictionary dict) {
7071
List<object> res = new List<object>();
7172
foreach (DictionaryEntry de in dict) {
72-
res.Add(de.Key);
73-
}
74-
75-
foreach (DictionaryEntry de in dict) {
76-
res.Add(de.Value);
73+
res.Add(PythonTuple.MakeTuple(de.Key, de.Value));
7774
}
7875

7976
return res;

Tests/modules/misc/test__weakref.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def __eq__(self, *args, **kwargs): return True
140140
self.assertTrue(not x==3)
141141
self.assertRaises(ReferenceError, lambda: x==y)
142142

143+
# https://github.com/IronLanguages/ironpython3/issues/697
144+
@unittest.expectedFailure
143145
def test_equals(self):
144146
global called
145147
class C:
@@ -155,12 +157,6 @@ def %s(self, *args, **kwargs):
155157
x = _weakref.proxy(a)
156158
for op in ('==', '>', '<', '>=', '<=', '!='):
157159
self.assertEqual(eval('a ' + op + ' 3'), True); self.assertEqual(called, op); called = None
158-
if op == '==' or op == '!=':
159-
self.assertEqual(eval('x ' + op + ' 3'), op == '!='); self.assertEqual(called, None)
160-
self.assertEqual(eval('3 ' + op + ' x'), op == '!='); self.assertEqual(called, None)
161-
else:
162-
res1, res2 = eval('x ' + op + ' 3'), eval('3 ' + op + ' x')
163-
self.assertEqual(called, None)
164-
self.assertTrue((res1 == True and res2 == False) or (res1 == False and res2 == True))
160+
self.assertEqual(eval('x ' + op + ' 3'), True); self.assertEqual(called, op); called = None
165161

166162
run_test(__name__)

Tests/modules/system_related/test_nt.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -586,48 +586,44 @@ def test_stat_result(self):
586586
self.assertEqual(x + x, tuple(range(10))*2)
587587

588588
#> (list/object)
589-
if is_cli:
590-
self.assertTrue(nt.stat_result(range(10)) > None)
591-
self.assertTrue(nt.stat_result(range(10)) > 1)
592-
self.assertTrue(nt.stat_result(range(10)) > range(10))
593589
self.assertTrue(nt.stat_result([1 for x in range(10)]) > nt.stat_result(range(10)))
594590
self.assertTrue(not nt.stat_result(range(10)) > nt.stat_result(range(10)))
595591
self.assertTrue(not nt.stat_result(range(10)) > nt.stat_result(range(11)))
596592
self.assertTrue(not nt.stat_result(range(10)) > nt.stat_result([1 for x in range(10)]))
597593
self.assertTrue(not nt.stat_result(range(11)) > nt.stat_result(range(10)))
594+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) > None)
595+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) > 1)
596+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) > range(10))
598597

599598
#< (list/object)
600-
if is_cli:
601-
self.assertTrue(not nt.stat_result(range(10)) < None)
602-
self.assertTrue(not nt.stat_result(range(10)) < 1)
603-
self.assertTrue(not nt.stat_result(range(10)) < range(10))
604599
self.assertTrue(not nt.stat_result([1 for x in range(10)]) < nt.stat_result(range(10)))
605600
self.assertTrue(not nt.stat_result(range(10)) < nt.stat_result(range(10)))
606601
self.assertTrue(not nt.stat_result(range(10)) < nt.stat_result(range(11)))
607602
self.assertTrue(nt.stat_result(range(10)) < nt.stat_result([1 for x in range(10)]))
608603
self.assertTrue(not nt.stat_result(range(11)) < nt.stat_result(range(10)))
604+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) < None)
605+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) < 1)
606+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) < range(10))
609607

610608
#>= (list/object)
611-
if is_cli:
612-
self.assertTrue(nt.stat_result(range(10)) >= None)
613-
self.assertTrue(nt.stat_result(range(10)) >= 1)
614-
self.assertTrue(nt.stat_result(range(10)) >= range(10))
615609
self.assertTrue(nt.stat_result([1 for x in range(10)]) >= nt.stat_result(range(10)))
616610
self.assertTrue(nt.stat_result(range(10)) >= nt.stat_result(range(10)))
617611
self.assertTrue(nt.stat_result(range(10)) >= nt.stat_result(range(11)))
618612
self.assertTrue(not nt.stat_result(range(10)) >= nt.stat_result([1 for x in range(10)]))
619613
self.assertTrue(nt.stat_result(range(11)) >= nt.stat_result(range(10)))
614+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) >= None)
615+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) >= 1)
616+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) >= range(10))
620617

621618
#<= (list/object)
622-
if is_cli:
623-
self.assertTrue(not nt.stat_result(range(10)) <= None)
624-
self.assertTrue(not nt.stat_result(range(10)) <= 1)
625-
self.assertTrue(not nt.stat_result(range(10)) <= range(10))
626619
self.assertTrue(not nt.stat_result([1 for x in range(10)]) <= nt.stat_result(range(10)))
627620
self.assertTrue(nt.stat_result(range(10)) <= nt.stat_result(range(10)))
628621
self.assertTrue(nt.stat_result(range(10)) <= nt.stat_result(range(11)))
629622
self.assertTrue(nt.stat_result(range(10)) <= nt.stat_result([1 for x in range(10)]))
630623
self.assertTrue(nt.stat_result(range(11)) <= nt.stat_result(range(10)))
624+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) <= None)
625+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) <= 1)
626+
self.assertRaises(TypeError, lambda: nt.stat_result(range(10)) <= range(10))
631627

632628
#* (size/stat_result)
633629
x = nt.stat_result(range(10))

Tests/test_attr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def CheckModule(mod):
5656
CheckDictionary(mod.__dict__)
5757

5858
mod.__dict__[1] = '1'
59-
self.assertEqual(dir(mod).__contains__(1), True)
59+
self.assertRaises(TypeError, lambda: dir(mod).__contains__(1))
6060
del mod.__dict__[1]
6161

6262
# Try to replace __dict__

Tests/test_builtinfunc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,14 @@ def test_max(self):
339339
self.assertTrue(max((), default=1) == 1)
340340
self.assertTrue(max([], default=None) is None)
341341

342-
self.assertEqual(max((1,2), None), (1, 2))
343-
344342
self.assertEqual(max(1, 2, 3.0), 3.0)
345343
self.assertEqual(max(1, 2.0, 3), 3)
346344
self.assertEqual(max(1.0, 2, 3), 3)
347345

348346
self.assertRaises(TypeError, max)
349347
self.assertRaises(TypeError, max, 42)
350348
self.assertRaises(ValueError, max, ())
349+
self.assertRaises(TypeError, max, (1, 2), None)
351350
class BadSeq:
352351
def __getitem__(self, index):
353352
raise ValueError

Tests/test_bytes.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,14 +1047,14 @@ def test_compares(self):
10471047

10481048
self.assertEqual(testType(ab) == [], False)
10491049

1050-
self.assertEqual(testType(a) > None, True)
1051-
self.assertEqual(testType(a) < None, False)
1052-
self.assertEqual(testType(a) <= None, False)
1053-
self.assertEqual(testType(a) >= None, True)
1054-
self.assertEqual(None > testType(a), False)
1055-
self.assertEqual(None < testType(a), True)
1056-
self.assertEqual(None <= testType(a), True)
1057-
self.assertEqual(None >= testType(a), False)
1050+
self.assertRaises(TypeError, lambda: testType(a) > None)
1051+
self.assertRaises(TypeError, lambda: testType(a) < None)
1052+
self.assertRaises(TypeError, lambda: testType(a) <= None)
1053+
self.assertRaises(TypeError, lambda: testType(a) >= None)
1054+
self.assertRaises(TypeError, lambda: None > testType(a))
1055+
self.assertRaises(TypeError, lambda: None < testType(a))
1056+
self.assertRaises(TypeError, lambda: None <= testType(a))
1057+
self.assertRaises(TypeError, lambda: None >= testType(a))
10581058

10591059

10601060
def test_bytearray(self):

Tests/test_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class KNewDerived(KNew): pass
492492
]
493493

494494
for temp_dict in test_dicts:
495-
expected = list(temp_dict.keys()) + list(temp_dict.values())
495+
expected = list((key, temp_dict[key]) for key in temp_dict.keys())
496496
expected.sort()
497497

498498
to_idict = list(DictConversion.ToIDictionary(temp_dict))

0 commit comments

Comments
 (0)