Skip to content

Commit

Permalink
Adding merging of any duplicate map columns in resolve() joins and co…
Browse files Browse the repository at this point in the history
…rrect handling of merging with NULL map columns.
  • Loading branch information
piotrszul committed Dec 20, 2024
1 parent 50dfe36 commit 786c3c2
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import au.csiro.pathling.fhirpath.context.ViewEvaluationContext;
import au.csiro.pathling.fhirpath.execution.DataRoot.JoinRoot;
import au.csiro.pathling.fhirpath.execution.DataRoot.ResolveRoot;
import au.csiro.pathling.fhirpath.execution.DataRoot.ResourceRoot;
import au.csiro.pathling.fhirpath.execution.DataRoot.ReverseResolveRoot;
import au.csiro.pathling.fhirpath.function.registry.FunctionRegistry;
import au.csiro.pathling.fhirpath.parser.Parser;
Expand All @@ -40,8 +39,10 @@
import jakarta.annotation.Nonnull;
import jakarta.annotation.Nullable;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.Value;
import org.apache.spark.sql.Column;
Expand All @@ -58,6 +59,14 @@
@Value
public class MultiFhirPathEvaluator implements FhirPathEvaluator {


@Nonnull
static Column ns_map_concat(@Nonnull final Column left, @Nonnull final Column right) {
return functions.when(left.isNull(), right)
.when(right.isNull(), left)
.otherwise(functions.map_concat(left, right));
}

// TODO: Move somewhere else

@Nonnull
Expand All @@ -68,7 +77,7 @@ public static Column collect_map(@Nonnull final Column mapColumn) {
return functions.reduce(
functions.collect_list(mapColumn),
functions.any_value(mapColumn),
(acc, elem) -> functions.when(acc.isNull(), elem).otherwise(functions.map_concat(acc, elem))
(acc, elem) -> functions.when(acc.isNull(), elem).otherwise(ns_map_concat(acc, elem))
);
}

Expand All @@ -92,7 +101,7 @@ public Dataset<Row> createInitialDataset() {
// createInitialDataset(), null)
// : createInitialDataset();

Dataset<Row> resolvedDataset = resolveJoinsEx(
Dataset<Row> resolvedDataset = resolveJoins(
JoinSet.mergeRoots(joinRoots).iterator().next(),
createInitialDataset());

Expand Down Expand Up @@ -148,6 +157,7 @@ private Dataset<Row> computeResolveJoin(@Nonnull final Dataset<Row> parentDatase
+ typedRoot.getResourceType());
}

Dataset<Row> resultDataset;
if (referenceCollection.isToOneReference()) {

// TODO: this should be replaced with call to evalPath() with not grouping context
Expand Down Expand Up @@ -190,25 +200,20 @@ private Dataset<Row> computeResolveJoin(@Nonnull final Dataset<Row> parentDatase
final Collection parentRegKey = parentExecutor.evaluate(new Traversal("reference"),
referenceCollection);

final boolean needsMerging = List.of(parentDataset.columns())
.contains(typedRoot.getValueTag());
final Dataset<Row> joinedDataset = parentDataset.join(childResult,
parentRegKey.getColumnValue().equalTo(functions.col(typedRoot.getChildKeyTag())),
"left_outer")
.withColumn(typedRoot.getValueTag(),
needsMerging
? functions.map_concat(
functions.col(typedRoot.getValueTag()),
functions.col(uniqueValueTag))
: functions.col(uniqueValueTag)
)
.drop(typedRoot.getChildKeyTag(), uniqueValueTag);

System.out.println("Joined dataset:");
joinedDataset.show();
return joinedDataset;
.drop(typedRoot.getChildKeyTag());

final Dataset<Row> finalDataset = mergeMapColumns(joinedDataset, typedRoot.getValueTag(),
uniqueValueTag);

System.out.println("Final dataset:");
finalDataset.show();
return finalDataset;
} else {

final String uniqueValueTag = typedRoot.getValueTag() + "_unique";
// TODO: this should be replaced with call to evalPath() with not grouping context
final FhirPathExecutor childExecutor = createExecutor(
typedRoot.getForeignResourceType(),
Expand Down Expand Up @@ -236,12 +241,15 @@ private Dataset<Row> computeResolveJoin(@Nonnull final Dataset<Row> parentDatase
functions.array(childDataset.col("key")),
// maybe need to be wrapped in another array
functions.array(childResource.getColumnValue())
).alias(typedRoot.getValueTag())
).alias(uniqueValueTag)
);

final Dataset<Row> childResult = childDataset.select(
Streams.concat(keyValuesColumns, childPassThroughColumns)
.toArray(Column[]::new));
final Dataset<Row> childResult =
childDataset.select(
Streams.concat(keyValuesColumns, childPassThroughColumns)
.toArray(Column[]::new));

// but we also need to map_concat child maps to the current join if exits

// and now join to the parent reference

Expand All @@ -251,10 +259,9 @@ private Dataset<Row> computeResolveJoin(@Nonnull final Dataset<Row> parentDatase
final Dataset<Row> expandedParent = parentDataset.withColumn(typedRoot.getParentKeyTag(),
functions.explode_outer(parentRegKey.getColumnValue()));

final Dataset<Row> joinedDataset = expandedParent.join(childResult,
expandedParent.col(typedRoot.getParentKeyTag())
.equalTo(childResult.col(typedRoot.getChildKeyTag())),
"left_outer")
final Dataset<Row> joinedDataset = joinWithMapMerge(expandedParent, childResult,
expandedParent.col(typedRoot.getParentKeyTag())
.equalTo(childResult.col(typedRoot.getChildKeyTag())))
.drop(typedRoot.getChildKeyTag(), typedRoot.getParentKeyTag());
joinedDataset.show();

Expand All @@ -271,7 +278,7 @@ private Dataset<Row> computeResolveJoin(@Nonnull final Dataset<Row> parentDatase
final Column[] allPassColumns = Stream.concat(
parentColumns,
Stream.concat(Stream.of(
collect_map(functions.col(typedRoot.getValueTag())).alias(typedRoot.getValueTag())),
collect_map(functions.col(uniqueValueTag)).alias(uniqueValueTag)),
passThroughColumns))
.toArray(Column[]::new);

Expand All @@ -281,12 +288,57 @@ private Dataset<Row> computeResolveJoin(@Nonnull final Dataset<Row> parentDatase
allPassColumns
);

regroupedDataset.show();
return regroupedDataset;

resultDataset = mergeMapColumns(regroupedDataset, typedRoot.getValueTag(),
uniqueValueTag);
resultDataset.show();
return resultDataset;
}
}

// I Need to be able to make a smart join where the map columns are merged
// and the other columns are passed through

@Nonnull
private Dataset<Row> joinWithMapMerge(@Nonnull final Dataset<Row> leftDataset,
@Nonnull final Dataset<Row> rightDataset,
@Nonnull final Column on) {

// deduplicate columns
// for @map colums map_contat
// for others keep the left
final Set<String> commonColumns = new HashSet<>(List.of(leftDataset.columns()));
commonColumns.retainAll(Set.of(rightDataset.columns()));

final Set<String> commonMapColumns = commonColumns.stream()
.filter(c -> c.contains("@"))
.collect(Collectors.toUnmodifiableSet());

final Column[] uniqueSelection = Stream.concat(
Stream.of(leftDataset.columns())
.map(c -> commonMapColumns.contains(c)
? ns_map_concat(leftDataset.col(c), rightDataset.col(c)).alias(c)
: leftDataset.col(c)),
Stream.of(rightDataset.columns())
.filter(c -> !commonColumns.contains(c))
.map(rightDataset::col)

).toArray(Column[]::new);
return leftDataset.join(rightDataset, on, "left_outer")
.select(uniqueSelection);
}


@Nonnull
private Dataset<Row> mergeMapColumns(@Nonnull final Dataset<Row> dataset,
@Nonnull final String finalColumn, @Nonnull final String tempColumn) {
if (List.of(dataset.columns()).contains(finalColumn)) {
return dataset.withColumn(finalColumn,
ns_map_concat(functions.col(finalColumn), functions.col(tempColumn)))
.drop(tempColumn);
} else {
return dataset.withColumnRenamed(tempColumn, finalColumn);
}
}

@Nonnull
private Dataset<Row> computeReverseJoin(@Nonnull final Dataset<Row> parentDataset,
Expand Down Expand Up @@ -354,52 +406,18 @@ private Dataset<Row> computeReverseJoin(@Nonnull final Dataset<Row> parentDatase
}

@Nonnull
private Dataset<Row> resolveJoinsEx(@Nonnull final JoinSet joinSet,
private Dataset<Row> resolveJoins(@Nonnull final JoinSet joinSet,
@Nonnull final Dataset<Row> parentDataset) {

// now just reduce current children
return joinSet.getChildren().stream()
.reduce(parentDataset, (dataset, subset) ->
// the parent dataset for subjoin should be different
computeJoin(dataset,
resolveJoinsEx(subset, resourceDataset(subset.getMaster().getResourceType())),
resolveJoins(subset, resourceDataset(subset.getMaster().getResourceType())),
(JoinRoot) subset.getMaster()), (dataset1, dataset2) -> dataset1);
}

@Nonnull
private Dataset<Row> resolveJoins(@Nonnull final JoinRoot joinRoot,
@Nonnull final Dataset<Row> parentDataset, @Nullable final Dataset<Row> maybeChildDataset) {

//

// for nested joins we need to do it recursively aggregating all the existing join map columns
// the problem is that I need to do it in reverse order

// Ineed to to do deep unnestign

if (joinRoot.getMaster() instanceof ResourceRoot) {
return computeJoin(parentDataset, maybeChildDataset, joinRoot);
} else if (joinRoot.getMaster() instanceof JoinRoot jr) {

final Dataset<Row> childDataset = computeJoin(
createExecutor(joinRoot.getMaster().getResourceType(),
dataSource).createInitialDataset(),
maybeChildDataset, joinRoot);

return resolveJoins(jr, parentDataset, childDataset);
// compute the dataset for the current resolve root with the master type as the subject
// and default master resource as parent dataset

// resulting dataset then needs to be passed as the child result for the deeper join
// if there is no child result then we create one from the default child dataset (so initially it's null)

} else {
throw new UnsupportedOperationException(
"Not implemented - unknown root type: " + joinRoot);
}
}



@Nonnull
private FhirPathExecutor createExecutor(final ResourceType subjectResourceType,
final DataSource dataSource) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.hl7.fhir.r4.model.Appointment;
import org.hl7.fhir.r4.model.CodeableConcept;
import org.hl7.fhir.r4.model.Coding;
import org.hl7.fhir.r4.model.Condition;
Expand Down Expand Up @@ -588,10 +589,9 @@ void resolveBackFromReverseResolve() {
RowFactory.create("3", null)
);
}



@Test
void multipleResolveToTheSameResourceOnDiffernetPaths() {
void multipleResolveToOneToTheSameResourceOnDiffernetPaths() {
final ObjectDataSource dataSource =
new ObjectDataSource(spark, encoders,
List.of(
Expand All @@ -607,24 +607,104 @@ void multipleResolveToTheSameResourceOnDiffernetPaths() {
.setDestination(new Reference("Location/3"))
)
.setId("Encounter/2"),

new Encounter()
.setHospitalization(new Encounter.EncounterHospitalizationComponent()
.setOrigin(new Reference("Location/3"))
)
.setId("Encounter/3"),
new Location().setId("Location/1"),
new Location().setId("Location/2"),
new Location().setId("Location/3")
));

final Dataset<Row> resultDataset = evalExpression(dataSource,
ResourceType.ENCOUNTER,
"hospitalization.origin.resolve().ofType(Location).id"
+ "=hospitalization.destination.resolve().ofType(Location).id"
"hospitalization.origin.resolve().ofType(Location).count()"
+ "=hospitalization.destination.resolve().ofType(Location).count()"
);
System.out.println(resultDataset.queryExecution().executedPlan().toString());

resultDataset.show();
new DatasetAssert(resultDataset)
.hasRowsUnordered(
RowFactory.create("1", false),
RowFactory.create("2", true)
RowFactory.create("1", true),
RowFactory.create("2", true),
RowFactory.create("3", false)
);
}

@Test
void multipleResolveToManyTheSameResourceInSubresolve() {
final ObjectDataSource dataSource =
new ObjectDataSource(spark, encoders,
List.of(
new Encounter()
.addAppointment(new Reference("Appointment/1.1"))
.addReasonReference(new Reference("Condition/1.1"))
.setId("Encounter/1"),
new Appointment()
.addReasonReference(new Reference("Condition/1.1"))
.setId("Appointment/1.1"),
new Condition().setId("Condition/1.1"),
new Encounter()
.addAppointment(new Reference("Appointment/2.1"))
.addReasonReference(new Reference("Condition/2.1"))
.setId("Encounter/2"),
new Appointment()
.addReasonReference(new Reference("Condition/2.2"))
.setId("Appointment/2.1"),
new Condition().setId("Condition/2.1"),
new Condition().setId("Condition/2.2")
));

final Dataset<Row> resultDataset = evalExpression(dataSource,
ResourceType.ENCOUNTER,
"appointment.resolve().reasonReference.resolve().ofType(Condition).id.first()"
+ " = reasonReference.resolve().ofType(Condition).id.first()"
);
resultDataset.show();
new DatasetAssert(resultDataset)
.hasRowsUnordered(
RowFactory.create("1", true),
RowFactory.create("2", false)
);
}


@Test
void chainedReferenceToSameResouece() {
final ObjectDataSource dataSource =
new ObjectDataSource(spark, encoders,
List.of(
new Observation()
.addHasMember(new Reference("Observation/2"))
.setId("Observation/1"),
new Observation()
.addHasMember(new Reference("Observation/3"))
.setId("Observation/2"),
new Observation()
.addHasMember(new Reference("Observation/4"))
.setId("Observation/3"),
new Observation().setId("Observation/4")
));

final Dataset<Row> resultDataset = evalExpression(dataSource,
ResourceType.OBSERVATION,
"hasMember.resolve().ofType(Observation)"
+ ".hasMember.resolve().ofType(Observation)"
+ ".hasMember.resolve().ofType(Observation)"
+ ".count()"
+ " + hasMember.resolve().ofType(Observation).count()"
+ " + hasMember.resolve().ofType(Observation).hasMember.resolve().ofType(Observation).count()"
);
resultDataset.show();
new DatasetAssert(resultDataset)
.hasRowsUnordered(
RowFactory.create("1", 3),
RowFactory.create("2", 2),
RowFactory.create("3", 1),
RowFactory.create("4", 0)
);
}

}

0 comments on commit 786c3c2

Please sign in to comment.