Skip to content

Commit 9c97198

Browse files
authored
Changes from ipy2 (#939)
* Don't insert __file__ to sys.path * Add bz2 test from ipy2 * Accept bytes in getaddrinfo * Clean up _ssl
1 parent f32bf73 commit 9c97198

4 files changed

Lines changed: 84 additions & 58 deletions

File tree

Src/IronPython.Modules/_socket.cs

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,10 +1104,10 @@ public static PythonList getaddrinfo(
11041104
CodeContext/*!*/ context,
11051105
string host,
11061106
object port,
1107-
int family= (int)AddressFamily.Unspecified,
1108-
int socktype=0,
1109-
int proto=(int)ProtocolType.IP,
1110-
int flags=(int)SocketFlags.None
1107+
int family = (int)AddressFamily.Unspecified,
1108+
int socktype = 0,
1109+
int proto = (int)ProtocolType.IP,
1110+
int flags = (int)SocketFlags.None
11111111
) {
11121112
int numericPort;
11131113

@@ -1117,28 +1117,25 @@ public static PythonList getaddrinfo(
11171117
numericPort = (int)port;
11181118
} else if (port is Extensible<int>) {
11191119
numericPort = ((Extensible<int>)port).Value;
1120+
} else if (port is Bytes) {
1121+
numericPort = ParsePort(context, ((Bytes)port).MakeString());
11201122
} else if (port is string) {
1121-
if (!Int32.TryParse((string)port, out numericPort)) {
1122-
try{
1123-
port = getservbyname(context,(string)port,null);
1124-
}
1125-
catch{
1126-
throw PythonExceptions.CreateThrowable(gaierror(context), "getaddrinfo failed");
1127-
}
1128-
}
1123+
numericPort = ParsePort(context, (string)port);
11291124
} else if (port is ExtensibleString) {
1130-
if (!Int32.TryParse(((ExtensibleString)port).Value, out numericPort)) {
1131-
try{
1132-
port = getservbyname(context, (string)port, null);
1133-
}
1134-
catch{
1135-
throw PythonExceptions.CreateThrowable(gaierror(context), "getaddrinfo failed");
1136-
}
1137-
}
1125+
numericPort = ParsePort(context, ((ExtensibleString)port).Value);
11381126
} else {
11391127
throw PythonExceptions.CreateThrowable(gaierror(context), "getaddrinfo failed");
11401128
}
11411129

1130+
static int ParsePort(CodeContext context, string port) {
1131+
if (int.TryParse(port, out var numericPort)) return numericPort;
1132+
try {
1133+
return getservbyname(context, port, null);
1134+
} catch {
1135+
throw PythonExceptions.CreateThrowable(gaierror(context), "getaddrinfo failed");
1136+
}
1137+
}
1138+
11421139
if (socktype != 0) {
11431140
// we just use this to validate; socketType isn't actually used
11441141
System.Net.Sockets.SocketType socketType = (System.Net.Sockets.SocketType)Enum.ToObject(typeof(System.Net.Sockets.SocketType), socktype);
@@ -1175,6 +1172,17 @@ public static PythonList getaddrinfo(
11751172
return results;
11761173
}
11771174

1175+
[Documentation("")]
1176+
public static PythonList getaddrinfo(
1177+
CodeContext/*!*/ context,
1178+
[NotNull] Bytes host,
1179+
object port,
1180+
int family = (int)AddressFamily.Unspecified,
1181+
int socktype = 0,
1182+
int proto = (int)ProtocolType.IP,
1183+
int flags = (int)SocketFlags.None
1184+
) => getaddrinfo(context, host.MakeString(), port, family, socktype, proto, flags);
1185+
11781186
private static PythonType gaierror(CodeContext/*!*/ context) {
11791187
return (PythonType)context.LanguageContext.GetModuleState("socketgaierror");
11801188
}

Src/IronPython.Modules/_ssl.cs

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ static PythonSsl() {
7777
}
7878

7979
[SpecialName]
80-
public static void PerformModuleReload(PythonContext/*!*/ context, PythonDictionary/*!*/ dict) {
80+
public static void PerformModuleReload(PythonContext/*!*/ context, PythonDictionary/*!*/ dict) {
8181
var sslError = context.EnsureModuleException("SSLError", PythonSocket.error, dict, "SSLError", "ssl");
8282
context.EnsureModuleException("SSLZeroReturnError", sslError, dict, "SSLZeroReturnError", "ssl");
8383
context.EnsureModuleException("SSLWantWriteError", sslError, dict, "SSLWantWriteError", "ssl");
@@ -112,7 +112,7 @@ public class _SSLContext {
112112
private string _cafile;
113113
private int _verify_mode = SSL_VERIFY_NONE;
114114

115-
public _SSLContext(CodeContext context, [DefaultParameterValue(PROTOCOL_SSLv23)] int protocol) {
115+
public _SSLContext(CodeContext context, int protocol = PROTOCOL_SSLv23) {
116116
if (protocol != PROTOCOL_SSLv2 && protocol != PROTOCOL_SSLv23 && protocol != PROTOCOL_SSLv3 &&
117117
protocol != PROTOCOL_TLSv1 && protocol != PROTOCOL_TLSv1_1 && protocol != PROTOCOL_TLSv1_2) {
118118
throw PythonOps.ValueError("invalid protocol version");
@@ -141,7 +141,7 @@ public int verify_mode {
141141
return _verify_mode;
142142
}
143143
set {
144-
if(_verify_mode != CERT_NONE && _verify_mode != CERT_OPTIONAL && _verify_mode != CERT_REQUIRED) {
144+
if (_verify_mode != CERT_NONE && _verify_mode != CERT_OPTIONAL && _verify_mode != CERT_REQUIRED) {
145145
throw PythonOps.ValueError("invalid value for verify_mode");
146146
}
147147
_verify_mode = value;
@@ -160,24 +160,24 @@ public void set_default_verify_paths(CodeContext context) {
160160

161161
}
162162

163-
public void load_cert_chain(string certfile, [DefaultParameterValue(null)] string keyfile, [DefaultParameterValue(null)] object password) {
163+
public void load_cert_chain(string certfile, string keyfile = null, object password = null) {
164164

165165
}
166166

167-
public void load_verify_locations(CodeContext context, [DefaultParameterValue(null)] string cafile, [DefaultParameterValue(null)] string capath, [DefaultParameterValue(null)] object cadata) {
168-
if(cafile == null && capath == null && cadata == null) {
167+
public void load_verify_locations(CodeContext context, string cafile = null, string capath = null, object cadata = null) {
168+
if (cafile == null && capath == null && cadata == null) {
169169
throw PythonOps.TypeError("cafile, capath and cadata cannot be all omitted");
170170
}
171171

172-
if(cafile != null) {
172+
if (cafile != null) {
173173
_cert_store.Add(ReadCertificate(context, cafile));
174174
_cafile = cafile;
175175
}
176176

177-
if(capath != null) {
177+
if (capath != null) {
178178
}
179179

180-
if(cadata != null) {
180+
if (cadata != null) {
181181
var cabuf = cadata as IBufferProtocol;
182182
if (cabuf != null) {
183183
int pos = 0;
@@ -196,25 +196,25 @@ public void load_verify_locations(CodeContext context, [DefaultParameterValue(nu
196196
}
197197
}
198198

199-
public object _wrap_socket(CodeContext context, [DefaultParameterValue(null)] PythonSocket.socket sock, [DefaultParameterValue(false)] bool server_side, [DefaultParameterValue(null)] string server_hostname, [DefaultParameterValue(null)] object ssl_sock) {
199+
public object _wrap_socket(CodeContext context, PythonSocket.socket sock = null, bool server_side = false, string server_hostname = null, object ssl_sock = null) {
200200
return new PythonSocket.ssl(context, sock, server_side, null, _cafile, verify_mode, protocol | options, null, _cert_store) { _serverHostName = server_hostname };
201201
}
202202
}
203203

204204
#endregion
205205

206206
public static PythonType SSLType = DynamicHelpers.GetPythonTypeFromType(typeof(PythonSocket.ssl));
207-
207+
208208
public static PythonSocket.ssl sslwrap(
209209
CodeContext context,
210-
PythonSocket.socket socket,
211-
bool server_side,
212-
string keyfile=null,
213-
string certfile=null,
214-
int certs_mode=PythonSsl.CERT_NONE,
215-
int protocol= (PythonSsl.PROTOCOL_SSLv23 | PythonSsl.OP_NO_SSLv2 | PythonSsl.OP_NO_SSLv3),
216-
string cacertsfile=null,
217-
object ciphers=null) {
210+
PythonSocket.socket socket,
211+
bool server_side,
212+
string keyfile = null,
213+
string certfile = null,
214+
int certs_mode = PythonSsl.CERT_NONE,
215+
int protocol = (PythonSsl.PROTOCOL_SSLv23 | PythonSsl.OP_NO_SSLv2 | PythonSsl.OP_NO_SSLv3),
216+
string cacertsfile = null,
217+
object ciphers = null) {
218218
return new PythonSocket.ssl(
219219
context,
220220
socket,
@@ -228,29 +228,28 @@ public static PythonSocket.ssl sslwrap(
228228
);
229229
}
230230

231-
public static object txt2obj(CodeContext context, string txt, [DefaultParameterValue(false)] object name) {
232-
bool nam = PythonOps.IsTrue(name); // if true, we also look at short name and long name
231+
public static object txt2obj(CodeContext context, string txt, bool name = false) {
233232
Asn1Object obj = null;
234-
if(nam) {
233+
if (name) {
235234
obj = _asn1Objects.Where(x => txt == x.OIDString || txt == x.ShortName || txt == x.LongName).FirstOrDefault();
236235
} else {
237236
obj = _asn1Objects.Where(x => txt == x.OIDString).FirstOrDefault();
238237
}
239238

240-
if(obj == null) {
239+
if (obj == null) {
241240
throw PythonOps.ValueError("unknown object '{0}'", txt);
242241
}
243242

244243
return obj.ToTuple();
245244
}
246245

247246
public static object nid2obj(CodeContext context, int nid) {
248-
if(nid < 0) {
247+
if (nid < 0) {
249248
throw PythonOps.ValueError("NID must be positive");
250249
}
251250

252251
var obj = _asn1Objects.Where(x => x.NID == nid).FirstOrDefault();
253-
if(obj == null) {
252+
if (obj == null) {
254253
throw PythonOps.ValueError("unknown NID {0}", nid);
255254
}
256255

@@ -267,7 +266,7 @@ public static PythonList enum_certificates(string store_name) {
267266
foreach (var cert in store.Certificates) {
268267
string format = cert.GetFormat();
269268

270-
switch(format) {
269+
switch (format) {
271270
case "X509":
272271
format = "x509_asn";
273272
break;
@@ -281,7 +280,7 @@ public static PythonList enum_certificates(string store_name) {
281280
foreach (var ext in cert.Extensions) {
282281
var keyUsage = ext as X509EnhancedKeyUsageExtension;
283282
if (keyUsage != null) {
284-
foreach(var oid in keyUsage.EnhancedKeyUsages) {
283+
foreach (var oid in keyUsage.EnhancedKeyUsages) {
285284
set.add(oid.Value);
286285
}
287286
found = true;
@@ -496,7 +495,7 @@ internal static X509Certificate2 ReadCertificate(CodeContext context, string fil
496495
if (key != null) {
497496
try {
498497
cert.PrivateKey = key;
499-
} catch(CryptographicException e) {
498+
} catch (CryptographicException e) {
500499
throw ErrorDecoding(context, filename, "cert and private key are incompatible", e);
501500
}
502501
}
@@ -566,8 +565,8 @@ private static RSACryptoServiceProvider ParsePkcs1DerEncodedPrivateKey(CodeConte
566565
parameters.DP = ReadUnivesalIntAsBytes(x, ref offset);
567566
parameters.DQ = ReadUnivesalIntAsBytes(x, ref offset);
568567
parameters.InverseQ = ReadUnivesalIntAsBytes(x, ref offset);
569-
570-
provider.ImportParameters(parameters);
568+
569+
provider.ImportParameters(parameters);
571570
return provider;
572571
}
573572

@@ -587,7 +586,7 @@ private static byte[] ReadUnivesalIntAsBytes(byte[] x, ref int offset) {
587586
for (int i = 0; i < res.Length; i++) {
588587
res[i] = x[offset++];
589588
}
590-
589+
591590
return res;
592591

593592
}
@@ -596,7 +595,7 @@ private static void ReadIntType(byte[] x, ref int offset) {
596595
int versionType = x[offset++];
597596
if (versionType != UniversalInteger) {
598597
throw new InvalidOperationException(String.Format("expected version, fonud {0}", versionType));
599-
}
598+
}
600599
}
601600
private static int ReadUnivesalInt(byte[] x, ref int offset) {
602601
ReadIntType(x, ref offset);
@@ -629,16 +628,16 @@ private static int ReadInt(byte[] x, ref int offset, int bytes) {
629628
/// BER encoding of an integer value is the number of bytes
630629
/// required to represent the integer followed by the bytes
631630
/// </summary>
632-
private static int ReadInt(byte[] x, ref int offset) {
631+
private static int ReadInt(byte[] x, ref int offset) {
633632
int bytes = x[offset++];
634-
633+
635634
return ReadInt(x, ref offset, bytes);
636635
}
637636

638637
private static string ReadToEnd(string[] lines, ref int start, string end) {
639638
StringBuilder key = new StringBuilder();
640639
for (start++; start < lines.Length; start++) {
641-
if (lines[start] == end) {
640+
if (lines[start] == end) {
642641
return key.ToString();
643642
}
644643
key.Append(lines[start]);
@@ -706,4 +705,5 @@ private static Exception ErrorDecoding(CodeContext context, params object[] args
706705
#endregion
707706
}
708707
}
708+
709709
#endif

Src/IronPython/Runtime/Importer.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -816,14 +816,13 @@ internal static bool TryImportMainFromZip(CodeContext/*!*/ context, string/*!*/
816816
return false;
817817
}
818818
importCache[path] = importer = FindImporterForPath(context, path);
819-
if (importer == null) {
819+
if (importer is null || importer is PythonImport.NullImporter) {
820820
return false;
821821
}
822822
// for consistency with cpython, insert zip as a first entry into sys.path
823823
var syspath = context.LanguageContext.GetSystemStateValue("path") as PythonList;
824824
syspath?.Insert(0, path);
825-
object dummy;
826-
return FindAndLoadModuleFromImporter(context, importer, "__main__", null, out dummy);
825+
return FindAndLoadModuleFromImporter(context, importer, "__main__", null, out _);
827826
}
828827

829828
private static object LoadFromDisk(CodeContext context, string name, string fullName, string str) {

Tests/test_bz2.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Licensed to the .NET Foundation under one or more agreements.
2+
# The .NET Foundation licenses this file to you under the Apache 2.0 License.
3+
# See the LICENSE file in the project root for more information.
4+
5+
import bz2
6+
import unittest
7+
8+
from iptest import run_test
9+
10+
class BZ2Test(unittest.TestCase):
11+
def test_bz2file(self):
12+
"""https://github.com/IronLanguages/ironpython2/pull/739"""
13+
14+
# BZ2File should not fail on invalid files, only on read
15+
with bz2.BZ2File(__file__, 'r') as f:
16+
with self.assertRaises(IOError):
17+
f.read()
18+
19+
run_test(__name__)

0 commit comments

Comments
 (0)