Skip to content

Commit

Permalink
Add new FromArrowArrayStreamOp to Iterators dialect.
Browse files Browse the repository at this point in the history
  • Loading branch information
ingomueller-net committed Feb 20, 2023
1 parent 062c38d commit fe2df4e
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===-- ArrowUtils.h - IR utils related to Apache Arrow --------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H
#define ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H

namespace mlir {
class MLIRContext;
namespace LLVM {
class LLVMStructType;
} // namespace LLVM
} // namespace mlir

namespace mlir {
namespace iterators {

/// Returns the LLVM struct type for Arrow arrays of the Arrow C data interface.
/// For the official definition of the type, see
/// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions.
LLVM::LLVMStructType getArrowArrayType(MLIRContext *context);

/// Returns the LLVM struct type for Arrow schemas of the Arrow C data
/// interface. For the official definition of the type, see
/// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions.
LLVM::LLVMStructType getArrowSchemaType(MLIRContext *context);

/// Returns the LLVM struct type for Arrow streams of the Arrow C stream
/// interface. For the official definition of the type, see
/// https://arrow.apache.org/docs/format/CStreamInterface.html#structure-definition.
LLVM::LLVMStructType getArrowArrayStreamType(MLIRContext *context);

} // namespace iterators
} // namespace mlir

#endif // ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@

#include "iterators/Dialect/Iterators/IR/IteratorsOpsDialect.h.inc"

namespace mlir {
namespace LLVM {
class LLVMPointerType;
} // namespace LLVM
} // namespace mlir

namespace mlir {
namespace iterators {
#include "iterators/Dialect/Iterators/IR/IteratorsOpInterfaces.h.inc"
#include "iterators/Dialect/Iterators/IR/IteratorsTypeInterfaces.h.inc"

} // namespace iterators
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,54 @@ def Iterators_FilterOp : Iterators_Op<"filter",
}];
}

def Iterators_FromArrowArrayStreamOp : Iterators_Op<"from_arrow_array_stream", [
TypesMatchWith<"type of operand must be an LLVM 'struct ArrowArrayStream'",
"result", "arrowStream",
"mlir::LLVM::LLVMPointerType::get("
" mlir::iterators::getArrowArrayStreamType("
" $_self.getContext()))">,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Converts an arrow array stream into a stream of tabular views";
let description = [{
Wraps an Apache Arrow Array Stream using the Arrow C stream interface in
an `Iterators_Op`, i.e., converts the `ArrowArrayStream`, which produces a
stream of `ArrowArrays`, into an `Iterator_Op` that produces a stream of
`Tabular_TabularView`. This allows to read from any producer of Apache Arrow
Arrays (a.k.a "record batches"), including file readers for CSV, JSON,
Parquet, and others, libraries and languages such as pyarrow (and thus
Python and pandas), remote processes via Arrow IPC, and many others.

**Limitations:** Currently, only a few numeric data types, only ArrowArrays
without offset, and no nullable or nested types are supported. These
limitations are inherited in part from the runtime and in part from
TabularView. Furthermore, returning TabularViews, which are references,
limits the lifetime of each returned element until when the next element is
produced, so it is up to consuming iterator ops to copy any buffers that
have to live longer. Finally, Arrow array streams have no "reset"/"re-open"
method, i.e., they can only consumed once, so the lowering to
open/next/close, in which that is possible, breaks if re-opening is
attempted.

Example:
```mlir
func.func @main(%external_input: !llvm.ptr<!arrow_array_stream>) ->
!iterators.stream<!tabular.tabular_view<i32>> {
%tabular_view_stream = iterators.from_arrow_array_stream %external_input
to !iterators.stream<!tabular.tabular_view<i32>>
```
}];
let arguments = (ins LLVM_PointerTo<LLVM_AnyStruct>:$arrowStream);
let results = (outs Iterators_StreamOf<Tabular_TabularView>:$result);
let assemblyFormat = "$arrowStream attr-dict `to` qualified(type($result))";
let extraClassDefinition = [{
/// Implement OpAsmOpInterface.
void $cppClass::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
setNameFn(getResult(), "fromarrowstream");
}
}];
}

def Iterators_MapOp : Iterators_Op<"map",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
Expand Down
115 changes: 115 additions & 0 deletions experimental/iterators/lib/Dialects/Iterators/IR/ArrowUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//===-- ArrowUtils.cpp - IR utils related to Apache Arrow ------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "iterators/Dialect/Iterators/IR/ArrowUtils.h"

#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinTypes.h"

using namespace mlir;
using namespace mlir::LLVM;

namespace mlir {
namespace iterators {

LLVMStructType getArrowArrayType(MLIRContext *context) {
LLVMStructType arrowArray =
LLVMStructType::getIdentified(context, "ArrowArray");
if (arrowArray.isInitialized())
return arrowArray;

Type voidPtr = LLVMPointerType::get(context);
Type i64 = IntegerType::get(context, 64);
Type arrowArrayPtr = LLVMPointerType::get(arrowArray);
auto voidType = LLVMVoidType::get(context);
Type releaseFunc = LLVMFunctionType::get(voidType, arrowArrayPtr);

ArrayRef<Type> body = {
i64 /*length*/,
i64 /*null_count*/,
i64 /*offset*/,
i64 /*n_buffers*/,
i64 /*n_children*/,
LLVMPointerType::get(voidPtr) /*buffers*/,
LLVMPointerType::get(arrowArrayPtr) /*children*/,
arrowArrayPtr /*dictionary*/,
LLVMPointerType::get(releaseFunc) /*release*/,
voidPtr /*private_data*/
};

LogicalResult status = arrowArray.setBody(body, /*isPacked=*/false);
assert(succeeded(status) && "could not create ArrowArray struct");
return arrowArray;
}

LLVMStructType getArrowSchemaType(MLIRContext *context) {
auto arrowSchema = LLVMStructType::getIdentified(context, "ArrowSchema");
if (arrowSchema.isInitialized())
return arrowSchema;

Type charPtr = LLVMPointerType::get(IntegerType::get(context, 8));
Type voidPtr = LLVMPointerType::get(context);
Type i64 = IntegerType::get(context, 64);
Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema);
auto voidType = LLVMVoidType::get(context);
Type releaseFunc = LLVMFunctionType::get(voidType, arrowSchemaPtr);

ArrayRef<Type> body = {
charPtr /*format*/,
charPtr /*name*/,
charPtr /*metadata*/,
i64 /*flags*/,
i64 /*n_children*/,
LLVMPointerType::get(arrowSchemaPtr) /*children*/,
arrowSchemaPtr /*dictionary*/,
LLVMPointerType::get(releaseFunc) /*release*/,
voidPtr /*private_data*/
};

LogicalResult status = arrowSchema.setBody(body, /*isPacked=*/false);
assert(succeeded(status) && "could not create ArrowSchema struct");
return arrowSchema;
}

LLVMStructType getArrowArrayStreamType(MLIRContext *context) {
auto arrowArrayStream =
LLVMStructType::getIdentified(context, "ArrowArrayStream");
if (arrowArrayStream.isInitialized())
return arrowArrayStream;

Type voidPtr = LLVMPointerType::get(context);
Type charPtr = LLVMPointerType::get(IntegerType::get(context, 8));
Type i32 = IntegerType::get(context, 32);
auto voidType = LLVMVoidType::get(context);
Type arrowArray = getArrowArrayType(context);
auto arrowSchema = getArrowSchemaType(context);
Type arrowArrayPtr = LLVMPointerType::get(arrowArray);
Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema);
Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream);

Type getSchemaFunc =
LLVMFunctionType::get(i32, {arrowArrayStreamPtr, arrowSchemaPtr});
Type getNextFunc =
LLVMFunctionType::get(i32, {arrowArrayStreamPtr, arrowArrayPtr});
Type getLastErrorFunc = LLVMFunctionType::get(charPtr, arrowArrayStreamPtr);
Type releaseFunc = LLVMFunctionType::get(voidType, arrowArrayStreamPtr);

ArrayRef<Type> body = {
LLVMPointerType::get(getSchemaFunc) /*get_schema*/,
LLVMPointerType::get(getNextFunc) /*get_next*/,
LLVMPointerType::get(getLastErrorFunc) /*get_last_error*/,
LLVMPointerType::get(releaseFunc) /*release*/, voidPtr /*private_data*/
};

LogicalResult status = arrowArrayStream.setBody(body, /*isPacked=*/false);
assert(succeeded(status) && "could not create ArrowArrayStream struct");
return arrowArrayStream;
}

} // namespace iterators
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_library(MLIRIterators
ArrowUtils.cpp
Iterators.cpp

LINK_LIBS PUBLIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
//===----------------------------------------------------------------------===//

#include "iterators/Dialect/Iterators/IR/Iterators.h"
#include "iterators/Dialect/Tabular/IR/Tabular.h"

#include "iterators/Dialect/Iterators/IR/ArrowUtils.h"
#include "iterators/Dialect/Tabular/IR/Tabular.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Support/LogicalResult.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: iterators-opt %s \
// RUN: | FileCheck %s

!arrow_schema = !llvm.struct<"ArrowSchema", (
ptr<i8>, // format
ptr<i8>, // name
ptr<i8>, // metadata
i64, // flags
i64, // n_children
ptr<ptr<struct<"ArrowSchema">>>, // children
ptr<struct<"ArrowSchema">>, // dictionary
ptr<func<void (ptr<struct<"ArrowSchema">>)>>, // release
ptr // private_data
)>
!arrow_array = !llvm.struct<"ArrowArray", (
i64, // length
i64, // null_count
i64, // offset
i64, // n_buffers
i64, // n_children
ptr<ptr>, // buffers
ptr<ptr<struct<"ArrowArray">>>, // children
ptr<struct<"ArrowArray">>, // dictionary
ptr<func<void (ptr<struct<"ArrowArray">>)>>, // release
ptr // private_data
)>
!arrow_array_stream = !llvm.struct<"ArrowArrayStream", (
ptr<func<i32 (ptr<struct<"ArrowArrayStream">>, ptr<!arrow_schema>)>>, // get_schema
ptr<func<i32 (ptr<struct<"ArrowArrayStream">>, ptr<!arrow_array>)>>, // get_next
ptr<func<ptr<i8> (ptr<struct<"ArrowArrayStream">>)>>, // get_last_error
ptr<func<void (ptr<struct<"ArrowArrayStream">>)>>, // release
ptr // private_data
)>

// CHECK-LABEL: func.func @main(
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<[[STREAMTYPE:.*]]>) ->
// CHECK-SAME: !iterators.stream<!tabular.tabular_view<i32>> {
// CHECK-NEXT: %[[V0:fromarrowstream.*]] = iterators.from_arrow_array_stream %[[ARG0]] to !iterators.stream<!tabular.tabular_view<i32>>
// CHECK-NEXT: return %[[V0]] : !iterators.stream<!tabular.tabular_view<i32>>
func.func @main(%arrow_stream: !llvm.ptr<!arrow_array_stream>) -> !iterators.stream<!tabular.tabular_view<i32>> {
%tabular_view_stream = iterators.from_arrow_array_stream %arrow_stream to !iterators.stream<!tabular.tabular_view<i32>>
return %tabular_view_stream : !iterators.stream<!tabular.tabular_view<i32>>
}

0 comments on commit fe2df4e

Please sign in to comment.