Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions arrow/compute/exprs/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ func GetExtensionIDSet(ctx context.Context) ExtensionIDSet {
return v
}

func intervalYearToMonthDatum(mem memory.Allocator, years, months int32) (compute.Datum, error) {
bldr := array.NewInt32Builder(mem)
defer bldr.Release()

bldr.Append(years)
bldr.Append(months)
arr := bldr.NewArray()
defer arr.Release()
return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
scalar.NewFixedSizeListScalar(arr), intervalYear())}, nil
}

func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet) (compute.Datum, error) {
switch v := lit.(type) {
case *expr.PrimitiveLiteral[bool]:
Expand Down Expand Up @@ -329,6 +341,10 @@ func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet)

s, err := scalar.NewStructScalarWithNames(fields, names)
return compute.NewDatum(s), err
case expr.IntervalYearToMonthLiteral:
return intervalYearToMonthDatum(mem, v.Years, v.Months)
case *expr.IntervalYearToMonthLiteral:
return intervalYearToMonthDatum(mem, v.Years, v.Months)
case *expr.ProtoLiteral:
switch t := v.Type.(type) {
case *types.DecimalType:
Expand All @@ -353,19 +369,10 @@ func literalToDatum(mem memory.Allocator, lit expr.Literal, ext ExtensionIDSet)
&arrow.Decimal128Type{Precision: t.Precision, Scale: t.Scale})), nil
case *types.UserDefinedType: // not yet implemented
case *types.IntervalYearToMonthType:
bldr := array.NewInt32Builder(memory.DefaultAllocator)
defer bldr.Release()

val := v.Value.(*types.IntervalYearToMonth)
typ := intervalYear()
bldr.Append(val.Years)
bldr.Append(val.Months)
arr := bldr.NewArray()
defer arr.Release()
return &compute.ScalarDatum{Value: scalar.NewExtensionScalar(
scalar.NewFixedSizeListScalar(arr), typ)}, nil
return intervalYearToMonthDatum(mem, val.Years, val.Months)
case *types.IntervalDayType:
bldr := array.NewInt32Builder(memory.DefaultAllocator)
bldr := array.NewInt32Builder(mem)
defer bldr.Release()

val := v.Value.(*types.IntervalDayToSecond)
Expand Down
47 changes: 47 additions & 0 deletions arrow/compute/exprs/exec_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ import (
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/v8/expr"
"github.com/substrait-io/substrait-go/v8/extensions"
"github.com/substrait-io/substrait-go/v8/types"
)

var (
Expand Down Expand Up @@ -112,3 +115,47 @@ func TestMakeExecBatch(t *testing.T) {
})
}
}

func TestLiteralToDatumIntervalYearToMonth(t *testing.T) {
// memory.NewCheckedAllocator with AssertSize would fail here:
// *scalar.Extension does not implement Release() (see
// arrow/scalar/scalar.go), so an extension scalar's underlying
// storage is never released even when the wrapping Datum is.
mem := memory.DefaultAllocator

extSet := NewExtensionSetDefault(
expr.NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError()))

const (
years int32 = 3
months int32 = 7
)

protoLitType := types.NewIntervalYearToMonthType()
protoLit := &expr.ProtoLiteral{
Value: &types.IntervalYearToMonth{Years: years, Months: months},
Type: &protoLitType,
}
expected, err := literalToDatum(mem, protoLit, extSet)
require.NoError(t, err, "ProtoLiteral baseline failed")
defer expected.Release()

cases := []struct {
name string
lit expr.Literal
}{
{"value", expr.IntervalYearToMonthLiteral{Years: years, Months: months}},
{"pointer", &expr.IntervalYearToMonthLiteral{Years: years, Months: months}},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := literalToDatum(mem, tc.lit, extSet)
require.NoError(t, err)
defer got.Release()
assert.Truef(t, got.Equals(expected),
"IntervalYearToMonthLiteral (%s) datum did not match ProtoLiteral baseline\nexpected: %s\ngot: %s",
tc.name, expected, got)
})
}
}