Skip to content

Commit

Permalink
Add support for sorting by expressions in CQL.
Browse files Browse the repository at this point in the history
This change adds support for sorting by expressions in CQL. This is done by adding a new model type, SortByExpression, which represents an expression that can be used to sort the results of a query. The parser is updated to parse sort by expressions, and the interpreter is updated to evaluate them.

PiperOrigin-RevId: 671490955
  • Loading branch information
rbrush authored and copybara-github committed Sep 24, 2024
1 parent 5915f54 commit 76d6089
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 48 deletions.
4 changes: 4 additions & 0 deletions cql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func TestCQL(t *testing.T) {
retriever: enginetests.BuildRetriever(t),
wantResult: newOrFatal(t, result.List{Value: []result.Value{
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "2"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
},
StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}},
}),
Expand All @@ -84,6 +85,7 @@ func TestCQL(t *testing.T) {
wantSourceValues: []result.Value{
newOrFatal(t, result.List{Value: []result.Value{
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "2"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
},
StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}},
}),
Expand Down Expand Up @@ -302,6 +304,7 @@ func TestCQL_MultipleEvals(t *testing.T) {
wantResult: newOrFatal(t, result.List{
Value: []result.Value{
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "2"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
},
StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}},
}),
Expand All @@ -324,6 +327,7 @@ func TestCQL_MultipleEvals(t *testing.T) {
newOrFatal(t, result.List{
Value: []result.Value{
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "1"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
newOrFatal(t, result.Named{Value: enginetests.RetrieveFHIRResource(t, "Encounter", "2"), RuntimeType: &types.Named{TypeName: "FHIR.Encounter"}}),
},
StaticType: &types.List{ElementType: &types.Named{TypeName: "FHIR.Encounter"}},
}),
Expand Down
30 changes: 30 additions & 0 deletions internal/reference/reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ type Resolver[T any, F any] struct {
// defined. Aliases live in the same namespace as definitions.
aliases []map[aliasKey]T

// scopedStructs hold the struct that are currently in scope for evaluation. For instance,
// an an expression like `[Encounter] O sort by start of period` places each encounter in scope,
// for the sorting criteria, and `period` is resolved against that encounter struct.
scopedStructs []T

// libs holds the qualified identifier of all named libraries that have been parsed.
libs map[namedLibKey]struct{}

Expand Down Expand Up @@ -405,6 +410,31 @@ func (r *Resolver[T, F]) ExitScope() {
}
}

// EnterStructScope starts a new scope for a struct.
func (r *Resolver[T, F]) EnterStructScope(q T) {
r.scopedStructs = append(r.scopedStructs, q)
}

// ExitStructScope clears the current struct scope.
func (r *Resolver[T, F]) ExitStructScope() {
if len(r.scopedStructs) > 0 {
r.scopedStructs = r.scopedStructs[:len(r.scopedStructs)-1]
}
}

// HasScopedStruct returns true if there is a struct in the current scope.
func (r *Resolver[T, F]) HasScopedStruct() bool {
return len(r.scopedStructs) > 0
}

// ScopedStruct returns the current struct scope.
func (r *Resolver[T, F]) ScopedStruct() (T, error) {
if len(r.scopedStructs) == 0 {
return zero[T](), fmt.Errorf("no scoped structs were set")
}
return r.scopedStructs[len(r.scopedStructs)-1], nil
}

// Alias creates a new alias within the current scope. When EndScope is called all aliases in the
// scope will be removed. Calling ResolveLocal with the same name will return the stored type t.
// Names must be unique within the CQL library.
Expand Down
51 changes: 51 additions & 0 deletions internal/reference/reference_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,57 @@ func TestParserAliasAndResolve(t *testing.T) {
}
}

func TestScopedStructs(t *testing.T) {
// Test scoping and de-scoping of structs in context.
r := NewResolver[result.Value, *model.FunctionDef]()

if r.HasScopedStruct() {
t.Errorf("HasScopedStruct() got true, want false")
}
_, err := r.ScopedStruct()
if err == nil {
t.Errorf("ScopedStruct() with no scope expected error but got success")
}

v1 := newOrFatal(1, t)
r.EnterStructScope(v1)
if !r.HasScopedStruct() {
t.Errorf("HasScopedStruct() got false when struct was in scope")
}

got, err := r.ScopedStruct()
if err != nil {
t.Fatalf("ScopedStruct() unexpected err: %v", err)
}
if diff := cmp.Diff(v1, got); diff != "" {
t.Errorf("ScopedStruct() diff (-want +got):\n%s", diff)
}

v2 := newOrFatal(2, t)
r.EnterStructScope(v2)
got, err = r.ScopedStruct()
if err != nil {
t.Fatalf("ScopedStruct() unexpected err: %v", err)
}
if diff := cmp.Diff(v2, got); diff != "" {
t.Errorf("ScopedStruct() diff (-want +got):\n%s", diff)
}

r.ExitStructScope()
got, err = r.ScopedStruct()
if err != nil {
t.Fatalf("ScopedStruct() unexpected err: %v", err)
}
if diff := cmp.Diff(v1, got); diff != "" {
t.Errorf("ScopedStruct() diff (-want +got):\n%s", diff)
}

r.ExitStructScope()
if r.HasScopedStruct() {
t.Errorf("HasScopedStruct() got true when no struct should be in scope")
}
}

func TestResolveIncludedLibrary(t *testing.T) {
// TEST SETUP - PREVIOUS PARSED LIBRARY
//
Expand Down
20 changes: 20 additions & 0 deletions interpreter/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func (i *interpreter) evalExpression(elem model.IExpression) (result.Value, erro
return i.evalQueryLetRef(elem)
case *model.AliasRef:
return i.evalAliasRef(elem)
case *model.IdentifierRef:
return i.evalIdentifierRef(elem)
case *model.CodeSystemRef:
return i.evalCodeSystemRef(elem)
case *model.ValuesetRef:
Expand Down Expand Up @@ -305,6 +307,24 @@ func (i *interpreter) evalAliasRef(a *model.AliasRef) (result.Value, error) {
return i.refs.ResolveLocal(a.Name)
}

func (i *interpreter) evalIdentifierRef(r *model.IdentifierRef) (result.Value, error) {
obj, err := i.refs.ScopedStruct()
if err != nil {
return result.Value{}, err
}

// Passing the static types here is likely unimportant, but we compute it for completeness.
aType, err := i.modelInfo.PropertyTypeSpecifier(obj.RuntimeType(), r.Name)
if err != nil {
return result.Value{}, err
}
ap, err := i.valueProperty(obj, r.Name, aType)
if err != nil {
return result.Value{}, err
}
return ap, nil
}

func (i *interpreter) evalOperandRef(a *model.OperandRef) (result.Value, error) {
return i.refs.ResolveLocal(a.Name)
}
Expand Down
73 changes: 41 additions & 32 deletions interpreter/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (i *interpreter) evalQuery(q *model.Query) (result.Value, error) {
return result.Value{}, err
}
} else {
i.sortByColumn(finalVals, q.Sort.ByItems)
err := i.sortByColumnOrExpression(finalVals, q.Sort.ByItems)
if err != nil {
return result.Value{}, err
}
Expand Down Expand Up @@ -490,49 +490,58 @@ func compareNumeralInt[t float64 | int64 | int32](left, right t) int {
}
}

func (i *interpreter) sortByColumn(objs []result.Value, sbis []model.ISortByItem) error {
// Validate sort column types.
for _, sortItems := range sbis {
// TODO(b/316984809): Is this validation in advance necessary? What if other values (beyond
// objs[0]) have a different runtime type for the property (e.g. if they're a choice type)?
// Consider validating types inline during the sort instead.
path := sortItems.(*model.SortByColumn).Path
propertyType, err := i.modelInfo.PropertyTypeSpecifier(objs[0].RuntimeType(), path)
func (i *interpreter) dateTimeOrError(v result.Value) (result.Value, error) {
switch sr := v.GolangValue().(type) {
case result.DateTime:
return v, nil
case result.Named:
if sr.RuntimeType.Equal(&types.Named{TypeName: "FHIR.dateTime"}) {
return i.protoProperty(sr, "value", types.DateTime)
}
}
return result.Value{}, fmt.Errorf("sorting only currently supported on DateTime columns")
}

// getSortValue returns the value to be used for the comparison-based sort. This
// is typically a field or expression on the structure being sorted.
func (i *interpreter) getSortValue(it model.ISortByItem, v result.Value) (result.Value, error) {
var rv result.Value
var err error
switch iv := it.(type) {
case *model.SortByColumn:
// Passing the static types here is likely unimportant, but we compute it for completeness.
t, err := i.modelInfo.PropertyTypeSpecifier(v.RuntimeType(), iv.Path)
if err != nil {
return err
return result.Value{}, err
}
columnVal, err := i.valueProperty(objs[0], path, propertyType)
rv, err = i.valueProperty(v, iv.Path, t)
if err != nil {
return err
return result.Value{}, err
}
// Strictly only allow DateTimes for now.
// TODO(b/316984809): add sorting support for other types.
if !columnVal.RuntimeType().Equal(types.DateTime) {
return fmt.Errorf("sort column of a query must evaluate to a date time, instead got %v", columnVal.RuntimeType())
case *model.SortByExpression:
i.refs.EnterStructScope(v)
defer i.refs.ExitStructScope()
rv, err = i.evalExpression(iv.SortExpression)
if err != nil {
return result.Value{}, err
}
default:
return result.Value{}, fmt.Errorf("internal error - unsupported sort by item type: %T", iv)
}

return i.dateTimeOrError(rv)
}

func (i *interpreter) sortByColumnOrExpression(objs []result.Value, sbis []model.ISortByItem) error {
var sortErr error = nil
slices.SortFunc(objs[:], func(a, b result.Value) int {
for _, sortItems := range sbis {
sortCol := sortItems.(*model.SortByColumn)
// Passing the static types here is likely unimportant, but we compute it for completeness.
aType, err := i.modelInfo.PropertyTypeSpecifier(a.RuntimeType(), sortCol.Path)
if err != nil {
sortErr = err
continue
}
ap, err := i.valueProperty(a, sortCol.Path, aType)
if err != nil {
sortErr = err
continue
}
bType, err := i.modelInfo.PropertyTypeSpecifier(b.RuntimeType(), sortCol.Path)
for _, sortItem := range sbis {
ap, err := i.getSortValue(sortItem, a)
if err != nil {
sortErr = err
continue
}
bp, err := i.valueProperty(b, sortCol.Path, bType)
bp, err := i.getSortValue(sortItem, b)
if err != nil {
sortErr = err
continue
Expand All @@ -544,7 +553,7 @@ func (i *interpreter) sortByColumn(objs []result.Value, sbis []model.ISortByItem
// TODO(b/308012659): Implement dateTime comparison that doesn't take a precision.
if av.Equal(bv) {
continue
} else if sortCol.SortByItem.Direction == model.DESCENDING {
} else if sortItem.SortDirection() == model.DESCENDING {
return bv.Compare(av)
}
return av.Compare(bv)
Expand Down
21 changes: 17 additions & 4 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ type ReturnClause struct {
// Follows format outlined in https://cql.hl7.org/elm/schema/expression.xsd.
type ISortByItem interface {
IElement
isSortByItem()
SortDirection() SortDirection
}

// SortByItem is the base abstract type for all query types.
Expand All @@ -492,20 +492,25 @@ type SortByItem struct {
Direction SortDirection
}

// SortDirection returns the direction of the sort, e.g. ASCENDING or DESCENDING.
func (s *SortByItem) SortDirection() SortDirection { return s.Direction }

// SortByDirection enables sorting non-tuple values by direction
type SortByDirection struct {
*SortByItem
}

func (c *SortByDirection) isSortByItem() {}

// SortByColumn enables sorting by a given column and direction.
type SortByColumn struct {
*SortByItem
Path string
}

func (c *SortByColumn) isSortByItem() {}
// SortByExpression enables sorting by an expression and direction.
type SortByExpression struct {
*SortByItem
SortExpression IExpression
}

// AliasedSource is a query source with an alias.
type AliasedSource struct {
Expand Down Expand Up @@ -1158,6 +1163,14 @@ type OperandRef struct {
Name string
}

// IdentifierRef defines a reference to an identifier within a defined scope, such as a sort by.
// This is distinct from other references since it not a defined name, but will typically reference
// a field for some structure in scope of a sort expression.
type IdentifierRef struct {
*Expression
Name string
}

// UNARY EXPRESSION GETNAME()

// GetName returns the name of the system operator.
Expand Down
21 changes: 21 additions & 0 deletions parser/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,27 @@ func (v *visitor) VisitQuantityContext(ctx cql.IQuantityContext) (model.Quantity
// visitor.
func (v *visitor) VisitReferentialIdentifier(ctx cql.IReferentialIdentifierContext) model.IExpression {
name := v.parseReferentialIdentifier(ctx)

if v.refs.HasScopedStruct() {
sourceFn, err := v.refs.ScopedStruct()
if err != nil {
return v.badExpression(err.Error(), ctx)
}

// If the query source has the expected property, return the identifier ref. Otherwise
// fall through to the resolution logic below.
source := sourceFn()
elementType := source.GetResultType().(*types.List).ElementType

ptype, err := v.modelInfo.PropertyTypeSpecifier(elementType, name)
if err == nil {
return &model.IdentifierRef{
Name: name,
Expression: model.ResultType(ptype),
}
}
}

if i := v.refs.ResolveInclude(name); i != nil {
return v.badExpression(fmt.Sprintf("internal error - referential identifier %v is a local identifier to an included library", name), ctx)
}
Expand Down
Loading

0 comments on commit 76d6089

Please sign in to comment.