Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Implement vector field type #1899

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 40 additions & 0 deletions dev/src/field-value.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,42 @@ import {

import api = proto.google.firestore.v1;

export class VectorValue implements firestore.VectorValue {
private readonly _values: number[];
constructor(values: number[] | undefined) {
// Making a copy of the parameter.
this._values = (values || []).map(n => n);
}

public toArray(): number[] {
return this._values.map(n => n);
}

/**
* @private
*/
toProto(serializer: Serializer): api.IValue {
return serializer.encodeVector(this._values);
}

/**
* @private
*/
static fromProto(valueArray: api.IValue): VectorValue {
const values = valueArray.arrayValue?.values?.map(v => {
return v.doubleValue!;
});
return new VectorValue(values);
}

/**
* @private
*/
isEqual(other: VectorValue): boolean {
return this._values === other._values;
}
}

/**
* Sentinel values that can be used when writing documents with set(), create()
* or update().
Expand All @@ -40,6 +76,10 @@ export class FieldValue implements firestore.FieldValue {
/** @private */
constructor() {}

static vector(values?: number[]): VectorValue {
return new VectorValue(values);
}

/**
* Returns a sentinel for use with update() or set() with {merge:true} to mark
* a field for deletion.
Expand Down
57 changes: 51 additions & 6 deletions dev/src/serializer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {DocumentData} from '@google-cloud/firestore';
import * as proto from '../protos/firestore_v1_proto_api';

import {detectValueType} from './convert';
import {DeleteTransform, FieldTransform} from './field-value';
import {DeleteTransform, FieldTransform, VectorValue} from './field-value';
import {GeoPoint} from './geo-point';
import {DocumentReference, Firestore} from './index';
import {FieldPath, QualifiedResourcePath} from './path';
Expand All @@ -38,6 +38,10 @@ import api = proto.google.firestore.v1;
*/
const MAX_DEPTH = 20;

const RESERVED_MAP_KEY = '__type__';
const RESERVED_MAP_KEY_VECTOR_VALUE = '__vector__';
const VECTOR_MAP_VECTORS_KEY = 'value';

/**
* An interface for Firestore types that can be serialized to Protobuf.
*
Expand Down Expand Up @@ -168,6 +172,10 @@ export class Serializer {
};
}

if (val instanceof VectorValue) {
return val.toProto(this);
}

if (isObject(val)) {
const toProto = val['toProto'];
if (typeof toProto === 'function') {
Expand Down Expand Up @@ -217,6 +225,31 @@ export class Serializer {
throw new Error(`Cannot encode value: ${val}`);
}

/**
* @private
*/
encodeVector(rawVector: number[]): api.IValue {
// A Firestore Vector is a map with reserved key/value pairs.
return {
mapValue: {
fields: {
[RESERVED_MAP_KEY]: {
stringValue: RESERVED_MAP_KEY_VECTOR_VALUE,
},
[VECTOR_MAP_VECTORS_KEY]: {
arrayValue: {
values: rawVector.map(value => {
return {
doubleValue: value,
};
}),
},
},
},
},
};
}

/**
* Decodes a single Firestore 'Value' Protobuf.
*
Expand Down Expand Up @@ -263,15 +296,25 @@ export class Serializer {
return null;
}
case 'mapValue': {
const obj: DocumentData = {};
const fields = proto.mapValue!.fields;
if (fields) {
for (const prop of Object.keys(fields)) {
obj[prop] = this.decodeValue(fields[prop]);
const props = Object.keys(fields);
if (
props.indexOf(RESERVED_MAP_KEY) !== -1 &&
this.decodeValue(fields[RESERVED_MAP_KEY]) ===
RESERVED_MAP_KEY_VECTOR_VALUE
) {
return VectorValue.fromProto(fields[VECTOR_MAP_VECTORS_KEY]);
} else {
const obj: DocumentData = {};
for (const prop of Object.keys(fields)) {
obj[prop] = this.decodeValue(fields[prop]);
}
return obj;
}
} else {
return {};
}

return obj;
}
case 'geoPointValue': {
return GeoPoint.fromProto(proto.geoPointValue!);
Expand Down Expand Up @@ -367,6 +410,8 @@ export function validateUserInput(
'If you want to ignore undefined values, enable `ignoreUndefinedProperties`.'
);
}
} else if (value instanceof VectorValue) {
// OK
} else if (value instanceof DeleteTransform) {
if (inArray) {
throw new Error(
Expand Down
109 changes: 109 additions & 0 deletions dev/system-test/firestore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,30 @@
return promise;
});

it.only('can write and read vector embeddings', async () => {

Check failure on line 1021 in dev/system-test/firestore.ts

View workflow job for this annotation

GitHub Actions / lint

'it.only' is restricted from being used
const ref = randomCol.doc();
await ref.create({
vectorEmpty: FieldValue.vector(),
vector1: FieldValue.vector([1, 2, 3.99]),
});
await ref.set({
vectorEmpty: FieldValue.vector(),
vector1: FieldValue.vector([1, 2, 3.99]),
vector2: FieldValue.vector([0, 0, 0]),
});
await ref.update({
vector3: FieldValue.vector([-1, -200, -999]),
});

const snap1 = await ref.get();
expect(snap1.get('vectorEmpty')).to.deep.equal(FieldValue.vector());
expect(snap1.get('vector1')).to.deep.equal(FieldValue.vector([1, 2, 3.99]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using VectorValue.isEqual instead of to.deep.equal because to.deep.equal relies on equality testing using private members.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to change this in the other PR to save myself from conflict resolving, hopefully that is OK with you.

expect(snap1.get('vector2')).to.deep.equal(FieldValue.vector([0, 0, 0]));
expect(snap1.get('vector3')).to.deep.equal(
FieldValue.vector([-1, -200, -999])
);
});

describe('watch', () => {
const currentDeferred = new DeferredPromise<DocumentSnapshot>();

Expand Down Expand Up @@ -1311,6 +1335,91 @@
const result2 = await ref2.get();
expect(result2.data()).to.deep.equal([1, 2, 3]);
});

it.only('can listen to documents with vectors', async () => {

Check failure on line 1339 in dev/system-test/firestore.ts

View workflow job for this annotation

GitHub Actions / lint

'it.only' is restricted from being used
const ref = randomCol.doc();
const initialDeferred = new Deferred<void>();
const createDeferred = new Deferred<void>();
const setDeferred = new Deferred<void>();
const updateDeferred = new Deferred<void>();
const deleteDeferred = new Deferred<void>();

const expected = [
initialDeferred,
createDeferred,
setDeferred,
updateDeferred,
deleteDeferred,
];
let idx = 0;
let document: DocumentSnapshot | null = null;

const unlisten = randomCol
.where('purpose', '==', 'vector tests')
.onSnapshot(snap => {
expected[idx].resolve();
idx += 1;
if (snap.docs.length > 0) {
document = snap.docs[0];
} else {
document = null;
}
});

await initialDeferred.promise;
expect(document).to.be.null;

await ref.create({
purpose: 'vector tests',
vectorEmpty: FieldValue.vector(),
vector1: FieldValue.vector([1, 2, 3.99]),
});

await createDeferred.promise;
expect(document).to.be.not.null;
expect(document!.get('vectorEmpty')).to.deep.equal(FieldValue.vector());
expect(document!.get('vector1')).to.deep.equal(
FieldValue.vector([1, 2, 3.99])
);

await ref.set({
purpose: 'vector tests',
vectorEmpty: FieldValue.vector(),
vector1: FieldValue.vector([1, 2, 3.99]),
vector2: FieldValue.vector([0, 0, 0]),
});
await setDeferred.promise;
expect(document).to.be.not.null;
expect(document!.get('vectorEmpty')).to.deep.equal(FieldValue.vector());
expect(document!.get('vector1')).to.deep.equal(
FieldValue.vector([1, 2, 3.99])
);
expect(document!.get('vector2')).to.deep.equal(
FieldValue.vector([0, 0, 0])
);

await ref.update({
vector3: FieldValue.vector([-1, -200, -999]),
});
await updateDeferred.promise;
expect(document).to.be.not.null;
expect(document!.get('vectorEmpty')).to.deep.equal(FieldValue.vector());
expect(document!.get('vector1')).to.deep.equal(
FieldValue.vector([1, 2, 3.99])
);
expect(document!.get('vector2')).to.deep.equal(
FieldValue.vector([0, 0, 0])
);
expect(document!.get('vector3')).to.deep.equal(
FieldValue.vector([-1, -200, -999])
);

await ref.delete();
await deleteDeferred.promise;
expect(document).to.be.null;

unlisten();
});
});

describe('runs query on a large collection', () => {
Expand Down
77 changes: 77 additions & 0 deletions dev/test/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,43 @@ describe('serialize document', () => {
return ref.set({ref});
});
});

it('is able to translate FirestoreVector to internal representation with set', () => {
const overrides: ApiOverride = {
commit: request => {
requestEquals(
request,
set({
document: document('documentId', 'embedding1', {
mapValue: {
fields: {
__type__: {
stringValue: '__vector__',
},
value: {
arrayValue: {
values: [
{doubleValue: 0},
{doubleValue: 1},
{doubleValue: 2},
],
},
},
},
},
}),
})
);
return response(writeResult(1));
},
};

return createInstance(overrides).then(firestore => {
return firestore.doc('collectionId/documentId').set({
embedding1: FieldValue.vector([0, 1, 2]),
});
});
});
});

describe('deserialize document', () => {
Expand Down Expand Up @@ -599,6 +636,46 @@ describe('deserialize document', () => {
});
});

it('deserializes FirestoreVector', () => {
const overrides: ApiOverride = {
batchGetDocuments: () => {
return stream(
found(
document('documentId', 'embedding', {
mapValue: {
fields: {
__type__: {
stringValue: '__vector__',
},
value: {
arrayValue: {
values: [
{doubleValue: -41.0},
{doubleValue: 0},
{doubleValue: 42},
],
},
},
},
},
})
)
);
},
};

return createInstance(overrides).then(firestore => {
return firestore
.doc('collectionId/documentId')
.get()
.then(res => {
expect(res.get('embedding')).to.deep.equal(
FieldValue.vector([-41.0, 0, 42])
);
});
});
});

it("doesn't deserialize unsupported types", () => {
const overrides: ApiOverride = {
batchGetDocuments: () => {
Expand Down
2 changes: 1 addition & 1 deletion dev/test/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ describe('FirestoreTypeConverter', () => {
await newDocRef.set({stringProperty: 'foo', numberProperty: 42});
await newDocRef.update({a: 'newFoo', b: 43});
const snapshot = await newDocRef.get();
const data: MyModelType = snapshot.data()!;
const data = snapshot.data()!;
expect(data.stringProperty).to.equal('newFoo');
expect(data.numberProperty).to.equal(43);
}
Expand Down