Skip to content

Commit

Permalink
Merge pull request #62 from suyashkumar:s/median
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662683013
  • Loading branch information
copybara-github committed Aug 13, 2024
2 parents 72ad288 + 738226b commit 540a79c
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 2 deletions.
84 changes: 84 additions & 0 deletions interpreter/operator_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package interpreter

import (
"fmt"
"sort"

"github.com/google/cql/model"
"github.com/google/cql/result"
Expand Down Expand Up @@ -264,6 +265,89 @@ func (i *interpreter) evalMinDateTime(m model.IUnaryExpression, operand result.V
return result.New(dt)
}

// Median(argument List<Decimal>) Decimal
// https://cql.hl7.org/09-b-cqlreference.html#median
func (i *interpreter) evalMedianDecimal(_ model.IUnaryExpression, operand result.Value) (result.Value, error) {
if result.IsNull(operand) {
return result.New(nil)
}

l, err := result.ToSlice(operand)
if err != nil {
return result.Value{}, err
}

values := make([]float64, 0, len(l))
for _, elem := range l {
if result.IsNull(elem) {
continue
}
v, err := result.ToFloat64(elem)
if err != nil {
return result.Value{}, err
}
values = append(values, v)
}
if len(values) == 0 {
return result.New(nil)
}

median := calculateMedianFloat64(values)
return result.New(median)
}

// Median(argument List<Quantity>) Quantity
// https://cql.hl7.org/09-b-cqlreference.html#median
func (i *interpreter) evalMedianQuantity(_ model.IUnaryExpression, operand result.Value) (result.Value, error) {
if result.IsNull(operand) {
return result.New(nil)
}

l, err := result.ToSlice(operand)
if err != nil {
return result.Value{}, err
}

values := make([]float64, 0, len(l))
var unit model.Unit
for idx, elem := range l {
if result.IsNull(elem) {
continue
}
v, err := result.ToQuantity(elem)
if err != nil {
return result.Value{}, err
}
// We only support List<Quantity> where all the elements have the exact same unit, since we
// do not support mixed unit Quantity math in our engine yet.
if idx == 0 {
unit = v.Unit
}
if unit != v.Unit {
// TODO: b/342061715 - technically we should treat '' unit and '1' unit as the same, but
// for now we don't (and we should apply this globally).
return result.Value{}, fmt.Errorf("Median(List<Quantity>) operand has different units which is not supported, got %v and %v", unit, v.Unit)
}
values = append(values, v.Value)
}
if len(values) == 0 {
return result.New(nil)
}
median := calculateMedianFloat64(values)
return result.New(result.Quantity{Value: median, Unit: unit})
}

// calculateMedianFloat64 calculates the median of a slice of float64 values.
// This modifies the values slice in place while sorting it.
func calculateMedianFloat64(values []float64) float64 {
sort.Float64s(values)
mid := len(values) / 2
if len(values)%2 == 0 {
return (values[mid-1] + values[mid]) / 2
}
return values[mid]
}

// Sum(argument List<Decimal>) Decimal
// Sum(argument List<Integer>) Integer
// Sum(argument List<Long>) Long
Expand Down
11 changes: 11 additions & 0 deletions interpreter/operator_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,17 @@ func (i *interpreter) unaryOverloads(m model.IUnaryExpression) ([]convert.Overlo
Result: i.evalSum,
},
}, nil
case *model.Median:
return []convert.Overload[evalUnarySignature]{
{
Operands: []types.IType{&types.List{ElementType: types.Decimal}},
Result: i.evalMedianDecimal,
},
{
Operands: []types.IType{&types.List{ElementType: types.Quantity}},
Result: i.evalMedianQuantity,
},
}, nil
default:
return nil, fmt.Errorf("unsupported Unary Expression %v", m.GetName())
}
Expand Down
9 changes: 9 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,12 @@ type Min struct{ *UnaryExpression }
// far as we can tell.
type Sum struct{ *UnaryExpression }

// Median ELM expression from https://cql.hl7.org/09-b-cqlreference.html#median
// TODO: b/347346351 - In ELM it's modeled as an AggregateExpression, but for now we model it as an
// UnaryExpression since there is no way to set the AggregateExpression's "path" property for CQL as
// far as we can tell.
type Median struct{ *UnaryExpression }

// CalculateAge CQL expression type
type CalculateAge struct {
*UnaryExpression
Expand Down Expand Up @@ -1409,3 +1415,6 @@ func (a *Combine) GetName() string { return "Combine" }

// GetName returns the name of the system operator.
func (i *Indexer) GetName() string { return "Indexer" }

// GetName returns the name of the system operator.
func (m *Median) GetName() string { return "Median" }
17 changes: 16 additions & 1 deletion parser/operators.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (
"fmt"
"strings"

"github.com/antlr4-go/antlr/v4"
"github.com/google/cql/internal/convert"
"github.com/google/cql/model"
"github.com/google/cql/types"
"github.com/antlr4-go/antlr/v4"
)

// parseFunction uses the reference resolver to resolve the function, visits the operands, and sets
Expand Down Expand Up @@ -184,6 +184,9 @@ func (v *visitor) resolveFunction(libraryName, funcName string, operands []model
// The operands should be AgeInYearsAt(convertedBirthDate)
resolved.WrappedOperands = []model.IExpression{res.WrappedOperand, resolved.WrappedOperands[0]}
}
case *model.Median:
listType := resolved.WrappedOperands[0].GetResultType().(*types.List)
t.Expression = model.ResultType(listType.ElementType)
}

// Set Operands.
Expand Down Expand Up @@ -1899,6 +1902,18 @@ func (p *Parser) loadSystemOperators() error {
return &model.Message{}
},
},
{
name: "Median",
operands: [][]types.IType{
{&types.List{ElementType: types.Decimal}},
{&types.List{ElementType: types.Quantity}},
},
model: func() model.IExpression {
return &model.Median{
UnaryExpression: &model.UnaryExpression{},
}
},
},
}

for _, b := range systemOperators {
Expand Down
27 changes: 27 additions & 0 deletions parser/operators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,33 @@ func TestBuiltInFunctions(t *testing.T) {
},
},
// AGGREGATE FUNCTIONS - https://cql.hl7.org/09-b-cqlreference.html#aggregate-functions
{
name: "Median Decimal",
cql: "Median({1.0, 2.0, 3.0})",
want: &model.Median{
UnaryExpression: &model.UnaryExpression{
Operand: model.NewList([]string{"1.0", "2.0", "3.0"}, types.Decimal),
Expression: model.ResultType(types.Decimal),
},
},
},
{
name: "Median Quantity",
cql: "Median({1.0 'cm', 2.0 'cm', 3.0 'cm'})",
want: &model.Median{
UnaryExpression: &model.UnaryExpression{
Operand: &model.List{
List: []model.IExpression{
&model.Quantity{Value: 1.0, Unit: "cm", Expression: model.ResultType(types.Quantity)},
&model.Quantity{Value: 2.0, Unit: "cm", Expression: model.ResultType(types.Quantity)},
&model.Quantity{Value: 3.0, Unit: "cm", Expression: model.ResultType(types.Quantity)},
},
Expression: model.ResultType(&types.List{ElementType: types.Quantity}),
},
Expression: model.ResultType(types.Quantity),
},
},
},
{
name: "Count",
cql: "Count({1, 2, 3})",
Expand Down
140 changes: 140 additions & 0 deletions tests/enginetests/operator_aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,143 @@ func TestSum_Error(t *testing.T) {
})
}
}

func TestMedian(t *testing.T) {
tests := []struct {
name string
cql string
wantModel model.IExpression
wantResult result.Value
}{
{
name: "Median({1.5, 2.5, 3.5, 4.5})",
cql: "Median({1.5, 2.5, 3.5, 4.5})",
wantModel: &model.Median{
UnaryExpression: &model.UnaryExpression{
Operand: model.NewList([]string{"1.5", "2.5", "3.5", "4.5"}, types.Decimal),
Expression: model.ResultType(types.Decimal),
},
},
wantResult: newOrFatal(t, 3.0),
},
{
name: "Median({1 'cm', 2 'cm', 3 'cm'})",
cql: "Median({1 'cm', 2 'cm', 3 'cm'})",
wantResult: newOrFatal(t, result.Quantity{Value: 2.0, Unit: "cm"}),
},
{
name: "Median({1.5 'g', 2.5 'g', 3.5 'g', 4.5 'g'})",
cql: "Median({1.5 'g', 2.5 'g', 3.5 'g', 4.5 'g'})",
wantResult: newOrFatal(t, result.Quantity{Value: 3.0, Unit: "g"}),
},
{
name: "Unordered Quantity list: Median({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})",
cql: "Median({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})",
wantResult: newOrFatal(t, result.Quantity{Value: 3.0, Unit: "g"}),
},
{
name: "Median({1.0, 2.0, 3.0})",
cql: "Median({1.0, 2.0, 3.0})",
wantResult: newOrFatal(t, 2.0),
},
{
name: "Median({1.5, 2.5, 3.5, 4.5})",
cql: "Median({1.5, 2.5, 3.5, 4.5})",
wantResult: newOrFatal(t, 3.0),
},
{
name: "Unordered Decimal list: Median({2.5, 3.5, 1.5, 4.5})",
cql: "Median({2.5, 3.5, 1.5, 4.5})",
wantResult: newOrFatal(t, 3.0),
},
{
name: "Median(List<Decimal>{})",
cql: "Median(List<Decimal>{})",
wantResult: newOrFatal(t, nil),
},
{
name: "Median({null as Decimal})",
cql: "Median({null as Decimal})",
wantResult: newOrFatal(t, nil),
},
{
name: "Median(null as List<Decimal>)",
cql: "Median(null as List<Decimal>)",
wantResult: newOrFatal(t, nil),
},
{
name: "Median(List<Quantity>{})",
cql: "Median(List<Quantity>{})",
wantResult: newOrFatal(t, nil),
},
{
name: "Median({null as Quantity})",
cql: "Median({null as Quantity})",
wantResult: newOrFatal(t, nil),
},
{
name: "Median(null as List<Quantity>)",
cql: "Median(null as List<Quantity>)",
wantResult: newOrFatal(t, nil),
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p := newFHIRParser(t)
parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{})
if err != nil {
t.Fatalf("Parse returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" {
t.Errorf("Parse diff (-want +got):\n%s", diff)
}

results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p))
if err != nil {
t.Fatalf("Eval returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" {
t.Errorf("Eval diff (-want +got)\n%v", diff)
}
})
}
}

func TestMedian_Error(t *testing.T) {
tests := []struct {
name string
cql string
wantModel model.IExpression
wantErrContains string
}{
{
name: "Median({1 'cm', 2 'g'})",
cql: "Median({1 'cm', 2 'g'})",
wantErrContains: "Median(List<Quantity>) operand has different units which is not supported",
},
{
name: "Median({1 '', 2 'g'})",
cql: "Median({1 '', 2 'g'})",
wantErrContains: "Median(List<Quantity>) operand has different units which is not supported",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p := newFHIRParser(t)
parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{})
if err != nil {
t.Fatalf("Parse returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" {
t.Errorf("Parse diff (-want +got):\n%s", diff)
}

_, err = interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p))
if err == nil || !strings.Contains(err.Error(), tc.wantErrContains) {
t.Errorf("Eval returned unexpected error: %v, want error containing %q", err, tc.wantErrContains)
}
})
}
}
1 change: 0 additions & 1 deletion tests/spectests/exclusions/exclusions.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func XMLTestFileExclusionDefinitions() map[string]XMLTestFileExclusions {
"CqlAggregateFunctionsTest.xml": XMLTestFileExclusions{
GroupExcludes: []string{
// TODO: b/342061715 - unsupported operators.
"Median",
"Mode",
"PopulationStdDev",
"PopulationVariance",
Expand Down

0 comments on commit 540a79c

Please sign in to comment.