Skip to content

Commit

Permalink
Add support for the Round() functional operator.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677032631
  • Loading branch information
evan-gordon authored and copybara-github committed Oct 1, 2024
1 parent 880391f commit 3ca3556
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 7 deletions.
49 changes: 49 additions & 0 deletions interpreter/operator_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,55 @@ func evalNegateQuantity(m model.IUnaryExpression, obj result.Value) (result.Valu
return result.New(val)
}

// Round(argument Decimal) Decimal
// Round(argument Decimal, argument Integer) Decimal
// https://cql.hl7.org/09-b-cqlreference.html#round
// If a precision is specified but is null then the default precision is 0.
// If a precision is specified but is negative then an error is returned. This is technically
// undefined behavior in the CQL spec, but we choose to throw an error here.
func evalRound(_ model.INaryExpression, operands []result.Value) (result.Value, error) {
// if len(operands) == 1 {
// return roundValue(operands[0])
// }
decimalVal := operands[0]
var precisionVal result.Value
var err error
// retrieve the precision if it exists, otherwise default to 0.
if len(operands) == 2 {
precisionVal = operands[1]
} else {
precisionVal, err = result.New(0)
if err != nil {
return result.Value{}, err
}
}
if result.IsNull(decimalVal) {
return result.New(nil)
}

p, err := result.ToInt32(precisionVal)
if err != nil {
p = 0
}
if p < 0 {
return result.Value{}, fmt.Errorf("internal error - precision must be non-negative, got %v", p)
}
d, err := result.ToFloat64(decimalVal)
if err != nil {
return result.Value{}, err
}
ratio := math.Pow10(int(p))
// CQL currently implements its own special version of rounding for now (which will be changed in
// the future). For now if the value is negative we round towards zero.
ratioedDecimal := d * ratio
_, frac := math.Modf(ratioedDecimal)
if frac == -0.5 {
// force go to round towards zero
ratioedDecimal += 0.1
}
return result.New(math.Round(ratioedDecimal) / ratio)
}

// predecessor of<T>(obj T) T
// https://cql.hl7.org/09-b-cqlreference.html#predecessor
func (i *interpreter) evalPredecessor(m model.IUnaryExpression, obj result.Value) (result.Value, error) {
Expand Down
11 changes: 11 additions & 0 deletions interpreter/operator_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,17 @@ func (i *interpreter) naryOverloads(m model.INaryExpression) ([]convert.Overload
Result: i.evalCombine,
},
}, nil
case *model.Round:
return []convert.Overload[evalNarySignature]{
{
Operands: []types.IType{types.Decimal},
Result: evalRound,
},
{
Operands: []types.IType{types.Decimal, types.Integer},
Result: evalRound,
},
}, nil
default:
return nil, fmt.Errorf("unsupported Nary Expression %v", m.GetName())
}
Expand Down
6 changes: 6 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,9 @@ type DateTime struct{ *NaryExpression }
// one of those at that point.
type Now struct{ *NaryExpression }

// Round ELM Expression https://cql.hl7.org/04-logicalspecification.html#round
type Round struct{ *NaryExpression }

// TimeOfDay is https://cql.hl7.org/04-logicalspecification.html#timeofday
// Note: in the future we may implement the OperatorExpression, and should convert this to
// one of those at that point.
Expand Down Expand Up @@ -1325,6 +1328,9 @@ func (a *Modulo) GetName() string { return "Modulo" }
// GetName returns the name of the system operator.
func (a *Power) GetName() string { return "Power" }

// GetName returns the name of the system operator.
func (a *Round) GetName() string { return "Round" }

// GetName returns the name of the system operator.
func (a *TruncatedDivide) GetName() string { return "TruncatedDivide" }

Expand Down
24 changes: 23 additions & 1 deletion parser/operators.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (
"fmt"
"strings"

"github.com/antlr4-go/antlr/v4"
"github.com/google/cql/internal/convert"
"github.com/google/cql/model"
"github.com/google/cql/types"
"github.com/antlr4-go/antlr/v4"
)

// parseFunction uses the reference resolver to resolve the function, visits the operands, and sets
Expand Down Expand Up @@ -881,6 +881,28 @@ func (p *Parser) loadSystemOperators() error {
}
},
},
{
name: "Round",
operands: [][]types.IType{{types.Decimal, types.Integer}},
model: func() model.IExpression {
return &model.Round{
NaryExpression: &model.NaryExpression{
Expression: model.ResultType(types.Decimal),
},
}
},
},
{
name: "Round",
operands: [][]types.IType{{types.Decimal}},
model: func() model.IExpression {
return &model.Round{
NaryExpression: &model.NaryExpression{
Expression: model.ResultType(types.Decimal),
},
}
},
},
{
name: "Successor",
operands: [][]types.IType{
Expand Down
25 changes: 25 additions & 0 deletions parser/operators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,31 @@ func TestBuiltInFunctions(t *testing.T) {
},
},
},
{
name: "Round 1.42",
cql: "Round(1.42)",
want: &model.Round{
NaryExpression: &model.NaryExpression{
Operands: []model.IExpression{
model.NewLiteral("1.42", types.Decimal),
},
Expression: model.ResultType(types.Decimal),
},
},
},
{
name: "Round 3.14159 to 3 decimal places",
cql: "Round(3.14159, 3)",
want: &model.Round{
NaryExpression: &model.NaryExpression{
Operands: []model.IExpression{
model.NewLiteral("3.14159", types.Decimal),
model.NewLiteral("3", types.Integer),
},
Expression: model.ResultType(types.Decimal),
},
},
},
{
name: "Predecessor for Date",
cql: "Predecessor(@2023)",
Expand Down
141 changes: 141 additions & 0 deletions tests/enginetests/operator_arithmetic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,147 @@ func TestMultiply(t *testing.T) {
}
}

func TestRound(t *testing.T) {
tests := []struct {
name string
cql string
wantModel model.IExpression
wantResult result.Value
}{
{
name: "Simple",
cql: "Round(42.101)",
wantResult: newOrFatal(t, 42.0),
},
{
name: "Negative decimal",
cql: "Round(-101.42)",
wantResult: newOrFatal(t, -101.0),
},
{
name: "Integers",
cql: "Round(2)",
wantModel: &model.Round{
NaryExpression: &model.NaryExpression{
Operands: []model.IExpression{
&model.ToDecimal{
UnaryExpression: &model.UnaryExpression{
Operand: model.NewLiteral("2", types.Integer),
Expression: model.ResultType(types.Decimal),
},
},
},
Expression: model.ResultType(types.Decimal),
},
},
wantResult: newOrFatal(t, 2.0),
},
{
name: "Negative, round up",
cql: "Round(-0.5)",
wantResult: newOrFatal(t, 0.0),
},
{
name: "Negative, round down",
cql: "Round(-0.6)",
wantResult: newOrFatal(t, -1.0),
},
{
name: "Zero",
cql: "Round(0.0)",
wantResult: newOrFatal(t, 0.0),
},
{
name: "Null",
cql: "Round(null as Decimal)",
wantResult: newOrFatal(t, nil),
},
// With precision
{
name: "Simple with precision",
cql: "Round(42.101, 1)",
wantResult: newOrFatal(t, 42.1),
},
{
name: "Negative decimal with precision round up",
cql: "Round(-101.45, 1)",
wantResult: newOrFatal(t, -101.4),
},
{
name: "Negative decimal with precision round down",
cql: "Round(-101.46, 1)",
wantResult: newOrFatal(t, -101.5),
},
{
name: "Precision is 0",
cql: "Round(2.123, 0)",
wantResult: newOrFatal(t, 2.0),
},
{
name: "Precision is null",
cql: "Round(2.123, null)",
wantResult: newOrFatal(t, 2.0),
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p := newFHIRParser(t)
parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{})
if err != nil {
t.Fatalf("Parse returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" {
t.Errorf("Parse diff (-want +got):\n%s", diff)
}

results, err := interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p))
if err != nil {
t.Fatalf("Eval returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantResult, getTESTRESULT(t, results), protocmp.Transform()); diff != "" {
t.Errorf("Eval diff (-want +got)\n%v", diff)
}
})
}
}

func TestRound_EvalErrors(t *testing.T) {
tests := []struct {
name string
cql string
wantModel model.IExpression
wantEvalErrContains string
}{
{
name: "Round with a negative precision",
cql: "Round(2.123, -1)",
wantEvalErrContains: "precision must be non-negative",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p := newFHIRParser(t)
parsedLibs, err := p.Libraries(context.Background(), wrapInLib(t, tc.cql), parser.Config{})
if err != nil {
t.Fatalf("Parse returned unexpected error: %v", err)
}
if diff := cmp.Diff(tc.wantModel, getTESTRESULTModel(t, parsedLibs)); tc.wantModel != nil && diff != "" {
t.Errorf("Parse diff (-want +got):\n%s", diff)
}

_, err = interpreter.Eval(context.Background(), parsedLibs, defaultInterpreterConfig(t, p))
if err == nil {
t.Fatalf("Evaluate Expression expected an error to be returned, got nil instead")
}
if !strings.Contains(err.Error(), tc.wantEvalErrContains) {
t.Errorf("Unexpected evaluation error contents got (%v) want (%v)", err.Error(), tc.wantEvalErrContains)
}
})
}
}

func TestTruncate(t *testing.T) {
tests := []struct {
name string
Expand Down
6 changes: 0 additions & 6 deletions tests/spectests/exclusions/exclusions.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,11 @@ func XMLTestFileExclusionDefinitions() map[string]XMLTestFileExclusions {
"HighBoundary",
"Log",
"LowBoundary",
"Round",
},
NamesExcludes: []string{
// TODO: b/342061715 - Unsupported operator.
"Divide103",
"Multiply1CMBy2CM",
"Power2DToNeg2DEquivalence",
"Exp1", // Require Round support.
"ExpNeg1", // Require Round support.
"Ln1000", // Require Round support.
"Ln1000D", // Require Round support.
// TODO: b/342061606 - Unit conversion is not supported.
"Divide1Q1",
"Divide10Q5I",
Expand Down

0 comments on commit 3ca3556

Please sign in to comment.