Skip to content

Commit

Permalink
SONARPY-2235: UnionType creation should have more guards (#2210)
Browse files Browse the repository at this point in the history
  • Loading branch information
Seppli11 authored Dec 5, 2024
1 parent 5390bf6 commit 4bd143a
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class AmbiguousDescriptorToPythonTypeConverter implements DescriptorToPyt

public PythonType convert(ConversionContext ctx, AmbiguousDescriptor from) {
var candidates = from.alternatives().stream().map(ctx::convert).collect(Collectors.toSet());
return new UnionType(candidates);
return UnionType.or(candidates);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@

import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.sonar.plugins.python.api.tree.Expression;
import org.sonar.plugins.python.api.tree.FunctionDef;
import org.sonar.plugins.python.api.tree.Name;
import org.sonar.plugins.python.api.tree.QualifiedExpression;
import org.sonar.python.semantic.v2.TypeTable;
import org.sonar.python.tree.NameImpl;
import org.sonar.python.types.v2.PythonType;
import org.sonar.python.types.v2.UnionType;
import org.sonar.python.types.v2.TypeUtils;

/**
* Used in FlowSensitiveTypeInference to update name types based on program state
Expand All @@ -44,7 +43,7 @@ public void visitName(Name name) {
Optional.ofNullable(name.symbolV2()).ifPresent(symbol -> {
Set<PythonType> pythonTypes = state.getTypes(symbol);
if (!pythonTypes.isEmpty()) {
((NameImpl) name).typeV2(union(pythonTypes.stream()));
((NameImpl) name).typeV2(union(pythonTypes));
}
});
super.visitName(name);
Expand All @@ -66,7 +65,7 @@ public void visitQualifiedExpression(QualifiedExpression qualifiedExpression) {
}
}

private static PythonType union(Stream<PythonType> types) {
return types.reduce(UnionType::or).orElse(PythonType.UNKNOWN);
private static PythonType union(Set<PythonType> types) {
return types.stream().collect(TypeUtils.toUnionType());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

public class LazyUnionType implements PythonType, ResolvableType {

Set<PythonType> candidates;
private final Set<PythonType> candidates = new HashSet<>();

public LazyUnionType(Set<PythonType> candidates) {
this.candidates = candidates;
this.candidates.addAll(candidates);
}

public PythonType resolve() {
Expand All @@ -35,6 +35,6 @@ public PythonType resolve() {
}
resolvedCandidates.add(candidate);
}
return new UnionType(resolvedCandidates);
return UnionType.or(resolvedCandidates);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.sonar.python.types.v2;

import java.util.function.UnaryOperator;
import java.util.stream.Collector;
import java.util.stream.Collectors;

public class TypeUtils {

Expand All @@ -40,9 +42,13 @@ public static PythonType ensureWrappedObjectType(PythonType pythonType) {

public static PythonType map(PythonType type, UnaryOperator<PythonType> mapper) {
if (type instanceof UnionType unionType) {
return unionType.candidates().stream().map(mapper).reduce(UnionType::or).orElse(PythonType.UNKNOWN);
return unionType.candidates().stream().map(mapper).collect(toUnionType());
} else {
return mapper.apply(type);
}
}

public static Collector<PythonType, ?, PythonType> toUnionType() {
return Collectors.collectingAndThen(Collectors.toSet(), UnionType::or);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,26 @@
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.sonar.api.Beta;

@Beta
public record UnionType(Set<PythonType> candidates) implements PythonType {
public class UnionType implements PythonType {

private final Set<PythonType> candidates = new HashSet<>();

private UnionType(Set<PythonType> candidates) {
this.candidates.addAll(candidates);
}

public Set<PythonType> candidates() {
return candidates;
}

@Override
public Optional<String> displayName() {
Expand Down Expand Up @@ -69,46 +81,56 @@ public TypeSource typeSource() {
.orElse(TypeSource.EXACT);
}

@Beta
public static PythonType or(Collection<PythonType> candidates) {
ensureCandidatesAreNotLazyTypes(candidates);
if (candidates.isEmpty()) {
return PythonType.UNKNOWN;
}
return candidates
.stream()
.reduce(new UnionType(new HashSet<>()), UnionType::or);
@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
UnionType unionType = (UnionType) o;
return Objects.equals(candidates, unionType.candidates);
}

@Beta
public static PythonType or(@Nullable PythonType type1, @Nullable PythonType type2) {
if (type1 == null) {
return type2;
}
if (type2 == null) {
return type1;
@Override
public int hashCode() {
return Objects.hashCode(candidates);
}

@Override
public String toString() {
return displayName().orElse(super.toString());
}

public static PythonType or(@Nullable PythonType type1, @Nullable PythonType type2, @Nullable PythonType ...types) {
if(types == null) {
types = new PythonType[0];
}
if (type1 == PythonType.UNKNOWN || type2 == PythonType.UNKNOWN) {
Set<PythonType> typeSet = new HashSet<>();
typeSet.add(type1);
typeSet.add(type2);
typeSet.addAll(Set.of(types));
return or(typeSet);
}

public static PythonType or(Collection<PythonType> types) {
types = types.stream().filter(Objects::nonNull).collect(Collectors.toSet());
if(types.isEmpty()) {
return PythonType.UNKNOWN;
}
if (type1.equals(type2)) {
return type1;
}
Set<PythonType> types = new HashSet<>();
addTypes(type1, types);
addTypes(type2, types);
if (types.size() == 1) {
if(types.size() == 1) {
return types.iterator().next();
}
ensureCandidatesAreNotLazyTypes(types);
return new UnionType(types);

Set<PythonType> flatTypes = types.stream().flatMap(UnionType::flattenPythonType).collect(Collectors.toSet());
if(flatTypes.stream().anyMatch(type -> type == PythonType.UNKNOWN)) {
return PythonType.UNKNOWN;
}
ensureCandidatesAreNotLazyTypes(flatTypes);
return new UnionType(flatTypes);
}

private static void addTypes(PythonType type, Set<PythonType> types) {
if (type instanceof UnionType unionType) {
types.addAll(unionType.candidates());
private static Stream<PythonType> flattenPythonType(PythonType type) {
if(type instanceof UnionType unionType) {
return unionType.candidates().stream();
} else {
types.add(type);
return Stream.of(type);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2789,7 +2789,9 @@ void return_type_of_call_expression_inconsistent() {
CallExpression callExpressionSpy = Mockito.spy(callExpression);

// Inconsistent union type, should not happen
UnionType unionType = new UnionType(Set.of(PythonType.UNKNOWN));
UnionType unionType = Mockito.mock(UnionType.class);
Mockito.when(unionType.candidates()).thenReturn(Set.of(PythonType.UNKNOWN));

Name mock = Mockito.mock(Name.class);
Mockito.when(mock.typeV2()).thenReturn(unionType);
Mockito.doReturn(mock).when(callExpressionSpy).callee();
Expand All @@ -2808,7 +2810,9 @@ void return_type_of_call_expression_inconsistent_2() {
CallExpression callExpressionSpy = Mockito.spy(callExpression);

// Inconsistent union type, should not happen
UnionType unionType = new UnionType(Set.of());
UnionType unionType = Mockito.mock(UnionType.class);
Mockito.when(unionType.candidates()).thenReturn(Set.of());

Name mock = Mockito.mock(Name.class);
Mockito.when(mock.typeV2()).thenReturn(unionType);
Mockito.doReturn(mock).when(callExpressionSpy).callee();
Expand All @@ -2827,7 +2831,9 @@ void return_type_of_call_expression_inconsistent_3() {
CallExpression callExpressionSpy = Mockito.spy(callExpression);

// Inconsistent union type, should not happen
UnionType unionType = new UnionType(Set.of(INT_TYPE));
UnionType unionType = Mockito.mock(UnionType.class);
Mockito.when(unionType.candidates()).thenReturn(Set.of(INT_TYPE));

Name mock = Mockito.mock(Name.class);
Mockito.when(mock.typeV2()).thenReturn(unionType);
Mockito.doReturn(mock).when(callExpressionSpy).callee();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ void testConvertOtherType() {
void testConvertUnionType() {
ClassType classType = new ClassType("classType", "my_package.classType", Set.of(new Member("aMember", intTypeWrapper.type())), List.of(), List.of(floatTypeWrapper), List.of(intTypeWrapper.type()), true, false, location);
ClassType anotherClassType = new ClassType("classType", "my_package.classType", Set.of(new Member("aMember", intTypeWrapper.type())), List.of(), List.of(floatTypeWrapper), List.of(intTypeWrapper.type()), true, false, location);
UnionType unionType = new UnionType(Set.of(classType, anotherClassType));
PythonType unionType = UnionType.or(classType, anotherClassType);
Descriptor descriptor = converter.convert("foo", new SymbolV2("myUnionType"), Set.of(unionType));

assertThat(descriptor).isInstanceOf(AmbiguousDescriptor.class);
Expand Down Expand Up @@ -221,7 +221,7 @@ void testConvertManyTypesWithUnionType() {
ClassType classType = new ClassType("classType", "my_package.classType", Set.of(new Member("aMember", intTypeWrapper.type())), List.of(), List.of(floatTypeWrapper), List.of(intTypeWrapper.type()), true, false, location);
ClassType anotherClassType = new ClassType("classType", "my_package.classType", Set.of(new Member("aMember", intTypeWrapper.type())), List.of(), List.of(floatTypeWrapper), List.of(intTypeWrapper.type()), true, false, location);

UnionType unionType = new UnionType(Set.of(classType, anotherClassType));
PythonType unionType = UnionType.or(classType, anotherClassType);
Descriptor descriptor = converter.convert("foo", new SymbolV2("myUnionType"), Set.of(unionType, classType));

assertThat(descriptor).isInstanceOf(AmbiguousDescriptor.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.sonar.python.types.v2;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.assertj.core.api.InstanceOfAssertFactories;
Expand All @@ -31,6 +32,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.sonar.python.types.v2.TypesTestUtils.BOOL_TYPE;
import static org.sonar.python.types.v2.TypesTestUtils.FLOAT_TYPE;
import static org.sonar.python.types.v2.TypesTestUtils.INT_TYPE;
import static org.sonar.python.types.v2.TypesTestUtils.STR_TYPE;
import static org.sonar.python.types.v2.TypesTestUtils.parseAndInferTypes;
Expand All @@ -44,7 +46,7 @@ void basicUnion() {
PythonType intType = ((ExpressionStatement) fileInput.statements().statements().get(0)).expressions().get(0).typeV2();
PythonType strType = ((ExpressionStatement) fileInput.statements().statements().get(1)).expressions().get(0).typeV2();

UnionType unionType = new UnionType(Set.of(intType, strType));
PythonType unionType = UnionType.or(intType, strType);

assertThat(unionType.isCompatibleWith(intType)).isTrue();
assertThat(unionType.isCompatibleWith(strType)).isTrue();
Expand All @@ -60,7 +62,7 @@ void unionWithUnknown() {
FileInput fileInput = parseAndInferTypes("42;foo()");
PythonType intType = ((ExpressionStatement) fileInput.statements().statements().get(0)).expressions().get(0).typeV2();
PythonType strType = ((ExpressionStatement) fileInput.statements().statements().get(1)).expressions().get(0).typeV2();
UnionType unionType = new UnionType(Set.of(intType, strType));
PythonType unionType = UnionType.or(intType, strType);

assertThat(unionType.displayName()).isEmpty();
assertThat(unionType.instanceDisplayName()).isEmpty();
Expand All @@ -85,6 +87,9 @@ void or_with_null() {
assertThat(type).isEqualTo(INT_TYPE);
type = UnionType.or(null, INT_TYPE);
assertThat(type).isEqualTo(INT_TYPE);

type = UnionType.or(null, INT_TYPE, (PythonType[]) null);
assertThat(type).isEqualTo(INT_TYPE);
}

@Test
Expand Down Expand Up @@ -115,6 +120,19 @@ void or_unresolevdImportType() {
assertThat(((UnionType) unionType).candidates()).containsExactlyInAnyOrder(unresolvedImportType, unresolvedImportType2);
}

@Test
void or_emptySet() {
assertThat(UnionType.or(Collections.emptyList())).isEqualTo(PythonType.UNKNOWN);
}

@Test
void or_singletonSet() {
assertThat(UnionType.or(Set.of(INT_TYPE))).isSameAs(INT_TYPE);

var union = UnionType.or(INT_TYPE, FLOAT_TYPE);
assertThat(UnionType.or(Set.of(union))).isSameAs(union);
}

@Test
void hasMemberUnionType() {
FileInput fileInput = parseAndInferTypes("""
Expand Down Expand Up @@ -167,4 +185,13 @@ void noLazyTypeInUnionType() {
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("UnionType cannot contain Lazy types");
}

@Test
void testEquality() {
var union1 = UnionType.or(INT_TYPE, FLOAT_TYPE);
var union2 = UnionType.or(INT_TYPE, FLOAT_TYPE);
assertThat(union1)
.isEqualTo(union2)
.hasSameHashCodeAs(union2);
}
}

0 comments on commit 4bd143a

Please sign in to comment.