Skip to content

Commit

Permalink
WIP: Multi root resolver.
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrszul committed Dec 19, 2024
1 parent e61efb8 commit 4f230c8
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ public interface DataRoot {
@Nonnull
ResourceType getResourceType();

@Nonnull
ResourceType getParentResourceType();

default int depth() {
return 0;
}
Expand All @@ -37,6 +40,12 @@ class ResourceRoot implements DataRoot {
@Nonnull
ResourceType resourceType;

@Override
@Nonnull
public ResourceType getParentResourceType() {
return resourceType;
}

@Nonnull
@Override
public String getTag() {
Expand All @@ -47,6 +56,12 @@ public String getTag() {

interface JoinRoot extends DataRoot {

@Override
@Nonnull
default ResourceType getParentResourceType() {
return getMaster().getResourceType();
}

@Nonnull
DataRoot getMaster();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public class DataRootResolver {
ResourceType subjectResource;
FhirContext fhirContext;




@Nonnull
public Set<DataRoot> findDataRoots(@Nonnull final FhirPath path) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package au.csiro.pathling.fhirpath.execution;

import static java.util.stream.Collectors.mapping;

import au.csiro.pathling.fhirpath.execution.DataRoot.JoinRoot;
import jakarta.annotation.Nonnull;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Value;

@Value()
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class JoinSet {

@Nonnull
DataRoot master;

@Nonnull
List<JoinSet> children;


@Nonnull
public static JoinSet of(@Nonnull final DataRoot master, @Nonnull final List<JoinSet> children) {
children.forEach(child -> {
if (child.getMaster() instanceof JoinRoot jr) {
if (!jr.getMaster().equals(master)) {
throw new IllegalArgumentException(
"Cannot have a join set when child master is differnt than set parent");
}
} else {
throw new IllegalArgumentException("Child must be a join root");
}
});
return new JoinSet(master, children);
}

@Nonnull
public static JoinSet of(@Nonnull final DataRoot master) {
return new JoinSet(master, List.of());
}

@Nonnull
private static Stream<DataRoot> toStream(@Nonnull final DataRoot dataRoot) {
if (dataRoot instanceof JoinRoot jr) {
return Stream.concat(toStream(jr.getMaster()), Stream.of(dataRoot));
}
return Stream.of(dataRoot);
}

@Nonnull
static List<DataRoot> toPath(@Nonnull final DataRoot root) {
return toStream(root).toList();
}

@Nonnull
public static List<JoinSet> mergeRoots(@Nonnull final List<List<DataRoot>> paths) {
// convert to paths and then group by recurively by common prefixes
// we have got it already somwhere else
final Map<DataRoot, List<List<DataRoot>>> suffixesByHeads = paths.stream()
.filter(Predicate.not(List::isEmpty))
.collect(Collectors.groupingBy(
path -> path.get(0),
mapping(path -> path.subList(1, path.size()), Collectors.toList())
));

return suffixesByHeads.entrySet().stream()
.map(entry -> JoinSet.of(entry.getKey(), mergeRoots(entry.getValue())))
.toList();
}

@Nonnull
public static List<JoinSet> mergeRoots(@Nonnull final Set<DataRoot> roots) {
// convert to paths and then group by recurively by common prefixes
// we have got it already somwhere else
return mergeRoots(roots.stream().map(JoinSet::toPath).toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public static Column collect_map(@Nonnull final Column mapColumn) {
public Dataset<Row> createInitialDataset() {
return resourceDataset(subjectResource);
}

@Override
public @NotNull CollectionDataset evaluate(@NotNull final String fhirpathExpression) {

Expand All @@ -86,11 +86,19 @@ public Dataset<Row> createInitialDataset() {
System.out.println("Join roots: " + joinRoots);
joinRoots.forEach(System.out::println);

Dataset<Row> resolvedDataset =
joinRoots.size() == 1
? resolveJoins((JoinRoot) joinRoots.iterator().next(),
createInitialDataset(), null)
: createInitialDataset();
// Dataset<Row> resolvedDataset =
// joinRoots.size() == 1
// ? resolveJoins((JoinRoot) joinRoots.iterator().next(),
// createInitialDataset(), null)
// : createInitialDataset();

Dataset<Row> resolvedDataset = resolveJoinsEx(
JoinSet.mergeRoots(joinRoots).iterator().next(),
createInitialDataset());

System.out.println("Resolved dataset:");
resolvedDataset.show();
System.out.println(resolvedDataset.queryExecution().executedPlan().toString());

final ResourceResolver resourceResolver = new DefaultResourceResolver();
final FhirPathContext fhirpathContext = FhirPathContext.ofResource(
Expand Down Expand Up @@ -182,10 +190,11 @@ private Dataset<Row> computeResolveJoin(@Nonnull final Dataset<Row> parentDatase
referenceCollection);

final Dataset<Row> joinedDataset = parentDataset.join(childResult,
parentRegKey.getColumnValue().equalTo(childResult.col(typedRoot.getChildKeyTag())),
parentRegKey.getColumnValue().equalTo(functions.col(typedRoot.getChildKeyTag())),
"left_outer")
.drop(typedRoot.getChildKeyTag());

System.out.println("Joined dataset:");
joinedDataset.show();
return joinedDataset;
} else {
Expand Down Expand Up @@ -333,7 +342,19 @@ private Dataset<Row> computeReverseJoin(@Nonnull final Dataset<Row> parentDatase
joinedDataset.show();
return joinedDataset;
}


@Nonnull
private Dataset<Row> resolveJoinsEx(@Nonnull final JoinSet joinSet,
@Nonnull final Dataset<Row> parentDataset) {

// now just reduce current children
return joinSet.getChildren().stream()
.reduce(parentDataset, (dataset, subset) ->
// the parent dataset for subjoin should be different
computeJoin(dataset, resolveJoinsEx(subset, resourceDataset(subset.getMaster().getResourceType())),
(JoinRoot) subset.getMaster()), (dataset1, dataset2) -> dataset1);
}

@Nonnull
private Dataset<Row> resolveJoins(@Nonnull final JoinRoot joinRoot,
@Nonnull final Dataset<Row> parentDataset, @Nullable final Dataset<Row> maybeChildDataset) {
Expand Down Expand Up @@ -465,7 +486,7 @@ Dataset<Row> resourceDataset(@Nonnull final ResourceType resourceType) {
.map(dataset::col).toArray(Column[]::new)
).alias(resourceType.toCode()));
}

public Dataset<Row> execute(@NotNull final FhirPath path,
@NotNull final Dataset<Row> subjectDataset) {
throw new UnsupportedOperationException("Not implemented");
Expand All @@ -492,5 +513,5 @@ public Set<DataRoot> findJoinsRoots(@Nonnull final FhirPath path) {
.sorted((r1, r2) -> Integer.compare(r2.depth(), r1.depth())).limit(1)
.collect(Collectors.toUnmodifiableSet());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import org.hl7.fhir.r4.model.Enumerations.ResourceType;
import org.jetbrains.annotations.NotNull;


@Value
Expand Down Expand Up @@ -107,12 +106,12 @@ Dataset<Row> resourceDataset(@Nonnull final ResourceType resourceType) {
dataset.col("id_versioned").alias("key"),
functions.struct(
Stream.of(dataset.columns()).filter(c -> !c.startsWith("_"))
.map(dataset::col).toArray(Column[]::new)
.map(functions::col).toArray(Column[]::new)
).alias(resourceType.toCode())
);
final Stream<Column> implicitColumns = Stream.of(dataset.columns())
.filter(c -> c.startsWith("_"))
.map(dataset::col);
.map(functions::col);

return dataset.select(Stream.concat(explicitColumns, implicitColumns)
.toArray(Column[]::new));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package au.csiro.pathling.fhirpath.execution;

import static org.junit.jupiter.api.Assertions.assertEquals;

import au.csiro.pathling.fhirpath.execution.DataRoot.ResolveRoot;
import au.csiro.pathling.fhirpath.execution.DataRoot.ResourceRoot;
import au.csiro.pathling.fhirpath.execution.DataRoot.ReverseResolveRoot;
import java.util.List;
import java.util.Set;
import org.hl7.fhir.r4.model.Enumerations.ResourceType;
import org.junit.jupiter.api.Test;

class JoinSetTest {


@Test
void testSingleResourcePath() {

final List<DataRoot> result = JoinSet.toPath(
ResourceRoot.of(ResourceType.PATIENT));
assertEquals(List.of(ResourceRoot.of(ResourceType.PATIENT)), result);
}

@Test
void testReverseResolvePath() {
final List<DataRoot> result = JoinSet.toPath(
ReverseResolveRoot.of(ResourceRoot.of(ResourceType.PATIENT), ResourceType.CONDITION,
"subject"));
assertEquals(List.of(
ResourceRoot.of(ResourceType.PATIENT),
ReverseResolveRoot.of(ResourceRoot.of(ResourceType.PATIENT), ResourceType.CONDITION,
"subject")
), result);
}

@Test
void singleResourceRoots() {
final List<JoinSet> result = JoinSet.mergeRoots(
Set.of(
ResourceRoot.of(ResourceType.PATIENT),
ResourceRoot.of(ResourceType.CONDITION)
)
);
System.out.println(result);
}

@Test
void singleReverseResolvePath() {
final List<JoinSet> result = JoinSet.mergeRoots(
Set.of(
ReverseResolveRoot.of(ResourceRoot.of(ResourceType.PATIENT), ResourceType.CONDITION,
"subject")
)
);
System.out.println(result);
}

@Test
void nestedRoots() {
final List<JoinSet> result = JoinSet.mergeRoots(
Set.of(
ResourceRoot.of(ResourceType.PATIENT),
ReverseResolveRoot.of(ResourceRoot.of(ResourceType.PATIENT), ResourceType.CONDITION,
"subject"),
ReverseResolveRoot.of(ResourceRoot.of(ResourceType.PATIENT), ResourceType.ENCOUNTER,
"subject")
)
);
System.out.println(result);
}
@Test
void complexNestedRoots() {

final ResourceRoot mainRoot1 = ResourceRoot.of(ResourceType.PATIENT);
final DataRoot root_1_1 = ReverseResolveRoot.of(mainRoot1, ResourceType.CONDITION, "subject");
final DataRoot root_1_2 = ReverseResolveRoot.of(mainRoot1, ResourceType.ENCOUNTER, "subject");
final DataRoot root_1_2_1 = ResolveRoot.of(root_1_2, ResourceType.OBSERVATION, "observations");

System.out.println(JoinSet.toPath(root_1_2_1));


final List<JoinSet> result = JoinSet.mergeRoots(
Set.of(
mainRoot1,
root_1_1,
root_1_2,
root_1_2_1
)
);


System.out.println(result);
}

}


Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,10 @@ void multipleResolveToTheSameResourceOnDiffernetPaths() {
final Dataset<Row> resultDataset = evalExpression(dataSource,
ResourceType.ENCOUNTER,
"hospitalization.origin.resolve().ofType(Location).id"
+ " = hospitalization.destination.resolve().ofType(Location).id"
+ "=hospitalization.destination.resolve().ofType(Location).id"
);
System.out.println(resultDataset.queryExecution().executedPlan().toString());

resultDataset.show();
new DatasetAssert(resultDataset)
.hasRowsUnordered(
Expand Down

0 comments on commit 4f230c8

Please sign in to comment.