diff --git a/Cargo.lock b/Cargo.lock index 373a807cc20..ce5252d0d2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9850,6 +9850,7 @@ dependencies = [ "tracing-subscriber", "url", "vortex", + "vortex-parquet-variant", ] [[package]] diff --git a/encodings/parquet-variant/src/arrow.rs b/encodings/parquet-variant/src/arrow.rs new file mode 100644 index 00000000000..bcad493f0ad --- /dev/null +++ b/encodings/parquet-variant/src/arrow.rs @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use arrow_array::Array as _; +use arrow_array::ArrayRef as ArrowArrayRef; +use arrow_array::cast::AsArray; +use arrow_schema::DataType; +use arrow_schema::Field; +use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; +use parquet_variant_compute::VariantArray as ArrowVariantArray; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::VTable; +use vortex_array::arrow::ArrowExport; +use vortex_array::arrow::ArrowExportVTable; +use vortex_array::arrow::ArrowImport; +use vortex_array::arrow::ArrowImportVTable; +use vortex_array::arrow::ArrowSession; +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDTypeRef; +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_session::registry::CachedId; +use vortex_session::registry::Id; + +use crate::ParquetVariant; +use crate::ParquetVariantArrayExt; + +/// Arrow canonical extension name for Parquet Variant storage. +pub const PARQUET_VARIANT_ARROW_EXTENSION_NAME: &str = "arrow.parquet.variant"; + +static ARROW_PARQUET_VARIANT: CachedId = CachedId::new(PARQUET_VARIANT_ARROW_EXTENSION_NAME); + +impl ArrowExportVTable for ParquetVariant { + fn arrow_ext_id(&self) -> Id { + *ARROW_PARQUET_VARIANT + } + + fn vortex_ext_id(&self) -> Id { + ParquetVariant.id() + } + + fn to_arrow_field( + &self, + _name: &str, + _dtype: &ExtDTypeRef, + _session: &ArrowSession, + ) -> VortexResult> { + Ok(None) + } + + fn execute_arrow( + &self, + array: ArrayRef, + target: &Field, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + if target + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str) + != Some(PARQUET_VARIANT_ARROW_EXTENSION_NAME) + || !array.dtype().is_variant() + { + return Ok(ArrowExport::Unsupported(array)); + } + + let executed = array.execute_until::(ctx)?; + let parquet_array = executed + .as_opt::() + .ok_or_else(|| vortex_err!("cannot export Variant without ParquetVariant storage"))?; + let arrow_variant = parquet_array.to_arrow(ctx)?; + Ok(ArrowExport::Exported(Arc::new(arrow_variant.into_inner()))) + } +} + +impl ArrowImportVTable for ParquetVariant { + fn arrow_ext_id(&self) -> Id { + *ARROW_PARQUET_VARIANT + } + + fn from_arrow_field(&self, field: &Field) -> VortexResult> { + if field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str) + != Some(PARQUET_VARIANT_ARROW_EXTENSION_NAME) + { + return Ok(None); + } + + Ok(Some(DType::Variant(field.is_nullable().into()))) + } + + fn from_arrow_array( + &self, + array: ArrowArrayRef, + field: &Field, + dtype: &DType, + ) -> VortexResult { + if !matches!(dtype, DType::Variant(_)) + || field + .metadata() + .get(EXTENSION_TYPE_NAME_KEY) + .map(String::as_str) + != Some(PARQUET_VARIANT_ARROW_EXTENSION_NAME) + || !matches!(array.data_type(), DataType::Struct(_)) + { + return Ok(ArrowImport::Unsupported(array)); + } + + let arrow_variant = ArrowVariantArray::try_new(array.as_struct())?; + let imported = if dtype.is_nullable() { + ParquetVariant::from_arrow_variant_nullable(&arrow_variant)? + } else { + ParquetVariant::from_arrow_variant(&arrow_variant)? + }; + Ok(ArrowImport::Imported(imported.into_array())) + } +} diff --git a/encodings/parquet-variant/src/lib.rs b/encodings/parquet-variant/src/lib.rs index e606e7b4e8c..2f87fffb374 100644 --- a/encodings/parquet-variant/src/lib.rs +++ b/encodings/parquet-variant/src/lib.rs @@ -25,11 +25,110 @@ //! [Arrow canonical extension type]: https://arrow.apache.org/docs/format/CanonicalExtensions.html#parquet-variant mod array; +mod arrow; mod kernel; mod operations; mod validity; mod vtable; +use std::sync::Arc; + pub use array::ParquetVariantArrayExt; +pub use arrow::PARQUET_VARIANT_ARROW_EXTENSION_NAME; +use vortex_array::arrow::ArrowSessionExt; +use vortex_array::session::ArraySessionExt; +use vortex_session::VortexSession; pub use vtable::ParquetVariant; pub use vtable::ParquetVariantArray; + +/// Register Parquet Variant array and Arrow extension support with a session. +pub fn initialize(session: &VortexSession) { + session.arrays().register(ParquetVariant); + session.arrow().register_exporter(Arc::new(ParquetVariant)); + session.arrow().register_importer(Arc::new(ParquetVariant)); +} + +#[cfg(test)] +mod arrow_session_tests { + use std::sync::Arc; + + use arrow_array::Array as _; + use arrow_array::ArrayRef as ArrowArrayRef; + use arrow_array::StructArray; + use arrow_array::cast::AsArray; + use arrow_schema::Field; + use arrow_schema::extension::EXTENSION_TYPE_NAME_KEY; + use parquet_variant::Variant as PqVariant; + use parquet_variant_compute::VariantArrayBuilder; + use vortex_array::VortexSessionExecute; + use vortex_array::arrow::ArrowSessionExt; + use vortex_array::dtype::DType; + use vortex_array::dtype::Nullability; + use vortex_array::session::ArraySession; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::ParquetVariant; + + fn session() -> VortexSession { + let session = VortexSession::empty().with::(); + crate::initialize(&session); + session + } + + fn arrow_variant_storage() -> StructArray { + let mut builder = VariantArrayBuilder::new(3); + builder.append_variant(PqVariant::from(42i8)); + builder.append_variant(PqVariant::from(true)); + builder.append_variant(PqVariant::from("vortex")); + builder.build().into_inner() + } + + fn arrow_variant_field(storage: &StructArray) -> Field { + Field::new("variant", storage.data_type().clone(), false).with_metadata( + [( + EXTENSION_TYPE_NAME_KEY.to_string(), + "arrow.parquet.variant".to_string(), + )] + .into(), + ) + } + + #[test] + fn arrow_session_imports_parquet_variant_extension_array() -> VortexResult<()> { + let session = session(); + let storage = arrow_variant_storage(); + let field = arrow_variant_field(&storage); + let imported = session + .arrow() + .from_arrow_array(Arc::new(storage) as ArrowArrayRef, &field)?; + + assert_eq!(imported.dtype(), &DType::Variant(Nullability::NonNullable)); + assert!(imported.as_opt::().is_some()); + Ok(()) + } + + #[test] + fn arrow_session_exports_parquet_variant_extension_array() -> VortexResult<()> { + let session = session(); + let storage = arrow_variant_storage(); + let field = arrow_variant_field(&storage); + let imported = session + .arrow() + .from_arrow_array(Arc::new(storage.clone()) as ArrowArrayRef, &field)?; + + let mut ctx = session.create_execution_ctx(); + let exported = session + .arrow() + .execute_arrow(imported, Some(&field), &mut ctx)?; + let exported = exported.as_struct(); + + assert_eq!(exported.len(), storage.len()); + assert_eq!(exported.column_names(), storage.column_names()); + assert_eq!(exported.fields(), storage.fields()); + for (actual, expected) in exported.columns().iter().zip(storage.columns()) { + assert_eq!(actual.to_data(), expected.to_data()); + } + Ok(()) + } +} diff --git a/java/vortex-jni/build.gradle.kts b/java/vortex-jni/build.gradle.kts index 09892fa6f44..c8f0192babc 100644 --- a/java/vortex-jni/build.gradle.kts +++ b/java/vortex-jni/build.gradle.kts @@ -48,7 +48,9 @@ mavenPublishing { coordinates(groupId = "dev.vortex", artifactId = "vortex-jni", version = "${rootProject.version}") publishToMavenCentral() - signAllPublications() + if (!project.hasProperty("skip.signing")) { + signAllPublications() + } pom { name = "vortex-jni" diff --git a/java/vortex-jni/src/test/java/dev/vortex/jni/JNIWriterTest.java b/java/vortex-jni/src/test/java/dev/vortex/jni/JNIWriterTest.java index 58cdbbf5315..03869ae8598 100644 --- a/java/vortex-jni/src/test/java/dev/vortex/jni/JNIWriterTest.java +++ b/java/vortex-jni/src/test/java/dev/vortex/jni/JNIWriterTest.java @@ -4,6 +4,7 @@ package dev.vortex.jni; import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -26,17 +27,25 @@ import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; public final class JNIWriterTest { + private static final String ARROW_EXTENSION_NAME = "ARROW:extension:name"; + private static final String PARQUET_VARIANT_EXTENSION_NAME = "arrow.parquet.variant"; + private static final byte[] VARIANT_METADATA = new byte[] {0x01, 0x00}; + private static final byte[] VARIANT_INT8_42 = new byte[] {0x0c, 0x2a}; + private static final byte[] VARIANT_TRUE = new byte[] {0x04}; @TempDir Path tempDir; @@ -52,6 +61,45 @@ private static Schema personSchema() { Field.notNullable("age", new ArrowType.Int(32, true)))); } + private static Schema parquetVariantSchema() { + Field variant = new Field( + "variant", + new FieldType( + true, + ArrowType.Struct.INSTANCE, + null, + Map.of(ARROW_EXTENSION_NAME, PARQUET_VARIANT_EXTENSION_NAME)), + List.of( + Field.notNullable("metadata", new ArrowType.Binary()), + Field.nullable("value", new ArrowType.Binary()))); + return new Schema(List.of(variant)); + } + + private static void populateParquetVariantRoot(VectorSchemaRoot root) { + StructVector variant = (StructVector) root.getVector("variant"); + VarBinaryVector metadata = variant.getChild("metadata", VarBinaryVector.class); + VarBinaryVector value = variant.getChild("value", VarBinaryVector.class); + + variant.allocateNew(); + metadata.allocateNew(3); + value.allocateNew(3); + + metadata.setSafe(0, VARIANT_METADATA); + metadata.setSafe(1, VARIANT_METADATA); + metadata.setSafe(2, VARIANT_METADATA); + value.setSafe(0, VARIANT_INT8_42); + value.setSafe(1, VARIANT_TRUE); + value.setNull(2); + variant.setIndexDefined(0); + variant.setIndexDefined(1); + variant.setNull(2); + + metadata.setValueCount(3); + value.setValueCount(3); + variant.setValueCount(3); + root.setRowCount(3); + } + @Test public void testCreateWriter() throws IOException { Path outputPath = tempDir.resolve("test_create.vortex"); @@ -155,4 +203,53 @@ public void testWriteBatch() throws IOException { } } } + + @Test + public void testParquetVariantRoundTrip() throws IOException { + Path outputPath = tempDir.resolve("test_parquet_variant.vortex"); + String writePath = outputPath.toAbsolutePath().toUri().toString(); + + BufferAllocator allocator = ArrowAllocation.rootAllocator(); + Schema schema = parquetVariantSchema(); + + Session session = Session.create(); + try (VortexWriter writer = VortexWriter.create(session, writePath, schema, new HashMap<>(), allocator); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + populateParquetVariantRoot(root); + + try (ArrowArray arrowArray = ArrowArray.allocateNew(allocator); + ArrowSchema arrowSchemaFfi = ArrowSchema.allocateNew(allocator)) { + Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchemaFfi); + writer.writeBatch(arrowArray.memoryAddress(), arrowSchemaFfi.memoryAddress()); + } + } + + assertTrue(Files.exists(outputPath), "output file should exist"); + + DataSource ds = DataSource.open(session, writePath); + Field dataSourceField = ds.arrowSchema(allocator).findField("variant"); + assertEquals( + PARQUET_VARIANT_EXTENSION_NAME, dataSourceField.getMetadata().get(ARROW_EXTENSION_NAME)); + + Scan scan = ds.scan(ScanOptions.of()); + Field scanField = scan.arrowSchema(allocator).findField("variant"); + assertEquals(PARQUET_VARIANT_EXTENSION_NAME, scanField.getMetadata().get(ARROW_EXTENSION_NAME)); + + while (scan.hasNext()) { + Partition p = scan.next(); + try (ArrowReader reader = p.scanArrow(allocator)) { + assertTrue(reader.loadNextBatch()); + VectorSchemaRoot resultRoot = reader.getVectorSchemaRoot(); + StructVector variant = (StructVector) resultRoot.getVector("variant"); + VarBinaryVector metadata = variant.getChild("metadata", VarBinaryVector.class); + VarBinaryVector value = variant.getChild("value", VarBinaryVector.class); + + assertArrayEquals(VARIANT_METADATA, metadata.get(0)); + assertArrayEquals(VARIANT_INT8_42, value.get(0)); + assertArrayEquals(VARIANT_METADATA, metadata.get(1)); + assertArrayEquals(VARIANT_TRUE, value.get(1)); + assertTrue(variant.isNull(2)); + } + } + } } diff --git a/vortex-array/src/arrow/session.rs b/vortex-array/src/arrow/session.rs index 012b9001bd2..0be4d8e0c50 100644 --- a/vortex-array/src/arrow/session.rs +++ b/vortex-array/src/arrow/session.rs @@ -125,7 +125,7 @@ pub trait ArrowExportVTable: 'static + Send + Sync + Debug { ) -> VortexResult; } -/// Plugin layer for importing an Arrow extension-typed array into a Vortex extension array. +/// Plugin layer for importing an Arrow extension-typed array into a Vortex array. /// /// Plugins are dispatched by `arrow_ext_id`. /// @@ -140,7 +140,7 @@ pub trait ArrowImportVTable: 'static + Send + Sync + Debug { #[allow(clippy::wrong_self_convention)] fn from_arrow_field(&self, field: &Field) -> VortexResult>; - /// Convert an Arrow array into a Vortex extension array of `dtype`. + /// Convert an Arrow array into a Vortex array of `dtype`. /// /// Returns ownership of `array` via [`ArrowImport::Unsupported`] when the plugin cannot /// handle the input. @@ -148,7 +148,8 @@ pub trait ArrowImportVTable: 'static + Send + Sync + Debug { fn from_arrow_array( &self, array: ArrowArrayRef, - dtype: &ExtDTypeRef, + field: &Field, + dtype: &DType, ) -> VortexResult; } @@ -490,16 +491,14 @@ impl ArrowSession { let importers = self.importers(&Id::new(extension_name)); if !importers.is_empty() { let dtype = self.from_arrow_field(field)?; - if let DType::Extension(ext_dtype) = dtype { - let mut current = array; - for plugin in importers.iter() { - match plugin.from_arrow_array(current, &ext_dtype)? { - ArrowImport::Imported(arr) => return Ok(arr), - ArrowImport::Unsupported(arr) => current = arr, - } + let mut current = array; + for plugin in importers.iter() { + match plugin.from_arrow_array(current, field, &dtype)? { + ArrowImport::Imported(arr) => return Ok(arr), + ArrowImport::Unsupported(arr) => current = arr, } - return ArrayRef::from_arrow(current.as_ref(), field.is_nullable()); } + return ArrayRef::from_arrow(current.as_ref(), field.is_nullable()); } } self.from_arrow_array_canonical(array, field) diff --git a/vortex-array/src/extension/uuid/arrow.rs b/vortex-array/src/extension/uuid/arrow.rs index eacbe975a55..7d683ae0af0 100644 --- a/vortex-array/src/extension/uuid/arrow.rs +++ b/vortex-array/src/extension/uuid/arrow.rs @@ -125,8 +125,12 @@ impl ArrowImportVTable for Uuid { fn from_arrow_array( &self, array: ArrowArrayRef, - dtype: &ExtDTypeRef, + _field: &Field, + dtype: &DType, ) -> VortexResult { + let DType::Extension(dtype) = dtype else { + return Ok(ArrowImport::Unsupported(array)); + }; if !matches!(array.data_type(), DataType::FixedSizeBinary(UUID_BYTE_LEN)) || !dtype.is::() { diff --git a/vortex-jni/Cargo.toml b/vortex-jni/Cargo.toml index 628c2c89d43..62b89e61fc8 100644 --- a/vortex-jni/Cargo.toml +++ b/vortex-jni/Cargo.toml @@ -29,6 +29,7 @@ tracing = { workspace = true, features = ["std", "log"] } tracing-subscriber = { workspace = true, features = ["env-filter"] } url = { workspace = true } vortex = { workspace = true, features = ["object_store", "files"] } +vortex-parquet-variant = { workspace = true } [dev-dependencies] jni = { workspace = true, features = ["invocation"] } diff --git a/vortex-jni/src/dtype.rs b/vortex-jni/src/dtype.rs index f748135ade1..5a0b54d2c14 100644 --- a/vortex-jni/src/dtype.rs +++ b/vortex-jni/src/dtype.rs @@ -10,8 +10,8 @@ use arrow_array::ffi::FFI_ArrowSchema; use arrow_schema::DataType; use arrow_schema::FieldRef; use arrow_schema::Fields; +use arrow_schema::Schema; use vortex::dtype::DType; -use vortex::dtype::arrow::FromArrowType; use vortex::error::VortexResult; /// Export a Vortex [`DType`] to the Arrow C Data Interface struct at `schema_addr`. Views @@ -24,7 +24,7 @@ pub(crate) fn export_dtype_to_arrow(dtype: &DType, schema_addr: i64) -> VortexRe DataType::Struct(fields) => fields, _ => unreachable!("Vortex DType always exports as a struct"), }; - let schema = arrow_schema::Schema::new(fields); + let schema = Schema::new(fields); let ffi_schema = FFI_ArrowSchema::try_from(&schema)?; unsafe { ptr::write(schema_addr as *mut FFI_ArrowSchema, ffi_schema); @@ -70,9 +70,8 @@ pub(crate) fn strip_views(data_type: DataType) -> DataType { } } -/// Decode an [`FFI_ArrowSchema`] pointed to by `schema_addr` into a Vortex [`DType`]. -pub(crate) fn import_dtype_from_arrow(schema_addr: i64) -> VortexResult { +/// Decode an [`FFI_ArrowSchema`] pointed to by `schema_addr` into an Arrow [`Schema`]. +pub(crate) fn import_arrow_schema(schema_addr: i64) -> VortexResult { let ffi_schema = unsafe { &*(schema_addr as *const FFI_ArrowSchema) }; - let arrow_schema = arrow_schema::Schema::try_from(ffi_schema)?; - Ok(DType::from_arrow(&arrow_schema)) + Ok(Schema::try_from(ffi_schema)?) } diff --git a/vortex-jni/src/session.rs b/vortex-jni/src/session.rs index 9f75d24f564..9adaf544431 100644 --- a/vortex-jni/src/session.rs +++ b/vortex-jni/src/session.rs @@ -16,7 +16,9 @@ use crate::RUNTIME; /// Constructs a fresh [`VortexSession`] bound to the JNI-shared tokio runtime and returns /// an opaque pointer that Java must pass to [`Java_dev_vortex_jni_NativeSession_free`]. pub(crate) fn new_session() -> Box { - Box::new(VortexSession::default().with_handle(RUNTIME.handle())) + let session = VortexSession::default().with_handle(RUNTIME.handle()); + vortex_parquet_variant::initialize(&session); + Box::new(session) } /// SAFETY: caller must pass a pointer previously returned by [`new_session`]. diff --git a/vortex-jni/src/writer.rs b/vortex-jni/src/writer.rs index da30952d87b..a7d1e6432f2 100644 --- a/vortex-jni/src/writer.rs +++ b/vortex-jni/src/writer.rs @@ -13,6 +13,7 @@ use arrow_array::RecordBatch; use arrow_array::StructArray; use arrow_array::ffi::FFI_ArrowArray; use arrow_array::ffi::FFI_ArrowSchema; +use arrow_schema::SchemaRef; use async_fs::File; use futures::SinkExt; use futures::channel::mpsc; @@ -28,12 +29,16 @@ use object_store::ObjectStore; use object_store::path::Path as ObjectStorePath; use url::Url; use vortex::array::ArrayRef; -use vortex::array::arrow::FromArrowArray; +use vortex::array::VTable; +use vortex::array::arrow::ArrowSessionExt; use vortex::array::stream::ArrayStreamAdapter; use vortex::dtype::DType; +use vortex::dtype::Field as DTypeField; +use vortex::dtype::FieldPath; use vortex::error::VortexResult; use vortex::error::vortex_err; use vortex::file::WriteOptionsSessionExt; +use vortex::file::WriteStrategyBuilder; use vortex::file::WriteSummary; use vortex::io::VortexWrite; use vortex::io::compat::Compat; @@ -41,10 +46,14 @@ use vortex::io::object_store::ObjectStoreWrite; use vortex::io::runtime::BlockingRuntime; use vortex::io::runtime::Task; use vortex::io::session::RuntimeSessionExt; +use vortex::layout::LayoutStrategy; +use vortex::layout::layouts::flat::writer::FlatLayoutStrategy; +use vortex::session::VortexSession; use vortex::utils::aliases::hash_map::HashMap; +use vortex_parquet_variant::ParquetVariant; use crate::RUNTIME; -use crate::dtype::import_dtype_from_arrow; +use crate::dtype::import_arrow_schema; use crate::errors::JNIError; use crate::errors::try_or_throw; use crate::file::extract_properties; @@ -81,21 +90,71 @@ fn resolve_store( } } +fn write_options_for_schema( + session: &VortexSession, + write_schema: &DType, +) -> vortex::file::VortexWriteOptions { + let variant_paths = variant_field_paths(write_schema); + if variant_paths.is_empty() { + return session.write_options(); + } + + let mut allowed = vortex::file::ALLOWED_ENCODINGS.clone(); + allowed.insert(ParquetVariant.id()); + let flat: Arc = + Arc::new(FlatLayoutStrategy::default().with_allow_encodings(allowed)); + + let mut strategy = WriteStrategyBuilder::default(); + for path in variant_paths { + strategy = strategy.with_field_writer(path, Arc::clone(&flat)); + } + + session.write_options().with_strategy(strategy.build()) +} + +fn variant_field_paths(dtype: &DType) -> Vec { + let mut paths = Vec::new(); + collect_variant_field_paths(dtype, FieldPath::root(), &mut paths); + paths +} + +fn collect_variant_field_paths(dtype: &DType, path: FieldPath, paths: &mut Vec) { + match dtype { + DType::Variant(_) => paths.push(path), + DType::Struct(fields, _) => { + for (name, field_dtype) in fields.names().iter().zip(fields.fields()) { + collect_variant_field_paths( + &field_dtype, + path.clone().push(DTypeField::from(name.clone())), + paths, + ); + } + } + _ => {} + } +} + /// Native writer holding a write-task handle and a sender that Java pushes batches into. pub struct NativeWriter { handle: Option>>, + session: VortexSession, + arrow_schema: SchemaRef, write_schema: DType, sender: mpsc::Sender>, } impl NativeWriter { pub fn new( + session: VortexSession, + arrow_schema: SchemaRef, write_schema: DType, handle: Task>, sender: mpsc::Sender>, ) -> Self { Self { handle: Some(handle), + session, + arrow_schema, write_schema, sender, } @@ -117,7 +176,10 @@ impl NativeWriter { } fn write_record_batch(&self, batch: RecordBatch) -> VortexResult<()> { - let vortex_batch = ArrayRef::from_arrow(batch, false)?; + let vortex_batch = self + .session + .arrow() + .from_arrow_record_batch(batch, self.arrow_schema.as_ref())?; if !vortex_batch.dtype().eq(&self.write_schema) { return Err(vortex_err!( "write schema mismatch: expected {}, got {}", @@ -162,13 +224,15 @@ pub extern "system" fn Java_dev_vortex_jni_NativeWriter_create( } let session = unsafe { session_ref(session_ptr) }; - let write_schema = import_dtype_from_arrow(arrow_schema_addr)?; + let arrow_schema = Arc::new(import_arrow_schema(arrow_schema_addr)?); + let write_schema = session.arrow().from_arrow_schema(arrow_schema.as_ref())?; let file_path: String = uri.try_to_string(env)?; let properties: HashMap = extract_properties(env, &options)?; let resolved = resolve_store(&file_path, &properties)?; let (tx, rx) = mpsc::channel(WRITE_CHANNEL_CAPACITY); let stream = ArrayStreamAdapter::new(write_schema.clone(), rx); + let write_options = write_options_for_schema(session, &write_schema); let handle = session.handle().spawn(async move { match resolved { @@ -177,21 +241,28 @@ pub extern "system" fn Java_dev_vortex_jni_NativeWriter_create( async_fs::create_dir_all(parent).await?; } let mut file = File::create(path).await?; - let summary = session.write_options().write(&mut file, stream).await?; + let summary = write_options.write(&mut file, stream).await?; file.shutdown().await?; Ok(summary) } ResolvedStore::ObjectStore(store, path) => { let mut write = ObjectStoreWrite::new(Arc::new(Compat::new(store)), &path).await?; - let summary = session.write_options().write(&mut write, stream).await?; + let summary = write_options.write(&mut write, stream).await?; write.shutdown().await?; Ok(summary) } } }); - Ok(Box::new(NativeWriter::new(write_schema, handle, tx)).into_raw()) + Ok(Box::new(NativeWriter::new( + session.clone(), + arrow_schema, + write_schema, + handle, + tx, + )) + .into_raw()) }) } diff --git a/vortex-tensor/src/types/vector/arrow.rs b/vortex-tensor/src/types/vector/arrow.rs index b7e5e7577a7..e186a72bfe4 100644 --- a/vortex-tensor/src/types/vector/arrow.rs +++ b/vortex-tensor/src/types/vector/arrow.rs @@ -141,8 +141,12 @@ impl ArrowImportVTable for Vector { fn from_arrow_array( &self, array: ArrowArrayRef, - dtype: &ExtDTypeRef, + _field: &Field, + dtype: &DType, ) -> VortexResult { + let DType::Extension(dtype) = dtype else { + return Ok(ArrowImport::Unsupported(array)); + }; if !dtype.is::() { return Ok(ArrowImport::Unsupported(array)); } @@ -362,13 +366,11 @@ mod tests { #[test] fn from_arrow_array_returns_unsupported_for_non_fsl() -> VortexResult<()> { let dtype = vector_dtype(false); - let ext = dtype - .as_extension_opt() - .expect("vector dtype should be an extension dtype") - .clone(); + let field = Field::new("embedding", DataType::Int32, false); let int_array: ArrowArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); - let result = ::from_arrow_array(&Vector, int_array, &ext)?; + let result = + ::from_arrow_array(&Vector, int_array, &field, &dtype)?; assert!(matches!(result, ArrowImport::Unsupported(_))); Ok(()) } @@ -386,8 +388,13 @@ mod tests { ExtDType::try_with_vtable(Uuid, UuidMetadata::default(), uuid_storage)?.erased(); let fsl_arrow = arrow_fsl_f32(vec![1.0, 2.0, 3.0], DIM as i32); - let result = - ::from_arrow_array(&Vector, fsl_arrow, &uuid_ext)?; + let field = Field::new("embedding", fsl_arrow.data_type().clone(), false); + let result = ::from_arrow_array( + &Vector, + fsl_arrow, + &field, + &DType::Extension(uuid_ext), + )?; assert!(matches!(result, ArrowImport::Unsupported(_))); Ok(()) }