Skip to content

Commit

Permalink
Implement vector field type
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-hui committed Sep 15, 2023
1 parent 81d4423 commit ad58cc1
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 6 deletions.
47 changes: 47 additions & 0 deletions dev/src/field-value.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,49 @@ import {

import api = proto.google.firestore.v1;

export class VectorValue {
private readonly values: number[];
constructor(values: number[] | undefined) {
this.values = values || [];
}

public toArray(): number[] {
return this.values;
}

/**
* @private
*/
toProto(serializer: Serializer): api.IValue {
return serializer.encodeVector({
arrayValue: {
values: this.values.map(value => {
return {
doubleValue: value,
};
}),
},
});
}

/**
* @private
*/
static fromProto(valueArray: api.IValue): VectorValue {
const values = valueArray.arrayValue?.values?.map(v => {
return v.doubleValue!;
});
return new VectorValue(values);
}

/**
* @private
*/
isEqual(other: VectorValue): boolean {
return this.values === other.values;
}
}

/**
* Sentinel values that can be used when writing documents with set(), create()
* or update().
Expand All @@ -40,6 +83,10 @@ export class FieldValue implements firestore.FieldValue {
/** @private */
constructor() {}

static vector(values?: number[]): VectorValue {
return new VectorValue(values);
}

/**
* Returns a sentinel for use with update() or set() with {merge:true} to mark
* a field for deletion.
Expand Down
50 changes: 44 additions & 6 deletions dev/src/serializer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {DocumentData} from '@google-cloud/firestore';
import * as proto from '../protos/firestore_v1_proto_api';

import {detectValueType} from './convert';
import {DeleteTransform, FieldTransform} from './field-value';
import {DeleteTransform, FieldTransform, VectorValue} from './field-value';
import {GeoPoint} from './geo-point';
import {DocumentReference, Firestore} from './index';
import {FieldPath, QualifiedResourcePath} from './path';
Expand All @@ -38,6 +38,10 @@ import api = proto.google.firestore.v1;
*/
const MAX_DEPTH = 20;

const RESERVED_MAP_KEY = '__type__';
const RESERVED_MAP_KEY_VECTOR_VALUE = '__vector__';
const RESERVED_VECTOR_MAP_VECTORS_KEY = 'value';

/**
* An interface for Firestore types that can be serialized to Protobuf.
*
Expand Down Expand Up @@ -168,6 +172,10 @@ export class Serializer {
};
}

if (val instanceof VectorValue) {
return val.toProto(this);
}

if (isObject(val)) {
const toProto = val['toProto'];
if (typeof toProto === 'function') {
Expand Down Expand Up @@ -217,6 +225,22 @@ export class Serializer {
throw new Error(`Cannot encode value: ${val}`);
}

/**
* @private
*/
encodeVector(vectorValue: api.IValue): api.IValue {
return {
mapValue: {
fields: {
[RESERVED_MAP_KEY]: {
stringValue: RESERVED_MAP_KEY_VECTOR_VALUE,
},
[RESERVED_VECTOR_MAP_VECTORS_KEY]: vectorValue,
},
},
};
}

/**
* Decodes a single Firestore 'Value' Protobuf.
*
Expand Down Expand Up @@ -263,15 +287,27 @@ export class Serializer {
return null;
}
case 'mapValue': {
const obj: DocumentData = {};
const fields = proto.mapValue!.fields;
if (fields) {
for (const prop of Object.keys(fields)) {
obj[prop] = this.decodeValue(fields[prop]);
const props = Object.keys(fields);
if (
props.indexOf(RESERVED_MAP_KEY) !== -1 &&
this.decodeValue(fields[RESERVED_MAP_KEY]) ===
RESERVED_MAP_KEY_VECTOR_VALUE
) {
return VectorValue.fromProto(
fields[RESERVED_VECTOR_MAP_VECTORS_KEY]
);
} else {
const obj: DocumentData = {};
for (const prop of Object.keys(fields)) {
obj[prop] = this.decodeValue(fields[prop]);
}
return obj;
}
} else {
return {};
}

return obj;
}
case 'geoPointValue': {
return GeoPoint.fromProto(proto.geoPointValue!);
Expand Down Expand Up @@ -367,6 +403,8 @@ export function validateUserInput(
'If you want to ignore undefined values, enable `ignoreUndefinedProperties`.'
);
}
} else if (value instanceof VectorValue) {
// OK
} else if (value instanceof DeleteTransform) {
if (inArray) {
throw new Error(
Expand Down
4 changes: 4 additions & 0 deletions dev/src/v1/firestore_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,10 @@ export class FirestoreClient {
database: request.database ?? '',
});
this.initialize();
// @ts-ignore

Check failure on line 859 in dev/src/v1/firestore_client.ts

View workflow job for this annotation

GitHub Actions / lint

Do not use "@ts-ignore" because it alters compilation errors
console.log(
`update is ${JSON.stringify(request.writes![0].update, null, 2)}`
);
return this.innerApiCalls.commit(request, options, callback);
}
/**
Expand Down
11 changes: 11 additions & 0 deletions dev/system-test/firestore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,17 @@ describe('DocumentReference class', () => {
return promise;
});

it.only('can write and read vector embeddings', async () => {

Check failure on line 1018 in dev/system-test/firestore.ts

View workflow job for this annotation

GitHub Actions / lint

'it.only' is restricted from being used
const ref = randomCol.doc();
await ref.create({
vectorEmpty: FieldValue.vector([1, 3]),
vector1: FieldValue.vector([1, 2, 3.99]),
});
const snap1 = await ref.get();
expect(snap1.get('vectorEmpty')).to.deep.equal(FieldValue.vector());
// expect(snap1.get('vector1')).to.deep.equal(FieldValue.vector([1, 2, 3.99]));
});

describe('watch', () => {
const currentDeferred = new DeferredPromise<DocumentSnapshot>();

Expand Down
77 changes: 77 additions & 0 deletions dev/test/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,43 @@ describe('serialize document', () => {
return ref.set({ref});
});
});

it('is able to translate FirestoreVector to internal representation with set', () => {
const overrides: ApiOverride = {
commit: request => {
requestEquals(
request,
set({
document: document('documentId', 'embedding1', {
mapValue: {
fields: {
__type__: {
stringValue: '__vector__',
},
value: {
arrayValue: {
values: [
{doubleValue: 0},
{doubleValue: 1},
{doubleValue: 2},
],
},
},
},
},
}),
})
);
return response(writeResult(1));
},
};

return createInstance(overrides).then(firestore => {
return firestore.doc('collectionId/documentId').set({
embedding1: FieldValue.vector([0, 1, 2]),
});
});
});
});

describe('deserialize document', () => {
Expand Down Expand Up @@ -598,6 +635,46 @@ describe('deserialize document', () => {
});
});

it('deserializes FirestoreVector', () => {
const overrides: ApiOverride = {
batchGetDocuments: () => {
return stream(
found(
document('documentId', 'embedding', {
mapValue: {
fields: {
__type__: {
stringValue: '__vector__',
},
value: {
arrayValue: {
values: [
{doubleValue: -41.0},
{doubleValue: 0},
{doubleValue: 42},
],
},
},
},
},
})
)
);
},
};

return createInstance(overrides).then(firestore => {
return firestore
.doc('collectionId/documentId')
.get()
.then(res => {
expect(res.get('embedding')).to.deep.equal(
FieldValue.vector([-41.0, 0, 42])
);
});
});
});

it("doesn't deserialize unsupported types", () => {
const overrides: ApiOverride = {
batchGetDocuments: () => {
Expand Down

0 comments on commit ad58cc1

Please sign in to comment.