diff --git a/experimental/iterators/include/iterators-c/Runtime/Arrow.h b/experimental/iterators/include/iterators-c/Runtime/Arrow.h new file mode 100644 index 000000000000..992fa7e8b38b --- /dev/null +++ b/experimental/iterators/include/iterators-c/Runtime/Arrow.h @@ -0,0 +1,152 @@ +//===-- Arrow.h - Runtime for handling Apache Arrow data --------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a shallow runtime around the C interfaces of Apache +// Arrow, namely the Arrow C data interface and the Arrow C stream interface +// (see https://arrow.apache.org/docs/format/CDataInterface.html and +// https://arrow.apache.org/docs/format/CStreamInterface.html). While these +// interfaces are already very simple and low-level, the goal of this runtime is +// to simplify its usage even further by doing all error handling and +// verification of current limitations. +// +//===----------------------------------------------------------------------===// + +#ifndef ITERATORS_C_RUNTIME_ARROW_H +#define ITERATORS_C_RUNTIME_ARROW_H + +#include "mlir-c/Support.h" + +struct ArrowArray; +struct ArrowArrayStream; +struct ArrowSchema; + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Arrow Array (aka Arrow RecordBatch). +//===----------------------------------------------------------------------===// + +/// Returns the number of rows in the given Arrow array. +MLIR_CAPI_EXPORTED +int64_t mlirIteratorsArrowArrayGetSize(ArrowArray *array); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is an int8 column. +MLIR_CAPI_EXPORTED const int8_t * +mlirIteratorsArrowArrayGetInt8Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is a uint8 column. +MLIR_CAPI_EXPORTED const uint8_t * +mlirIteratorsArrowArrayGetUInt8Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is an int16 column. +MLIR_CAPI_EXPORTED const int16_t * +mlirIteratorsArrowArrayGetInt16Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is a uint16 column. +MLIR_CAPI_EXPORTED const uint16_t * +mlirIteratorsArrowArrayGetUInt16Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is an int32 column. +MLIR_CAPI_EXPORTED const int32_t * +mlirIteratorsArrowArrayGetInt32Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is a uint32 column. +MLIR_CAPI_EXPORTED const uint32_t * +mlirIteratorsArrowArrayGetUInt32Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is an int64 column. +MLIR_CAPI_EXPORTED const int64_t * +mlirIteratorsArrowArrayGetInt64Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is a uint64 column. +MLIR_CAPI_EXPORTED const uint64_t * +mlirIteratorsArrowArrayGetUInt64Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is a float16 column. +MLIR_CAPI_EXPORTED const uint16_t * +mlirIteratorsArrowArrayGetFloat16Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is a float32 column. +MLIR_CAPI_EXPORTED const float * +mlirIteratorsArrowArrayGetFloat32Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Returns the raw data pointer to the buffer of the i-th column of the given +/// Arrow array, ensuring that that column is a float64 column. +MLIR_CAPI_EXPORTED const double * +mlirIteratorsArrowArrayGetFloat64Column(ArrowArray *array, ArrowSchema *schema, + int64_t i); + +/// Releases the memory owned by the given Arrow array (by calling its release +/// function). Unlike the lower-level release function from the Arrow C +/// interface, this function may be called on already released structs, in which +/// case the release function is not called. +MLIR_CAPI_EXPORTED +void mlirIteratorsArrowArrayRelease(ArrowArray *array); + +//===----------------------------------------------------------------------===// +// ArrowSchema. +//===----------------------------------------------------------------------===// + +/// Releases the memory owned by the given schema (by calling its release +/// function). Unlike the lower-level release function from the Arrow C +/// interface, this function may be called on already released structs, in which +/// case the release function is not called. +MLIR_CAPI_EXPORTED +void mlirIteratorsArrowSchemaRelease(ArrowSchema *schema); + +//===----------------------------------------------------------------------===// +// ArrowArrayStream (aka RecordBatchReader). +//===----------------------------------------------------------------------===// + +/// Attempts to extract the next record batch from the given stream. Stores the +/// returned batch in the given result pointer and returns true iff the stream +/// did return a batch. If an error occurs, prints a message and exits. +MLIR_CAPI_EXPORTED +bool mlirIteratorsArrowArrayStreamGetNext(ArrowArrayStream *stream, + ArrowArray *result); + +/// Gets the schema for the stream and stores it in the result pointer. If an +/// error occurs, prints a message and exits. +MLIR_CAPI_EXPORTED +void mlirIteratorsArrowArrayStreamGetSchema(ArrowArrayStream *stream, + ArrowSchema *result); + +/// Releases the memory owned by the given schema (by calling its release +/// function). Unlike the lower-level release function from the Arrow C +/// interface, this function may be called on already released structs, in which +/// case the release function is not called. +MLIR_CAPI_EXPORTED +void mlirIteratorsArrowArrayStreamRelease(ArrowArrayStream *stream); + +#ifdef __cplusplus +} +#endif + +#endif // ITERATORS_C_RUNTIME_ARROW_H diff --git a/experimental/iterators/include/iterators-c/Runtime/ArrowInterfaces.h b/experimental/iterators/include/iterators-c/Runtime/ArrowInterfaces.h new file mode 100644 index 000000000000..51a6f19ed90f --- /dev/null +++ b/experimental/iterators/include/iterators-c/Runtime/ArrowInterfaces.h @@ -0,0 +1,76 @@ +//===- ArrowCInterfaces.h - Arrow C data and stream interfaces ----*- C -*-===// +// +// This file is licensed under the Apache License v2.0. +// See https://www.apache.org/licenses/LICENSE-2.0 for license information. +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// This file contains the struct definitions of the Apache Arrow C ABI from +// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions +// and +// https://arrow.apache.org/docs/format/CStreamInterface.html#structure-definition. +// +//===----------------------------------------------------------------------===// + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#include + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char *format; + const char *name; + const char *metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema **children; + struct ArrowSchema *dictionary; + + // Release callback + void (*release)(struct ArrowSchema *); + // Opaque producer-specific data + void *private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void **buffers; + struct ArrowArray **children; + struct ArrowArray *dictionary; + + // Release callback + void (*release)(struct ArrowArray *); + // Opaque producer-specific data + void *private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + +#ifndef ARROW_C_STREAM_INTERFACE +#define ARROW_C_STREAM_INTERFACE + +struct ArrowArrayStream { + // Callbacks providing stream functionality + int (*get_schema)(struct ArrowArrayStream *, struct ArrowSchema *out); + int (*get_next)(struct ArrowArrayStream *, struct ArrowArray *out); + const char *(*get_last_error)(struct ArrowArrayStream *); + + // Release callback + void (*release)(struct ArrowArrayStream *); + + // Opaque producer-specific data + void *private_data; +}; + +#endif // ARROW_C_STREAM_INTERFACE diff --git a/experimental/iterators/include/iterators/Dialect/Iterators/IR/ArrowUtils.h b/experimental/iterators/include/iterators/Dialect/Iterators/IR/ArrowUtils.h new file mode 100644 index 000000000000..c74e94243edf --- /dev/null +++ b/experimental/iterators/include/iterators/Dialect/Iterators/IR/ArrowUtils.h @@ -0,0 +1,40 @@ +//===-- ArrowUtils.h - IR utils related to Apache Arrow --------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H +#define ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H + +namespace mlir { +class MLIRContext; +namespace LLVM { +class LLVMStructType; +} // namespace LLVM +} // namespace mlir + +namespace mlir { +namespace iterators { + +/// Returns the LLVM struct type for Arrow arrays of the Arrow C data interface. +/// For the official definition of the type, see +/// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions. +LLVM::LLVMStructType getArrowArrayType(MLIRContext *context); + +/// Returns the LLVM struct type for Arrow schemas of the Arrow C data +/// interface. For the official definition of the type, see +/// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions. +LLVM::LLVMStructType getArrowSchemaType(MLIRContext *context); + +/// Returns the LLVM struct type for Arrow streams of the Arrow C stream +/// interface. For the official definition of the type, see +/// https://arrow.apache.org/docs/format/CStreamInterface.html#structure-definition. +LLVM::LLVMStructType getArrowArrayStreamType(MLIRContext *context); + +} // namespace iterators +} // namespace mlir + +#endif // ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H diff --git a/experimental/iterators/include/iterators/Dialect/Iterators/IR/Iterators.h b/experimental/iterators/include/iterators/Dialect/Iterators/IR/Iterators.h index 31501a9cd8b7..6a303a1d0946 100644 --- a/experimental/iterators/include/iterators/Dialect/Iterators/IR/Iterators.h +++ b/experimental/iterators/include/iterators/Dialect/Iterators/IR/Iterators.h @@ -18,10 +18,17 @@ #include "iterators/Dialect/Iterators/IR/IteratorsOpsDialect.h.inc" +namespace mlir { +namespace LLVM { +class LLVMPointerType; +} // namespace LLVM +} // namespace mlir + namespace mlir { namespace iterators { #include "iterators/Dialect/Iterators/IR/IteratorsOpInterfaces.h.inc" #include "iterators/Dialect/Iterators/IR/IteratorsTypeInterfaces.h.inc" + } // namespace iterators } // namespace mlir diff --git a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td index e8e47af68533..4aeb461594a0 100644 --- a/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td +++ b/experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td @@ -192,6 +192,54 @@ def Iterators_FilterOp : Iterators_Op<"filter", }]; } +def Iterators_FromArrowArrayStreamOp : Iterators_Op<"from_arrow_array_stream", [ + TypesMatchWith<"type of operand must be an LLVM 'struct ArrowArrayStream'", + "result", "arrowStream", + "mlir::LLVM::LLVMPointerType::get(" + " mlir::iterators::getArrowArrayStreamType(" + " $_self.getContext()))">, + DeclareOpInterfaceMethods]> { + let summary = "Converts an arrow array stream into a stream of tabular views"; + let description = [{ + Wraps an Apache Arrow Array Stream using the Arrow C stream interface in + an `Iterators_Op`, i.e., converts the `ArrowArrayStream`, which produces a + stream of `ArrowArrays`, into an `Iterator_Op` that produces a stream of + `Tabular_TabularView`. This allows to read from any producer of Apache Arrow + Arrays (a.k.a "record batches"), including file readers for CSV, JSON, + Parquet, and others, libraries and languages such as pyarrow (and thus + Python and pandas), remote processes via Arrow IPC, and many others. + + **Limitations:** Currently, only a few numeric data types, only ArrowArrays + without offset, and no nullable or nested types are supported. These + limitations are inherited in part from the runtime and in part from + TabularView. Furthermore, returning TabularViews, which are references, + limits the lifetime of each returned element until when the next element is + produced, so it is up to consuming iterator ops to copy any buffers that + have to live longer. Finally, Arrow array streams have no "reset"/"re-open" + method, i.e., they can only consumed once, so the lowering to + open/next/close, in which that is possible, breaks if re-opening is + attempted. + + Example: + ```mlir + func.func @main(%external_input: !llvm.ptr) -> + !iterators.stream> { + %tabular_view_stream = iterators.from_arrow_array_stream %external_input + to !iterators.stream> + ``` + }]; + let arguments = (ins LLVM_PointerTo:$arrowStream); + let results = (outs Iterators_StreamOf:$result); + let assemblyFormat = "$arrowStream attr-dict `to` qualified(type($result))"; + let extraClassDefinition = [{ + /// Implement OpAsmOpInterface. + void $cppClass::getAsmResultNames( + llvm::function_ref setNameFn) { + setNameFn(getResult(), "fromarrowstream"); + } + }]; +} + def Iterators_MapOp : Iterators_Op<"map", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { diff --git a/experimental/iterators/lib/CMakeLists.txt b/experimental/iterators/lib/CMakeLists.txt index fdc1d6eed337..a0aed5cbb4c1 100644 --- a/experimental/iterators/lib/CMakeLists.txt +++ b/experimental/iterators/lib/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(CAPI) add_subdirectory(Conversion) add_subdirectory(Dialect) +add_subdirectory(Runtime) add_subdirectory(Utils) diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp new file mode 100644 index 000000000000..61ece40b5819 --- /dev/null +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.cpp @@ -0,0 +1,116 @@ +//===-- ArrowUtils.cpp - Utils for converting Arrow to LLVM -----*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ArrowUtils.h" + +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" + +using namespace mlir; +using namespace mlir::iterators; +using namespace mlir::LLVM; + +namespace mlir { +namespace iterators { + +LLVMFuncOp lookupOrInsertArrowArrayGetSize(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type i64 = IntegerType::get(context, 64); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayGetSize", {arrayPtr}, + i64); +} + +LLVMFuncOp lookupOrInsertArrowArrayGetColumn(ModuleOp module, + Type elementType) { + assert(elementType.isIntOrFloat() && + "only int or float types supported currently"); + MLIRContext *context = module.getContext(); + + // Assemble types for signature. + Type elementPtr = LLVMPointerType::get(elementType); + Type i64 = IntegerType::get(context, 64); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + + // Assemble function name. + StringRef typeNameBase; + if (elementType.isSignedInteger() || elementType.isSignlessInteger()) + typeNameBase = "Int"; + else if (elementType.isUnsignedInteger()) + typeNameBase = "UInt"; + else { + assert(elementType.isF16() || elementType.isF32() || elementType.isF64()); + typeNameBase = "Float"; + } + std::string typeWidth = std::to_string(elementType.getIntOrFloatBitWidth()); + std::string funcName = + ("mlirIteratorsArrowArrayGet" + typeNameBase + typeWidth + "Column") + .str(); + + // Lookup or insert function. + return lookupOrCreateFn(module, funcName, {arrayPtr, schemaPtr, i64}, + elementPtr); +} + +LLVMFuncOp lookupOrInsertArrowArrayRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayRelease", {arrayPtr}, + voidType); +} + +LLVMFuncOp lookupOrInsertArrowSchemaRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowSchemaRelease", + {schemaPtr}, voidType); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamGetSchema(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type i32 = IntegerType::get(context, 32); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamGetSchema", + {arrayStreamPtr, schemaPtr}, i32); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamGetNext(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type i1 = IntegerType::get(context, 1); + Type stream = getArrowArrayStreamType(context); + Type streamPtr = LLVMPointerType::get(stream); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamGetNext", + {streamPtr, arrayPtr}, i1); +} + +LLVMFuncOp lookupOrInsertArrowArrayStreamRelease(ModuleOp module) { + MLIRContext *context = module.getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type voidType = LLVMVoidType::get(context); + return lookupOrCreateFn(module, "mlirIteratorsArrowArrayStreamRelease", + {arrayStreamPtr}, voidType); +} + +} // namespace iterators +} // namespace mlir diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h new file mode 100644 index 000000000000..e1b8a7e5722d --- /dev/null +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/ArrowUtils.h @@ -0,0 +1,60 @@ +//===-- ArrowUtils.h - Utils for converting Arrow to LLVM -------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H +#define LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H + +namespace mlir { +class ModuleOp; +class Type; +namespace LLVM { +class LLVMFuncOp; +} // namespace LLVM +} // namespace mlir + +namespace mlir { +namespace iterators { + +/// Ensures that the runtime function `mlirIteratorsArrowArrayGetSize` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowArrayGetSize(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayGet*Column` +/// corresponding to the given type is present in the current module and returns +/// the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayGetColumn(mlir::ModuleOp module, + mlir::Type elementType); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowArrayRelease(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowSchemaRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp lookupOrInsertArrowSchemaRelease(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamGetSchema` +/// is present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamGetSchema(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamGetNext` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamGetNext(mlir::ModuleOp module); + +/// Ensures that the runtime function `mlirIteratorsArrowArrayStreamRelease` is +/// present in the current module and returns the corresponding LLVM func op. +mlir::LLVM::LLVMFuncOp +lookupOrInsertArrowArrayStreamRelease(mlir::ModuleOp module); + +} // namespace iterators +} // namespace mlir + +#endif // LIB_CONVERSION_ITERATORSTOLLVM_ARROWUTILS_H diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt b/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt index 2e3f8f1407d9..89446c7e0944 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(MLIRIteratorsToLLVM + ArrowUtils.cpp IteratorsToLLVM.cpp IteratorAnalysis.cpp diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp index acf620b37e10..3ae23645a0a4 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp @@ -1,13 +1,16 @@ #include "IteratorAnalysis.h" +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" #include "iterators/Dialect/Iterators/IR/Iterators.h" #include "iterators/Utils/NameAssigner.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::iterators; +using namespace mlir::LLVM; using SymbolTriple = std::tuple; @@ -77,6 +80,26 @@ StateTypeComputer::operator()(FilterOp op, return StateType::get(context, {upstreamStateTypes[0]}); } +/// The state of FromArrowArrayStreamOp consists of the pointers to the +/// ArrowArrayStream struct it reads, to an ArrowSchema struct describing the +/// stream, and to an ArrowArray struct that owns the memory of the last element +/// the iterator has returned. Pseudocode: +/// +/// struct { struct ArrowArrayStream *stream; struct ArrowSchema *schema; }; +template <> +StateType StateTypeComputer::operator()( + FromArrowArrayStreamOp op, + llvm::SmallVector /*upstreamStateTypes*/) { + MLIRContext *context = op->getContext(); + Type arrayStream = getArrowArrayStreamType(context); + Type arrayStreamPtr = LLVMPointerType::get(arrayStream); + Type schema = getArrowSchemaType(context); + Type schemaPtr = LLVMPointerType::get(schema); + Type array = getArrowArrayType(context); + Type arrayPtr = LLVMPointerType::get(array); + return StateType::get(context, {arrayStreamPtr, schemaPtr, arrayPtr}); +} + /// The state of MapOp only consists of the state of its upstream iterator, /// i.e., the state of the iterator that produces its input stream. template <> @@ -182,6 +205,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis( // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, diff --git a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp index 0fef8b6694ea..d13f6d03619d 100644 --- a/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp +++ b/experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp @@ -9,11 +9,14 @@ #include "iterators/Conversion/IteratorsToLLVM/IteratorsToLLVM.h" #include "../PassDetail.h" +#include "ArrowUtils.h" #include "IteratorAnalysis.h" #include "iterators/Conversion/TabularToLLVM/TabularToLLVM.h" +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" #include "iterators/Dialect/Iterators/IR/Iterators.h" #include "iterators/Dialect/Tabular/IR/Tabular.h" #include "iterators/Dialect/Tuple/IR/Tuple.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -700,6 +703,302 @@ static Value buildStateCreation(FilterOp op, FilterOp::Adaptor adaptor, return b.create(stateType, upstreamState); } +//===----------------------------------------------------------------------===// +// FromArrowArrayStreamOp. +//===----------------------------------------------------------------------===// + +/// Builds IR that retrieves the schema from the input input stream in order to +/// allow cached access during the next calls. Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : !state_type +/// %1 = iterators.extractvalue %arg0[1] : !state_type +/// llvm.call @mlirIteratorsArrowArrayStreamGetSchema(%0, %1) : +/// (!llvm.ptr, !llvm.ptr) -> () +static Value buildOpenBody(FromArrowArrayStreamOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Extract stream and schema pointers from state. + Type arrowArrayStream = getArrowArrayStreamType(context); + Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream); + Type arrowSchema = getArrowSchemaType(context); + Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); + + Value streamPtr = b.create( + arrowArrayStreamPtr, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + arrowSchemaPtr, initialState, b.getIndexAttr(1)); + + // Call runtime function to load schema. + ModuleOp module = op->getParentOfType(); + LLVMFuncOp getSchemaFunc = lookupOrInsertArrowArrayStreamGetSchema(module); + b.create(getSchemaFunc, ValueRange{streamPtr, schemaPtr}); + + // Return initial state. (We only modified the pointees.) + return initialState; +} + +/// Builds IR that calls the get_next function of the Arrow array stream and +/// returns the obtained record batch wrapped in a tabular view. Pseudo-code +/// +/// if (array = arrow_stream->get_next(arrow_stream)): +/// return convert_to_tabular_view(array) +/// return {} +/// +/// Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : !state_type +/// %1 = iterators.extractvalue %arg0[1] : !state_type +/// %2 = iterators.extractvalue %arg0[2] : !state_type +/// llvm.call @mlirIteratorsArrowArrayRelease(%2) : +/// (!llvm.ptr) -> () +/// %3 = llvm.call @mlirIteratorsArrowArrayStreamGetNext(%0, %2) : +/// (!llvm.ptr, !llvm.ptr) -> i1 +/// %c0_i64 = arith.constant 0 : i64 +/// %4 = scf.if %3 -> (i64) { +/// %6 = llvm.call @mlirIteratorsArrowArrayGetSize(%2) : +/// (!llvm.ptr) -> i64 +/// scf.yield %6 : i64 +/// } else { +/// scf.yield %c0_i64 : i64 +/// } +/// %5:2 = scf.if %3 -> (!llvm.ptr, i64) { +/// %c2_i64 = arith.constant 2 : i64 +/// %6 = llvm.call @mlirIteratorsArrowArrayGetInt32Column(%2, %1, %c2_i64) : +/// (!llvm.ptr, !llvm.ptr, i64) -> +/// !llvm.ptr +/// scf.yield %6, %4 : !llvm.ptr, i64 +/// } else { +/// %6 = llvm.mlir.null : !llvm.ptr +/// scf.yield %6, %c0_i64 : !llvm.ptr, i64 +/// } +/// %6 = llvm.mlir.undef : !memref_descr_type +/// %7 = llvm.insertvalue %5#0, %6[0] : !memref_descr_type +/// %8 = llvm.insertvalue %5#0, %7[1] : !memref_descr_type +/// %9 = llvm.insertvalue %c0_i64, %8[2] : !memref_descr_type +/// %10 = llvm.insertvalue %5#1, %9[3, 0] : !memref_descr_type +/// %11 = llvm.insertvalue %5#1, %10[4, 0] : !memref_descr_type +/// %12 = builtin.unrealized_conversion_cast %11 : +/// !memref_descr_type to memref +/// %tabularview = tabular.view_as_tabular %12 : (memref) -> +/// !tabular.tabular_view +static llvm::SmallVector +buildNextBody(FromArrowArrayStreamOp op, OpBuilder &builder, Value initialState, + ArrayRef upstreamInfos, Type elementType) { + MLIRContext *context = op->getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + Type opaquePtrType = LLVMPointerType::get(context); + + // Extract stream, schema, and block pointers from state. + Type arrowArrayStream = getArrowArrayStreamType(context); + Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream); + Type arrowSchema = getArrowSchemaType(context); + Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); + Type arrowArray = getArrowArrayType(context); + Type arrowArrayPtr = LLVMPointerType::get(arrowArray); + + Value streamPtr = b.create( + arrowArrayStreamPtr, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + arrowSchemaPtr, initialState, b.getIndexAttr(1)); + Value arrayPtr = b.create( + arrowArrayPtr, initialState, b.getIndexAttr(2)); + + // Get type-unspecific LLVM functions. + ModuleOp module = op->getParentOfType(); + LLVMFuncOp releaseArrayFunc = lookupOrInsertArrowArrayRelease(module); + LLVMFuncOp getNextFunc = lookupOrInsertArrowArrayStreamGetNext(module); + LLVMFuncOp getArraySizeFunc = lookupOrInsertArrowArrayGetSize(module); + + // Release Arrow array from previous call to next. + b.create(releaseArrayFunc, arrayPtr); + + // Call getNext on Arrow stream. + auto getNextResult = + b.create(getNextFunc, ValueRange{streamPtr, arrayPtr}); + Value hasNextElement = getNextResult.getResult(); + + // Call getSize on current array if we got one; use 0 otherwise. + Value zero = b.create(/*value=*/0, /*width=*/64); + auto ifOp = b.create( + /*condition=*/hasNextElement, /*ifBuilder*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + /*elseBuilder*/ + auto callOp = b.create(getArraySizeFunc, arrayPtr); + Value arraySize = callOp.getResult(); + b.create(arraySize); + }, + [&](OpBuilder &builder, Location loc) { + // Apply map function. + ImplicitLocOpBuilder b(loc, builder); + b.create(zero); + }); + Value arraySize = ifOp->getResult(0); + + // Extract column pointers from Arrow array. + auto tabularViewType = elementType.cast(); + SmallVector memrefs; + LLVMTypeConverter typeConverter(context); + for (auto [idx, t] : llvm::enumerate(tabularViewType.getColumnTypes())) { + auto memrefType = MemRefType::get({ShapedType::kDynamic}, t); + + // Get column pointer from the array if we got one; nullptr otherwise. + auto ifOp = b.create( + /*condition=*/hasNextElement, /*ifBuilder*/ + [&, idx = idx, t = t](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + + // Call type-specific getColumn on current array. + auto idxValue = + b.create(/*value=*/idx, /*width=*/64); + LLVMFuncOp getColumnFunc = + lookupOrInsertArrowArrayGetColumn(module, t); + auto callOp = b.create( + getColumnFunc, ValueRange{arrayPtr, schemaPtr, idxValue}); + Value columnPtr = callOp->getResult(0); + columnPtr = b.create(opaquePtrType, columnPtr); + + b.create(ValueRange{columnPtr, arraySize}); + }, + /*elseBuilder*/ + [&](OpBuilder &builder, Location loc) { + ImplicitLocOpBuilder b(loc, builder); + + // Use nullptr instead. + Value columnPtr = b.create(opaquePtrType); + b.create(ValueRange{columnPtr, zero}); + }); + + Value columnPtr = ifOp.getResult(0); + Value size = ifOp->getResult(1); + + // Assemble a memref descriptor and cast it to memref. + auto memrefValues = {/*allocated pointer=*/columnPtr, + /*aligned pointer=*/columnPtr, + /*offset=*/zero, /*sizes=*/size, + /*shapes=*/size}; + auto memrefDescriptor = + MemRefDescriptor::pack(b, loc, typeConverter, memrefType, memrefValues); + auto castOp = + b.create(memrefType, memrefDescriptor); + + memrefs.push_back(castOp.getResult(0)); + } + + // Create a tabular view from the memrefs. + Value tab = b.create(elementType, memrefs); + + return {initialState, hasNextElement, tab}; +} + +/// Builds IR that frees up all resources, namely, release the stream, the +/// schema, and the current array. Possible output: +/// +/// %0 = iterators.extractvalue %arg0[0] : !state_type +/// %1 = iterators.extractvalue %arg0[1] : !state_type +/// %2 = iterators.extractvalue %arg0[2] : !state_type +/// llvm.call @mlirIteratorsArrowArrayStreamRelease(%0) : +/// (!llvm.ptr) -> () +/// llvm.call @mlirIteratorsArrowSchemaRelease(%1) : +/// (!llvm.ptr) -> () +/// llvm.call @mlirIteratorsArrowArrayRelease(%2) : +/// (!llvm.ptr) -> () +static Value buildCloseBody(FromArrowArrayStreamOp op, OpBuilder &builder, + Value initialState, + ArrayRef upstreamInfos) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Extract stream and schema pointers from state. + Type arrayStreamType = getArrowArrayStreamType(context); + Type arrayStreamPtrType = LLVMPointerType::get(arrayStreamType); + Type schemaType = getArrowSchemaType(context); + Type schemaPtrType = LLVMPointerType::get(schemaType); + Type arrayType = getArrowArrayType(context); + Type arrayPtrType = LLVMPointerType::get(arrayType); + + Value streamPtr = b.create( + arrayStreamPtrType, initialState, b.getIndexAttr(0)); + Value schemaPtr = b.create( + schemaPtrType, initialState, b.getIndexAttr(1)); + Value arrayPtr = b.create( + arrayPtrType, initialState, b.getIndexAttr(2)); + + // Call runtime functions to release structs. + ModuleOp module = op->getParentOfType(); + LLVMFuncOp releaseStreamFunc = lookupOrInsertArrowArrayStreamRelease(module); + LLVMFuncOp releaseSchemaFunc = lookupOrInsertArrowSchemaRelease(module); + LLVMFuncOp releaseArrayFunc = lookupOrInsertArrowArrayRelease(module); + b.create(releaseStreamFunc, streamPtr); + b.create(releaseSchemaFunc, schemaPtr); + b.create(releaseArrayFunc, arrayPtr); + + // Return initial state. (We only modified the pointees.) + return initialState; +} + +/// Builds IR that allocates data for the schema and the current array on the +/// stack and stores pointers to them in the state. Possible output: +/// +/// %c1_i64 = arith.constant 1 : i64 +/// %0 = llvm.alloca %c1_i64 x !llvm.!array_type : (i64) -> +/// !llvm.ptr +/// %1 = llvm.alloca %c1_i64 x !llvm.!schema_type : (i64) -> +/// !llvm.ptr +/// %c0_i8 = arith.constant 0 : i8 +/// %false = arith.constant false +/// %c80_i64 = arith.constant 80 : i64 +/// %c72_i64 = arith.constant 72 : i64 +/// "llvm.intr.memset"(%0, %c0_i8, %c80_i64, %false) : +/// (!llvm.ptr, i8, i64, i1) -> () +/// "llvm.intr.memset"(%1, %c0_i8, %c72_i64, %false) : +/// (!llvm.ptr, i8, i64, i1) -> () +/// %state = iterators.createstate(%arg0, %1, %0) : !state_type +static Value buildStateCreation(FromArrowArrayStreamOp op, + FromArrowArrayStreamOp::Adaptor adaptor, + OpBuilder &builder, StateType stateType) { + MLIRContext *context = op.getContext(); + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, builder); + + // Allocate memory for schema and array on the stack. + Value one = b.create(/*value=*/1, /*width=*/64); + LLVMStructType arrayType = getArrowArrayType(context); + LLVMStructType schemaType = getArrowSchemaType(context); + Type arrayPtrType = LLVMPointerType::get(arrayType); + Type schemaPtrType = LLVMPointerType::get(schemaType); + Value arrayPtr = b.create(arrayPtrType, one); + Value schemaPtr = b.create(schemaPtrType, one); + + // Initialize it with zeros. + Value zero = b.create(/*value=*/0, /*width=*/8); + Value constFalse = b.create(/*value=*/0, /*width=*/1); + uint32_t arrayTypeSize = mlir::DataLayout::closest(op).getTypeSize(arrayType); + uint32_t schemaTypeSize = + mlir::DataLayout::closest(op).getTypeSize(schemaType); + Value arrayTypeSizeVal = + b.create(/*value=*/arrayTypeSize, + /*width=*/64); + Value schemaTypeSizeVal = + b.create(/*value=*/schemaTypeSize, + /*width=*/64); + b.create(arrayPtr, zero, arrayTypeSizeVal, + /*isVolatile=*/constFalse); + b.create(schemaPtr, zero, schemaTypeSizeVal, + /*isVolatile=*/constFalse); + + // Create the state. + Value streamPtr = adaptor.getArrowStream(); + return b.create(stateType, + ValueRange{streamPtr, schemaPtr, arrayPtr}); +} + //===----------------------------------------------------------------------===// // MapOp. //===----------------------------------------------------------------------===// @@ -1545,6 +1844,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1565,6 +1865,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1586,6 +1887,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, @@ -1605,6 +1907,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder, // clang-format off ConstantStreamOp, FilterOp, + FromArrowArrayStreamOp, MapOp, ReduceOp, TabularViewToStreamOp, diff --git a/experimental/iterators/lib/Dialect/Iterators/IR/ArrowUtils.cpp b/experimental/iterators/lib/Dialect/Iterators/IR/ArrowUtils.cpp new file mode 100644 index 000000000000..a312a48cde4d --- /dev/null +++ b/experimental/iterators/lib/Dialect/Iterators/IR/ArrowUtils.cpp @@ -0,0 +1,115 @@ +//===-- ArrowUtils.cpp - IR utils related to Apache Arrow ------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace mlir { +namespace iterators { + +LLVMStructType getArrowArrayType(MLIRContext *context) { + LLVMStructType arrowArray = + LLVMStructType::getIdentified(context, "ArrowArray"); + if (arrowArray.isInitialized()) + return arrowArray; + + Type voidPtr = LLVMPointerType::get(context); + Type i64 = IntegerType::get(context, 64); + Type arrowArrayPtr = LLVMPointerType::get(arrowArray); + auto voidType = LLVMVoidType::get(context); + Type releaseFunc = LLVMFunctionType::get(voidType, arrowArrayPtr); + + ArrayRef body = { + i64 /*length*/, + i64 /*null_count*/, + i64 /*offset*/, + i64 /*n_buffers*/, + i64 /*n_children*/, + LLVMPointerType::get(voidPtr) /*buffers*/, + LLVMPointerType::get(arrowArrayPtr) /*children*/, + arrowArrayPtr /*dictionary*/, + LLVMPointerType::get(releaseFunc) /*release*/, + voidPtr /*private_data*/ + }; + + LogicalResult status = arrowArray.setBody(body, /*isPacked=*/false); + assert(succeeded(status) && "could not create ArrowArray struct"); + return arrowArray; +} + +LLVMStructType getArrowSchemaType(MLIRContext *context) { + auto arrowSchema = LLVMStructType::getIdentified(context, "ArrowSchema"); + if (arrowSchema.isInitialized()) + return arrowSchema; + + Type charPtr = LLVMPointerType::get(IntegerType::get(context, 8)); + Type voidPtr = LLVMPointerType::get(context); + Type i64 = IntegerType::get(context, 64); + Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); + auto voidType = LLVMVoidType::get(context); + Type releaseFunc = LLVMFunctionType::get(voidType, arrowSchemaPtr); + + ArrayRef body{ + charPtr /*format*/, + charPtr /*name*/, + charPtr /*metadata*/, + i64 /*flags*/, + i64 /*n_children*/, + LLVMPointerType::get(arrowSchemaPtr) /*children*/, + arrowSchemaPtr /*dictionary*/, + LLVMPointerType::get(releaseFunc) /*release*/, + voidPtr /*private_data*/ + }; + + LogicalResult status = arrowSchema.setBody(body, /*isPacked=*/false); + assert(succeeded(status) && "could not create ArrowSchema struct"); + return arrowSchema; +} + +LLVMStructType getArrowArrayStreamType(MLIRContext *context) { + auto arrowArrayStream = + LLVMStructType::getIdentified(context, "ArrowArrayStream"); + if (arrowArrayStream.isInitialized()) + return arrowArrayStream; + + Type voidPtr = LLVMPointerType::get(context); + Type charPtr = LLVMPointerType::get(IntegerType::get(context, 8)); + Type i32 = IntegerType::get(context, 32); + auto voidType = LLVMVoidType::get(context); + Type arrowArray = getArrowArrayType(context); + auto arrowSchema = getArrowSchemaType(context); + Type arrowArrayPtr = LLVMPointerType::get(arrowArray); + Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); + Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream); + + Type getSchemaFunc = + LLVMFunctionType::get(i32, {arrowArrayStreamPtr, arrowSchemaPtr}); + Type getNextFunc = + LLVMFunctionType::get(i32, {arrowArrayStreamPtr, arrowArrayPtr}); + Type getLastErrorFunc = LLVMFunctionType::get(charPtr, arrowArrayStreamPtr); + Type releaseFunc = LLVMFunctionType::get(voidType, arrowArrayStreamPtr); + + ArrayRef body{ + LLVMPointerType::get(getSchemaFunc) /*get_schema*/, + LLVMPointerType::get(getNextFunc) /*get_next*/, + LLVMPointerType::get(getLastErrorFunc) /*get_last_error*/, + LLVMPointerType::get(releaseFunc) /*release*/, voidPtr /*private_data*/ + }; + + LogicalResult status = arrowArrayStream.setBody(body, /*isPacked=*/false); + assert(succeeded(status) && "could not create ArrowArrayStream struct"); + return arrowArrayStream; +} + +} // namespace iterators +} // namespace mlir diff --git a/experimental/iterators/lib/Dialect/Iterators/IR/CMakeLists.txt b/experimental/iterators/lib/Dialect/Iterators/IR/CMakeLists.txt index 7253bf954e13..813025ac7694 100644 --- a/experimental/iterators/lib/Dialect/Iterators/IR/CMakeLists.txt +++ b/experimental/iterators/lib/Dialect/Iterators/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRIterators + ArrowUtils.cpp Iterators.cpp LINK_LIBS PUBLIC diff --git a/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp b/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp index 87a68d6881fa..1c78e80ff8c0 100644 --- a/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp +++ b/experimental/iterators/lib/Dialect/Iterators/IR/Iterators.cpp @@ -7,8 +7,9 @@ //===----------------------------------------------------------------------===// #include "iterators/Dialect/Iterators/IR/Iterators.h" -#include "iterators/Dialect/Tabular/IR/Tabular.h" +#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" +#include "iterators/Dialect/Tabular/IR/Tabular.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/TypeUtilities.h" diff --git a/experimental/iterators/lib/Runtime/Arrow.cpp b/experimental/iterators/lib/Runtime/Arrow.cpp new file mode 100644 index 000000000000..cabfc94566aa --- /dev/null +++ b/experimental/iterators/lib/Runtime/Arrow.cpp @@ -0,0 +1,195 @@ +//===-- Arrow.cpp - Runtime for handling Apache Arrow data ------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "iterators-c/Runtime/Arrow.h" +#include "iterators-c/Runtime/ArrowInterfaces.h" + +#include +#include +#include + +#include +#include +#include + +using namespace std::string_literals; + +template +static void callReleaseCallbackIfUnreleased(T *ptr) { + assert(ptr && "need to provide non-null pointer"); + if (ptr->release == nullptr) + return; + ptr->release(ptr); + assert(ptr->release == nullptr && + "struct not marked as release after calling release"); +} + +int64_t mlirIteratorsArrowArrayGetSize(ArrowArray *array) { + assert(array && "need to provide non-null pointer"); + assert(array->release != nullptr && + "provided record batch has been released already"); + return array->length; +} + +template < + typename DataType, const char *kTypeFormat, + typename BufferPointerType = std::add_pointer_t>> +static BufferPointerType getColumnImpl(ArrowArray *array, ArrowSchema *schema, + int64_t i) { + assert(array && "need to provide non-null pointer"); + assert(array->release != nullptr && + "provided record batch has been released already"); + assert(array->n_buffers <= 1 && + "unexpected number of buffers for struct type"); + assert(schema && "need to provide non-null pointer"); + assert(schema->n_children == array->n_children && + "mismatch between provided array and schema"); + assert(schema->format == "+s"s && "only struct arrays are supported"); + assert((schema->flags & ARROW_FLAG_NULLABLE) == 0 && + "nullable fields are unsupported"); + assert(i < schema->n_children && "attempt to access non-existing column"); + + ArrowArray *childArray = array->children[i]; + ArrowSchema *childSchema = schema->children[i]; + + assert(childSchema->n_children == childArray->n_children && + "mismatch between provided array and schema"); + assert(childSchema->n_children == 0 && "nested structs not supported"); + assert(childSchema->format == std::string_view(kTypeFormat) && + "attempt to access column with wrong type"); + assert((childSchema->flags & ARROW_FLAG_NULLABLE) == 0 && + "nullable fields are unsupported"); + assert(childArray->n_buffers == 2 && "unsupported number of buffers"); + assert(childArray->buffers[0] == nullptr && "nullable types not supported"); + assert(childArray->offset == 0 && "offset unsupported"); + + return reinterpret_cast(childArray->buffers[1]); +} + +const int8_t *mlirIteratorsArrowArrayGetInt8Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "c"; + return getColumnImpl(array, schema, i); +} + +const uint8_t *mlirIteratorsArrowArrayGetUInt8Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "C"; + return getColumnImpl(array, schema, i); +} + +const int16_t *mlirIteratorsArrowArrayGetInt16Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "s"; + return getColumnImpl(array, schema, i); +} + +const uint16_t *mlirIteratorsArrowArrayGetUInt16Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "S"; + return getColumnImpl(array, schema, i); +} + +const int32_t *mlirIteratorsArrowArrayGetInt32Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "i"; + return getColumnImpl(array, schema, i); +} + +const uint32_t *mlirIteratorsArrowArrayGetUInt32Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "I"; + return getColumnImpl(array, schema, i); +} + +const int64_t *mlirIteratorsArrowArrayGetInt64Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "l"; + return getColumnImpl(array, schema, i); +} + +const uint64_t *mlirIteratorsArrowArrayGetUInt64Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "L"; + return getColumnImpl(array, schema, i); +} + +const uint16_t *mlirIteratorsArrowArrayGetFloat16Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "e"; + return getColumnImpl(array, schema, + i); +} + +const float *mlirIteratorsArrowArrayGetFloat32Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "f"; + return getColumnImpl(array, schema, i); +} + +const double *mlirIteratorsArrowArrayGetFloat64Column(ArrowArray *array, + ArrowSchema *schema, + int64_t i) { + static constexpr const char kFormatString[] = "g"; + return getColumnImpl(array, schema, i); +} + +void mlirIteratorsArrowArrayRelease(ArrowArray *schema) { + callReleaseCallbackIfUnreleased(schema); +} + +void mlirIteratorsArrowSchemaRelease(ArrowSchema *schema) { + callReleaseCallbackIfUnreleased(schema); +} + +static void handleError(ArrowArrayStream *stream, int errorCode) { + const char *errorMessage = stream->get_last_error(stream); + if (!errorMessage) + errorMessage = strerror(errorCode); + std::cerr << "Error while getting next record batch: " << errorMessage + << std::endl; + std::exit(1); +} + +bool mlirIteratorsArrowArrayStreamGetNext(ArrowArrayStream *stream, + ArrowArray *result) { + assert(stream && "need to provide non-null pointer"); + assert(result && "need to provide non-null pointer"); + assert(result->release == nullptr && + "provided result pointer still owned memory"); + + if (int errorCode = stream->get_next(stream, result)) + handleError(stream, errorCode); + + return result->release; +} + +void mlirIteratorsArrowArrayStreamGetSchema(ArrowArrayStream *stream, + ArrowSchema *result) { + assert(stream && "need to provide non-null pointer"); + assert(result && "need to provide non-null pointer"); + assert(result->release == nullptr && + "provided result pointer still owned memory"); + + if (int errorCode = stream->get_schema(stream, result)) + handleError(stream, errorCode); +} + +void mlirIteratorsArrowArrayStreamRelease(ArrowArrayStream *stream) { + callReleaseCallbackIfUnreleased(stream); +} diff --git a/experimental/iterators/lib/Runtime/CMakeLists.txt b/experimental/iterators/lib/Runtime/CMakeLists.txt new file mode 100644 index 000000000000..47f0f22ae842 --- /dev/null +++ b/experimental/iterators/lib/Runtime/CMakeLists.txt @@ -0,0 +1,24 @@ +add_mlir_public_c_api_library(IteratorsRuntime + Arrow.cpp + + LINK_LIBS PUBLIC +) + +# Determine full path to output name of the built library. +get_target_property(ITERATORS_RUNTIME_LIBRARY_OUTPUT_NAME IteratorsRuntime LIBRARY_OUTPUT_NAME) +get_target_property(ITERATORS_RUNTIME_OUTPUT_NAME IteratorsRuntime OUTPUT_NAME) +if ("${ITERATORS_RUNTIME_LIBRARY_OUTPUT_NAME}") + set(ITERATORS_RUNTIME_LIBRARY_FILENAME "${ITERATORS_RUNTIME_LIBRARY_OUTPUT_NAME}") +elseif("${ITERATORS_RUNTIME_OUTPUT_NAME}") + set(ITERATORS_RUNTIME_LIBRARY_FILENAME "${ITERATORS_RUNTIME_OUTPUT_NAME}") +else() + get_target_property(ITERATORS_RUNTIME_NAME IteratorsRuntime NAME) + set(ITERATORS_RUNTIME_LIBRARY_FILENAME + "${CMAKE_SHARED_LIBRARY_PREFIX}${ITERATORS_RUNTIME_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX}") +endif() +get_target_property(ITERATORS_RUNTIME_LIBRARY_OUTPUT_DIRECTORY IteratorsRuntime LIBRARY_OUTPUT_DIRECTORY) +cmake_path(APPEND ITERATORS_RUNTIME_LIBRARY_PATH + "${ITERATORS_RUNTIME_LIBRARY_OUTPUT_DIRECTORY}" "${ITERATORS_RUNTIME_LIBRARY_FILENAME}") +set(ITERATORS_RUNTIME_LIBRARY_PATH "${ITERATORS_RUNTIME_LIBRARY_PATH}" + CACHE INTERNAL "Full output path of the Iterators runtime library") +message("-- Determined full path to Iterators runtime lib: ${ITERATORS_RUNTIME_LIBRARY_PATH}") diff --git a/experimental/iterators/requirements.txt b/experimental/iterators/requirements.txt index 58aa9c544eee..f57609947808 100644 --- a/experimental/iterators/requirements.txt +++ b/experimental/iterators/requirements.txt @@ -2,7 +2,9 @@ -r third_party/llvm-project/mlir/python/requirements.txt # Testing. +cffi lit +pyarrow # Plotting. pandas diff --git a/experimental/iterators/test/CMakeLists.txt b/experimental/iterators/test/CMakeLists.txt index c979616d384f..8334c9f188c4 100644 --- a/experimental/iterators/test/CMakeLists.txt +++ b/experimental/iterators/test/CMakeLists.txt @@ -16,6 +16,7 @@ set(ITERATORS_TEST_DEPENDS count FileCheck iterators-opt + IteratorsRuntime mlir-cpu-runner mlir_c_runner_utils mlir_runner_utils diff --git a/experimental/iterators/test/Conversion/IteratorsToLLVM/from-arrow-stream.mlir b/experimental/iterators/test/Conversion/IteratorsToLLVM/from-arrow-stream.mlir new file mode 100644 index 000000000000..4744eb9fdd1b --- /dev/null +++ b/experimental/iterators/test/Conversion/IteratorsToLLVM/from-arrow-stream.mlir @@ -0,0 +1,106 @@ +// RUN: iterators-opt %s -convert-iterators-to-llvm \ +// RUN: | FileCheck --enable-var-scope %s +!arrow_schema = !llvm.struct<"ArrowSchema", ( + ptr, // format + ptr, // name + ptr, // metadata + i64, // flags + i64, // n_children + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> +!arrow_array = !llvm.struct<"ArrowArray", ( + i64, // length + i64, // null_count + i64, // offset + i64, // n_buffers + i64, // n_children + ptr, // buffers + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> +!arrow_array_stream = !llvm.struct<"ArrowArrayStream", ( + ptr>, ptr)>>, // get_schema + ptr>, ptr)>>, // get_next + ptr (ptr>)>>, // get_last_error + ptr>)>>, // release + ptr // private_data + )> + +// CHECK-LABEL: func.func private @iterators.from_arrow_array_stream.close.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state, !llvm.ptr<[[SCHEMATYPE:.*]]>, !llvm.ptr<[[ARRAYTYPE:.*]]>>) -> +// CHECK-SAME: !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> { +// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][1] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[ARG0]][2] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: llvm.call @mlirIteratorsArrowArrayStreamRelease(%[[V0]]) : (!llvm.ptr<[[STREAMTYPE]]>) -> () +// CHECK-NEXT: llvm.call @mlirIteratorsArrowSchemaRelease(%[[V1]]) : (!llvm.ptr<[[SCHEMATYPE]]>) -> () +// CHECK-NEXT: llvm.call @mlirIteratorsArrowArrayRelease(%[[V2]]) : (!llvm.ptr<[[ARRAYTYPE]]>) -> () +// CHECK-NEXT: return %[[ARG0]] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> + +// CHECK-LABEL: func.func private @iterators.from_arrow_array_stream.next.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state, !llvm.ptr<[[SCHEMATYPE:.*]]>, !llvm.ptr<[[ARRAYTYPE:.*]]>>) -> +// CHECK-SAME: (!iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>>, i1, !llvm.struct<(i64, ptr)>) { +// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][1] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[ARG0]][2] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: llvm.call @mlirIteratorsArrowArrayRelease(%[[V2]]) : (!llvm.ptr<[[ARRAYTYPE]]>) -> () +// CHECK-NEXT: %[[V3:.*]] = llvm.call @mlirIteratorsArrowArrayStreamGetNext(%[[V0]], %[[V2]]) : (!llvm.ptr<[[STREAMTYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>) -> i1 +// CHECK-NEXT: %[[V4:.*]] = arith.constant 0 : i64 +// CHECK-NEXT: %[[V5:.*]] = scf.if %[[V3]] -> (i64) { +// CHECK-NEXT: %[[V6:.*]] = llvm.call @mlirIteratorsArrowArrayGetSize(%[[V2]]) : (!llvm.ptr<[[ARRAYTYPE]]>) -> i64 +// CHECK-NEXT: scf.yield %[[V6]] : i64 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[V4]] : i64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[V7:.*]]:2 = scf.if %[[V3]] -> (!llvm.ptr, i64) { +// CHECK-NEXT: %[[V8:.*]] = arith.constant 0 : i64 +// CHECK-NEXT: %[[V9:.*]] = llvm.call @mlirIteratorsArrowArrayGetInt32Column(%[[V2]], %[[V1]], %[[V8]]) : (!llvm.ptr<[[ARRAYTYPE]]>, !llvm.ptr<[[SCHEMATYPE]]>, i64) -> !llvm.ptr +// CHECK-NEXT: %[[Va:.*]] = llvm.bitcast %[[V9]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: scf.yield %[[Va]], %[[V5]] : !llvm.ptr, i64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[Va:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: scf.yield %[[Va]], %[[V4]] : !llvm.ptr, i64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[Vb:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vc:.*]] = llvm.insertvalue %[[V7]]#0, %[[Vb]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vd:.*]] = llvm.insertvalue %[[V7]]#0, %[[Vc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Ve:.*]] = llvm.insertvalue %[[V4]], %[[Vd]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vf:.*]] = llvm.insertvalue %[[V7]]#1, %[[Ve]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vg:.*]] = llvm.insertvalue %[[V7]]#1, %[[Vf]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[Vh:.*]] = builtin.unrealized_conversion_cast %[[Vg]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref +// CHECK-NEXT: %[[Vi:.*]] = tabular.view_as_tabular %[[Vh]] : (memref) -> !tabular.tabular_view +// CHECK-NEXT: %[[Vj:.*]] = builtin.unrealized_conversion_cast %[[Vi]] : !tabular.tabular_view to !llvm.struct<(i64, ptr)> +// CHECK-NEXT: return %[[ARG0]], %[[V3]], %[[Vj]] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>>, i1, !llvm.struct<(i64, ptr)> + +// CHECK-LABEL: func.func private @iterators.from_arrow_array_stream.open.{{[0-9]+}}( +// CHECK-SAME: %[[ARG0:.*]]: !iterators.state, !llvm.ptr<[[SCHEMATYPE:.*]]>, !llvm.ptr<[[ARRAYTYPE:.*]]>>) -> +// CHECK-SAME: !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> { +// CHECK-NEXT: %[[V0:.*]] = iterators.extractvalue %[[ARG0]][0] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[ARG0]][1] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: llvm.call @mlirIteratorsArrowArrayStreamGetSchema(%[[V0]], %[[V1]]) : (!llvm.ptr<[[STREAMTYPE]]>, !llvm.ptr<[[SCHEMATYPE]]>) -> i32 +// CHECK-NEXT: return %[[ARG0]] : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> + +// CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<[[STREAMTYPE:.*]]>) +// CHECK-NEXT: %[[V0:.*]] = arith.constant 1 : i64 +// CHECK-NEXT: %[[V1:.*]] = llvm.alloca %[[V0]] x !llvm.[[ARRAYTYPE:.*]] : (i64) -> +// CHECK-SAME: !llvm.ptr<[[ARRAYTYPE]]> +// CHECK-NEXT: %[[V2:.*]] = llvm.alloca %[[V0]] x !llvm.[[SCHEMATYPE:.*]] : (i64) -> +// CHECK-SAME: !llvm.ptr<[[SCHEMATYPE]]> +// CHECK-NEXT: %[[V3:.*]] = arith.constant 0 : i8 +// CHECK-NEXT: %[[V4:.*]] = arith.constant false +// CHECK-NEXT: %[[V5:.*]] = arith.constant 80 : i64 +// CHECK-NEXT: %[[V6:.*]] = arith.constant 72 : i64 +// CHECK-NEXT: "llvm.intr.memset"(%[[V1]], %[[V3]], %[[V5]], %[[V4]]) : (!llvm.ptr<[[ARRAYTYPE]]>, i8, i64, i1) -> () +// CHECK-NEXT: "llvm.intr.memset"(%[[V2]], %[[V3]], %[[V6]], %[[V4]]) : (!llvm.ptr<[[SCHEMATYPE]]>, i8, i64, i1) -> () +// CHECK-NEXT: %[[V7:.*]] = iterators.createstate(%[[ARG0]], %[[V2]], %[[V1]]) : !iterators.state, !llvm.ptr<[[SCHEMATYPE]]>, !llvm.ptr<[[ARRAYTYPE]]>> +// CHECK-NEXT: return +func.func @main(%arrow_stream: !llvm.ptr) { + %tabular_view_stream = iterators.from_arrow_array_stream %arrow_stream to !iterators.stream> + return +} diff --git a/experimental/iterators/test/Dialect/Iterators/from-arrow-stream.mlir b/experimental/iterators/test/Dialect/Iterators/from-arrow-stream.mlir new file mode 100644 index 000000000000..4893ed4559c7 --- /dev/null +++ b/experimental/iterators/test/Dialect/Iterators/from-arrow-stream.mlir @@ -0,0 +1,43 @@ +// RUN: iterators-opt %s \ +// RUN: | FileCheck %s + +!arrow_schema = !llvm.struct<"ArrowSchema", ( + ptr, // format + ptr, // name + ptr, // metadata + i64, // flags + i64, // n_children + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> +!arrow_array = !llvm.struct<"ArrowArray", ( + i64, // length + i64, // null_count + i64, // offset + i64, // n_buffers + i64, // n_children + ptr, // buffers + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> +!arrow_array_stream = !llvm.struct<"ArrowArrayStream", ( + ptr>, ptr)>>, // get_schema + ptr>, ptr)>>, // get_next + ptr (ptr>)>>, // get_last_error + ptr>)>>, // release + ptr // private_data + )> + +// CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<[[STREAMTYPE:.*]]>) -> +// CHECK-SAME: !iterators.stream> { +// CHECK-NEXT: %[[V0:fromarrowstream.*]] = iterators.from_arrow_array_stream %[[ARG0]] to !iterators.stream> +// CHECK-NEXT: return %[[V0]] : !iterators.stream> +func.func @main(%arrow_stream: !llvm.ptr) -> !iterators.stream> { + %tabular_view_stream = iterators.from_arrow_array_stream %arrow_stream to !iterators.stream> + return %tabular_view_stream : !iterators.stream> +} diff --git a/experimental/iterators/test/Runtime/arrow.py b/experimental/iterators/test/Runtime/arrow.py new file mode 100644 index 000000000000..2bffeabe8961 --- /dev/null +++ b/experimental/iterators/test/Runtime/arrow.py @@ -0,0 +1,215 @@ +# RUN: %PYTHON %s | FileCheck %s + +import os + +import numpy as np +import pyarrow as pa +import pyarrow.cffi + +# Define type short-hands. +i8 = pa.int8() +u8 = pa.uint8() +i16 = pa.int16() +u16 = pa.uint16() +i32 = pa.int32() +u32 = pa.uint32() +i64 = pa.int64() +u64 = pa.uint64() +f16 = pa.float16() +f32 = pa.float32() +f64 = pa.float64() + +#===------------------------------------------------------------------------===# +# Load runtime lib as CFFI lib. +#===------------------------------------------------------------------------===# +# Reuse ffi object of pyarrow, which has definitions of the C interface types. +ffi = pa.cffi.ffi + +# Load our functions. +GET_COLUMN_FUNCS_NAMES = { + i8: ('signed char', 'mlirIteratorsArrowArrayGetInt8Column'), + u8: ('unsigned char', 'mlirIteratorsArrowArrayGetUInt8Column'), + i16: ('signed short', 'mlirIteratorsArrowArrayGetInt16Column'), + u16: ('unsigned short', 'mlirIteratorsArrowArrayGetUInt16Column'), + i32: ('signed int', 'mlirIteratorsArrowArrayGetInt32Column'), + u32: ('unsigned int', 'mlirIteratorsArrowArrayGetUInt32Column'), + i64: ('signed long', 'mlirIteratorsArrowArrayGetInt64Column'), + u64: ('unsigned long', 'mlirIteratorsArrowArrayGetUInt64Column'), + f16: ('unsigned short', 'mlirIteratorsArrowArrayGetFloat16Column'), + f32: ('float', 'mlirIteratorsArrowArrayGetFloat32Column'), + f64: ('double', 'mlirIteratorsArrowArrayGetFloat64Column'), +} + +for pointee_type, func_name in GET_COLUMN_FUNCS_NAMES.values(): + ffi.cdef('const {} * {}(struct ArrowArray *array,' + ' struct ArrowSchema *schema,' + ' long long i);'.format(pointee_type, func_name)) + +ffi.cdef(''' + long long mlirIteratorsArrowArrayGetSize(struct ArrowArray *array); + void mlirIteratorsArrowArrayRelease(struct ArrowArray *array); + void mlirIteratorsArrowSchemaRelease(struct ArrowSchema *schema); + bool mlirIteratorsArrowArrayStreamGetNext(struct ArrowArrayStream *stream, + struct ArrowArray *result); + void mlirIteratorsArrowArrayStreamGetSchema(struct ArrowArrayStream *stream, + struct ArrowSchema *result); + void mlirIteratorsArrowArrayStreamRelease(struct ArrowArrayStream *stream); + ''') + +# Dlopen our library. +runtime_lib_path = os.environ['ITERATORS_RUNTIME_LIBRARY_PATH'] +lib = ffi.dlopen(runtime_lib_path) + +#===------------------------------------------------------------------------===# +# Set up test data. +#===------------------------------------------------------------------------===# +# Define schema and table. +fields = [pa.field(str(t), t, False) for t in GET_COLUMN_FUNCS_NAMES.keys()] +schema = pa.schema(fields) + +# CHECK-LABEL: schema: +# CHECK-NEXT: int8: int8 not null +# CHECK-NEXT: uint8: uint8 not null +# CHECK-NEXT: int16: int16 not null +# CHECK-NEXT: uint16: uint16 not null +# CHECK-NEXT: int32: int32 not null +# CHECK-NEXT: uint32: uint32 not null +# CHECK-NEXT: int64: int64 not null +# CHECK-NEXT: uint64: uint64 not null +# CHECK-NEXT: halffloat: halffloat not null +# CHECK-NEXT: float: float not null +# CHECK-NEXT: double: double not null + +print("schema:") +print(schema) + +arrays = [ + pa.array(np.array(np.arange(10) + 100 * i, t.to_pandas_dtype())) + for i, t in enumerate(GET_COLUMN_FUNCS_NAMES.keys()) +] +table = pa.table(arrays, schema) +batch = table.to_batches()[0] + +# CHECK-LABEL: original batch: +# CHECK-NEXT: [0 1 2 3 4 5 6 7 8 9] +# CHECK-NEXT: [100 101 102 103 104 105 106 107 108 109] +# CHECK-NEXT: [200 201 202 203 204 205 206 207 208 209] +# CHECK-NEXT: [300 301 302 303 304 305 306 307 308 309] +# CHECK-NEXT: [400 401 402 403 404 405 406 407 408 409] +# CHECK-NEXT: [500 501 502 503 504 505 506 507 508 509] +# CHECK-NEXT: [600 601 602 603 604 605 606 607 608 609] +# CHECK-NEXT: [700 701 702 703 704 705 706 707 708 709] +# CHECK-NEXT: [800. 801. 802. 803. 804. 805. 806. 807. 808. 809.] +# CHECK-NEXT: [900. 901. 902. 903. 904. 905. 906. 907. 908. 909.] +# CHECK-NEXT: [1000. 1001. 1002. 1003. 1004. 1005. 1006. 1007. 1008. 1009.] + +print("original batch:") +for c in batch.columns: + print(c.to_numpy()) + +#===------------------------------------------------------------------------===# +# Test C data interface. +#===------------------------------------------------------------------------===# +# Create C struct describing the batch. +cffi_batch = ffi.new('struct ArrowArray *') +batch._export_to_c(int(ffi.cast("intptr_t", cffi_batch))) + +# Create C struct describing schema. +cffi_schema = ffi.new('struct ArrowSchema *') +schema._export_to_c(int(ffi.cast("intptr_t", cffi_schema))) + +# Test function returning batch size. +batch_size = lib.mlirIteratorsArrowArrayGetSize(cffi_batch) + +# CHECK-LABEL: retrieved batch size: +# CHECK-NEXT: 10 + +print("retrieved batch size:") +print(batch_size) + +# CHECK-LABEL: retrieved batch: +# CHECK-NEXT: [0 1 2 3 4 5 6 7 8 9] +# CHECK-NEXT: [100 101 102 103 104 105 106 107 108 109] +# CHECK-NEXT: [200 201 202 203 204 205 206 207 208 209] +# CHECK-NEXT: [300 301 302 303 304 305 306 307 308 309] +# CHECK-NEXT: [400 401 402 403 404 405 406 407 408 409] +# CHECK-NEXT: [500 501 502 503 504 505 506 507 508 509] +# CHECK-NEXT: [600 601 602 603 604 605 606 607 608 609] +# CHECK-NEXT: [700 701 702 703 704 705 706 707 708 709] +# CHECK-NEXT: [800. 801. 802. 803. 804. 805. 806. 807. 808. 809.] +# CHECK-NEXT: [900. 901. 902. 903. 904. 905. 906. 907. 908. 909.] +# CHECK-NEXT: [1000. 1001. 1002. 1003. 1004. 1005. 1006. 1007. 1008. 1009.] + +# Test functions accessing columns from the batch. +print("retrieved batch:") +for i, (type, (_, func_name)) in enumerate(GET_COLUMN_FUNCS_NAMES.items()): + func = lib.__getattr__(func_name) + + # Call function, which returns a pointer. + ptr = func(cffi_batch, cffi_schema, i) + + # Wrap the pointer into a buffer, convert that into a type numpy array. + buffer = ffi.buffer(ptr, batch_size * type.bit_width // 8) + array = np.frombuffer(buffer, dtype=type.to_pandas_dtype()) + + print(array) + +# Release memory owned by the C structs. +lib.mlirIteratorsArrowArrayRelease(cffi_batch) +lib.mlirIteratorsArrowSchemaRelease(cffi_schema) + +#===------------------------------------------------------------------------===# +# Test C stream interface. +#===------------------------------------------------------------------------===# +reader = pa.RecordBatchReader.from_batches(schema, + table.to_batches(max_chunksize=5)) + +# Create C struct describing record batch reader. +cffi_stream = ffi.new('struct ArrowArrayStream *') +reader._export_to_c(int(ffi.cast("intptr_t", cffi_stream))) + +# Get schema and import it into pyarrow. +lib.mlirIteratorsArrowArrayStreamGetSchema(cffi_stream, cffi_schema) +schema = pa.Schema._import_from_c(int(ffi.cast("intptr_t", cffi_schema))) + +# CHECK-LABEL: schema from stream: +# CHECK-NEXT: int8: int8 not null +# CHECK-NEXT: uint8: uint8 not null +# CHECK-NEXT: int16: int16 not null +# CHECK-NEXT: uint16: uint16 not null +# CHECK-NEXT: int32: int32 not null +# CHECK-NEXT: uint32: uint32 not null +# CHECK-NEXT: int64: int64 not null +# CHECK-NEXT: uint64: uint64 not null +# CHECK-NEXT: halffloat: halffloat not null +# CHECK-NEXT: float: float not null +# CHECK-NEXT: double: double not null + +print("schema from stream:") +print(schema) + +# CHECK-LABEL: batches: +# CHECK-NEXT: int8 uint8 int16 uint16 int32 uint32 int64 uint64 halffloat float double +# CHECK-NEXT: 0 0 100 200 300 400 500 600 700 800.0 900.0 1000.0 +# CHECK-NEXT: 1 1 101 201 301 401 501 601 701 801.0 901.0 1001.0 +# CHECK-NEXT: 2 2 102 202 302 402 502 602 702 802.0 902.0 1002.0 +# CHECK-NEXT: 3 3 103 203 303 403 503 603 703 803.0 903.0 1003.0 +# CHECK-NEXT: 4 4 104 204 304 404 504 604 704 804.0 904.0 1004.0 +# CHECK-NEXT: int8 uint8 int16 uint16 int32 uint32 int64 uint64 halffloat float double +# CHECK-NEXT: 0 5 105 205 305 405 505 605 705 805.0 905.0 1005.0 +# CHECK-NEXT: 1 6 106 206 306 406 506 606 706 806.0 906.0 1006.0 +# CHECK-NEXT: 2 7 107 207 307 407 507 607 707 807.0 907.0 1007.0 +# CHECK-NEXT: 3 8 108 208 308 408 508 608 708 808.0 908.0 1008.0 +# CHECK-NEXT: 4 9 109 209 309 409 509 609 709 809.0 909.0 1009.0 + +# Iterate over batches provided by stream and print. +print("batches:") +while lib.mlirIteratorsArrowArrayStreamGetNext(cffi_stream, cffi_batch): + batch = pa.RecordBatch._import_from_c(int(ffi.cast("intptr_t", cffi_batch)), + schema) + print(batch.to_pandas().to_string()) + +# Release memory owned by the C structs. +lib.mlirIteratorsArrowArrayRelease(cffi_batch) +lib.mlirIteratorsArrowSchemaRelease(cffi_schema) +lib.mlirIteratorsArrowArrayStreamRelease(cffi_stream) diff --git a/experimental/iterators/test/lit.cfg.py b/experimental/iterators/test/lit.cfg.py index a96469ad31ae..d4696bc29a31 100644 --- a/experimental/iterators/test/lit.cfg.py +++ b/experimental/iterators/test/lit.cfg.py @@ -56,6 +56,10 @@ ToolSubst('%mlir_lib_dir', config.mlir_lib_dir), ] +# Add the full path of the Iterators runtime lib to the environment. +config.environment[ + 'ITERATORS_RUNTIME_LIBRARY_PATH'] = config.iterators_runtime_lib_path + # Pass through LLVM_SYMBOLIZER_PATH from environment if "LLVM_SYMBOLIZER_PATH" in os.environ: config.environment["LLVM_SYMBOLIZER_PATH"] = \ diff --git a/experimental/iterators/test/lit.site.cfg.py.in b/experimental/iterators/test/lit.site.cfg.py.in index df9241d2685f..9aed14f69a1b 100644 --- a/experimental/iterators/test/lit.site.cfg.py.in +++ b/experimental/iterators/test/lit.site.cfg.py.in @@ -6,6 +6,7 @@ config.mlir_obj_dir = "@MLIR_BINARY_DIR@" config.mlir_lib_dir = "@MLIR_LIB_DIR@" config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ config.iterators_build_root = "@ITERATORS_BINARY_DIR@" +config.iterators_runtime_lib_path = "@ITERATORS_RUNTIME_LIBRARY_PATH@" import lit.llvm lit.llvm.initialize(lit_config, config) diff --git a/experimental/iterators/test/python/dialects/iterators/arrow.py b/experimental/iterators/test/python/dialects/iterators/arrow.py new file mode 100644 index 000000000000..e52a224559ed --- /dev/null +++ b/experimental/iterators/test/python/dialects/iterators/arrow.py @@ -0,0 +1,345 @@ +# RUN: %PYTHON %s | FileCheck %s + +import argparse +import ctypes +import io +import logging +import os +import sys +import time + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.cffi +import pyarrow.parquet + +from mlir_iterators.dialects import iterators as it +from mlir_iterators.dialects import tabular as tab +from mlir_iterators.dialects import tuple as tup +from mlir_iterators.passmanager import PassManager +from mlir_iterators.execution_engine import ExecutionEngine +from mlir_iterators.ir import Context, Module + +# Set up logging. +LOGLEVELS = { + logging.getLevelName(l): l + for l in (logging.CRITICAL, logging.ERROR, logging.WARNING, logging.INFO, + logging.DEBUG) +} + +# Parse command line arguments for interactive testing/debugging. +parser = argparse.ArgumentParser( + description='Integration tests for iterators related to Apache Arrow.') +parser.add_argument('--log-level', + type=LOGLEVELS.__getitem__, + default=logging.ERROR, + help='Set the log level by name') +parser.add_argument('--enable-ir-printing', + action='store_true', + help='Enable printing IR after every pass') +args = parser.parse_args() + +logging.getLogger().setLevel(args.log_level) + + +def format_code(code: str) -> str: + return '\n'.join( + (f'{i:>4}: {l}' for i, l in enumerate(str(code).splitlines()))) + + +def run(f): + print("\nTEST:", f.__name__) + with Context(): + it.register_dialect() + tab.register_dialect() + tup.register_dialect() + f() + return f + + +# MLIR definitions of the C structs of the Arrow ABI. +ARROW_STRUCT_DEFINITIONS_MLIR = ''' + !arrow_schema = !llvm.struct<"ArrowSchema", ( + ptr, // format + ptr, // name + ptr, // metadata + i64, // flags + i64, // n_children + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> + !arrow_array = !llvm.struct<"ArrowArray", ( + i64, // length + i64, // null_count + i64, // offset + i64, // n_buffers + i64, // n_children + ptr, // buffers + ptr>>, // children + ptr>, // dictionary + ptr>)>>, // release + ptr // private_data + )> + !arrow_array_stream = !llvm.struct<"ArrowArrayStream", ( + ptr>, ptr)>>, // get_schema + ptr>, ptr)>>, // get_next + ptr (ptr>)>>, // get_last_error + ptr>)>>, // release + ptr // private_data + )> + ''' + +# Arrow data types that are currently supported. +ARROW_SUPPORTED_TYPES = [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.float16(), + pa.float32(), + pa.float64() +] + + +# Converts the given Arrow type to the name of the corresponding MLIR type. +def to_mlir_type(t: pa.DataType) -> str: + if pa.types.is_signed_integer(t): + return 'i' + str(t.bit_width) + if pa.types.is_floating(t): + return 'f' + str(t.bit_width) + raise NotImplementedError("Only floats and signed integers supported") + + +# Compiles the given code and wraps it into an execution engine. +def build_and_create_engine(code: str) -> ExecutionEngine: + # Assemble, log, and parse input IR. + code = ARROW_STRUCT_DEFINITIONS_MLIR + code + logging.info("Input IR:\n\n%s\n", format_code(code)) + mod = Module.parse(code) + + # Assemble and log pass pipeline. + pm = PassManager.parse('builtin.module(' + 'convert-iterators-to-llvm,' + 'convert-tabular-to-llvm,' + 'decompose-iterator-states,' + 'decompose-tuples,' + 'one-shot-bufferize,' + 'canonicalize,cse,' + 'expand-strided-metadata,' + 'finalize-memref-to-llvm,' + 'canonicalize,cse,' + 'convert-func-to-llvm,' + 'reconcile-unrealized-casts,' + 'convert-scf-to-cf,' + 'convert-cf-to-llvm)') + logging.info("Pass pipeline:\n\n%s\n", pm) + + # Enable printing of intermediate IR if requested. + if args.enable_ir_printing: + mod.context.enable_multithreading(False) + pm.enable_ir_printing() + + # Run pipeline. + pm.run(mod.operation) + + # Create and return engine. + runtime_lib = os.environ['ITERATORS_RUNTIME_LIBRARY_PATH'] + engine = ExecutionEngine(mod, shared_libs=[runtime_lib]) + return engine + + +# Generate MLIR that reads the arrays of an Arrow array stream and produces (and +# prints) the element-wise sum of each array. +def generate_sum_batches_elementwise_code(schema: pa.Schema) -> str: + mlir_types = [to_mlir_type(t) for t in schema.types] + + # Generate code that, for each type, extracts rhs and lhs struct values, adds + # them, and then inserts the result into a result struct. + elementwise_sum = f''' + %lhsvals:{len(mlir_types)} = tuple.to_elements %lhs : !tuple_type + %rhsvals:{len(mlir_types)} = tuple.to_elements %rhs : !tuple_type + ''' + for i, t in enumerate(mlir_types): + elementwise_sum += f''' + %sum{i} = arith.add{t[0]} %lhsvals#{i}, %rhsvals#{i} : {t} + ''' + result_vars = ', '.join((f'%sum{i}' for i in range(len(mlir_types)))) + elementwise_sum += f''' + %result = tuple.from_elements {result_vars} : !tuple_type + ''' + + # Adapt main program to types of the given schema. + code = f''' + !tuple_type = tuple<{', '.join(mlir_types)}> + !tabular_view_type = !tabular.tabular_view<{', '.join(mlir_types)}> + + // Add numbers of two structs element-wise. + func.func private @sum_struct(%lhs : !tuple_type, %rhs : !tuple_type) -> !tuple_type {{ + {elementwise_sum} + return %result : !tuple_type + }} + + // Consume the given tabular view and produce one element-wise sum from the elements. + func.func @sum_tabular_view(%tabular_view: !tabular_view_type) -> !tuple_type {{ + %tabular_stream = iterators.tabular_view_to_stream %tabular_view + to !iterators.stream + %reduced = "iterators.reduce"(%tabular_stream) {{reduceFuncRef = @sum_struct}} + : (!iterators.stream) -> (!iterators.stream) + %result:2 = iterators.stream_to_value %reduced : !iterators.stream + return %result#0 : !tuple_type + }} + + // For each Arrow array in the input stream, produce an element-wise sum. + func.func @main(%arrow_stream: !llvm.ptr) + attributes {{ llvm.emit_c_interface }} {{ + %tabular_view_stream = iterators.from_arrow_array_stream %arrow_stream + to !iterators.stream + %sums = "iterators.map"(%tabular_view_stream) {{mapFuncRef = @sum_tabular_view}} + : (!iterators.stream) -> (!iterators.stream) + "iterators.sink"(%sums) : (!iterators.stream) -> () + return + }} + ''' + + return code + + +# Feeds the given Arrow array stream/record batch reader into an MLIR kernel +# that reads the arrays the stream and produces (and prints) the element-wise +# sum of each array/record batch. +def sum_batches_elementwise_with_iterators( + record_batch_reader: pa.RecordBatchReader) -> None: + + code = generate_sum_batches_elementwise_code(record_batch_reader.schema) + engine = build_and_create_engine(code) + + # Create C struct describing the record batch reader. + ffi = pa.cffi.ffi + cffi_stream = ffi.new('struct ArrowArrayStream *') + cffi_stream_ptr = int(ffi.cast("intptr_t", cffi_stream)) + record_batch_reader._export_to_c(cffi_stream_ptr) + + # Wrap argument and invoke compiled function. + arg = ctypes.pointer(ctypes.cast(cffi_stream_ptr, ctypes.c_void_p)) + engine.invoke('main', arg) + + +# Create a sample Arrow table with one column per supported type. +def create_test_input() -> pa.Table: + # Use pyarrow to create an Arrow table in memory. + fields = [pa.field(str(t), t, False) for t in ARROW_SUPPORTED_TYPES] + schema = pa.schema(fields) + arrays = [ + pa.array(np.array(np.arange(10) + 100 * i, field.type.to_pandas_dtype())) + for i, field in enumerate(fields) + ] + table = pa.table(arrays, schema) + return table + + +# Test case: Read from a sequence of Arrow arrays/record batches (produced by a +# Python generator). + + +# CHECK-LABEL: TEST: testArrowStreamInput +# CHECK-NEXT: (10, 510, 1010, 1510, 2010, 2510, 3010) +# CHECK-NEXT: (35, 535, 1035, 1535, 2035, 2535, 3035) +@run +def testArrowStreamInput(): + # Use pyarrow to create an Arrow table in memory. + table = create_test_input() + + # Make physically separate batches from the table. (This ensures offset=0). + batches = (b for batch in table.to_batches(max_chunksize=5) + for b in pa.Table.from_pandas(batch.to_pandas()).to_batches()) + + # Create a RecordBatchReader and export it as a C struct. + reader = pa.RecordBatchReader.from_batches(table.schema, batches) + + # Hand the reader as an Arrow array stream to the Iterators test program. + sum_batches_elementwise_with_iterators(reader) + + +# Test case: Read data from a Parquet file (through pyarrow's C++-implemented +# Parquet reader). + + +# CHECK-LABEL: TEST: testArrowParquetInput +# CHECK-NEXT: (10, 510, 1010, 1510, 2510, 3010) +# CHECK-NEXT: (35, 535, 1035, 1535, 2535, 3035) +@run +def testArrowParquetInput(): + table = create_test_input() + # Remove f16 column, which the Parquet reader/writer doesn't support yet. + table = table.drop(['halffloat']) + + # Create a tempororay in-memory file with test data. + with io.BytesIO() as temp_file, \ + pa.PythonFile(temp_file) as parquet_file: + # Export test data as Parquet. + pa.parquet.write_table(table, parquet_file) + + # Flush and rewind to the beginning of the file. + parquet_file.flush() + temp_file.flush() + temp_file.seek(0) + + # Open as ParquetFile instance. + parquet_file = pa.parquet.ParquetFile(temp_file) + + # Create a Python generator of batches (which reads record batches using the + # C++ implementation) and turn that into a RecordBatchReader. + # TODO: It may be possible to get a RecordBatchReader for the Parquet file + # directly (i.e., without going through a Python generator) but I did not + # see it exposed to Python. + batches_generator = parquet_file.iter_batches(batch_size=5) + reader = pa.RecordBatchReader.from_batches(table.schema, batches_generator) + + # Hand the reader as an Arrow array stream to the Iterators test program. + sum_batches_elementwise_with_iterators(reader) + + +# Test case: Read from a sequence of Arrow arrays/record batches (produced by a +# Python generator). + + +# Create a generator that produces single-row record batches with increasing +# numbers with an artificial delay of one second after each of them. Since each +# generated record batch immediately produces output, this visually demonstrate +# that the consumption by the MLIR-based iterators interleaves with the +# Python-based production of the record batches in the stream. +def generate_batches_with_delay(schema: pa.Schema) -> None: + for i in range(5): + arrays = [ + pa.array(np.array([i], field.type.to_pandas_dtype())) + for field in schema + ] + batch = pa.RecordBatch.from_arrays(arrays, schema=schema) + yield batch + # Sleep only when a TTY is attached (in order not to delay unit tests). + if sys.stdout.isatty(): + time.sleep(1) + + +# CHECK-LABEL: TEST: testGeneratorInput +# CHECK-NEXT: (0, 0, 0, 0, 0, 0, 0) +# CHECK-NEXT: (1, 1, 1, 1, 1, 1, 1) +# CHECK-NEXT: (2, 2, 2, 2, 2, 2, 2) +# CHECK-NEXT: (3, 3, 3, 3, 3, 3, 3) +# CHECK-NEXT: (4, 4, 4, 4, 4, 4, 4) +@run +def testGeneratorInput(): + # Use pyarrow to create an Arrow table in memory. + table = create_test_input() + + # Make physically separate batches from the table. (This ensures offset=0). + generator = generate_batches_with_delay(table.schema) + + # Create a RecordBatchReader and export it as a C struct. + reader = pa.RecordBatchReader.from_batches(table.schema, generator) + + # Hand the reader as an Arrow array stream to the Iterators test program. + sum_batches_elementwise_with_iterators(reader)