Skip to content

Commit

Permalink
Support UTF8 in nested Apache Arrow data types (e.g. List) (#300)
Browse files Browse the repository at this point in the history
* support UTF8[]

* add tests

* fix test

* format

* clippy

* bump cause github is broken

---------

Co-authored-by: Max Gabrielsson <max@gabrielsson.com>
  • Loading branch information
Jeadie and Maxxen committed Apr 24, 2024
1 parent 0018cd8 commit 5a1729e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 8 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pretty_assertions = "1.4.0"
path = "libduckdb-sys"
version = "0.10.1"


[package.metadata.docs.rs]
features = ['vtab', 'chrono']
all-features = false
Expand Down
86 changes: 78 additions & 8 deletions src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,20 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
match value_array.data_type() {
dt if dt.is_primitive() => {
primitive_array_to_vector(value_array.as_ref(), &mut child)?;
for i in 0..array.len() {
let offset = array.value_offsets()[i];
let length = array.value_length(i);
out.set_entry(i, offset.as_(), length.as_());
}
}
DataType::Utf8 => {
string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
}
_ => {
return Err("Nested list is not supported yet.".into());
}
}

for i in 0..array.len() {
let offset = array.value_offsets()[i];
let length = array.value_length(i);
out.set_entry(i, offset.as_(), length.as_());
}
Ok(())
}

Expand All @@ -452,10 +455,19 @@ fn fixed_size_list_array_to_vector(
}
out.set_len(value_array.len());
}
DataType::Utf8 => {
string_array_to_vector(as_string_array(value_array.as_ref()), &mut child);
}
_ => {
return Err("Nested list is not supported yet.".into());
}
}
for i in 0..array.len() {
let offset = array.value_offset(i);
let length = array.value_length();
out.set_entry(i, offset as usize, length as usize);
}
out.set_len(value_array.len());

Ok(())
}
Expand Down Expand Up @@ -543,10 +555,12 @@ mod test {
use crate::{Connection, Result};
use arrow::{
array::{
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, Int32Array,
PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray,
Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray,
Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray,
},
buffer::{OffsetBuffer, ScalarBuffer},
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
record_batch::RecordBatch,
};
Expand Down Expand Up @@ -676,6 +690,62 @@ mod test {
Ok(())
}

fn check_generic_array_roundtrip<T>(arry: GenericListArray<T>) -> Result<(), Box<dyn Error>>
where
T: OffsetSizeTrait,
{
let expected_output_array = arry.clone();

let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);
assert!(output_any_array
.data_type()
.equals_datatype(expected_output_array.data_type()));

match output_any_array.as_list_opt::<T>() {
Some(output_array) => {
assert_eq!(output_array.len(), expected_output_array.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), expected_output_array.is_valid(i));
if output_array.is_valid(i) {
assert!(expected_output_array.value(i).eq(&output_array.value(i)));
}
}
}
None => panic!("Expected GenericListArray"),
}

Ok(())
}

#[test]
fn test_array_roundtrip() -> Result<(), Box<dyn Error>> {
check_generic_array_roundtrip(ListArray::new(
Arc::new(Field::new("item", DataType::Utf8, true)),
OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])),
Arc::new(StringArray::from(vec![
Some("foo"),
Some("baz"),
Some("bar"),
Some("foo"),
Some("baz"),
])),
None,
))?;

Ok(())
}

#[test]
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;
Expand Down

0 comments on commit 5a1729e

Please sign in to comment.