-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new FromArrowArrayStreamOp to Iterators dialect.
- Loading branch information
1 parent
062c38d
commit fe2df4e
Showing
7 changed files
with
256 additions
and
1 deletion.
There are no files selected for viewing
40 changes: 40 additions & 0 deletions
40
experimental/iterators/include/iterators/Dialect/Iterators/IR/ArrowUtils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
//===-- ArrowUtils.h - IR utils related to Apache Arrow --------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H | ||
#define ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H | ||
|
||
namespace mlir { | ||
class MLIRContext; | ||
namespace LLVM { | ||
class LLVMStructType; | ||
} // namespace LLVM | ||
} // namespace mlir | ||
|
||
namespace mlir { | ||
namespace iterators { | ||
|
||
/// Returns the LLVM struct type for Arrow arrays of the Arrow C data interface. | ||
/// For the official definition of the type, see | ||
/// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions. | ||
LLVM::LLVMStructType getArrowArrayType(MLIRContext *context); | ||
|
||
/// Returns the LLVM struct type for Arrow schemas of the Arrow C data | ||
/// interface. For the official definition of the type, see | ||
/// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions. | ||
LLVM::LLVMStructType getArrowSchemaType(MLIRContext *context); | ||
|
||
/// Returns the LLVM struct type for Arrow streams of the Arrow C stream | ||
/// interface. For the official definition of the type, see | ||
/// https://arrow.apache.org/docs/format/CStreamInterface.html#structure-definition. | ||
LLVM::LLVMStructType getArrowArrayStreamType(MLIRContext *context); | ||
|
||
} // namespace iterators | ||
} // namespace mlir | ||
|
||
#endif // ITERATORS_DIALECT_ITERATORS_IR_ARROWUTILS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
115 changes: 115 additions & 0 deletions
115
experimental/iterators/lib/Dialects/Iterators/IR/ArrowUtils.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
//===-- ArrowUtils.cpp - IR utils related to Apache Arrow ------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "iterators/Dialect/Iterators/IR/ArrowUtils.h" | ||
|
||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::LLVM; | ||
|
||
namespace mlir { | ||
namespace iterators { | ||
|
||
LLVMStructType getArrowArrayType(MLIRContext *context) { | ||
LLVMStructType arrowArray = | ||
LLVMStructType::getIdentified(context, "ArrowArray"); | ||
if (arrowArray.isInitialized()) | ||
return arrowArray; | ||
|
||
Type voidPtr = LLVMPointerType::get(context); | ||
Type i64 = IntegerType::get(context, 64); | ||
Type arrowArrayPtr = LLVMPointerType::get(arrowArray); | ||
auto voidType = LLVMVoidType::get(context); | ||
Type releaseFunc = LLVMFunctionType::get(voidType, arrowArrayPtr); | ||
|
||
ArrayRef<Type> body = { | ||
i64 /*length*/, | ||
i64 /*null_count*/, | ||
i64 /*offset*/, | ||
i64 /*n_buffers*/, | ||
i64 /*n_children*/, | ||
LLVMPointerType::get(voidPtr) /*buffers*/, | ||
LLVMPointerType::get(arrowArrayPtr) /*children*/, | ||
arrowArrayPtr /*dictionary*/, | ||
LLVMPointerType::get(releaseFunc) /*release*/, | ||
voidPtr /*private_data*/ | ||
}; | ||
|
||
LogicalResult status = arrowArray.setBody(body, /*isPacked=*/false); | ||
assert(succeeded(status) && "could not create ArrowArray struct"); | ||
return arrowArray; | ||
} | ||
|
||
LLVMStructType getArrowSchemaType(MLIRContext *context) { | ||
auto arrowSchema = LLVMStructType::getIdentified(context, "ArrowSchema"); | ||
if (arrowSchema.isInitialized()) | ||
return arrowSchema; | ||
|
||
Type charPtr = LLVMPointerType::get(IntegerType::get(context, 8)); | ||
Type voidPtr = LLVMPointerType::get(context); | ||
Type i64 = IntegerType::get(context, 64); | ||
Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); | ||
auto voidType = LLVMVoidType::get(context); | ||
Type releaseFunc = LLVMFunctionType::get(voidType, arrowSchemaPtr); | ||
|
||
ArrayRef<Type> body = { | ||
charPtr /*format*/, | ||
charPtr /*name*/, | ||
charPtr /*metadata*/, | ||
i64 /*flags*/, | ||
i64 /*n_children*/, | ||
LLVMPointerType::get(arrowSchemaPtr) /*children*/, | ||
arrowSchemaPtr /*dictionary*/, | ||
LLVMPointerType::get(releaseFunc) /*release*/, | ||
voidPtr /*private_data*/ | ||
}; | ||
|
||
LogicalResult status = arrowSchema.setBody(body, /*isPacked=*/false); | ||
assert(succeeded(status) && "could not create ArrowSchema struct"); | ||
return arrowSchema; | ||
} | ||
|
||
LLVMStructType getArrowArrayStreamType(MLIRContext *context) { | ||
auto arrowArrayStream = | ||
LLVMStructType::getIdentified(context, "ArrowArrayStream"); | ||
if (arrowArrayStream.isInitialized()) | ||
return arrowArrayStream; | ||
|
||
Type voidPtr = LLVMPointerType::get(context); | ||
Type charPtr = LLVMPointerType::get(IntegerType::get(context, 8)); | ||
Type i32 = IntegerType::get(context, 32); | ||
auto voidType = LLVMVoidType::get(context); | ||
Type arrowArray = getArrowArrayType(context); | ||
auto arrowSchema = getArrowSchemaType(context); | ||
Type arrowArrayPtr = LLVMPointerType::get(arrowArray); | ||
Type arrowSchemaPtr = LLVMPointerType::get(arrowSchema); | ||
Type arrowArrayStreamPtr = LLVMPointerType::get(arrowArrayStream); | ||
|
||
Type getSchemaFunc = | ||
LLVMFunctionType::get(i32, {arrowArrayStreamPtr, arrowSchemaPtr}); | ||
Type getNextFunc = | ||
LLVMFunctionType::get(i32, {arrowArrayStreamPtr, arrowArrayPtr}); | ||
Type getLastErrorFunc = LLVMFunctionType::get(charPtr, arrowArrayStreamPtr); | ||
Type releaseFunc = LLVMFunctionType::get(voidType, arrowArrayStreamPtr); | ||
|
||
ArrayRef<Type> body = { | ||
LLVMPointerType::get(getSchemaFunc) /*get_schema*/, | ||
LLVMPointerType::get(getNextFunc) /*get_next*/, | ||
LLVMPointerType::get(getLastErrorFunc) /*get_last_error*/, | ||
LLVMPointerType::get(releaseFunc) /*release*/, voidPtr /*private_data*/ | ||
}; | ||
|
||
LogicalResult status = arrowArrayStream.setBody(body, /*isPacked=*/false); | ||
assert(succeeded(status) && "could not create ArrowArrayStream struct"); | ||
return arrowArrayStream; | ||
} | ||
|
||
} // namespace iterators | ||
} // namespace mlir |
1 change: 1 addition & 0 deletions
1
experimental/iterators/lib/Dialects/Iterators/IR/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
add_mlir_library(MLIRIterators | ||
ArrowUtils.cpp | ||
Iterators.cpp | ||
|
||
LINK_LIBS PUBLIC | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
experimental/iterators/test/Dialect/Iterators/from-arrow-stream.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// RUN: iterators-opt %s \ | ||
// RUN: | FileCheck %s | ||
|
||
!arrow_schema = !llvm.struct<"ArrowSchema", ( | ||
ptr<i8>, // format | ||
ptr<i8>, // name | ||
ptr<i8>, // metadata | ||
i64, // flags | ||
i64, // n_children | ||
ptr<ptr<struct<"ArrowSchema">>>, // children | ||
ptr<struct<"ArrowSchema">>, // dictionary | ||
ptr<func<void (ptr<struct<"ArrowSchema">>)>>, // release | ||
ptr // private_data | ||
)> | ||
!arrow_array = !llvm.struct<"ArrowArray", ( | ||
i64, // length | ||
i64, // null_count | ||
i64, // offset | ||
i64, // n_buffers | ||
i64, // n_children | ||
ptr<ptr>, // buffers | ||
ptr<ptr<struct<"ArrowArray">>>, // children | ||
ptr<struct<"ArrowArray">>, // dictionary | ||
ptr<func<void (ptr<struct<"ArrowArray">>)>>, // release | ||
ptr // private_data | ||
)> | ||
!arrow_array_stream = !llvm.struct<"ArrowArrayStream", ( | ||
ptr<func<i32 (ptr<struct<"ArrowArrayStream">>, ptr<!arrow_schema>)>>, // get_schema | ||
ptr<func<i32 (ptr<struct<"ArrowArrayStream">>, ptr<!arrow_array>)>>, // get_next | ||
ptr<func<ptr<i8> (ptr<struct<"ArrowArrayStream">>)>>, // get_last_error | ||
ptr<func<void (ptr<struct<"ArrowArrayStream">>)>>, // release | ||
ptr // private_data | ||
)> | ||
|
||
// CHECK-LABEL: func.func @main( | ||
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<[[STREAMTYPE:.*]]>) -> | ||
// CHECK-SAME: !iterators.stream<!tabular.tabular_view<i32>> { | ||
// CHECK-NEXT: %[[V0:fromarrowstream.*]] = iterators.from_arrow_array_stream %[[ARG0]] to !iterators.stream<!tabular.tabular_view<i32>> | ||
// CHECK-NEXT: return %[[V0]] : !iterators.stream<!tabular.tabular_view<i32>> | ||
func.func @main(%arrow_stream: !llvm.ptr<!arrow_array_stream>) -> !iterators.stream<!tabular.tabular_view<i32>> { | ||
%tabular_view_stream = iterators.from_arrow_array_stream %arrow_stream to !iterators.stream<!tabular.tabular_view<i32>> | ||
return %tabular_view_stream : !iterators.stream<!tabular.tabular_view<i32>> | ||
} |