diff --git a/interpreter/operator_arithmetic.go b/interpreter/operator_arithmetic.go index 63e3d60..f7eb0e4 100644 --- a/interpreter/operator_arithmetic.go +++ b/interpreter/operator_arithmetic.go @@ -603,6 +603,55 @@ func evalNegateQuantity(m model.IUnaryExpression, obj result.Value) (result.Valu return result.New(val) } +// Round(argument Decimal) Decimal +// Round(argument Decimal, argument Integer) Decimal +// https://cql.hl7.org/09-b-cqlreference.html#round +// If a precision is specified but is null then the default precision is 0. +// If a precision is specified but is negative then an error is returned. This is technically +// undefined behavior in the CQL spec, but we choose to throw an error here. +func evalRound(_ model.INaryExpression, operands []result.Value) (result.Value, error) { + // if len(operands) == 1 { + // return roundValue(operands[0]) + // } + decimalVal := operands[0] + var precisionVal result.Value + var err error + // retrieve the precision if it exists, otherwise default to 0. + if len(operands) == 2 { + precisionVal = operands[1] + } else { + precisionVal, err = result.New(0) + if err != nil { + return result.Value{}, err + } + } + if result.IsNull(decimalVal) { + return result.New(nil) + } + + p, err := result.ToInt32(precisionVal) + if err != nil { + p = 0 + } + if p < 0 { + return result.Value{}, fmt.Errorf("internal error - precision must be non-negative, got %v", p) + } + d, err := result.ToFloat64(decimalVal) + if err != nil { + return result.Value{}, err + } + ratio := math.Pow10(int(p)) + // CQL currently implements its own special version of rounding for now (which will be changed in + // the future). For now if the value is negative we round towards zero. + ratioedDecimal := d * ratio + _, frac := math.Modf(ratioedDecimal) + if frac == -0.5 { + // force go to round towards zero + ratioedDecimal += 0.1 + } + return result.New(math.Round(ratioedDecimal) / ratio) +} + // predecessor of(obj T) T // https://cql.hl7.org/09-b-cqlreference.html#predecessor func (i *interpreter) evalPredecessor(m model.IUnaryExpression, obj result.Value) (result.Value, error) { diff --git a/interpreter/operator_dispatcher.go b/interpreter/operator_dispatcher.go index a9924fc..5a83031 100644 --- a/interpreter/operator_dispatcher.go +++ b/interpreter/operator_dispatcher.go @@ -1049,6 +1049,17 @@ func (i *interpreter) naryOverloads(m model.INaryExpression) ([]convert.Overload Result: i.evalCombine, }, }, nil + case *model.Round: + return []convert.Overload[evalNarySignature]{ + { + Operands: []types.IType{types.Decimal}, + Result: evalRound, + }, + { + Operands: []types.IType{types.Decimal, types.Integer}, + Result: evalRound, + }, + }, nil default: return nil, fmt.Errorf("unsupported Nary Expression %v", m.GetName()) } diff --git a/model/model.go b/model/model.go index 0bf6320..7fe825a 100644 --- a/model/model.go +++ b/model/model.go @@ -1067,6 +1067,9 @@ type DateTime struct{ *NaryExpression } // one of those at that point. type Now struct{ *NaryExpression } +// Round ELM Expression https://cql.hl7.org/04-logicalspecification.html#round +type Round struct{ *NaryExpression } + // TimeOfDay is https://cql.hl7.org/04-logicalspecification.html#timeofday // Note: in the future we may implement the OperatorExpression, and should convert this to // one of those at that point. @@ -1325,6 +1328,9 @@ func (a *Modulo) GetName() string { return "Modulo" } // GetName returns the name of the system operator. func (a *Power) GetName() string { return "Power" } +// GetName returns the name of the system operator. +func (a *Round) GetName() string { return "Round" } + // GetName returns the name of the system operator. func (a *TruncatedDivide) GetName() string { return "TruncatedDivide" } diff --git a/parser/operators.go b/parser/operators.go index ebdf398..819e7fd 100644 --- a/parser/operators.go +++ b/parser/operators.go @@ -19,10 +19,10 @@ import ( "fmt" "strings" - "github.com/antlr4-go/antlr/v4" "github.com/google/cql/internal/convert" "github.com/google/cql/model" "github.com/google/cql/types" + "github.com/antlr4-go/antlr/v4" ) // parseFunction uses the reference resolver to resolve the function, visits the operands, and sets @@ -881,6 +881,28 @@ func (p *Parser) loadSystemOperators() error { } }, }, + { + name: "Round", + operands: [][]types.IType{{types.Decimal, types.Integer}}, + model: func() model.IExpression { + return &model.Round{ + NaryExpression: &model.NaryExpression{ + Expression: model.ResultType(types.Decimal), + }, + } + }, + }, + { + name: "Round", + operands: [][]types.IType{{types.Decimal}}, + model: func() model.IExpression { + return &model.Round{ + NaryExpression: &model.NaryExpression{ + Expression: model.ResultType(types.Decimal), + }, + } + }, + }, { name: "Successor", operands: [][]types.IType{ diff --git a/parser/operators_test.go b/parser/operators_test.go index 2307906..955e22c 100644 --- a/parser/operators_test.go +++ b/parser/operators_test.go @@ -615,6 +615,31 @@ func TestBuiltInFunctions(t *testing.T) { }, }, }, + { + name: "Round 1.42", + cql: "Round(1.42)", + want: &model.Round{ + NaryExpression: &model.NaryExpression{ + Operands: []model.IExpression{ + model.NewLiteral("1.42", types.Decimal), + }, + Expression: model.ResultType(types.Decimal), + }, + }, + }, + { + name: "Round 3.14159 to 3 decimal places", + cql: "Round(3.14159, 3)", + want: &model.Round{ + NaryExpression: &model.NaryExpression{ + Operands: []model.IExpression{ + model.NewLiteral("3.14159", types.Decimal), + model.NewLiteral("3", types.Integer), + }, + Expression: model.ResultType(types.Decimal), + }, + }, + }, { name: "Predecessor for Date", cql: "Predecessor(@2023)", diff --git a/tests/enginetests/operator_arithmetic_test.go b/tests/enginetests/operator_arithmetic_test.go index e5d7a6d..4404951 100644 --- a/tests/enginetests/operator_arithmetic_test.go +++ b/tests/enginetests/operator_arithmetic_test.go @@ -973,6 +973,147 @@ func TestMultiply(t *testing.T) { } } +func TestRound(t *testing.T) { + tests := []struct { + name string + cql string + wantModel model.IExpression + wantResult result.Value + }{ + { + name: "Simple", + cql: "Round(42.101)", + wantResult: newOrFatal(t, 42.0), + }, + { + name: "Negative decimal", + cql: "Round(-101.42)", + wantResult: newOrFatal(t, -101.0), + }, + { + name: "Integers", + cql: "Round(2)", + wantModel: &model.Round{ + NaryExpression: &model.NaryExpression{ + Operands: []model.IExpression{ + &model.ToDecimal{ + UnaryExpression: &model.UnaryExpression{ + Operand: model.NewLiteral("2", types.Integer), + Expression: model.ResultType(types.Decimal), + }, + }, + }, + Expression: model.ResultType(types.Decimal), + }, + }, + wantResult: newOrFatal(t, 2.0), + }, + { + name: "Negative, round up", + cql: "Round(-0.5)", + wantResult: newOrFatal(t, 0.0), + }, + { + name: "Negative, round down", + cql: "Round(-0.6)", + wantResult: newOrFatal(t, -1.0), + }, + { + name: "Zero", + cql: "Round(0.0)", + wantResult: newOrFatal(t, 0.0), + }, + { + name: "Null", + cql: "Round(null as Decimal)", + wantResult: newOrFatal(t, nil), + }, + // With precision + { + name: "Simple with precision", + cql: "Round(42.101, 1)", + wantResult: newOrFatal(t, 42.1), + }, + { + name: "Negative decimal with precision round up", + cql: "Round(-101.45, 1)", + wantResult: newOrFatal(t, -101.4), + }, + { + name: "Negative decimal with precision round down", + cql: "Round(-101.46, 1)", + wantResult: newOrFatal(t, -101.5), + }, + { + name: "Precision is 0", + cql: "Round(2.123, 0)", + wantResult: newOrFatal(t, 2.0), + }, + { + name: "Precision is null", + cql: "Round(2.123, null)", + wantResult: newOrFatal(t, 2.0), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" { + t.Errorf("Parse diff (-want +got):\n%s", diff) + } + + results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if err != nil { + t.Fatalf("Eval returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" { + t.Errorf("Eval diff (-want +got)\n%v", diff) + } + }) + } +} + +func TestRound_EvalErrors(t *testing.T) { + tests := []struct { + name string + cql string + wantModel model.IExpression + wantEvalErrContains string + }{ + { + name: "Round with a negative precision", + cql: "Round(2.123, -1)", + wantEvalErrContains: "precision must be non-negative", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newFHIRParser(t) + parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{}) + if err != nil { + t.Fatalf("Parse returned unexpected error: %v", err) + } + if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" { + t.Errorf("Parse diff (-want +got):\n%s", diff) + } + + _, err = interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p)) + if err == nil { + t.Fatalf("Evaluate Expression expected an error to be returned, got nil instead") + } + if !strings.Contains(err.Error(), tc.wantEvalErrContains) { + t.Errorf("Unexpected evaluation error contents got (%v) want (%v)", err.Error(), tc.wantEvalErrContains) + } + }) + } +} + func TestTruncate(t *testing.T) { tests := []struct { name string diff --git a/tests/spectests/exclusions/exclusions.go b/tests/spectests/exclusions/exclusions.go index 993e9c2..8bc3a03 100644 --- a/tests/spectests/exclusions/exclusions.go +++ b/tests/spectests/exclusions/exclusions.go @@ -60,17 +60,11 @@ func XMLTestFileExclusionDefinitions() map[string]XMLTestFileExclusions { "HighBoundary", "Log", "LowBoundary", - "Round", }, NamesExcludes: []string{ // TODO: b/342061715 - Unsupported operator. - "Divide103", "Multiply1CMBy2CM", "Power2DToNeg2DEquivalence", - "Exp1", // Require Round support. - "ExpNeg1", // Require Round support. - "Ln1000", // Require Round support. - "Ln1000D", // Require Round support. // TODO: b/342061606 - Unit conversion is not supported. "Divide1Q1", "Divide10Q5I",