From d231c682c738ce619f0d12d6e371dd5a9a3ee302 Mon Sep 17 00:00:00 2001 From: Evan Gordon Date: Fri, 15 Nov 2024 10:14:44 -0800 Subject: [PATCH] Implement PopulationStdDev functional operator. PiperOrigin-RevId: 696925769 --- interpreter/operator_aggregate.go | 108 ++++++++++++- interpreter/operator_dispatcher.go | 11 ++ model/model.go | 9 ++ parser/operators.go | 15 ++ parser/operators_test.go | 27 ++++ tests/enginetests/operator_aggregate_test.go | 151 +++++++++++++++++++ tests/spectests/exclusions/exclusions.go | 3 +- 7 files changed, 322 insertions(+), 2 deletions(-) diff --git a/interpreter/operator_aggregate.go b/interpreter/operator_aggregate.go index b23b272..5d51a29 100644 --- a/interpreter/operator_aggregate.go +++ b/interpreter/operator_aggregate.go @@ -16,6 +16,7 @@ package interpreter import ( "fmt" + "math" "sort" "github.com/google/cql/model" @@ -144,7 +145,7 @@ func (i *interpreter) evalAvg(m model.IUnaryExpression, operand result.Value) (r // Count(argument List) Integer // https://cql.hl7.org/09-b-cqlreference.html#count -func (i *interpreter) evalCount(m model.IUnaryExpression, operand result.Value) (result.Value, error) { +func (i *interpreter) evalCount(_ model.IUnaryExpression, operand result.Value) (result.Value, error) { if result.IsNull(operand) { return result.New(0) } @@ -348,6 +349,111 @@ func calculateMedianFloat64(values []float64) float64 { return values[mid] } +// PopulationStdDev(argument List) Decimal +// sqrt(sum((v - mean)^2) / count) +// https://cql.hl7.org/09-b-cqlreference.html#population-stddev +func (i *interpreter) evalPopulationStdDevDecimal(m model.IUnaryExpression, operand result.Value) (result.Value, error) { + if result.IsNull(operand) { + return result.New(nil) + } + l, err := result.ToSlice(operand) + if err != nil { + return result.Value{}, err + } + + countValue, err := i.evalCount(m, operand) + if err != nil { + return result.Value{}, err + } + if result.IsNull(countValue) { + return result.New(nil) + } + count, err := result.ToInt32(countValue) + if err != nil { + return result.Value{}, err + } + if count == 0 { + return result.New(nil) + } + meanValue, err := i.evalAvg(m, operand) + if err != nil { + return result.Value{}, err + } + if result.IsNull(meanValue) { + return result.New(nil) + } + mean, err := result.ToFloat64(meanValue) + if err != nil { + return result.Value{}, err + } + var sum float64 + for _, elem := range l { + if result.IsNull(elem) { + continue + } + v, err := result.ToFloat64(elem) + if err != nil { + return result.Value{}, err + } + sum += (v - mean) * (v - mean) + } + return result.New(math.Sqrt(sum / float64(count))) +} + +// PopulationStdDev(argument List) Quantity +// sqrt(sum((v - mean)^2) / count) +// https://cql.hl7.org/09-b-cqlreference.html#population-stddev +func (i *interpreter) evalPopulationStdDevQuantity(m model.IUnaryExpression, operand result.Value) (result.Value, error) { + if result.IsNull(operand) { + return result.New(nil) + } + l, err := result.ToSlice(operand) + if err != nil { + return result.Value{}, err + } + + countValue, err := i.evalCount(m, operand) + if err != nil { + return result.Value{}, err + } + if result.IsNull(countValue) { + return result.New(nil) + } + count, err := result.ToInt32(countValue) + if err != nil { + return result.Value{}, err + } + if count == 0 { + return result.New(nil) + } + meanValue, err := i.evalAvg(m, operand) + if err != nil { + return result.Value{}, err + } + if result.IsNull(meanValue) { + return result.New(nil) + } + mean, err := result.ToQuantity(meanValue) + if err != nil { + return result.Value{}, err + } + var sum float64 + for _, elem := range l { + if result.IsNull(elem) { + continue + } + v, err := result.ToQuantity(elem) + if err != nil { + return result.Value{}, err + } + if v.Unit != mean.Unit { + return result.Value{}, fmt.Errorf("PopulationStdDev(List) operand has different units which is not supported, got %v and %v", v.Unit, mean.Unit) + } + sum += (v.Value - mean.Value) * (v.Value - mean.Value) + } + return result.New(result.Quantity{Value: math.Sqrt(sum / float64(count)), Unit: mean.Unit}) +} + // Sum(argument List) Decimal // Sum(argument List) Integer // Sum(argument List) Long diff --git a/interpreter/operator_dispatcher.go b/interpreter/operator_dispatcher.go index 5a83031..6113225 100644 --- a/interpreter/operator_dispatcher.go +++ b/interpreter/operator_dispatcher.go @@ -576,6 +576,17 @@ func (i *interpreter) unaryOverloads(m model.IUnaryExpression) ([]convert.Overlo Result: i.evalMedianQuantity, }, }, nil + case *model.PopulationStdDev: + return []convert.Overload[evalUnarySignature]{ + { + Operands: []types.IType{&types.List{ElementType: types.Decimal}}, + Result: i.evalPopulationStdDevDecimal, + }, + { + Operands: []types.IType{&types.List{ElementType: types.Quantity}}, + Result: i.evalPopulationStdDevQuantity, + }, + }, nil default: return nil, fmt.Errorf("unsupported Unary Expression %v", m.GetName()) } diff --git a/model/model.go b/model/model.go index 7fe825a..415872d 100644 --- a/model/model.go +++ b/model/model.go @@ -820,6 +820,12 @@ type Sum struct{ *UnaryExpression } // far as we can tell. type Median struct{ *UnaryExpression } +// PopulationStdDev ELM expression from https://cql.hl7.org/09-b-cqlreference.html#population-stddev +// TODO: b/347346351 - In ELM it's modeled as an AggregateExpression, but for now we model it as an +// UnaryExpression since there is no way to set the AggregateExpression's "path" property for CQL as +// far as we can tell. +type PopulationStdDev struct{ *UnaryExpression } + // CalculateAge CQL expression type type CalculateAge struct { *UnaryExpression @@ -1437,3 +1443,6 @@ func (i *Indexer) GetName() string { return "Indexer" } // GetName returns the name of the system operator. func (m *Median) GetName() string { return "Median" } + +// GetName returns the name of the system operator. +func (m *PopulationStdDev) GetName() string { return "PopulationStdDev" } diff --git a/parser/operators.go b/parser/operators.go index 819e7fd..82978ce 100644 --- a/parser/operators.go +++ b/parser/operators.go @@ -187,6 +187,9 @@ func (v *visitor) resolveFunction(libraryName, funcName string, operands []model case *model.Median: listType := resolved.WrappedOperands[0].GetResultType().(*types.List) t.Expression = model.ResultType(listType.ElementType) + case *model.PopulationStdDev: + listType := resolved.WrappedOperands[0].GetResultType().(*types.List) + t.Expression = model.ResultType(listType.ElementType) } // Set Operands. @@ -1936,6 +1939,18 @@ func (p *Parser) loadSystemOperators() error { } }, }, + { + name: "PopulationStdDev", + operands: [][]types.IType{ + {&types.List{ElementType: types.Decimal}}, + {&types.List{ElementType: types.Quantity}}, + }, + model: func() model.IExpression { + return &model.PopulationStdDev{ + UnaryExpression: &model.UnaryExpression{}, + } + }, + }, } for _, b := range systemOperators { diff --git a/parser/operators_test.go b/parser/operators_test.go index 955e22c..4b150d9 100644 --- a/parser/operators_test.go +++ b/parser/operators_test.go @@ -1279,6 +1279,33 @@ func TestBuiltInFunctions(t *testing.T) { }, }, }, + { + name: "PopulationStdDev Decimal", + cql: "PopulationStdDev({1.0, 2.0, 3.0})", + want: &model.PopulationStdDev{ + UnaryExpression: &model.UnaryExpression{ + Operand: model.NewList([]string{"1.0", "2.0", "3.0"}, types.Decimal), + Expression: model.ResultType(types.Decimal), + }, + }, + }, + { + name: "PopulationStdDev Quantity", + cql: "PopulationStdDev({1.0 'cm', 2.0 'cm', 3.0 'cm'})", + want: &model.PopulationStdDev{ + UnaryExpression: &model.UnaryExpression{ + Operand: &model.List{ + List: []model.IExpression{ + &model.Quantity{Value: 1.0, Unit: "cm", Expression: model.ResultType(types.Quantity)}, + &model.Quantity{Value: 2.0, Unit: "cm", Expression: model.ResultType(types.Quantity)}, + &model.Quantity{Value: 3.0, Unit: "cm", Expression: model.ResultType(types.Quantity)}, + }, + Expression: model.ResultType(&types.List{ElementType: types.Quantity}), + }, + Expression: model.ResultType(types.Quantity), + }, + }, + }, { name: "Count", cql: "Count({1, 2, 3})", diff --git a/tests/enginetests/operator_aggregate_test.go b/tests/enginetests/operator_aggregate_test.go index 05f24fb..8b30ea3 100644 --- a/tests/enginetests/operator_aggregate_test.go +++ b/tests/enginetests/operator_aggregate_test.go @@ -705,3 +705,154 @@ func TestMedian_Error(t *testing.T) { }) } } + +func TestPopulationStdDev(t *testing.T) { + tests := []struct { + name string + cql string + wantModel model.IExpression + wantResult result.Value + }{ + // Decimal cases - Round is added to the cql to avoid float point comparison issues. + { + name: "PopulationStdDev({1.5, 2.5, 3.5, 4.5})", + cql: "Round(PopulationStdDev({1.5, 2.5, 3.5, 4.5}), 3)", + wantModel: &model.Round{ + NaryExpression: &model.NaryExpression{ + Operands: []model.IExpression{ + &model.PopulationStdDev{ + UnaryExpression: &model.UnaryExpression{ + Operand: model.NewList([]string{"1.5", "2.5", "3.5", "4.5"}, types.Decimal), + Expression: model.ResultType(types.Decimal), + }, + }, + model.NewLiteral("3", types.Integer), + }, + Expression: model.ResultType(types.Decimal), + }, + }, + wantResult: newOrFatal(t, 1.118), + }, + { + name: "PopulationStdDev({1.0, 2.0, 3.0, 4.0, 5.0})", + cql: "Round(PopulationStdDev({1.0, 2.0, 3.0, 4.0, 5.0}), 3)", + wantResult: newOrFatal(t, 1.414), + }, + { + name: "Unordered Decimal list: PopulationStdDev({2.5, 3.5, 1.5, 4.5})", + cql: "Round(PopulationStdDev({2.5, 3.5, 1.5, 4.5}), 3)", + wantResult: newOrFatal(t, 1.118), + }, + // Quantity cases - Round is added and quantity values is unwrapped to avoid float point + // comparison issues. + { + name: "PopulationStdDev({1 'cm', 2 'cm', 3 'cm', 4 'cm', 5 'cm'})", + cql: "Round(PopulationStdDev({1 'cm', 2 'cm', 3 'cm', 4 'cm', 5 'cm'}).value, 3)", + wantResult: newOrFatal(t, 1.414), + }, + { + name: "PopulationStdDev({1.5 'g', 2.5 'g', 3.5 'g', 4.5 'g'})", + cql: "Round(PopulationStdDev({1.5 'g', 2.5 'g', 3.5 'g', 4.5 'g'}).value, 3)", + wantResult: newOrFatal(t, 1.118), + }, + { + name: "Unordered Quantity list: PopulationStdDev({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})", + cql: "Round(PopulationStdDev({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'}).value, 3)", + wantResult: newOrFatal(t, 1.118), + }, + { + name: "PopulationStdDev({1 'cm', 3 'cm'}) simplified case with no rounding", + cql: "PopulationStdDev({1 'cm', 3 'cm'})", + wantResult: newOrFatal(t, result.Quantity{Value: 1.0, Unit: "cm"}), + }, + { + name: "PopulationStdDev(List{})", + cql: "PopulationStdDev(List{})", + wantResult: newOrFatal(t, nil), + }, + { + name: "PopulationStdDev({null as Decimal})", + cql: "PopulationStdDev({null as Decimal})", + wantResult: newOrFatal(t, nil), + }, + { + name: "PopulationStdDev(null as List)", + cql: "PopulationStdDev(null as List)", + wantResult: newOrFatal(t, nil), + }, + { + name: "PopulationStdDev(List{})", + cql: "PopulationStdDev(List{})", + wantResult: newOrFatal(t, nil), + }, + { + name: "PopulationStdDev({null as Quantity})", + cql: "PopulationStdDev({null as Quantity})", + wantResult: newOrFatal(t, nil), + }, + { + name: "PopulationStdDev(null as List)", + cql: "PopulationStdDev(null as List)", + wantResult: newOrFatal(t, nil), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" { + t.Errorf("Parse diff (-want +got):\n%s", diff) + } + + results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if err != nil { + t.Fatalf("Eval returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" { + t.Errorf("Eval diff (-want +got)\n%v", diff) + } + }) + } +} + +func TestPopulationStdDev_Error(t *testing.T) { + tests := []struct { + name string + cql string + wantModel model.IExpression + wantErrContains string + }{ + { + name: "PopulationStdDev({1 'cm', 2 'g'})", + cql: "PopulationStdDev({1 'cm', 2 'g'})", + wantErrContains: "operand has different units which is not supported", + }, + { + name: "PopulationStdDev({1 '', 2 'g'})", + cql: "PopulationStdDev({1 '', 2 'g'})", + wantErrContains: "operand has different units which is not supported", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" { + t.Errorf("Parse diff (-want +got):\n%s", diff) + } + + _, err = interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if err == nil || !strings.Contains(err.Error(), tc.wantErrContains) { + t.Errorf("Eval returned unexpected error: %v, want error containing %q", err, tc.wantErrContains) + } + }) + } +} diff --git a/tests/spectests/exclusions/exclusions.go b/tests/spectests/exclusions/exclusions.go index 8bc3a03..012c620 100644 --- a/tests/spectests/exclusions/exclusions.go +++ b/tests/spectests/exclusions/exclusions.go @@ -31,7 +31,6 @@ func XMLTestFileExclusionDefinitions() map[string]XMLTestFileExclusions { GroupExcludes: []string{ // TODO: b/342061715 - unsupported operators. "Mode", - "PopulationStdDev", "PopulationVariance", "StdDev", "Variance", @@ -45,6 +44,8 @@ func XMLTestFileExclusionDefinitions() map[string]XMLTestFileExclusions { "MinTestInteger", "MinTestString", "MinTestTime", + // TODO: b/342061783 - Operator is supported but the test assertion uses a rounded value. + "PopStdDevTest1", }, }, "CqlAggregateTest.xml": XMLTestFileExclusions{