Skip to content

Commit

Permalink
fix: sanitize foreign schemas (#1058)
Browse files Browse the repository at this point in the history
Arrow-js uses brittle `instanceof` checks throughout the code base.
These fail unless the library instance that produced the object matches
exactly the same instance the vectordb is using. At a minimum, this
means that a user using arrow version 15 (or any version that doesn't
match exactly the version that vectordb is using) will get strange
errors when they try and use vectordb.

However, there are even cases where the versions can be perfectly
identical, and the instanceof check still fails. One such example is
when using `vite` (e.g. vitejs/vite#3910)

This PR solves the problem in a rather brute force, but workable,
fashion. If we encounter a schema that does not pass the `instanceof`
check then we will attempt to sanitize that schema by traversing the
object and, if it has all the correct properties, constructing an
appropriate `Schema` instance via deep cloning.
  • Loading branch information
westonpace committed Apr 5, 2024
1 parent 785ecfa commit c60a193
Show file tree
Hide file tree
Showing 11 changed files with 1,241 additions and 42 deletions.
5 changes: 5 additions & 0 deletions node/.eslintrc.js
Expand Up @@ -13,5 +13,10 @@ module.exports = {
},
rules: {
"@typescript-eslint/method-signature-style": "off",
"@typescript-eslint/quotes": "off",
"@typescript-eslint/semi": "off",
"@typescript-eslint/explicit-function-return-type": "off",
"@typescript-eslint/space-before-function-paren": "off",
"@typescript-eslint/indent": "off",
}
}
3 changes: 2 additions & 1 deletion node/package.json
Expand Up @@ -41,6 +41,7 @@
"@types/temp": "^0.9.1",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^5.59.1",
"apache-arrow-old": "npm:apache-arrow@13.0.0",
"cargo-cp-artifact": "^0.1",
"chai": "^4.3.7",
"chai-as-promised": "^7.1.1",
Expand Down Expand Up @@ -93,4 +94,4 @@
"@lancedb/vectordb-linux-x64-gnu": "0.4.11",
"@lancedb/vectordb-win32-x64-msvc": "0.4.11"
}
}
}
27 changes: 23 additions & 4 deletions node/src/arrow.ts
Expand Up @@ -20,19 +20,20 @@ import {
type Vector,
FixedSizeList,
vectorFromArray,
type Schema,
Schema,
Table as ArrowTable,
RecordBatchStreamWriter,
List,
RecordBatch,
makeData,
Struct,
type Float,
Float,
DataType,
Binary,
Float32
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
import { sanitizeSchema } from './sanitize'

/*
* Options to control how a column should be converted to a vector array
Expand Down Expand Up @@ -201,10 +202,13 @@ export function makeArrowTable (
}

const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema)
}
const columns: Record<string, Vector> = {}
// TODO: sample dataset to find missing columns
// Prefer the field ordering of the schema, if present
const columnNames = ((options?.schema) != null) ? (options?.schema?.names as string[]) : Object.keys(data[0])
const columnNames = ((opt.schema) != null) ? (opt.schema.names as string[]) : Object.keys(data[0])
for (const colName of columnNames) {
if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) {
// The field is present in the schema, but not in the data, skip it
Expand Down Expand Up @@ -329,6 +333,9 @@ async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunc
if (embeddings == null) {
return table
}
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema)
}

// Convert from ArrowTable to Record<String, Vector>
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
Expand Down Expand Up @@ -439,6 +446,9 @@ export async function fromRecordsToBuffer<T> (
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema)
}
const table = await convertToTable(data, embeddings, { schema })
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
Expand All @@ -456,6 +466,9 @@ export async function fromRecordsToStreamBuffer<T> (
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
}
const table = await convertToTable(data, embeddings, { schema })
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
Expand All @@ -474,6 +487,9 @@ export async function fromTableToBuffer<T> (
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
}
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings)
return Buffer.from(await writer.toUint8Array())
Expand All @@ -492,6 +508,9 @@ export async function fromTableToStreamBuffer<T> (
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
}
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings)
return Buffer.from(await writer.toUint8Array())
Expand Down Expand Up @@ -528,5 +547,5 @@ function alignTable (table: ArrowTable, schema: Schema): ArrowTable {

// Creates an empty Arrow Table
export function createEmptyTable (schema: Schema): ArrowTable {
return new ArrowTable(schema)
return new ArrowTable(sanitizeSchema(schema))
}

0 comments on commit c60a193

Please sign in to comment.