diff --git a/cmd/simple/engine_test.go b/cmd/simple/engine_test.go index d5c33b4..dbbaab9 100644 --- a/cmd/simple/engine_test.go +++ b/cmd/simple/engine_test.go @@ -18,9 +18,9 @@ func TestParquetFile(t *testing.T) { _ = df.Show() df = df. - Filter(logicalplan.Eq( + Filter(logicalplan.Lt( logicalplan.ColumnExpr{Name: "c1"}, - logicalplan.LiteralInt64Expr{Val: 200}, + logicalplan.LiteralInt64Expr{Val: 300}, )). Project( logicalplan.ColumnExpr{Name: "c1"}, @@ -38,11 +38,11 @@ func TestParquetFile(t *testing.T) { logicalPlan, _ := df.LogicalPlan() fmt.Println(logicalplan.PrettyPrint(logicalPlan, 0)) - assert.Equal(t, "Aggregate: groupExpr=[#c1], aggregateExpr=[sum(#c2)]\n\tProjection: #c1, #c2\n\t\tFilter: #c1 = 200\n\t\t\tInput: ../../test/data/c1_c2_c3_int64.parquet; projExpr=None\n", logicalplan.PrettyPrint(logicalPlan, 0)) + assert.Equal(t, "Aggregate: groupExpr=[#c1], aggregateExpr=[sum(#c2)]\n\tProjection: #c1, #c2\n\t\tFilter: #c1 < 300\n\t\t\tInput: ../../test/data/c1_c2_c3_int64.parquet; projExpr=None\n", logicalplan.PrettyPrint(logicalPlan, 0)) logicalPlan, _ = df.OptimizedLogicalPlan() fmt.Println(logicalplan.PrettyPrint(logicalPlan, 0)) - assert.Equal(t, "Aggregate: groupExpr=[#c1], aggregateExpr=[sum(#c2)]\n\tProjection: #c1, #c2\n\t\tFilter: #c1 = 200\n\t\t\tInput: ../../test/data/c1_c2_c3_int64.parquet; projExpr=[c1 c2]\n", logicalplan.PrettyPrint(logicalPlan, 0)) + assert.Equal(t, "Aggregate: groupExpr=[#c1], aggregateExpr=[sum(#c2)]\n\tProjection: #c1, #c2\n\t\tFilter: #c1 < 300\n\t\t\tInput: ../../test/data/c1_c2_c3_int64.parquet; projExpr=[c1 c2]\n", logicalplan.PrettyPrint(logicalPlan, 0)) err = df.Show() if err != nil { diff --git a/pkg/d_physicalplan/a_eval_expr/physical_expr.go b/pkg/d_physicalplan/a_eval_expr/physical_expr.go index f93042f..98393ff 100644 --- a/pkg/d_physicalplan/a_eval_expr/physical_expr.go +++ b/pkg/d_physicalplan/a_eval_expr/physical_expr.go @@ -113,12 +113,15 @@ func (e BooleanBinaryExpr) Evaluate(input containers.IBatch) (containers.IVector func (e BooleanBinaryExpr) evaluate(l, r containers.IVector) (containers.IVector, error) { res := make([]any, 0) switch e.Op { - case "=": - for i := 0; i < l.Len(); i++ { - if l.GetValue(i) == r.GetValue(i) { - res = append(res, true) - } else { - res = append(res, false) + case "<": + switch l.DataType() { + case arrow.PrimitiveTypes.Int64: + for i := 0; i < l.Len(); i++ { + if l.GetValue(i).(int64) < r.GetValue(i).(int64) { + res = append(res, true) + } else { + res = append(res, false) + } } } return containers.NewVector(arrow.FixedWidthTypes.Boolean, res), nil