Skip to content

Commit c98e0d5

Browse files
authored
Misc changes (#1384)
* Update mappingproxy * Improve perf on BufferedReader.readline * _datetime changes * Fix CVE-2022-0391
1 parent 474a4aa commit c98e0d5

6 files changed

Lines changed: 200 additions & 39 deletions

File tree

Src/IronPython.Modules/_datetime.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,14 +670,21 @@ public object __ne__(object other) {
670670
return string.Format("datetime.date({0}, {1}, {2})", _dateTime.Year, _dateTime.Month, _dateTime.Day);
671671
}
672672

673-
public virtual string __format__(CodeContext/*!*/ context, string dateFormat){
673+
public virtual string __format__(CodeContext/*!*/ context, [NotNull] string dateFormat){
674674
if (string.IsNullOrEmpty(dateFormat)) {
675675
return PythonOps.ToString(context, this);
676676
} else {
677677
return strftime(context, dateFormat);
678678
}
679679
}
680680

681+
// overload to make test_datetime happy
682+
public string __format__(CodeContext/*!*/ context, object spec) {
683+
if (spec is string s) return __format__(context, s);
684+
if (spec is Extensible<string> es) return __format__(context, es.Value);
685+
throw PythonOps.TypeError("__format__() argument 1 must be str, not {0}", PythonOps.GetPythonTypeName(spec));
686+
}
687+
681688
#endregion
682689
}
683690

@@ -1294,6 +1301,7 @@ public PythonTuple __reduce__() {
12941301
);
12951302
}
12961303

1304+
// TODO: get rid of __bool__ in 3.5
12971305
public bool __bool__() {
12981306
return this.UtcTime.TimeSpan.Ticks != 0 || this.UtcTime.LostMicroseconds != 0;
12991307
}
@@ -1474,7 +1482,7 @@ private int CompareTo(object other) {
14741482

14751483
#endregion
14761484

1477-
public object __format__(CodeContext/*!*/ context, string dateFormat) {
1485+
public object __format__(CodeContext/*!*/ context, [NotNull] string dateFormat) {
14781486
if (string.IsNullOrEmpty(dateFormat)) {
14791487
return PythonOps.ToString(context, this);
14801488
}
@@ -1490,6 +1498,13 @@ public object __format__(CodeContext/*!*/ context, string dateFormat) {
14901498
}
14911499
}
14921500

1501+
// overload to make test_datetime happy
1502+
public object __format__(CodeContext/*!*/ context, object spec) {
1503+
if (spec is string s) return __format__(context, s);
1504+
if (spec is Extensible<string> es) return __format__(context, es.Value);
1505+
throw PythonOps.TypeError("__format__() argument 1 must be str, not {0}", PythonOps.GetPythonTypeName(spec));
1506+
}
1507+
14931508
private class UnifiedTime {
14941509
public TimeSpan TimeSpan;
14951510
public int LostMicroseconds;

Src/IronPython/Modules/_io.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,79 @@ public override Bytes read1(CodeContext/*!*/ context, int length=0) {
886886
}
887887
}
888888

889+
public override object readline(CodeContext context, int limit) {
890+
_checkClosed();
891+
892+
if (limit == 0) {
893+
return Bytes.Empty;
894+
}
895+
896+
lock (this) {
897+
bool limited = limit > 0;
898+
899+
List<Bytes> chunks = null;
900+
int cnt = 0;
901+
while (true) {
902+
var buf = _readBuf.AsSpan().Slice(_readBufPos);
903+
if (buf.Length > 0) {
904+
// we hit the limit so we're done
905+
bool done = false;
906+
if (limited && buf.Length > limit - cnt) {
907+
buf = buf.Slice(0, limit - cnt);
908+
done = true;
909+
}
910+
911+
// we found the eol so we're done
912+
var idx = buf.IndexOf((byte)'\n');
913+
if (idx != -1) {
914+
buf = buf.Slice(0, idx + 1);
915+
done = true;
916+
}
917+
918+
if (done) {
919+
_readBufPos += buf.Length;
920+
if (_readBufPos == _readBuf.Count) {
921+
ResetReadBuf();
922+
}
923+
var bytes = Bytes.Make(buf.ToArray());
924+
if (chunks is null) {
925+
return bytes;
926+
}
927+
chunks.Add(bytes);
928+
cnt += buf.Length;
929+
return Bytes.Concat(chunks, cnt);
930+
}
931+
932+
(chunks ??= new List<Bytes>()).Add(ResetReadBuf());
933+
cnt += buf.Length;
934+
}
935+
936+
// end of file
937+
if (!TryReadNextChunk(context)) {
938+
if (chunks is null) {
939+
return Bytes.Empty;
940+
}
941+
Debug.Assert(cnt > 0);
942+
return Bytes.Concat(chunks, cnt);
943+
}
944+
}
945+
}
946+
947+
bool TryReadNextChunk(CodeContext context) {
948+
object chunkObj;
949+
if (_rawIO != null) {
950+
chunkObj = _rawIO.read(context, _bufSize);
951+
} else {
952+
chunkObj = PythonOps.Invoke(context, _raw, "read", _bufSize);
953+
}
954+
955+
Bytes chunk = chunkObj != null ? GetBytes(chunkObj, "read()") : Bytes.Empty;
956+
957+
_readBuf = chunk;
958+
return chunk.Count != 0;
959+
}
960+
}
961+
889962
public override BigInteger tell(CodeContext/*!*/ context) {
890963
BigInteger res = _rawIO != null ?
891964
_rawIO.tell(context) :

Src/IronPython/Runtime/Types/MappingProxy.cs

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
// The .NET Foundation licenses this file to you under the Apache 2.0 License.
33
// See the LICENSE file in the project root for more information.
44

5+
#nullable enable
6+
7+
#if NETCOREAPP3_1 // IDictionary<object?, object?> is incorrectly annotated with TKey : notnull
8+
#pragma warning disable CS8714 // The type cannot be used as type parameter in the generic type or method. Nullability of type argument doesn't match 'notnull' constraint.
9+
#endif
10+
511
using System;
612
using System.Collections;
713
using System.Collections.Generic;
@@ -13,36 +19,52 @@
1319

1420
namespace IronPython.Runtime.Types {
1521
[PythonType("mappingproxy")]
16-
public sealed class MappingProxy : IDictionary<object, object>, IDictionary {
17-
internal PythonDictionary GetDictionary(CodeContext context) => dictionary ?? type.GetMemberDictionary(context, false);
22+
public sealed class MappingProxy : IDictionary<object?, object?>, IDictionary {
23+
internal PythonDictionary GetDictionary(CodeContext context) => dictionary ?? type!.GetMemberDictionary(context, false);
1824

19-
private readonly PythonDictionary dictionary;
20-
private readonly PythonType type;
25+
private readonly PythonDictionary? dictionary;
26+
private readonly PythonType? type;
2127

2228
internal MappingProxy(CodeContext context, PythonType/*!*/ dt) {
2329
Debug.Assert(dt != null);
2430
type = dt;
2531
}
2632

27-
public MappingProxy([NotNull]PythonDictionary dict) {
33+
public MappingProxy([NotNull] PythonDictionary dict) {
2834
dictionary = dict;
2935
}
3036

3137
#region Python Public API Surface
3238

3339
public int __len__(CodeContext context) => GetDictionary(context).Count;
3440

35-
public bool __contains__(CodeContext/*!*/ context, object value) => GetDictionary(context).TryGetValue(value, out _);
41+
public bool __contains__(CodeContext/*!*/ context, object? value) => GetDictionary(context).TryGetValue(value, out _);
3642

3743
public string/*!*/ __str__(CodeContext/*!*/ context) => DictionaryOps.__repr__(context, this);
3844

39-
public object get(CodeContext/*!*/ context, [NotNull]object k, object d=null) {
40-
object res;
41-
if (!GetDictionary(context).TryGetValue(k, out res)) {
42-
res = d;
45+
public string __repr__(CodeContext/*!*/ context) {
46+
var dict = GetDictionary(context);
47+
List<object>? infinite = PythonOps.GetAndCheckInfinite(this);
48+
if (infinite == null) {
49+
return "mappingproxy({...})";
4350
}
4451

45-
return res;
52+
int infiniteIndex = infinite.Count;
53+
infinite.Add(this);
54+
try {
55+
return $"mappingproxy({PythonOps.Repr(context, dict)})";
56+
} finally {
57+
System.Diagnostics.Debug.Assert(infiniteIndex == infinite.Count - 1);
58+
infinite.RemoveAt(infiniteIndex);
59+
}
60+
}
61+
62+
public object? get(CodeContext/*!*/ context, object? k, object? d = null) {
63+
if (GetDictionary(context).TryGetValue(k, out object? res)) {
64+
return res;
65+
}
66+
67+
return d;
4668
}
4769

4870
public object keys(CodeContext context) {
@@ -66,11 +88,11 @@ public object items(CodeContext context) {
6688
return items;
6789
}
6890

69-
public PythonDictionary copy(CodeContext/*!*/ context) => new PythonDictionary(context, this);
91+
public PythonDictionary copy(CodeContext/*!*/ context) => GetDictionary(context).copy(context);
7092

7193
public const object __hash__ = null;
7294

73-
public object __eq__(CodeContext/*!*/ context, object other) {
95+
public object __eq__(CodeContext/*!*/ context, object? other) {
7496
if (other is MappingProxy proxy) {
7597
if (type == null) {
7698
return __eq__(context, proxy.GetDictionary(context));
@@ -90,7 +112,7 @@ public object __eq__(CodeContext/*!*/ context, object other) {
90112

91113
#region IDictionary Members
92114

93-
public object this[object key] {
115+
public object? this[object? key] {
94116
get => GetDictionary(DefaultContext.Default)[key];
95117
[PythonHidden]
96118
set => throw PythonOps.TypeError("'mappingproxy' object does not support item assignment");
@@ -109,7 +131,7 @@ public object this[object key] {
109131
#region IDictionary Members
110132

111133
[PythonHidden]
112-
public void Add(object key, object value) {
134+
public void Add(object? key, object? value) {
113135
this[key] = value;
114136
}
115137

@@ -149,11 +171,7 @@ ICollection IDictionary.Values {
149171

150172
#region ICollection Members
151173

152-
void ICollection.CopyTo(Array array, int index) {
153-
foreach (DictionaryEntry de in (IDictionary)this) {
154-
array.SetValue(de, index++);
155-
}
156-
}
174+
void ICollection.CopyTo(Array array, int index) => throw new NotImplementedException("The method or operation is not implemented.");
157175

158176
int ICollection.Count => __len__(DefaultContext.Default);
159177

@@ -165,43 +183,43 @@ void ICollection.CopyTo(Array array, int index) {
165183

166184
#region IDictionary<object,object> Members
167185

168-
bool IDictionary<object, object>.ContainsKey(object key) => __contains__(DefaultContext.Default, key);
186+
bool IDictionary<object?, object?>.ContainsKey(object? key) => __contains__(DefaultContext.Default, key);
169187

170-
ICollection<object> IDictionary<object, object>.Keys => GetDictionary(DefaultContext.Default).Keys;
188+
ICollection<object?> IDictionary<object?, object?>.Keys => GetDictionary(DefaultContext.Default).Keys;
171189

172-
bool IDictionary<object, object>.Remove(object key) => throw new InvalidOperationException("mappingproxy is read-only");
190+
bool IDictionary<object?, object?>.Remove(object? key) => throw new InvalidOperationException("mappingproxy is read-only");
173191

174-
bool IDictionary<object, object>.TryGetValue(object key, out object value) => GetDictionary(DefaultContext.Default).TryGetValue(key, out value);
192+
bool IDictionary<object?, object?>.TryGetValue(object? key, out object? value) => GetDictionary(DefaultContext.Default).TryGetValue(key, out value);
175193

176-
ICollection<object> IDictionary<object, object>.Values => GetDictionary(DefaultContext.Default).Values;
194+
ICollection<object?> IDictionary<object?, object?>.Values => GetDictionary(DefaultContext.Default).Values;
177195

178196
#endregion
179197

180198
#region ICollection<KeyValuePair<object,object>> Members
181199

182-
void ICollection<KeyValuePair<object, object>>.Add(KeyValuePair<object, object> item) {
200+
void ICollection<KeyValuePair<object?, object?>>.Add(KeyValuePair<object?, object?> item) {
183201
this[item.Key] = item.Value;
184202
}
185203

186-
bool ICollection<KeyValuePair<object, object>>.Contains(KeyValuePair<object, object> item) => __contains__(DefaultContext.Default, item.Key);
204+
bool ICollection<KeyValuePair<object?, object?>>.Contains(KeyValuePair<object?, object?> item) => __contains__(DefaultContext.Default, item.Key);
187205

188-
void ICollection<KeyValuePair<object, object>>.CopyTo(KeyValuePair<object, object>[] array, int arrayIndex) {
189-
foreach (KeyValuePair<object, object> de in (IEnumerable<KeyValuePair<object, object>>)this) {
206+
void ICollection<KeyValuePair<object?, object?>>.CopyTo(KeyValuePair<object?, object?>[] array, int arrayIndex) {
207+
foreach (KeyValuePair<object?, object?> de in (IEnumerable<KeyValuePair<object?, object?>>)this) {
190208
array.SetValue(de, arrayIndex++);
191209
}
192210
}
193211

194-
int ICollection<KeyValuePair<object, object>>.Count => __len__(DefaultContext.Default);
212+
int ICollection<KeyValuePair<object?, object?>>.Count => __len__(DefaultContext.Default);
195213

196-
bool ICollection<KeyValuePair<object, object>>.IsReadOnly => true;
214+
bool ICollection<KeyValuePair<object?, object?>>.IsReadOnly => true;
197215

198-
bool ICollection<KeyValuePair<object, object>>.Remove(KeyValuePair<object, object> item) => ((IDictionary<object, object>)this).Remove(item.Key);
216+
bool ICollection<KeyValuePair<object?, object?>>.Remove(KeyValuePair<object?, object?> item) => ((IDictionary<object?, object?>)this).Remove(item.Key);
199217

200218
#endregion
201219

202220
#region IEnumerable<KeyValuePair<object,object>> Members
203221

204-
IEnumerator<KeyValuePair<object, object>> IEnumerable<KeyValuePair<object, object>>.GetEnumerator() => GetDictionary(DefaultContext.Default).GetEnumerator();
222+
IEnumerator<KeyValuePair<object?, object?>> IEnumerable<KeyValuePair<object?, object?>>.GetEnumerator() => GetDictionary(DefaultContext.Default).GetEnumerator();
205223

206224
#endregion
207225
}

Src/StdLib/Lib/test/test_urlparse.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,54 @@ def test_urlsplit_attributes(self):
538538
p = urllib.parse.urlsplit(url)
539539
self.assertEqual(p.port, None)
540540

541+
def test_urlsplit_remove_unsafe_bytes(self):
542+
# Remove ASCII tabs and newlines from input, for http common case scenario.
543+
url = "h\nttp://www.python\n.org\t/java\nscript:\talert('msg\r\n')/?query\n=\tsomething#frag\nment"
544+
p = urllib.parse.urlsplit(url)
545+
self.assertEqual(p.scheme, "http")
546+
self.assertEqual(p.netloc, "www.python.org")
547+
self.assertEqual(p.path, "/javascript:alert('msg')/")
548+
self.assertEqual(p.query, "query=something")
549+
self.assertEqual(p.fragment, "fragment")
550+
self.assertEqual(p.username, None)
551+
self.assertEqual(p.password, None)
552+
self.assertEqual(p.hostname, "www.python.org")
553+
self.assertEqual(p.port, None)
554+
self.assertEqual(p.geturl(), "http://www.python.org/javascript:alert('msg')/?query=something#fragment")
555+
556+
# Remove ASCII tabs and newlines from input as bytes, for http common case scenario.
557+
url = b"h\nttp://www.python\n.org\t/java\nscript:\talert('msg\r\n')/?query\n=\tsomething#frag\nment"
558+
p = urllib.parse.urlsplit(url)
559+
self.assertEqual(p.scheme, b"http")
560+
self.assertEqual(p.netloc, b"www.python.org")
561+
self.assertEqual(p.path, b"/javascript:alert('msg')/")
562+
self.assertEqual(p.query, b"query=something")
563+
self.assertEqual(p.fragment, b"fragment")
564+
self.assertEqual(p.username, None)
565+
self.assertEqual(p.password, None)
566+
self.assertEqual(p.hostname, b"www.python.org")
567+
self.assertEqual(p.port, None)
568+
self.assertEqual(p.geturl(), b"http://www.python.org/javascript:alert('msg')/?query=something#fragment")
569+
570+
# any scheme
571+
url = "x-new-scheme\t://www.python\n.org\t/java\nscript:\talert('msg\r\n')/?query\n=\tsomething#frag\nment"
572+
p = urllib.parse.urlsplit(url)
573+
self.assertEqual(p.geturl(), "x-new-scheme://www.python.org/javascript:alert('msg')/?query=something#fragment")
574+
575+
# Remove ASCII tabs and newlines from input as bytes, any scheme.
576+
url = b"x-new-scheme\t://www.python\n.org\t/java\nscript:\talert('msg\r\n')/?query\n=\tsomething#frag\nment"
577+
p = urllib.parse.urlsplit(url)
578+
self.assertEqual(p.geturl(), b"x-new-scheme://www.python.org/javascript:alert('msg')/?query=something#fragment")
579+
580+
# Unsafe bytes is not returned from urlparse cache.
581+
# scheme is stored after parsing, sending an scheme with unsafe bytes *will not* return an unsafe scheme
582+
url = "https://www.python\n.org\t/java\nscript:\talert('msg\r\n')/?query\n=\tsomething#frag\nment"
583+
scheme = "htt\nps"
584+
for _ in range(2):
585+
p = urllib.parse.urlsplit(url, scheme=scheme)
586+
self.assertEqual(p.scheme, "https")
587+
self.assertEqual(p.geturl(), "https://www.python.org/javascript:alert('msg')/?query=something#fragment")
588+
541589
def test_attributes_bad_port(self):
542590
"""Check handling of non-integer ports."""
543591
p = urllib.parse.urlsplit("http://www.example.net:foo")

Src/StdLib/Lib/urllib/parse.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@
6565
'0123456789'
6666
'+-.')
6767

68+
# Unsafe bytes to be removed per WHATWG spec
69+
_UNSAFE_URL_BYTES_TO_REMOVE = ['\t', '\r', '\n']
70+
6871
# XXX: Consider replacing with functools.lru_cache
6972
MAX_CACHE_SIZE = 20
7073
_parse_cache = {}
@@ -331,13 +334,20 @@ def _checknetloc(netloc):
331334
raise ValueError("netloc '" + netloc2 + "' contains invalid " +
332335
"characters under NFKC normalization")
333336

337+
def _remove_unsafe_bytes_from_url(url):
338+
for b in _UNSAFE_URL_BYTES_TO_REMOVE:
339+
url = url.replace(b, "")
340+
return url
341+
334342
def urlsplit(url, scheme='', allow_fragments=True):
335343
"""Parse a URL into 5 components:
336344
<scheme>://<netloc>/<path>?<query>#<fragment>
337345
Return a 5-tuple: (scheme, netloc, path, query, fragment).
338346
Note that we don't break the components up in smaller bits
339347
(e.g. netloc is a single string) and we don't expand % escapes."""
340348
url, scheme, _coerce_result = _coerce_args(url, scheme)
349+
url = _remove_unsafe_bytes_from_url(url)
350+
scheme = _remove_unsafe_bytes_from_url(scheme)
341351
allow_fragments = bool(allow_fragments)
342352
key = url, scheme, allow_fragments, type(url), type(scheme)
343353
cached = _parse_cache.get(key, None)

0 commit comments

Comments
 (0)