Skip to content

Commit

Permalink
Implement PopulationStdDev functional operator.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666922731
  • Loading branch information
evan-gordon authored and copybara-github committed Nov 15, 2024
1 parent 61eae1e commit c2bdd83
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 2 deletions.
108 changes: 107 additions & 1 deletion interpreter/operator_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package interpreter

import (
"fmt"
"math"
"sort"

"github.com/google/cql/model"
Expand Down Expand Up @@ -144,7 +145,7 @@ func (i *interpreter) evalAvg(m model.IUnaryExpression, operand result.Value) (r

// Count(argument List<T>) Integer
// https://cql.hl7.org/09-b-cqlreference.html#count
func (i *interpreter) evalCount(m model.IUnaryExpression, operand result.Value) (result.Value, error) {
func (i *interpreter) evalCount(_ model.IUnaryExpression, operand result.Value) (result.Value, error) {
if result.IsNull(operand) {
return result.New(0)
}
Expand Down Expand Up @@ -348,6 +349,111 @@ func calculateMedianFloat64(values []float64) float64 {
return values[mid]
}

// PopulationStdDev(argument List<Decimal>) Decimal
// sqrt(sum((v - mean)^2) / count)
// https://cql.hl7.org/09-b-cqlreference.html#population-stddev
func (i *interpreter) evalPopulationStdDevDecimal(m model.IUnaryExpression, operand result.Value) (result.Value, error) {
if result.IsNull(operand) {
return result.New(nil)
}
l, err := result.ToSlice(operand)
if err != nil {
return result.Value{}, err
}

countValue, err := i.evalCount(m, operand)
if err != nil {
return result.Value{}, err
}
if result.IsNull(countValue) {
return result.New(nil)
}
count, err := result.ToInt32(countValue)
if err != nil {
return result.Value{}, err
}
if count == 0 {
return result.New(nil)
}
meanValue, err := i.evalAvg(m, operand)
if err != nil {
return result.Value{}, err
}
if result.IsNull(meanValue) {
return result.New(nil)
}
mean, err := result.ToFloat64(meanValue)
if err != nil {
return result.Value{}, err
}
var sum float64
for _, elem := range l {
if result.IsNull(elem) {
continue
}
v, err := result.ToFloat64(elem)
if err != nil {
return result.Value{}, err
}
sum += (v - mean) * (v - mean)
}
return result.New(math.Sqrt(sum / float64(count)))
}

// PopulationStdDev(argument List<Quantity>) Quantity
// sqrt(sum((v - mean)^2) / count)
// https://cql.hl7.org/09-b-cqlreference.html#population-stddev
func (i *interpreter) evalPopulationStdDevQuantity(m model.IUnaryExpression, operand result.Value) (result.Value, error) {
if result.IsNull(operand) {
return result.New(nil)
}
l, err := result.ToSlice(operand)
if err != nil {
return result.Value{}, err
}

countValue, err := i.evalCount(m, operand)
if err != nil {
return result.Value{}, err
}
if result.IsNull(countValue) {
return result.New(nil)
}
count, err := result.ToInt32(countValue)
if err != nil {
return result.Value{}, err
}
if count == 0 {
return result.New(nil)
}
meanValue, err := i.evalAvg(m, operand)
if err != nil {
return result.Value{}, err
}
if result.IsNull(meanValue) {
return result.New(nil)
}
mean, err := result.ToQuantity(meanValue)
if err != nil {
return result.Value{}, err
}
var sum float64
for _, elem := range l {
if result.IsNull(elem) {
continue
}
v, err := result.ToQuantity(elem)
if err != nil {
return result.Value{}, err
}
if v.Unit != mean.Unit {
return result.Value{}, fmt.Errorf("PopulationStdDev(List<Quantity>) operand has different units which is not supported, got %v and %v", v.Unit, mean.Unit)
}
sum += (v.Value - mean.Value) * (v.Value - mean.Value)
}
return result.New(result.Quantity{Value: math.Sqrt(sum / float64(count)), Unit: mean.Unit})
}

// Sum(argument List<Decimal>) Decimal
// Sum(argument List<Integer>) Integer
// Sum(argument List<Long>) Long
Expand Down
11 changes: 11 additions & 0 deletions interpreter/operator_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,17 @@ func (i *interpreter) unaryOverloads(m model.IUnaryExpression) ([]convert.Overlo
Result: i.evalMedianQuantity,
},
}, nil
case *model.PopulationStdDev:
return []convert.Overload[evalUnarySignature]{
{
Operands: []types.IType{&types.List{ElementType: types.Decimal}},
Result: i.evalPopulationStdDevDecimal,
},
{
Operands: []types.IType{&types.List{ElementType: types.Quantity}},
Result: i.evalPopulationStdDevQuantity,
},
}, nil
default:
return nil, fmt.Errorf("unsupported Unary Expression %v", m.GetName())
}
Expand Down
9 changes: 9 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,12 @@ type Sum struct{ *UnaryExpression }
// far as we can tell.
type Median struct{ *UnaryExpression }

// PopulationStdDev ELM expression from https://cql.hl7.org/09-b-cqlreference.html#population-stddev
// TODO: b/347346351 - In ELM it's modeled as an AggregateExpression, but for now we model it as an
// UnaryExpression since there is no way to set the AggregateExpression's "path" property for CQL as
// far as we can tell.
type PopulationStdDev struct{ *UnaryExpression }

// CalculateAge CQL expression type
type CalculateAge struct {
*UnaryExpression
Expand Down Expand Up @@ -1437,3 +1443,6 @@ func (i *Indexer) GetName() string { return "Indexer" }

// GetName returns the name of the system operator.
func (m *Median) GetName() string { return "Median" }

// GetName returns the name of the system operator.
func (m *PopulationStdDev) GetName() string { return "PopulationStdDev" }
15 changes: 15 additions & 0 deletions parser/operators.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ func (v *visitor) resolveFunction(libraryName, funcName string, operands []model
case *model.Median:
listType := resolved.WrappedOperands[0].GetResultType().(*types.List)
t.Expression = model.ResultType(listType.ElementType)
case *model.PopulationStdDev:
listType := resolved.WrappedOperands[0].GetResultType().(*types.List)
t.Expression = model.ResultType(listType.ElementType)
}

// Set Operands.
Expand Down Expand Up @@ -1936,6 +1939,18 @@ func (p *Parser) loadSystemOperators() error {
}
},
},
{
name: "PopulationStdDev",
operands: [][]types.IType{
{&types.List{ElementType: types.Decimal}},
{&types.List{ElementType: types.Quantity}},
},
model: func() model.IExpression {
return &model.PopulationStdDev{
UnaryExpression: &model.UnaryExpression{},
}
},
},
}

for _, b := range systemOperators {
Expand Down
27 changes: 27 additions & 0 deletions parser/operators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,33 @@ func TestBuiltInFunctions(t *testing.T) {
},
},
},
{
name: "PopulationStdDev Decimal",
cql: "PopulationStdDev({1.0, 2.0, 3.0})",
want: &model.PopulationStdDev{
UnaryExpression: &model.UnaryExpression{
Operand: model.NewList([]string{"1.0", "2.0", "3.0"}, types.Decimal),
Expression: model.ResultType(types.Decimal),
},
},
},
{
name: "PopulationStdDev Quantity",
cql: "PopulationStdDev({1.0 'cm', 2.0 'cm', 3.0 'cm'})",
want: &model.PopulationStdDev{
UnaryExpression: &model.UnaryExpression{
Operand: &model.List{
List: []model.IExpression{
&model.Quantity{Value: 1.0, Unit: "cm", Expression: model.ResultType(types.Quantity)},
&model.Quantity{Value: 2.0, Unit: "cm", Expression: model.ResultType(types.Quantity)},
&model.Quantity{Value: 3.0, Unit: "cm", Expression: model.ResultType(types.Quantity)},
},
Expression: model.ResultType(&types.List{ElementType: types.Quantity}),
},
Expression: model.ResultType(types.Quantity),
},
},
},
{
name: "Count",
cql: "Count({1, 2, 3})",
Expand Down
151 changes: 151 additions & 0 deletions tests/enginetests/operator_aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -705,3 +705,154 @@ func TestMedian_Error(t *testing.T) {
})
}
}

func TestPopulationStdDev(t *testing.T) {
tests := []struct {
name string
cql string
wantModel model.IExpression
wantResult result.Value
}{
// Decimal cases - Round is added to the cql to avoid float point comparison issues.
{
name: "PopulationStdDev({1.5, 2.5, 3.5, 4.5})",
cql: "Round(PopulationStdDev({1.5, 2.5, 3.5, 4.5}), 3)",
wantModel: &model.Round{
NaryExpression: &model.NaryExpression{
Operands: []model.IExpression{
&model.PopulationStdDev{
UnaryExpression: &model.UnaryExpression{
Operand: model.NewList([]string{"1.5", "2.5", "3.5", "4.5"}, types.Decimal),
Expression: model.ResultType(types.Decimal),
},
},
model.NewLiteral("3", types.Integer),
},
Expression: model.ResultType(types.Decimal),
},
},
wantResult: newOrFatal(t, 1.118),
},
{
name: "PopulationStdDev({1.0, 2.0, 3.0, 4.0, 5.0})",
cql: "Round(PopulationStdDev({1.0, 2.0, 3.0, 4.0, 5.0}), 3)",
wantResult: newOrFatal(t, 1.414),
},
{
name: "Unordered Decimal list: PopulationStdDev({2.5, 3.5, 1.5, 4.5})",
cql: "Round(PopulationStdDev({2.5, 3.5, 1.5, 4.5}), 3)",
wantResult: newOrFatal(t, 1.118),
},
// Quantity cases - Round is added and quantity values is unwrapped to avoid float point
// comparison issues.
{
name: "PopulationStdDev({1 'cm', 2 'cm', 3 'cm', 4 'cm', 5 'cm'})",
cql: "Round(PopulationStdDev({1 'cm', 2 'cm', 3 'cm', 4 'cm', 5 'cm'}).value, 3)",
wantResult: newOrFatal(t, 1.414),
},
{
name: "PopulationStdDev({1.5 'g', 2.5 'g', 3.5 'g', 4.5 'g'})",
cql: "Round(PopulationStdDev({1.5 'g', 2.5 'g', 3.5 'g', 4.5 'g'}).value, 3)",
wantResult: newOrFatal(t, 1.118),
},
{
name: "Unordered Quantity list: PopulationStdDev({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'})",
cql: "Round(PopulationStdDev({2.5 'g', 3.5 'g', 1.5 'g', 4.5 'g'}).value, 3)",
wantResult: newOrFatal(t, 1.118),
},
{
name: "PopulationStdDev({1 'cm', 3 'cm'}) simplified case with no rounding",
cql: "PopulationStdDev({1 'cm', 3 'cm'})",
wantResult: newOrFatal(t, result.Quantity{Value: 1.0, Unit: "cm"}),
},
{
name: "PopulationStdDev(List<Decimal>{})",
cql: "PopulationStdDev(List<Decimal>{})",
wantResult: newOrFatal(t, nil),
},
{
name: "PopulationStdDev({null as Decimal})",
cql: "PopulationStdDev({null as Decimal})",
wantResult: newOrFatal(t, nil),
},
{
name: "PopulationStdDev(null as List<Decimal>)",
cql: "PopulationStdDev(null as List<Decimal>)",
wantResult: newOrFatal(t, nil),
},
{
name: "PopulationStdDev(List<Quantity>{})",
cql: "PopulationStdDev(List<Quantity>{})",
wantResult: newOrFatal(t, nil),
},
{
name: "PopulationStdDev({null as Quantity})",
cql: "PopulationStdDev({null as Quantity})",
wantResult: newOrFatal(t, nil),
},
{
name: "PopulationStdDev(null as List<Quantity>)",
cql: "PopulationStdDev(null as List<Quantity>)",
wantResult: newOrFatal(t, nil),
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p := newFHIRParser(t)
parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{})
if err != nil {
t.Fatalf("Parse returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" {
t.Errorf("Parse diff (-want +got):\n%s", diff)
}

results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p))
if err != nil {
t.Fatalf("Eval returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" {
t.Errorf("Eval diff (-want +got)\n%v", diff)
}
})
}
}

func TestPopulationStdDev_Error(t *testing.T) {
tests := []struct {
name string
cql string
wantModel model.IExpression
wantErrContains string
}{
{
name: "PopulationStdDev({1 'cm', 2 'g'})",
cql: "PopulationStdDev({1 'cm', 2 'g'})",
wantErrContains: "operand has different units which is not supported",
},
{
name: "PopulationStdDev({1 '', 2 'g'})",
cql: "PopulationStdDev({1 '', 2 'g'})",
wantErrContains: "operand has different units which is not supported",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p := newFHIRParser(t)
parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{})
if err != nil {
t.Fatalf("Parse returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" {
t.Errorf("Parse diff (-want +got):\n%s", diff)
}

_, err = interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p))
if err == nil || !strings.Contains(err.Error(), tc.wantErrContains) {
t.Errorf("Eval returned unexpected error: %v, want error containing %q", err, tc.wantErrContains)
}
})
}
}
Loading

0 comments on commit c2bdd83

Please sign in to comment.