Skip to content

Commit de789a8

Browse files
Seppli11sonartech
authored andcommitted
SONARPY-3774 Improve SelfType Collapsing (#841)
GitOrigin-RevId: f1ec4e555e18d698887fa8ede94cd726d8b75cfc
1 parent f2da51b commit de789a8

4 files changed

Lines changed: 160 additions & 12 deletions

File tree

python-frontend/src/main/java/org/sonar/python/semantic/v2/types/TrivialTypeInferenceVisitor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,7 @@ private PythonType resolveTypeAnnotationExpressionType(Expression expression, @N
780780
.expressions()
781781
.stream()
782782
.map(Expression::typeV2)
783+
.map(type -> resolveSelfType(type, enclosingClassType))
783784
.map(type -> ObjectType.Builder.fromType(type).withTypeSource(TypeSource.TYPE_HINT).build())
784785
.collect(Collectors.toList());
785786

python-frontend/src/main/java/org/sonar/python/semantic/v2/types/typecalculator/CallReturnTypeCalculator.java

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ private static PythonType collapseSelfTypeIfNeeded(CallExpression callExpr, Pyth
8585
return returnType;
8686
}
8787

88-
if (isInstanceOrClassMethodCall(callExpr, typePredicateContext)) {
88+
PythonType calleeType = callExpr.callee().typeV2();
89+
if (isInstanceOrClassMethodCall(calleeType, typePredicateContext)) {
8990
PythonType receiverType = getReceiverType(callExpr);
9091
if (receiverType == PythonType.UNKNOWN) {
9192
return PythonType.UNKNOWN;
@@ -99,19 +100,22 @@ private static PythonType collapseSelfTypeIfNeeded(CallExpression callExpr, Pyth
99100
private static boolean containsSelfType(PythonType type, TypePredicateContext typePredicateContext) {
100101
if (type instanceof SelfType || type.unwrappedType() instanceof SelfType) {
101102
return true;
102-
}
103-
if (type instanceof ObjectType objectType) {
104-
return objectType.attributes().stream().anyMatch(t -> containsSelfType(t, typePredicateContext));
103+
} else if (type instanceof ObjectType objectType) {
104+
return objectType.attributes().stream().anyMatch(t -> containsSelfType(t, typePredicateContext)) || containsSelfType(objectType.unwrappedType(), typePredicateContext);
105+
} else if (type instanceof UnionType unionType) {
106+
return unionType.candidates().stream().anyMatch(t -> containsSelfType(t, typePredicateContext));
105107
}
106108

107109
return false;
108110
}
109111

110-
private static boolean isInstanceOrClassMethodCall(CallExpression callExpr, TypePredicateContext typePredicateContext) {
111-
PythonType calleeType = callExpr.callee().typeV2();
112+
private static boolean isInstanceOrClassMethodCall(PythonType calleeType, TypePredicateContext typePredicateContext) {
112113
if (calleeType instanceof FunctionType functionType) {
113114
return functionType.isInstanceMethod() || functionType.isClassMethod();
115+
} else if (calleeType instanceof UnionType unionType) {
116+
return unionType.candidates().stream().allMatch(t -> isInstanceOrClassMethodCall(t, typePredicateContext));
114117
}
118+
115119
return hasCallableMember(calleeType, typePredicateContext);
116120
}
117121

@@ -140,17 +144,16 @@ private static PythonType getReceiverType(CallExpression callExpr) {
140144

141145
private static PythonType collapseSelfType(PythonType returnType, PythonType receiverType) {
142146
if (returnType instanceof ObjectType objectType) {
143-
var objectBuilder = ObjectType.Builder.fromType(objectType);
144-
if (objectType.type() instanceof SelfType) {
145-
objectBuilder.withType(receiverType);
146-
}
147-
return objectBuilder
147+
return ObjectType.Builder.fromType(objectType)
148+
.withType(collapseSelfType(objectType.unwrappedType(), receiverType))
148149
.withAttributes(objectType.attributes().stream().map(t -> collapseSelfType(t, receiverType)).toList())
149150
.build();
150151
} else if (returnType instanceof SelfType) {
151152
return receiverType;
153+
} else if (returnType instanceof UnionType unionType) {
154+
return UnionType.or(unionType.candidates().stream().map(t -> collapseSelfType(t, receiverType)).toList());
152155
}
153-
return PythonType.UNKNOWN;
156+
return returnType;
154157
}
155158

156159
private static TypeSource computeTypeSource(PythonType calleeType, CallExpression callExpr) {

python-frontend/src/test/java/org/sonar/python/semantic/v2/TypeInferenceV2Test.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4714,6 +4714,38 @@ def foo(self) -> Self | None:
47144714
assertThat(noneCandidate).isPresent();
47154715
}
47164716

4717+
@Test
4718+
void selfInSubscriptionExpressionReturnType() {
4719+
FileInput root = inferTypes("""
4720+
from typing import Self
4721+
class A:
4722+
def foo(self) -> list[Self] | None:
4723+
...
4724+
""");
4725+
4726+
ClassDef classDef = getFirstDescendant(root, t -> t instanceof ClassDef cd && "A".equals(cd.name().name()));
4727+
var classType = (ClassType) classDef.name().typeV2();
4728+
FunctionDef functionDef = getFirstDescendant(root, t -> t instanceof FunctionDef fd && "foo".equals(fd.name().name()));
4729+
var functionType = (FunctionType) functionDef.name().typeV2();
4730+
4731+
var returnType = functionType.returnType();
4732+
assertThat(returnType).isInstanceOf(ObjectType.class);
4733+
var unwrappedReturnType = returnType.unwrappedType();
4734+
assertThat(unwrappedReturnType).isInstanceOf(UnionType.class);
4735+
4736+
var unionType = (UnionType) unwrappedReturnType;
4737+
var candidates = unionType.candidates();
4738+
4739+
assertThat(candidates)
4740+
.satisfiesOnlyOnce(type -> assertThat(type)
4741+
.is(TypesTestUtils.objectTypeOf(LIST_TYPE))
4742+
.asInstanceOf(type(ObjectType.class))
4743+
.extracting(ObjectType::attributes)
4744+
.isEqualTo(List.of(ObjectType.fromType(SelfType.of(classType)))))
4745+
.satisfiesOnlyOnce(type -> assertThat(type).is(TypesTestUtils.objectTypeOf(NONE_TYPE)))
4746+
.hasSize(2);
4747+
}
4748+
47174749
@Test
47184750
void selfInUnionReturnTypeWithoutEnclosingClass() {
47194751
FileInput root = inferTypes("""

python-frontend/src/test/java/org/sonar/python/semantic/v2/types/typecalculator/CallReturnTypeCalculatorTest.java

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.sonar.python.semantic.v2.types.typecalculator;
1818

19+
import java.util.List;
20+
import java.util.Set;
1921
import org.junit.jupiter.api.Test;
2022
import org.sonar.plugins.python.api.tree.CallExpression;
2123
import org.sonar.plugins.python.api.tree.ClassDef;
@@ -33,8 +35,11 @@
3335
import org.sonar.python.types.v2.matchers.TypePredicateContext;
3436

3537
import static org.assertj.core.api.Assertions.assertThat;
38+
import static org.assertj.core.api.InstanceOfAssertFactories.type;
3639
import static org.sonar.python.PythonTestUtils.getFirstDescendant;
3740
import static org.sonar.python.PythonTestUtils.getLastDescendant;
41+
import static org.sonar.python.types.v2.TypesTestUtils.INT_TYPE;
42+
import static org.sonar.python.types.v2.TypesTestUtils.LIST_TYPE;
3843
import static org.sonar.python.types.v2.TypesTestUtils.PROJECT_LEVEL_TYPE_TABLE;
3944
import static org.sonar.python.types.v2.TypesTestUtils.parseAndInferTypes;
4045

@@ -181,6 +186,96 @@ def foo(self) -> Self: ...
181186
assertThat(result).isSameAs(PythonType.UNKNOWN);
182187
}
183188

189+
@Test
190+
void computeCallExpressionType_selfType_unionWithSelfType() {
191+
FileInput fileInput = parseAndInferTypes("""
192+
from typing import Self
193+
class A:
194+
def foo(self) -> Self | int: ...
195+
196+
a = A()
197+
a.foo()
198+
""");
199+
200+
ClassDef classDef = getFirstDescendant(fileInput, ClassDef.class::isInstance);
201+
ClassType classType = (ClassType) classDef.name().typeV2();
202+
203+
CallExpressionImpl callExpr = getLastDescendant(fileInput, CallExpression.class::isInstance);
204+
205+
PythonType result = CallReturnTypeCalculator.computeCallExpressionType(callExpr, typePredicateContext);
206+
207+
Set<PythonType> candidates = assertThat(result)
208+
.isInstanceOf(ObjectType.class)
209+
.extracting(PythonType::unwrappedType)
210+
.asInstanceOf(type(UnionType.class))
211+
.extracting(UnionType::candidates)
212+
.actual();
213+
214+
assertThat(candidates)
215+
.satisfiesOnlyOnce(type -> assertThat(type).is(TypesTestUtils.objectTypeOf(classType)))
216+
.satisfiesOnlyOnce(type -> assertThat(type).is(TypesTestUtils.objectTypeOf(INT_TYPE)))
217+
.hasSize(2);
218+
}
219+
220+
@Test
221+
void computeCallExpressionType_selfType_unionWithSelfTypeInList() {
222+
FileInput fileInput = parseAndInferTypes("""
223+
from typing import Self
224+
class A:
225+
def foo(self) -> list[Self] | int: ...
226+
227+
a = A()
228+
a.foo()
229+
""");
230+
231+
ClassDef classDef = getFirstDescendant(fileInput, ClassDef.class::isInstance);
232+
ClassType classType = (ClassType) classDef.name().typeV2();
233+
234+
CallExpressionImpl callExpr = getLastDescendant(fileInput, CallExpression.class::isInstance);
235+
236+
PythonType result = CallReturnTypeCalculator.computeCallExpressionType(callExpr, typePredicateContext);
237+
238+
Set<PythonType> candidates = assertThat(result)
239+
.isInstanceOf(ObjectType.class)
240+
.extracting(PythonType::unwrappedType)
241+
.asInstanceOf(type(UnionType.class))
242+
.extracting(UnionType::candidates)
243+
.actual();
244+
assertThat(candidates)
245+
.satisfiesOnlyOnce(type -> assertThat(type).is(TypesTestUtils.objectTypeOf(INT_TYPE)))
246+
.satisfiesOnlyOnce(type -> assertThat(type)
247+
.is(TypesTestUtils.objectTypeOf(LIST_TYPE))
248+
.asInstanceOf(type(ObjectType.class))
249+
.extracting(ObjectType::attributes)
250+
.isEqualTo(List.of(ObjectType.fromType(classType))))
251+
.hasSize(2);
252+
}
253+
254+
@Test
255+
void computeCallExpressionType_selfType_listWithSelfType() {
256+
FileInput fileInput = parseAndInferTypes("""
257+
from typing import Self
258+
class A:
259+
def foo(self) -> list[Self]: ...
260+
261+
a = A()
262+
a.foo()
263+
""");
264+
265+
ClassDef classDef = getFirstDescendant(fileInput, ClassDef.class::isInstance);
266+
ClassType classType = (ClassType) classDef.name().typeV2();
267+
268+
CallExpressionImpl callExpr = getLastDescendant(fileInput, CallExpression.class::isInstance);
269+
270+
PythonType result = CallReturnTypeCalculator.computeCallExpressionType(callExpr, typePredicateContext);
271+
272+
assertThat(result)
273+
.asInstanceOf(type(ObjectType.class))
274+
.is(TypesTestUtils.objectTypeOf(LIST_TYPE))
275+
.extracting(ObjectType::attributes)
276+
.isEqualTo(List.of(ObjectType.fromType(classType)));
277+
}
278+
184279
@Test
185280
void computeCallExpressionType_selfTypeInGenerator_typeshedFunction() {
186281
PythonType generatorType = PROJECT_LEVEL_TYPE_TABLE.getType("typing.Generator");
@@ -224,6 +319,23 @@ void computeCallExpressionType_selfType_typeshedFunction() {
224319
assertThat(objectType.attributes()).isEmpty();
225320
}
226321

322+
@Test
323+
void computeCallExpressionType_selfTypeInOverloadedFunction_typeshedFunction() {
324+
PythonType moduleType = PROJECT_LEVEL_TYPE_TABLE.getType("torch.nn.modules.Module");
325+
assertThat(moduleType).isNotEqualTo(PythonType.UNKNOWN);
326+
327+
FileInput fileInput = parseAndInferTypes("""
328+
from torch.nn.modules import Module
329+
module = Module()
330+
toModule = module.to(...)
331+
""");
332+
333+
Name toModuleName = PythonTestUtils.getLastDescendant(fileInput, tree -> tree instanceof Name name && "toModule".equals(name.name()));
334+
PythonType toModuleType = toModuleName.typeV2();
335+
336+
assertThat(toModuleType).is(TypesTestUtils.objectTypeOf(moduleType));
337+
}
338+
227339
@Test
228340
void computeCallExpressionType_methodCall() {
229341
FileInput fileInput = parseAndInferTypes("""

0 commit comments

Comments
 (0)