From 786c3c23c9f6a3b0f52a984977644ed2e585d361 Mon Sep 17 00:00:00 2001 From: Piotr Szul Date: Fri, 20 Dec 2024 11:28:47 +1000 Subject: [PATCH] Adding merging of any duplicate map columns in resolve() joins and correct handling of merging with NULL map columns. --- .../execution/MultiFhirPathEvaluator.java | 150 ++++++++++-------- .../fhirpathe/execution/FhirpathTest.java | 96 ++++++++++- 2 files changed, 172 insertions(+), 74 deletions(-) diff --git a/fhirpath/src/main/java/au/csiro/pathling/fhirpath/execution/MultiFhirPathEvaluator.java b/fhirpath/src/main/java/au/csiro/pathling/fhirpath/execution/MultiFhirPathEvaluator.java index 9c4cb7af77..df7c0536dc 100644 --- a/fhirpath/src/main/java/au/csiro/pathling/fhirpath/execution/MultiFhirPathEvaluator.java +++ b/fhirpath/src/main/java/au/csiro/pathling/fhirpath/execution/MultiFhirPathEvaluator.java @@ -29,7 +29,6 @@ import au.csiro.pathling.fhirpath.context.ViewEvaluationContext; import au.csiro.pathling.fhirpath.execution.DataRoot.JoinRoot; import au.csiro.pathling.fhirpath.execution.DataRoot.ResolveRoot; -import au.csiro.pathling.fhirpath.execution.DataRoot.ResourceRoot; import au.csiro.pathling.fhirpath.execution.DataRoot.ReverseResolveRoot; import au.csiro.pathling.fhirpath.function.registry.FunctionRegistry; import au.csiro.pathling.fhirpath.parser.Parser; @@ -40,8 +39,10 @@ import jakarta.annotation.Nonnull; import jakarta.annotation.Nullable; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; import java.util.stream.Stream; import lombok.Value; import org.apache.spark.sql.Column; @@ -58,6 +59,14 @@ @Value public class MultiFhirPathEvaluator implements FhirPathEvaluator { + + @Nonnull + static Column ns_map_concat(@Nonnull final Column left, @Nonnull final Column right) { + return functions.when(left.isNull(), right) + .when(right.isNull(), left) + .otherwise(functions.map_concat(left, right)); + } + // TODO: Move somewhere else @Nonnull @@ -68,7 +77,7 @@ public static Column collect_map(@Nonnull final Column mapColumn) { return functions.reduce( functions.collect_list(mapColumn), functions.any_value(mapColumn), - (acc, elem) -> functions.when(acc.isNull(), elem).otherwise(functions.map_concat(acc, elem)) + (acc, elem) -> functions.when(acc.isNull(), elem).otherwise(ns_map_concat(acc, elem)) ); } @@ -92,7 +101,7 @@ public Dataset createInitialDataset() { // createInitialDataset(), null) // : createInitialDataset(); - Dataset resolvedDataset = resolveJoinsEx( + Dataset resolvedDataset = resolveJoins( JoinSet.mergeRoots(joinRoots).iterator().next(), createInitialDataset()); @@ -148,6 +157,7 @@ private Dataset computeResolveJoin(@Nonnull final Dataset parentDatase + typedRoot.getResourceType()); } + Dataset resultDataset; if (referenceCollection.isToOneReference()) { // TODO: this should be replaced with call to evalPath() with not grouping context @@ -190,25 +200,20 @@ private Dataset computeResolveJoin(@Nonnull final Dataset parentDatase final Collection parentRegKey = parentExecutor.evaluate(new Traversal("reference"), referenceCollection); - final boolean needsMerging = List.of(parentDataset.columns()) - .contains(typedRoot.getValueTag()); final Dataset joinedDataset = parentDataset.join(childResult, parentRegKey.getColumnValue().equalTo(functions.col(typedRoot.getChildKeyTag())), "left_outer") - .withColumn(typedRoot.getValueTag(), - needsMerging - ? functions.map_concat( - functions.col(typedRoot.getValueTag()), - functions.col(uniqueValueTag)) - : functions.col(uniqueValueTag) - ) - .drop(typedRoot.getChildKeyTag(), uniqueValueTag); - - System.out.println("Joined dataset:"); - joinedDataset.show(); - return joinedDataset; + .drop(typedRoot.getChildKeyTag()); + + final Dataset finalDataset = mergeMapColumns(joinedDataset, typedRoot.getValueTag(), + uniqueValueTag); + + System.out.println("Final dataset:"); + finalDataset.show(); + return finalDataset; } else { + final String uniqueValueTag = typedRoot.getValueTag() + "_unique"; // TODO: this should be replaced with call to evalPath() with not grouping context final FhirPathExecutor childExecutor = createExecutor( typedRoot.getForeignResourceType(), @@ -236,12 +241,15 @@ private Dataset computeResolveJoin(@Nonnull final Dataset parentDatase functions.array(childDataset.col("key")), // maybe need to be wrapped in another array functions.array(childResource.getColumnValue()) - ).alias(typedRoot.getValueTag()) + ).alias(uniqueValueTag) ); - final Dataset childResult = childDataset.select( - Streams.concat(keyValuesColumns, childPassThroughColumns) - .toArray(Column[]::new)); + final Dataset childResult = + childDataset.select( + Streams.concat(keyValuesColumns, childPassThroughColumns) + .toArray(Column[]::new)); + + // but we also need to map_concat child maps to the current join if exits // and now join to the parent reference @@ -251,10 +259,9 @@ private Dataset computeResolveJoin(@Nonnull final Dataset parentDatase final Dataset expandedParent = parentDataset.withColumn(typedRoot.getParentKeyTag(), functions.explode_outer(parentRegKey.getColumnValue())); - final Dataset joinedDataset = expandedParent.join(childResult, - expandedParent.col(typedRoot.getParentKeyTag()) - .equalTo(childResult.col(typedRoot.getChildKeyTag())), - "left_outer") + final Dataset joinedDataset = joinWithMapMerge(expandedParent, childResult, + expandedParent.col(typedRoot.getParentKeyTag()) + .equalTo(childResult.col(typedRoot.getChildKeyTag()))) .drop(typedRoot.getChildKeyTag(), typedRoot.getParentKeyTag()); joinedDataset.show(); @@ -271,7 +278,7 @@ private Dataset computeResolveJoin(@Nonnull final Dataset parentDatase final Column[] allPassColumns = Stream.concat( parentColumns, Stream.concat(Stream.of( - collect_map(functions.col(typedRoot.getValueTag())).alias(typedRoot.getValueTag())), + collect_map(functions.col(uniqueValueTag)).alias(uniqueValueTag)), passThroughColumns)) .toArray(Column[]::new); @@ -281,12 +288,57 @@ private Dataset computeResolveJoin(@Nonnull final Dataset parentDatase allPassColumns ); - regroupedDataset.show(); - return regroupedDataset; - + resultDataset = mergeMapColumns(regroupedDataset, typedRoot.getValueTag(), + uniqueValueTag); + resultDataset.show(); + return resultDataset; } } + // I Need to be able to make a smart join where the map columns are merged + // and the other columns are passed through + + @Nonnull + private Dataset joinWithMapMerge(@Nonnull final Dataset leftDataset, + @Nonnull final Dataset rightDataset, + @Nonnull final Column on) { + + // deduplicate columns + // for @map colums map_contat + // for others keep the left + final Set commonColumns = new HashSet<>(List.of(leftDataset.columns())); + commonColumns.retainAll(Set.of(rightDataset.columns())); + + final Set commonMapColumns = commonColumns.stream() + .filter(c -> c.contains("@")) + .collect(Collectors.toUnmodifiableSet()); + + final Column[] uniqueSelection = Stream.concat( + Stream.of(leftDataset.columns()) + .map(c -> commonMapColumns.contains(c) + ? ns_map_concat(leftDataset.col(c), rightDataset.col(c)).alias(c) + : leftDataset.col(c)), + Stream.of(rightDataset.columns()) + .filter(c -> !commonColumns.contains(c)) + .map(rightDataset::col) + + ).toArray(Column[]::new); + return leftDataset.join(rightDataset, on, "left_outer") + .select(uniqueSelection); + } + + + @Nonnull + private Dataset mergeMapColumns(@Nonnull final Dataset dataset, + @Nonnull final String finalColumn, @Nonnull final String tempColumn) { + if (List.of(dataset.columns()).contains(finalColumn)) { + return dataset.withColumn(finalColumn, + ns_map_concat(functions.col(finalColumn), functions.col(tempColumn))) + .drop(tempColumn); + } else { + return dataset.withColumnRenamed(tempColumn, finalColumn); + } + } @Nonnull private Dataset computeReverseJoin(@Nonnull final Dataset parentDataset, @@ -354,7 +406,7 @@ private Dataset computeReverseJoin(@Nonnull final Dataset parentDatase } @Nonnull - private Dataset resolveJoinsEx(@Nonnull final JoinSet joinSet, + private Dataset resolveJoins(@Nonnull final JoinSet joinSet, @Nonnull final Dataset parentDataset) { // now just reduce current children @@ -362,44 +414,10 @@ private Dataset resolveJoinsEx(@Nonnull final JoinSet joinSet, .reduce(parentDataset, (dataset, subset) -> // the parent dataset for subjoin should be different computeJoin(dataset, - resolveJoinsEx(subset, resourceDataset(subset.getMaster().getResourceType())), + resolveJoins(subset, resourceDataset(subset.getMaster().getResourceType())), (JoinRoot) subset.getMaster()), (dataset1, dataset2) -> dataset1); } - - @Nonnull - private Dataset resolveJoins(@Nonnull final JoinRoot joinRoot, - @Nonnull final Dataset parentDataset, @Nullable final Dataset maybeChildDataset) { - - // - - // for nested joins we need to do it recursively aggregating all the existing join map columns - // the problem is that I need to do it in reverse order - - // Ineed to to do deep unnestign - - if (joinRoot.getMaster() instanceof ResourceRoot) { - return computeJoin(parentDataset, maybeChildDataset, joinRoot); - } else if (joinRoot.getMaster() instanceof JoinRoot jr) { - - final Dataset childDataset = computeJoin( - createExecutor(joinRoot.getMaster().getResourceType(), - dataSource).createInitialDataset(), - maybeChildDataset, joinRoot); - - return resolveJoins(jr, parentDataset, childDataset); - // compute the dataset for the current resolve root with the master type as the subject - // and default master resource as parent dataset - - // resulting dataset then needs to be passed as the child result for the deeper join - // if there is no child result then we create one from the default child dataset (so initially it's null) - - } else { - throw new UnsupportedOperationException( - "Not implemented - unknown root type: " + joinRoot); - } - } - - + @Nonnull private FhirPathExecutor createExecutor(final ResourceType subjectResourceType, final DataSource dataSource) { diff --git a/fhirpath/src/test/java/au/csiro/pathling/fhirpathe/execution/FhirpathTest.java b/fhirpath/src/test/java/au/csiro/pathling/fhirpathe/execution/FhirpathTest.java index 1219ef424b..3bfa5e7cc4 100644 --- a/fhirpath/src/test/java/au/csiro/pathling/fhirpathe/execution/FhirpathTest.java +++ b/fhirpath/src/test/java/au/csiro/pathling/fhirpathe/execution/FhirpathTest.java @@ -19,6 +19,7 @@ import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; +import org.hl7.fhir.r4.model.Appointment; import org.hl7.fhir.r4.model.CodeableConcept; import org.hl7.fhir.r4.model.Coding; import org.hl7.fhir.r4.model.Condition; @@ -588,10 +589,9 @@ void resolveBackFromReverseResolve() { RowFactory.create("3", null) ); } - - + @Test - void multipleResolveToTheSameResourceOnDiffernetPaths() { + void multipleResolveToOneToTheSameResourceOnDiffernetPaths() { final ObjectDataSource dataSource = new ObjectDataSource(spark, encoders, List.of( @@ -607,7 +607,11 @@ void multipleResolveToTheSameResourceOnDiffernetPaths() { .setDestination(new Reference("Location/3")) ) .setId("Encounter/2"), - + new Encounter() + .setHospitalization(new Encounter.EncounterHospitalizationComponent() + .setOrigin(new Reference("Location/3")) + ) + .setId("Encounter/3"), new Location().setId("Location/1"), new Location().setId("Location/2"), new Location().setId("Location/3") @@ -615,16 +619,92 @@ void multipleResolveToTheSameResourceOnDiffernetPaths() { final Dataset resultDataset = evalExpression(dataSource, ResourceType.ENCOUNTER, - "hospitalization.origin.resolve().ofType(Location).id" - + "=hospitalization.destination.resolve().ofType(Location).id" + "hospitalization.origin.resolve().ofType(Location).count()" + + "=hospitalization.destination.resolve().ofType(Location).count()" ); System.out.println(resultDataset.queryExecution().executedPlan().toString()); resultDataset.show(); new DatasetAssert(resultDataset) .hasRowsUnordered( - RowFactory.create("1", false), - RowFactory.create("2", true) + RowFactory.create("1", true), + RowFactory.create("2", true), + RowFactory.create("3", false) + ); + } + + @Test + void multipleResolveToManyTheSameResourceInSubresolve() { + final ObjectDataSource dataSource = + new ObjectDataSource(spark, encoders, + List.of( + new Encounter() + .addAppointment(new Reference("Appointment/1.1")) + .addReasonReference(new Reference("Condition/1.1")) + .setId("Encounter/1"), + new Appointment() + .addReasonReference(new Reference("Condition/1.1")) + .setId("Appointment/1.1"), + new Condition().setId("Condition/1.1"), + new Encounter() + .addAppointment(new Reference("Appointment/2.1")) + .addReasonReference(new Reference("Condition/2.1")) + .setId("Encounter/2"), + new Appointment() + .addReasonReference(new Reference("Condition/2.2")) + .setId("Appointment/2.1"), + new Condition().setId("Condition/2.1"), + new Condition().setId("Condition/2.2") + )); + + final Dataset resultDataset = evalExpression(dataSource, + ResourceType.ENCOUNTER, + "appointment.resolve().reasonReference.resolve().ofType(Condition).id.first()" + + " = reasonReference.resolve().ofType(Condition).id.first()" + ); + resultDataset.show(); + new DatasetAssert(resultDataset) + .hasRowsUnordered( + RowFactory.create("1", true), + RowFactory.create("2", false) + ); + } + + + @Test + void chainedReferenceToSameResouece() { + final ObjectDataSource dataSource = + new ObjectDataSource(spark, encoders, + List.of( + new Observation() + .addHasMember(new Reference("Observation/2")) + .setId("Observation/1"), + new Observation() + .addHasMember(new Reference("Observation/3")) + .setId("Observation/2"), + new Observation() + .addHasMember(new Reference("Observation/4")) + .setId("Observation/3"), + new Observation().setId("Observation/4") + )); + + final Dataset resultDataset = evalExpression(dataSource, + ResourceType.OBSERVATION, + "hasMember.resolve().ofType(Observation)" + + ".hasMember.resolve().ofType(Observation)" + + ".hasMember.resolve().ofType(Observation)" + + ".count()" + + " + hasMember.resolve().ofType(Observation).count()" + + " + hasMember.resolve().ofType(Observation).hasMember.resolve().ofType(Observation).count()" + ); + resultDataset.show(); + new DatasetAssert(resultDataset) + .hasRowsUnordered( + RowFactory.create("1", 3), + RowFactory.create("2", 2), + RowFactory.create("3", 1), + RowFactory.create("4", 0) ); } + }