Skip to content

Commit d9100a1

Browse files
in-code-i-trustslozier
authored andcommitted
Get super() and __class__ both works in non-global context. (#721)
1 parent ecb0d94 commit d9100a1

4 files changed

Lines changed: 75 additions & 20 deletions

File tree

Src/IronPython/Compiler/Ast/PythonNameBinder.cs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -821,10 +821,28 @@ public override bool Walk(CallExpression node) {
821821
node.Parent = _currentScope;
822822

823823
if (node.Target is NameExpression nameExpr && nameExpr.Name == "super" && _currentScope is FunctionDefinition func) {
824-
_currentScope.Reference("__class__");
825824
if (node.Args.Length == 0 && func.ParameterNames.Length > 0) {
826-
node.ImplicitArgs.Add(new Arg(new NameExpression("__class__")));
827-
node.ImplicitArgs.Add(new Arg(new NameExpression(func.ParameterNames[0])));
825+
if (ShouldExpandSuperSyntaxSugar(node)) {
826+
// if `super()` is referenced in a class method.
827+
_currentScope.Reference(node.Parent.Parent.Name);
828+
node.ImplicitArgs.Add(new Arg(new NameExpression(node.Parent.Parent.Name)));
829+
node.ImplicitArgs.Add(new Arg(new NameExpression(func.ParameterNames[0])));
830+
} else {
831+
// otherwise, fallback to default implementation.
832+
_currentScope.Reference("__class__");
833+
node.ImplicitArgs.Add(new Arg(new NameExpression("__class__")));
834+
node.ImplicitArgs.Add(new Arg(new NameExpression(func.ParameterNames[0])));
835+
}
836+
}
837+
838+
bool ShouldExpandSuperSyntaxSugar(CallExpression node) {
839+
if (!(node.Parent is FunctionDefinition)) {
840+
return false;
841+
}
842+
if (!(node.Parent.Parent is ClassDefinition)) {
843+
return false;
844+
}
845+
return true;
828846
}
829847
}
830848
return base.Walk(node);

Src/IronPython/Compiler/Parser.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1962,8 +1962,28 @@ private Expression ParseAtom() {
19621962
NextToken();
19631963
string name = (string)t.Value;
19641964
_sink?.StartName(GetSourceSpan(), name);
1965-
ret = new NameExpression(FixName(name));
1965+
var fixedName = FixName(name);
1966+
if (ShouldReplaceClassKeyword(fixedName)) {
1967+
// if `__class__` variable is used in a class method, replace with `name_of_the_class`.
1968+
ret = new NameExpression(CurrentClass.Name);
1969+
} else {
1970+
ret = new NameExpression(fixedName);
1971+
}
19661972
ret.SetLoc(_globalParent, GetStart(), GetEnd());
1973+
1974+
bool ShouldReplaceClassKeyword(string fixedName) {
1975+
if (CurrentClass == null || CurrentFunction == null) {
1976+
return false;
1977+
}
1978+
if (PeekToken(TokenKind.Assign)) {
1979+
// avoid replacing statement like `__class__ = getproperty(...)`.
1980+
return false;
1981+
}
1982+
if (fixedName != "__class__") {
1983+
return false;
1984+
}
1985+
return true;
1986+
}
19671987
return ret;
19681988
case TokenKind.Constant: // literal
19691989
NextToken();

Tests/test_dict.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,10 @@ def __setitem__(self, *args):
5555
super(MyDict, self).__setitem__(*args)
5656

5757
a = MyDict()
58-
with self.assertRaises(SystemError): # TODO: remove assertRaises when https://github.com/IronLanguages/ironpython3/issues/456 is fixed
59-
a[0] = 'abc'
60-
self.assertEqual(a[0], 'abc')
61-
with self.assertRaises(SystemError): # TODO: remove assertRaises when https://github.com/IronLanguages/ironpython3/issues/456 is fixed
62-
a[None] = 3
63-
self.assertEqual(a[None], 3)
58+
a[0] = 'abc'
59+
self.assertEqual(a[0], 'abc')
60+
a[None] = 3
61+
self.assertEqual(a[None], 3)
6462

6563
class MyDict(dict):
6664
def __setitem__(self, *args):

Tests/test_regressions.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -996,16 +996,7 @@ def __init__(self):
996996
super(two, self).__init__()
997997
self.cnt += 1
998998

999-
if is_cli:
1000-
try:
1001-
self.assertEqual(two().cnt, 3)
1002-
except SystemError:
1003-
# https://github.com/IronLanguages/ironpython3/issues/451
1004-
pass
1005-
else:
1006-
self.fail("delete the try/except when https://github.com/IronLanguages/ironpython3/issues/451 is fixed")
1007-
else:
1008-
self.assertEqual(two().cnt, 3)
999+
self.assertEqual(two().cnt, 3)
10091000

10101001
def test_ipy2_gh292(self):
10111002
"""https://github.com/IronLanguages/ironpython2/issues/292"""
@@ -1449,4 +1440,32 @@ def run(self):
14491440
self.assertEqual(thread.thread_executed, 1)
14501441
test_thread()
14511442

1443+
def test_ipy3_gh451(self):
1444+
"""https://github.com/IronLanguages/ironpython3/issues/451"""
1445+
def test():
1446+
class two(object):
1447+
def __init__(self):
1448+
super().__init__()
1449+
self.cnt = 1
1450+
return two().cnt
1451+
self.assertEqual(test(), 1)
1452+
1453+
class test__class__keyword(object):
1454+
def __new__(cls):
1455+
return super().__new__(cls)
1456+
def __init__(self):
1457+
super().__init__()
1458+
def get_self_class(self):
1459+
return self.__class__
1460+
def get_class(self):
1461+
return __class__
1462+
@classmethod
1463+
def get_class_class(cls):
1464+
return cls
1465+
1466+
o = test__class__keyword()
1467+
self.assertEqual(o.get_self_class(), o.get_class())
1468+
self.assertEqual(o.get_class(), o.get_class_class())
1469+
1470+
14521471
run_test(__name__)

0 commit comments

Comments
 (0)