Skip to content
This repository was archived by the owner on Apr 14, 2022. It is now read-only.

Commit c885338

Browse files
author
Mikhail Arkhipov
authored
Prevent SO on documentation fetch (#1106)
* Prevent SO on documentation fetch * Remove disposed exception * Add Pop * PR feedback * Add lock to push/pop * Add test * Remove lock
1 parent 57eb7d6 commit c885338

5 files changed

Lines changed: 117 additions & 87 deletions

File tree

src/Analysis/Ast/Impl/Modules/PythonModule.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ public virtual string Documentation {
135135
_documentation = m.TryGetConstant<string>(out var s) ? s : string.Empty;
136136
if (string.IsNullOrEmpty(_documentation)) {
137137
m = GetMember($"_{Name}");
138-
_documentation = m?.GetPythonType()?.Documentation;
138+
var t = m?.GetPythonType();
139+
_documentation = t != null && !t.Equals(this) ? t.Documentation : null;
139140
if (string.IsNullOrEmpty(_documentation)) {
140141
_documentation = TryGetDocFromModuleInitFile();
141142
}
@@ -496,7 +497,7 @@ private string TryGetDocFromModuleInitFile() {
496497
// Also, handle quadruple+ quotes.
497498
line = line.Trim();
498499
line = line.All(c => c == quote[0]) ? quote : line;
499-
if (line.EndsWithOrdinal(quote) && line.IndexOf(quote, StringComparison.Ordinal) < line.LastIndexOf(quote, StringComparison.Ordinal)) {
500+
if (line.EndsWithOrdinal(quote) && line.IndexOf(quote, StringComparison.Ordinal) < line.LastIndexOf(quote, StringComparison.Ordinal)) {
500501
return line.Substring(quote.Length, line.Length - 2 * quote.Length).Trim();
501502
}
502503
var sb = new StringBuilder();

src/Analysis/Ast/Impl/Types/PythonClassType.cs

Lines changed: 81 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
using System.Collections.Generic;
1818
using System.Diagnostics;
1919
using System.Linq;
20-
using System.Threading;
2120
using Microsoft.Python.Analysis.Modules;
2221
using Microsoft.Python.Analysis.Specializations.Typing;
2322
using Microsoft.Python.Analysis.Types.Collections;
@@ -33,11 +32,11 @@ namespace Microsoft.Python.Analysis.Types {
3332
[DebuggerDisplay("Class {Name}")]
3433
internal class PythonClassType : PythonType, IPythonClassType, IPythonTemplateType, IEquatable<IPythonClassType> {
3534
private static readonly string[] _classMethods = { "mro", "__dict__", @"__weakref__" };
36-
private readonly object _lock = new object();
37-
private readonly AsyncLocal<IPythonClassType> _processing = new AsyncLocal<IPythonClassType>();
35+
private IPythonClassType _processing;
3836
private List<IPythonType> _bases;
3937
private IReadOnlyList<IPythonType> _mro;
4038
private Dictionary<string, IPythonType> _genericParameters;
39+
private string _documentation;
4140

4241
// For tests
4342
internal PythonClassType(string name, Location location)
@@ -59,34 +58,31 @@ public PythonClassType(
5958

6059
public override IEnumerable<string> GetMemberNames() {
6160
var names = new HashSet<string>();
62-
lock (_lock) {
63-
names.UnionWith(Members.Keys);
64-
}
61+
names.UnionWith(Members.Keys);
6562
foreach (var m in Mro.Skip(1)) {
6663
names.UnionWith(m.GetMemberNames());
6764
}
6865
return DeclaringModule.Interpreter.LanguageVersion.Is3x() ? names.Concat(_classMethods).Distinct() : names;
6966
}
7067

7168
public override IMember GetMember(string name) {
72-
IMember member;
73-
lock (_lock) {
74-
if (Members.TryGetValue(name, out member)) {
75-
return member;
76-
}
69+
// Push/Pop should be lock protected.
70+
if (Members.TryGetValue(name, out var member)) {
71+
return member;
72+
}
7773

78-
// Special case names that we want to add to our own Members dict
79-
var is3x = DeclaringModule.Interpreter.LanguageVersion.Is3x();
80-
switch (name) {
81-
case "__mro__":
82-
case "mro":
83-
return is3x ? PythonCollectionType.CreateList(DeclaringModule.Interpreter, Mro) : UnknownType;
84-
case "__dict__":
85-
return is3x ? DeclaringModule.Interpreter.GetBuiltinType(BuiltinTypeId.Dict) : UnknownType;
86-
case @"__weakref__":
87-
return is3x ? DeclaringModule.Interpreter.GetBuiltinType(BuiltinTypeId.Object) : UnknownType;
88-
}
74+
// Special case names that we want to add to our own Members dict
75+
var is3x = DeclaringModule.Interpreter.LanguageVersion.Is3x();
76+
switch (name) {
77+
case "__mro__":
78+
case "mro":
79+
return is3x ? PythonCollectionType.CreateList(DeclaringModule.Interpreter, Mro) : UnknownType;
80+
case "__dict__":
81+
return is3x ? DeclaringModule.Interpreter.GetBuiltinType(BuiltinTypeId.Dict) : UnknownType;
82+
case @"__weakref__":
83+
return is3x ? DeclaringModule.Interpreter.GetBuiltinType(BuiltinTypeId.Object) : UnknownType;
8984
}
85+
9086
if (Push(this)) {
9187
try {
9288
foreach (var m in Mro.Reverse()) {
@@ -104,19 +100,33 @@ public override IMember GetMember(string name) {
104100

105101
public override string Documentation {
106102
get {
107-
// Try doc from the type (class definition AST node).
108-
var doc = base.Documentation;
109-
// Try docs __init__.
110-
if (string.IsNullOrEmpty(doc)) {
111-
var init = GetMember("__init__") as IPythonFunctionType;
112-
doc = init?.DeclaringType == this ? init.Documentation : null;
103+
if (!string.IsNullOrEmpty(_documentation)) {
104+
return _documentation;
113105
}
114-
// Try bases.
115-
if (string.IsNullOrEmpty(doc) && Bases != null) {
116-
var o = DeclaringModule.Interpreter.GetBuiltinType(BuiltinTypeId.Object);
117-
doc = Bases.Except(Enumerable.Repeat(o, 1)).FirstOrDefault(b => !string.IsNullOrEmpty(b?.Documentation))?.Documentation;
106+
// Make sure we do not cycle through bases back here.
107+
if (!Push(this)) {
108+
return null;
118109
}
119-
return doc;
110+
try {
111+
// Try doc from the type first (class definition AST node).
112+
_documentation = base.Documentation;
113+
if (string.IsNullOrEmpty(_documentation)) {
114+
// If not present, try docs __init__. IPythonFunctionType handles
115+
// __init__ in a special way so there is no danger of call coming
116+
// back here and causing stack overflow.
117+
_documentation = (GetMember("__init__") as IPythonFunctionType)?.Documentation;
118+
}
119+
120+
if (string.IsNullOrEmpty(_documentation) && Bases != null) {
121+
// If still not found, try bases.
122+
var o = DeclaringModule.Interpreter.GetBuiltinType(BuiltinTypeId.Object);
123+
_documentation = Bases.FirstOrDefault(b => b != o && !string.IsNullOrEmpty(b?.Documentation))
124+
?.Documentation;
125+
}
126+
} finally {
127+
Pop();
128+
}
129+
return _documentation;
120130
}
121131
}
122132

@@ -158,17 +168,15 @@ public override IMember Index(IPythonInstance instance, object index) {
158168

159169
public IReadOnlyList<IPythonType> Mro {
160170
get {
161-
lock (_lock) {
162-
if (_mro != null) {
163-
return _mro;
164-
}
165-
if (_bases == null) {
166-
return new IPythonType[] { this };
167-
}
168-
_mro = new IPythonType[] { this };
169-
_mro = CalculateMro(this);
171+
if (_mro != null) {
170172
return _mro;
171173
}
174+
if (_bases == null) {
175+
return new IPythonType[] { this };
176+
}
177+
_mro = new IPythonType[] { this };
178+
_mro = CalculateMro(this);
179+
return _mro;
172180
}
173181
}
174182

@@ -177,36 +185,34 @@ public IReadOnlyDictionary<string, IPythonType> GenericParameters
177185
#endregion
178186

179187
internal void SetBases(IEnumerable<IPythonType> bases) {
180-
lock (_lock) {
181-
if (_bases != null) {
182-
return; // Already set
183-
}
184-
185-
bases = bases != null ? bases.Where(b => !b.GetPythonType().IsUnknown()).ToArray() : Array.Empty<IPythonType>();
186-
// For Python 3+ attach object as a base class by default except for the object class itself.
187-
if (DeclaringModule.Interpreter.LanguageVersion.Is3x() && DeclaringModule.ModuleType != ModuleType.Builtins) {
188-
var objectType = DeclaringModule.Interpreter.GetBuiltinType(BuiltinTypeId.Object);
189-
// During processing of builtins module some types may not be available yet.
190-
// Specialization will attach proper base at the end.
191-
Debug.Assert(!objectType.IsUnknown());
192-
if (!bases.Any(b => objectType.Equals(b))) {
193-
bases = bases.Concat(Enumerable.Repeat(objectType, 1));
194-
}
195-
}
188+
if (_bases != null) {
189+
return; // Already set
190+
}
196191

197-
_bases = bases.ToList();
198-
if (_bases.Count > 0) {
199-
AddMember("__base__", _bases[0], true);
200-
}
201-
// Invalidate MRO
202-
_mro = null;
203-
if (DeclaringModule is BuiltinsPythonModule) {
204-
// TODO: If necessary, we can set __bases__ on builtins when the module is fully analyzed.
205-
return;
192+
bases = bases != null ? bases.Where(b => !b.GetPythonType().IsUnknown()).ToArray() : Array.Empty<IPythonType>();
193+
// For Python 3+ attach object as a base class by default except for the object class itself.
194+
if (DeclaringModule.Interpreter.LanguageVersion.Is3x() && DeclaringModule.ModuleType != ModuleType.Builtins) {
195+
var objectType = DeclaringModule.Interpreter.GetBuiltinType(BuiltinTypeId.Object);
196+
// During processing of builtins module some types may not be available yet.
197+
// Specialization will attach proper base at the end.
198+
Debug.Assert(!objectType.IsUnknown());
199+
if (!bases.Any(b => objectType.Equals(b))) {
200+
bases = bases.Concat(Enumerable.Repeat(objectType, 1));
206201
}
202+
}
207203

208-
AddMember("__bases__", PythonCollectionType.CreateList(DeclaringModule.Interpreter, _bases), true);
204+
_bases = bases.ToList();
205+
if (_bases.Count > 0) {
206+
AddMember("__base__", _bases[0], true);
207+
}
208+
// Invalidate MRO
209+
_mro = null;
210+
if (DeclaringModule is BuiltinsPythonModule) {
211+
// TODO: If necessary, we can set __bases__ on builtins when the module is fully analyzed.
212+
return;
209213
}
214+
215+
AddMember("__bases__", PythonCollectionType.CreateList(DeclaringModule.Interpreter, _bases), true);
210216
}
211217

212218
/// <summary>
@@ -263,15 +269,17 @@ internal static IReadOnlyList<IPythonType> CalculateMro(IPythonType type, HashSe
263269
}
264270
}
265271

272+
#region Reentrancy guards
266273
private bool Push(IPythonClassType cls) {
267-
if (_processing.Value == null) {
268-
_processing.Value = cls;
274+
if (_processing == null) {
275+
_processing = cls;
269276
return true;
270277
}
271278
return false;
272279
}
280+
private void Pop() => _processing = null;
281+
#endregion
273282

274-
private void Pop() => _processing.Value = null;
275283
public bool Equals(IPythonClassType other)
276284
=> Name == other?.Name && DeclaringModule.Equals(other?.DeclaringModule);
277285

@@ -375,7 +383,7 @@ public IPythonType CreateSpecificType(IArgumentSet args) {
375383
// Prevent reentrancy when resolving generic class where
376384
// method may be returning instance of type of the same class.
377385
if (!Push(classType)) {
378-
return _processing.Value;
386+
return _processing;
379387
}
380388

381389
try {
@@ -400,6 +408,7 @@ public IPythonType CreateSpecificType(IArgumentSet args) {
400408
}
401409
}
402410
}
411+
403412
// Set specific class bases
404413
classType.SetBases(bases.Concat(newBases));
405414
// Transfer members from generic to specific type.

src/Analysis/Ast/Impl/Types/PythonFunctionType.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,12 @@ Location location
8181
FunctionDefinition = fd;
8282
DeclaringType = declaringType;
8383

84+
// For __init__ documentation may either come from the function node of the the declaring
85+
// type. Note that if there is no documentation on the class node, the class will try and
86+
// get documentation from its __init__ function, delegating down to this type. So we need
87+
// to set documentation statically for __init__ here or we may end up/ with stack overflows.
8488
if (fd.Name == "__init__") {
85-
_documentation = declaringType?.Documentation;
89+
_documentation = declaringType?.Documentation ?? fd.Documentation;
8690
}
8791
ProcessDecorators(fd);
8892
}
@@ -139,7 +143,7 @@ internal void Specialize(string[] dependencies) {
139143

140144
internal ImmutableArray<string> Dependencies { get; private set; } = ImmutableArray<string>.Empty;
141145

142-
internal void AddOverload(IPythonFunctionOverload overload)
146+
internal void AddOverload(IPythonFunctionOverload overload)
143147
=> _overloads = _overloads.Count > 0 ? _overloads.Add(overload) : ImmutableArray<IPythonFunctionOverload>.Create(overload);
144148

145149
internal IPythonFunctionType ToUnbound() => new PythonUnboundMethod(this);

src/Analysis/Ast/Test/ClassesTests.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,5 +553,24 @@ def a(self):
553553
analysis.Should().HaveVariable("a").OfType(BuiltinTypeId.Int);
554554
}
555555

556+
[TestMethod, Priority(0)]
557+
public async Task DocumentationWithCycleInBases() {
558+
const string code = @"
559+
class A(C):
560+
'''class A doc'''
561+
562+
class B(A):
563+
def __init__(self):
564+
'''class B doc'''
565+
pass
566+
567+
class C(B):
568+
'''class C doc'''
569+
";
570+
var analysis = await GetAnalysisAsync(code);
571+
analysis.Should().HaveClass("A").Which.Should().HaveDocumentation("class A doc");
572+
analysis.Should().HaveClass("B").Which.Should().HaveDocumentation("class B doc");
573+
analysis.Should().HaveClass("C").Which.Should().HaveDocumentation("class C doc");
574+
}
556575
}
557576
}

src/Core/Impl/Extensions/IOExtensions.cs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,25 +114,22 @@ public static string ReadTextWithRetry(this IFileSystem fs, string file) {
114114
for (var retries = 100; retries > 0; --retries) {
115115
try {
116116
return fs.ReadAllText(file);
117-
} catch (UnauthorizedAccessException) {
118-
Thread.Sleep(10);
119-
} catch (IOException) {
117+
} catch (Exception ex) when (ex is UnauthorizedAccessException || ex is IOException || ex is ObjectDisposedException) {
120118
Thread.Sleep(10);
121119
}
122120
}
123121
return null;
124122
}
125123

126124
public static void WriteTextWithRetry(this IFileSystem fs, string filePath, string text) {
127-
try {
128-
using (var stream = fs.OpenWithRetry(filePath, FileMode.Create, FileAccess.Write, FileShare.Read)) {
129-
if (stream != null) {
130-
var bytes = Encoding.UTF8.GetBytes(text);
131-
stream.Write(bytes, 0, bytes.Length);
132-
return;
133-
}
125+
for (var retries = 100; retries > 0; --retries) {
126+
try {
127+
fs.WriteAllText(filePath, text);
128+
return;
129+
} catch (Exception ex) when (ex is IOException || ex is UnauthorizedAccessException) {
130+
Thread.Sleep(10);
134131
}
135-
} catch (IOException) { } catch (UnauthorizedAccessException) { }
132+
}
136133

137134
try {
138135
fs.DeleteFile(filePath);

0 commit comments

Comments
 (0)