diff --git a/README.md b/README.md index 1a5a625..5499040 100644 --- a/README.md +++ b/README.md @@ -832,7 +832,7 @@ curl -X POST http://localhost:3000/api/vector-retry-failed-batch \ ### Local API -The plugin provides a `getVectorizedPayload(payload)` function which returns a `vectorizedPayload` object exposing `search`, `queueEmbed`, `bulkEmbed`, and `retryFailedBatch` methods. +The plugin provides a `getVectorizedPayload(payload)` function which returns a `vectorizedPayload` object exposing `search`, `findByIds`, `queueEmbed`, `bulkEmbed`, and `retryFailedBatch` methods. #### Getting the Vectorized Payload Object @@ -883,6 +883,33 @@ const results = await vectorizedPayload.search({ }) ``` +#### `vectorizedPayload.findByIds(params)` + +Fetch stored embedding records by primary key. The `id` of each record is whatever [`search()`](#vectorizedpayloadsearchparams) returns as `result.id`, so a search result round-trips directly. Pass `populateEmbedding: true` to also get the raw embedding vector back (the normal search/query API never returns it) — the building block for "more like this" flows. It defaults to `false`, so by default you get the record's text and metadata without the heavy vector. + +**Params:** `{ knowledgePool: string; ids: string[]; populateEmbedding?: boolean }` (`populateEmbedding` defaults to `false`). + +**Returns:** `Promise>` — an object keyed by the ids you passed in. Each requested id is present as a key; a found record is the value, and an unknown or malformed id maps to `undefined`. `EmbeddingRecord` is the search result shape without `score` and with an optional `embedding?: number[]`, present only when `populateEmbedding: true`. + +**Example:** + +```typescript +const id = '' +const records = await vectorizedPayload.findByIds({ + knowledgePool: 'mainKnowledgePool', + ids: [id], + populateEmbedding: true, +}) + +const record = records[id] +if (record) { + // record.embedding is the raw number[] vector — feed it back into search for "more like this" + console.log(record.embedding!.length, record.chunkText) +} +``` + +Because the result is keyed by id, a search result round-trips directly (`records[searchHit.id]`) and there's no positional alignment to worry about — look records up by id rather than relying on key order. Unknown or malformed ids map to `undefined` (never throw), and an empty `ids` array returns `{}` without touching the backend. + #### `vectorizedPayload.queueEmbed(params)` Manually queue a vectorization job for a document. diff --git a/adapters/README.md b/adapters/README.md index 096a6e4..9b9cbac 100644 --- a/adapters/README.md +++ b/adapters/README.md @@ -110,6 +110,7 @@ import type { KnowledgePoolDynamicConfig, StoreChunkData, VectorSearchResult, + EmbeddingRecord, } from 'payloadcms-vectorize' export type DbAdapter = { @@ -150,6 +151,13 @@ export type DbAdapter = { limit?: number, where?: Where, ) => Promise> + + findByIds: ( + payload: BasePayload, + poolName: KnowledgePoolName, + ids: string[], + populateEmbedding?: boolean, + ) => Promise> } ``` @@ -162,6 +170,7 @@ export type DbAdapter = { | `deleteChunks` | After a source document is deleted. | Remove every chunk where `sourceCollection === ... && docId === ...`. Must be safe to call when no chunks exist (no-op, no throw). | | `hasEmbeddingVersion` | During bulk-embed planning, per candidate document. | Return `true` iff at least one chunk exists with the matching `(sourceCollection, docId, embeddingVersion)` triple. Must filter on **all three** — older `0.7.0` adapters that ignored `embeddingVersion` caused stale embeddings on model bumps. | | `search` | Per `/vector-search` request and per `getVectorizedPayload().search()` call. | Translate `where` (Payload-style) into your store's filter language, perform a vector search using `queryEmbedding`, and return up to `limit` results sorted by descending relevance. | +| `findByIds` | Per `getVectorizedPayload().findByIds()` call. | Fetch stored embedding records by primary key. **Return an object keyed by the ids you were given:** every requested id must be present as a key, with a found record as the value and `undefined` for any id that didn't resolve. The raw `embedding` vector is **only included when `populateEmbedding` is `true`** (default `false`) — omit it otherwise so callers that only need text/metadata don't pay for it. Where possible, skip reading the vector at the source (pg: don't select the column; MongoDB: `{ projection: { embedding: 0 } }`); CF's `getByIds` always returns values, so omit them post-fetch. Look up by the same `id` your `search` returns as `result.id`. Unknown **and** malformed ids must map to `undefined` — never throw for a bad id. Validate the id shape against your key type before querying so a malformed id can't error the whole batch (MongoDB drops non-24-hex ids; pg drops ids that don't match the PK column type — numeric for integer PKs, uuid-shaped for `uuid` PKs — before the `IN` query; CF's ids are arbitrary strings, so an unknown one is simply absent from `getByIds`). Empty `ids` returns `{}` without a backend call. | ### Error contract @@ -286,6 +295,14 @@ export const createYourDbVectorIntegration = ( // Return Array sorted by descending score. return [] }, + + findByIds: async (payload, poolName, ids, populateEmbedding = false) => { + // TODO: fetch stored records by primary key. Include the raw `embedding` vector + // only when `populateEmbedding` is true (default false); skip reading it otherwise. + // Return an object keyed by every requested id: a record for hits, `undefined` + // for unknown or malformed ids (never throw for a bad id). + return Object.fromEntries(ids.map((id) => [id, undefined])) + }, } return { adapter } @@ -361,6 +378,26 @@ export interface VectorSearchResult { /** Any extensionFields persisted via storeChunk must round-trip here. */ [key: string]: any } + +export interface EmbeddingRecord { + /** Embedding record ID — the same value your adapter returns as VectorSearchResult.id. */ + id: string + /** Source collection slug (echoed from StoreChunkData). */ + sourceCollection: string + /** Source document ID (echoed from StoreChunkData). */ + docId: string + /** Chunk index within the source document. */ + chunkIndex: number + /** The original chunk text. */ + chunkText: string + /** Embedding model/version string. */ + embeddingVersion: string + /** The raw embedding vector — never returned by `search`, and only present + * when `findByIds` is called with `populateEmbedding: true`. */ + embedding?: number[] + /** Any extensionFields persisted via storeChunk round-trip here. */ + [key: string]: any +} ``` | Field | Required | Notes | @@ -371,6 +408,8 @@ export interface VectorSearchResult { | `chunkText`, `embeddingVersion` | yes | Same. | | `extensionFields.*` | optional | Whatever the user passed in `extensionFields` must be queryable via `where`. | +> `EmbeddingRecord` (returned by `findByIds`) is `VectorSearchResult` without `score` and with an optional raw `embedding?: number[]` — present only when `findByIds` is called with `populateEmbedding: true`. + ## Testing your adapter The dev harness in [`dev/`](../dev) runs the integration suite against any adapter you wire up. To test a new adapter: diff --git a/adapters/cf/dev/specs/adapter.spec.ts b/adapters/cf/dev/specs/adapter.spec.ts index 3ecf20e..ee988be 100644 --- a/adapters/cf/dev/specs/adapter.spec.ts +++ b/adapters/cf/dev/specs/adapter.spec.ts @@ -61,6 +61,13 @@ function createMockCloudflareBinding() { } }), + getByIds: vi.fn(async (ids: string[]) => { + return ids + .map((id) => storage.get(id)) + .filter((v): v is { id: string; values: number[]; metadata: any } => v !== undefined) + .map((v) => ({ id: v.id, values: v.values, metadata: v.metadata })) + }), + list: vi.fn(async (options: any) => { const vectors = Array.from(storage.values()).map((item) => ({ id: item.id, @@ -431,4 +438,104 @@ describe('createCloudflareVectorizeIntegration', () => { }) }) }) + + describe('findByIds', () => { + test('returns full EmbeddingRecord including embedding values when populateEmbedding is true', async () => { + const mockBinding = createMockCloudflareBinding() + const { adapter } = createCloudflareVectorizeIntegration({ + config: { default: { dims: DIMS } }, + binding: mockBinding as any, + }) + const mockPayload = createMockPayload(mockBinding) + const embedding = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] + + await adapter.storeChunk(mockPayload, 'default', { + sourceCollection: 'posts', + docId: 'doc-1', + chunkIndex: 0, + chunkText: 'find me', + embeddingVersion: 'v1', + embedding, + extensionFields: { category: 'science' }, + }) + + const id = 'default:posts:doc-1:0' + const records = await adapter.findByIds(mockPayload, 'default', [id], true) + expect(Object.keys(records)).toEqual([id]) + const r = records[id]! + expect(r.id).toBe(id) + expect(r.embedding).toEqual(embedding) + expect(r.sourceCollection).toBe('posts') + expect(r.docId).toBe('doc-1') + expect(r.chunkText).toBe('find me') + expect(r.embeddingVersion).toBe('v1') + expect((r as any).category).toBe('science') + }) + + test('omits embedding values by default', async () => { + const mockBinding = createMockCloudflareBinding() + const { adapter } = createCloudflareVectorizeIntegration({ + config: { default: { dims: DIMS } }, + binding: mockBinding as any, + }) + const mockPayload = createMockPayload(mockBinding) + + await adapter.storeChunk(mockPayload, 'default', { + sourceCollection: 'posts', + docId: 'doc-1', + chunkIndex: 0, + chunkText: 'find me', + embeddingVersion: 'v1', + embedding: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + extensionFields: { category: 'science' }, + }) + + const id = 'default:posts:doc-1:0' + const records = await adapter.findByIds(mockPayload, 'default', [id]) + expect(Object.keys(records)).toEqual([id]) + const r = records[id]! + expect(r.id).toBe(id) + expect(r.embedding).toBeUndefined() + expect(r.chunkText).toBe('find me') + expect((r as any).category).toBe('science') + }) + + test('maps misses to undefined', async () => { + const mockBinding = createMockCloudflareBinding() + const { adapter } = createCloudflareVectorizeIntegration({ + config: { default: { dims: DIMS } }, + binding: mockBinding as any, + }) + const mockPayload = createMockPayload(mockBinding) + await adapter.storeChunk(mockPayload, 'default', { + sourceCollection: 'posts', + docId: 'doc-1', + chunkIndex: 0, + chunkText: 'x', + embeddingVersion: 'v1', + embedding: [0, 0, 0, 0, 0, 0, 0, 0], + extensionFields: {}, + }) + const records = await adapter.findByIds(mockPayload, 'default', [ + 'default:posts:doc-1:0', + 'default:posts:nope:0', + ]) + expect(Object.keys(records).sort()).toEqual( + ['default:posts:doc-1:0', 'default:posts:nope:0'].sort(), + ) + expect(records['default:posts:doc-1:0']!.id).toBe('default:posts:doc-1:0') + expect(records['default:posts:nope:0']).toBeUndefined() + }) + + test('empty ids returns {}', async () => { + const mockBinding = createMockCloudflareBinding() + const { adapter } = createCloudflareVectorizeIntegration({ + config: { default: { dims: DIMS } }, + binding: mockBinding as any, + }) + const mockPayload = createMockPayload(mockBinding) + const records = await adapter.findByIds(mockPayload, 'default', []) + expect(records).toEqual({}) + }) + }) }) diff --git a/adapters/cf/src/findByIds.ts b/adapters/cf/src/findByIds.ts new file mode 100644 index 0000000..c8f2104 --- /dev/null +++ b/adapters/cf/src/findByIds.ts @@ -0,0 +1,48 @@ +import { BasePayload } from 'payload' +import { KnowledgePoolName, EmbeddingRecord } from 'payloadcms-vectorize' +import { getVectorizeBinding } from './types.js' + +const RESERVED_METADATA = ['sourceCollection', 'docId', 'chunkIndex', 'chunkText', 'embeddingVersion'] + +export default async ( + payload: BasePayload, + _poolName: KnowledgePoolName, + ids: string[], + populateEmbedding = false, +): Promise> => { + const result: Record = {} + for (const id of ids) result[id] = undefined + if (ids.length === 0) return result + + const binding = getVectorizeBinding(payload) + + try { + const vectors = await binding.getByIds(ids) + if (!vectors) return result + + for (const vector of vectors) { + const metadata = (vector.metadata || {}) as Record + const extensionFields = Object.fromEntries( + Object.entries(metadata).filter(([k]) => !RESERVED_METADATA.includes(k)), + ) + result[vector.id] = { + id: vector.id, + sourceCollection: String(metadata.sourceCollection ?? ''), + docId: String(metadata.docId ?? ''), + chunkIndex: + typeof metadata.chunkIndex === 'number' + ? metadata.chunkIndex + : parseInt(String(metadata.chunkIndex ?? '0'), 10), + chunkText: String(metadata.chunkText ?? ''), + embeddingVersion: String(metadata.embeddingVersion ?? ''), + ...(populateEmbedding ? { embedding: Array.from(vector.values ?? []) } : {}), + ...extensionFields, + } + } + return result + } catch (e) { + const errorMessage = e instanceof Error ? e.message : String(e) + payload.logger.error(`[@payloadcms-vectorize/cf] findByIds failed: ${errorMessage}`) + throw new Error(`[@payloadcms-vectorize/cf] findByIds failed: ${errorMessage}`) + } +} diff --git a/adapters/cf/src/index.ts b/adapters/cf/src/index.ts index f3a51bb..44a57b3 100644 --- a/adapters/cf/src/index.ts +++ b/adapters/cf/src/index.ts @@ -5,6 +5,7 @@ import type { CloudflareVectorizeBinding, KnowledgePoolsConfig, VectorizeBinding import cfMappingsCollection, { CF_MAPPINGS_SLUG } from './collections/cfMappings.js' import embed from './embed.js' import search from './search.js' +import findByIds from './findByIds.js' /** * Configuration for Cloudflare Vectorize integration @@ -113,6 +114,8 @@ export const createCloudflareVectorizeIntegration = ( } }, + findByIds, + hasEmbeddingVersion: async (payload, poolName, sourceCollection, docId, embeddingVersion) => { const result = await payload.find({ collection: CF_MAPPINGS_SLUG as CollectionSlug, diff --git a/adapters/mongodb/dev/specs/findByIds.spec.ts b/adapters/mongodb/dev/specs/findByIds.spec.ts new file mode 100644 index 0000000..231a0e4 --- /dev/null +++ b/adapters/mongodb/dev/specs/findByIds.spec.ts @@ -0,0 +1,103 @@ +import { afterAll, beforeAll, describe, expect, test } from 'vitest' +import { MongoClient } from 'mongodb' +import type { BasePayload } from 'payload' +import type { DbAdapter } from 'payloadcms-vectorize' +import { DIMS, MONGO_URI } from './constants.js' +import { buildMongoTestPayload, teardownDbs } from './utils.js' +import { testEmbeddingVersion, makeDummyEmbedDocs, makeDummyEmbedQuery } from '@shared-test/helpers/embed' + +const DB = `mongo_find_by_ids_${Date.now()}` + +describe('mongodb findByIds', () => { + let payload: BasePayload + let adapter: DbAdapter + let embeddingId: string + + beforeAll(async () => { + const built = await buildMongoTestPayload({ + uri: MONGO_URI, + dbName: DB, + pools: { default: { dimensions: DIMS, filterableFields: ['category'] } }, + knowledgePools: { + default: { + collections: {}, + extensionFields: [{ name: 'category', type: 'text' }], + embeddingConfig: { + version: testEmbeddingVersion, + queryFn: makeDummyEmbedQuery(DIMS), + realTimeIngestionFn: makeDummyEmbedDocs(DIMS), + }, + }, + }, + }) + payload = built.payload + adapter = built.adapter + + await adapter.storeChunk(payload, 'default', { + sourceCollection: 'posts', + docId: 'doc-1', + chunkIndex: 0, + chunkText: 'find me', + embeddingVersion: testEmbeddingVersion, + embedding: Array(DIMS).fill(0.25), + extensionFields: { category: 'science' }, + }) + + const c = new MongoClient(MONGO_URI) + await c.connect() + const doc = await c.db(`${DB}_vectors`).collection('vectorize_default').findOne({ docId: 'doc-1' }) + embeddingId = String(doc!._id) + await c.close() + }) + + afterAll(async () => { + await teardownDbs(payload, MONGO_URI, DB) + }) + + test('returns full EmbeddingRecord including numeric embedding array when populateEmbedding is true', async () => { + const records = await adapter.findByIds(payload, 'default', [embeddingId], true) + expect(Object.keys(records)).toEqual([embeddingId]) + const r = records[embeddingId]! + expect(r.id).toBe(embeddingId) + expect(Array.isArray(r.embedding)).toBe(true) + expect(r.embedding!.length).toBe(DIMS) + expect(r.embedding!.every((n) => typeof n === 'number')).toBe(true) + expect(r.sourceCollection).toBe('posts') + expect(r.chunkText).toBe('find me') + expect(r.embeddingVersion).toBe(testEmbeddingVersion) + }) + + test('omits the embedding array by default', async () => { + const records = await adapter.findByIds(payload, 'default', [embeddingId]) + expect(Object.keys(records)).toEqual([embeddingId]) + const r = records[embeddingId]! + expect(r.id).toBe(embeddingId) + expect(r.embedding).toBeUndefined() + expect(r.sourceCollection).toBe('posts') + expect(r.chunkText).toBe('find me') + }) + + test('includes extension fields', async () => { + const records = await adapter.findByIds(payload, 'default', [embeddingId]) + expect((records[embeddingId] as any).category).toBe('science') + }) + + test('maps misses and invalid ids to undefined without throwing', async () => { + const records = await adapter.findByIds(payload, 'default', [ + embeddingId, + '000000000000000000000000', + 'not-an-object-id', + ]) + expect(Object.keys(records).sort()).toEqual( + [embeddingId, '000000000000000000000000', 'not-an-object-id'].sort(), + ) + expect(records[embeddingId]!.id).toBe(embeddingId) + expect(records['000000000000000000000000']).toBeUndefined() + expect(records['not-an-object-id']).toBeUndefined() + }) + + test('empty ids returns {}', async () => { + const records = await adapter.findByIds(payload, 'default', []) + expect(records).toEqual({}) + }) +}) diff --git a/adapters/mongodb/src/findByIds.ts b/adapters/mongodb/src/findByIds.ts new file mode 100644 index 0000000..fe90203 --- /dev/null +++ b/adapters/mongodb/src/findByIds.ts @@ -0,0 +1,71 @@ +import type { BasePayload } from 'payload' +import type { EmbeddingRecord } from 'payloadcms-vectorize' +import { ObjectId } from 'mongodb' +import { getMongoClient } from './client.js' +import { RESERVED_FIELDS, type ResolvedPoolConfig } from './types.js' + +export interface MongoFindByIdsCtx { + uri: string + dbName: string + pools: Record +} + +const HEX24 = /^[a-f\d]{24}$/i +const RESERVED_AND_META = new Set([...RESERVED_FIELDS, '_id', 'createdAt', 'updatedAt']) + +export async function findByIdsImpl( + ctx: MongoFindByIdsCtx, + _payload: BasePayload, + poolName: string, + ids: string[], + populateEmbedding = false, +): Promise> { + const result: Record = {} + for (const id of ids) result[id] = undefined + if (ids.length === 0) return result + + const cfg = ctx.pools[poolName] + if (!cfg) { + throw new Error( + `[@payloadcms-vectorize/mongodb] Unknown pool "${poolName}". Configured pools: ${Object.keys(ctx.pools).join(', ')}`, + ) + } + + const objectIds = ids.filter((id) => HEX24.test(id)).map((id) => new ObjectId(id)) + if (objectIds.length === 0) return result + + const client = await getMongoClient(ctx.uri) + const docs = await client + .db(ctx.dbName) + .collection(cfg.collectionName) + .find({ _id: { $in: objectIds } }, populateEmbedding ? {} : { projection: { embedding: 0 } }) + .toArray() + + for (const doc of docs) { + const record = mapDocToRecord(doc as Record, populateEmbedding) + result[record.id] = record + } + return result +} + +function mapDocToRecord( + doc: Record, + populateEmbedding: boolean, +): EmbeddingRecord { + const extensionFields = Object.fromEntries( + Object.entries(doc).filter(([k]) => !RESERVED_AND_META.has(k)), + ) + return { + id: String(doc._id), + sourceCollection: String(doc.sourceCollection ?? ''), + docId: String(doc.docId ?? ''), + chunkIndex: + typeof doc.chunkIndex === 'number' ? doc.chunkIndex : Number(doc.chunkIndex ?? 0), + chunkText: String(doc.chunkText ?? ''), + embeddingVersion: String(doc.embeddingVersion ?? ''), + ...(populateEmbedding + ? { embedding: Array.isArray(doc.embedding) ? (doc.embedding as number[]) : [] } + : {}), + ...extensionFields, + } +} diff --git a/adapters/mongodb/src/index.ts b/adapters/mongodb/src/index.ts index 0ca05a4..5a0d2d2 100644 --- a/adapters/mongodb/src/index.ts +++ b/adapters/mongodb/src/index.ts @@ -2,6 +2,7 @@ import type { DbAdapter } from 'payloadcms-vectorize' import { getMongoClient } from './client.js' import { storeChunkImpl } from './embed.js' import { searchImpl } from './search.js' +import { findByIdsImpl } from './findByIds.js' import { resolvePoolConfig, type MongoVectorIntegrationConfig, @@ -89,6 +90,9 @@ export const createMongoVectorIntegration = ( search: (payload, queryEmbedding, poolName, limit, where) => searchImpl(getCtx(), payload, queryEmbedding, poolName, limit, where), + + findByIds: (payload, poolName, ids, populateEmbedding) => + findByIdsImpl(getCtx(), payload, poolName, ids, populateEmbedding), } return { adapter } diff --git a/adapters/pg/dev/specs/findByIds.spec.ts b/adapters/pg/dev/specs/findByIds.spec.ts new file mode 100644 index 0000000..f03a260 --- /dev/null +++ b/adapters/pg/dev/specs/findByIds.spec.ts @@ -0,0 +1,243 @@ +import type { Payload } from 'payload' +import { afterAll, beforeAll, describe, expect, test } from 'vitest' +import { postgresAdapter } from '@payloadcms/db-postgres' +import { eq } from '@payloadcms/db-postgres/drizzle' +import { getEmbeddingsTable } from '../../src/drizzle.js' +import { buildDummyConfig, integration, plugin, DIMS } from './constants.js' +import { createTestDb, destroyPayload, waitForVectorizationJobs } from './utils.js' +import { getPayload } from 'payload' +import { chunkText } from '@shared-test/helpers/chunkers' +import { makeDummyEmbedDocs, makeDummyEmbedQuery, testEmbeddingVersion } from '@shared-test/helpers/embed' + +describe('pg findByIds', () => { + let payload: Payload + const dbName = 'pg_find_by_ids_test' + let embeddingId: string + + beforeAll(async () => { + await createTestDb({ dbName }) + const config = await buildDummyConfig({ + jobs: { tasks: [], autoRun: [{ cron: '*/5 * * * * *', limit: 10 }] }, + collections: [ + { slug: 'posts', fields: [ + { name: 'title', type: 'text' }, + { name: 'category', type: 'text' }, + ] }, + ], + db: postgresAdapter({ + extensions: ['vector'], + afterSchemaInit: [integration.afterSchemaInitHook], + pool: { connectionString: `postgresql://postgres:password@localhost:5433/${dbName}` }, + }), + plugins: [ + plugin({ + knowledgePools: { + default: { + collections: { + posts: { + toKnowledgePool: async (doc) => { + const chunks: Array<{ chunk: string; category?: string }> = [] + if (doc.title) { + for (const chunk of chunkText(doc.title)) { + chunks.push({ chunk, category: doc.category || 'general' }) + } + } + return chunks + }, + }, + }, + extensionFields: [{ name: 'category', type: 'text' }], + embeddingConfig: { + version: testEmbeddingVersion, + queryFn: makeDummyEmbedQuery(DIMS), + realTimeIngestionFn: makeDummyEmbedDocs(DIMS), + }, + }, + }, + }), + ], + }) + payload = await getPayload({ config, key: `pg-find-by-ids-${Date.now()}`, cron: true }) + + const post = await payload.create({ + collection: 'posts', + data: { title: 'Find me by id', category: 'science' }, + }) + await waitForVectorizationJobs(payload) + const rows = await payload.find({ + collection: 'default' as any, + where: { docId: { equals: String(post.id) } }, + limit: 1, + }) + embeddingId = String(rows.docs[0].id) + }) + + afterAll(async () => { + await destroyPayload(payload) + }) + + test('returns full EmbeddingRecord including numeric embedding array when populateEmbedding is true', async () => { + const records = await integration.adapter.findByIds(payload, 'default', [embeddingId], true) + expect(Object.keys(records)).toEqual([embeddingId]) + const r = records[embeddingId]! + expect(r.id).toBe(embeddingId) + expect(Array.isArray(r.embedding)).toBe(true) + expect(r.embedding!.length).toBe(DIMS) + expect(r.embedding!.every((n) => typeof n === 'number')).toBe(true) + expect(r.sourceCollection).toBe('posts') + expect(typeof r.chunkText).toBe('string') + expect(r.embeddingVersion).toBe(testEmbeddingVersion) + }) + + test('omits the embedding array by default', async () => { + const records = await integration.adapter.findByIds(payload, 'default', [embeddingId]) + expect(Object.keys(records)).toEqual([embeddingId]) + const r = records[embeddingId]! + expect(r.id).toBe(embeddingId) + expect(r.embedding).toBeUndefined() + expect(r.sourceCollection).toBe('posts') + }) + + test('includes extension fields when the pool defines them', async () => { + const records = await integration.adapter.findByIds(payload, 'default', [embeddingId]) + expect((records[embeddingId] as any).category).toBe('science') + }) + + test('maps a well-formed but nonexistent id to undefined', async () => { + const records = await integration.adapter.findByIds(payload, 'default', [embeddingId, '999999']) + expect(Object.keys(records).sort()).toEqual([embeddingId, '999999'].sort()) + expect(records[embeddingId]!.id).toBe(embeddingId) + expect(records['999999']).toBeUndefined() + }) + + test('maps a malformed (non-numeric) id to undefined instead of throwing', async () => { + const records = await integration.adapter.findByIds(payload, 'default', [embeddingId, 'not-an-id']) + expect(Object.keys(records).sort()).toEqual([embeddingId, 'not-an-id'].sort()) + expect(records[embeddingId]!.id).toBe(embeddingId) + expect(records['not-an-id']).toBeUndefined() + }) + + test('empty ids returns {}', async () => { + const records = await integration.adapter.findByIds(payload, 'default', []) + expect(records).toEqual({}) + }) + + test('coerces null chunkText/embeddingVersion to "" (EmbeddingRecord type)', async () => { + // These columns are not required in the embeddings schema, so a row can have + // nulls. Set them directly and confirm findByIds returns '' (parity with cf/mongo), + // not null — which would violate EmbeddingRecord's `chunkText: string`. + const table = getEmbeddingsTable('default')! + await (payload.db as any).drizzle + .update(table) + .set({ chunkText: null, embeddingVersion: null }) + .where(eq(table.id, Number(embeddingId))) + + const r = (await integration.adapter.findByIds(payload, 'default', [embeddingId]))[embeddingId]! + expect(r.chunkText).toBe('') + expect(r.embeddingVersion).toBe('') + }) +}) + +describe('pg findByIds (uuid idType)', () => { + let payload: Payload + const dbName = 'pg_find_by_ids_uuid_test' + let embeddingId: string + + beforeAll(async () => { + await createTestDb({ dbName }) + const config = await buildDummyConfig({ + jobs: { tasks: [], autoRun: [{ cron: '*/5 * * * * *', limit: 10 }] }, + collections: [ + { slug: 'posts', fields: [ + { name: 'title', type: 'text' }, + { name: 'category', type: 'text' }, + ] }, + ], + db: postgresAdapter({ + idType: 'uuid', + extensions: ['vector'], + afterSchemaInit: [integration.afterSchemaInitHook], + pool: { connectionString: `postgresql://postgres:password@localhost:5433/${dbName}` }, + }), + plugins: [ + plugin({ + knowledgePools: { + default: { + collections: { + posts: { + toKnowledgePool: async (doc) => { + const chunks: Array<{ chunk: string; category?: string }> = [] + if (doc.title) { + for (const chunk of chunkText(doc.title)) { + chunks.push({ chunk, category: doc.category || 'general' }) + } + } + return chunks + }, + }, + }, + extensionFields: [{ name: 'category', type: 'text' }], + embeddingConfig: { + version: testEmbeddingVersion, + queryFn: makeDummyEmbedQuery(DIMS), + realTimeIngestionFn: makeDummyEmbedDocs(DIMS), + }, + }, + }, + }), + ], + }) + payload = await getPayload({ config, key: `pg-find-by-ids-uuid-${Date.now()}`, cron: true }) + + const post = await payload.create({ + collection: 'posts', + data: { title: 'Find me by uuid', category: 'science' }, + }) + await waitForVectorizationJobs(payload) + const rows = await payload.find({ + collection: 'default' as any, + where: { docId: { equals: String(post.id) } }, + limit: 1, + }) + embeddingId = String(rows.docs[0].id) + }) + + afterAll(async () => { + await destroyPayload(payload) + }) + + test('embedding id is a uuid, not a numeric PK', () => { + expect(embeddingId).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i, + ) + }) + + test('findByIds resolves a uuid id (regression: numeric-only filter dropped uuids)', async () => { + const records = await integration.adapter.findByIds(payload, 'default', [embeddingId], true) + expect(Object.keys(records)).toEqual([embeddingId]) + const r = records[embeddingId]! + expect(r.id).toBe(embeddingId) + expect(Array.isArray(r.embedding)).toBe(true) + expect(r.embedding!.length).toBe(DIMS) + expect((r as any).category).toBe('science') + }) + + test('maps a well-formed but nonexistent uuid to undefined', async () => { + const records = await integration.adapter.findByIds(payload, 'default', [ + embeddingId, + '00000000-0000-0000-0000-000000000000', + ]) + expect(Object.keys(records).sort()).toEqual( + [embeddingId, '00000000-0000-0000-0000-000000000000'].sort(), + ) + expect(records[embeddingId]!.id).toBe(embeddingId) + expect(records['00000000-0000-0000-0000-000000000000']).toBeUndefined() + }) + + test('maps a malformed (non-uuid) id to undefined instead of throwing', async () => { + const records = await integration.adapter.findByIds(payload, 'default', [embeddingId, '999999']) + expect(Object.keys(records).sort()).toEqual([embeddingId, '999999'].sort()) + expect(records[embeddingId]!.id).toBe(embeddingId) + expect(records['999999']).toBeUndefined() + }) +}) diff --git a/adapters/pg/src/findByIds.ts b/adapters/pg/src/findByIds.ts new file mode 100644 index 0000000..9e51aeb --- /dev/null +++ b/adapters/pg/src/findByIds.ts @@ -0,0 +1,134 @@ +import { inArray } from '@payloadcms/db-postgres/drizzle' +import { BasePayload, SanitizedCollectionConfig } from 'payload' +import { KnowledgePoolName, EmbeddingRecord } from 'payloadcms-vectorize' +import toSnakeCase from 'to-snake-case' +import { getEmbeddingsTable } from './drizzle.js' + +export default async ( + payload: BasePayload, + poolName: KnowledgePoolName, + ids: string[], + populateEmbedding = false, +): Promise> => { + const result: Record = {} + for (const id of ids) result[id] = undefined + if (ids.length === 0) return result + + const isPostgres = payload.db?.pool?.query || payload.db?.drizzle + if (!isPostgres) { + throw new Error('[@payloadcms-vectorize/pg] Only works with Postgres') + } + const drizzle = payload.db?.drizzle + if (!drizzle) { + throw new Error('[@payloadcms-vectorize/pg] Drizzle instance not found in adapter') + } + + const collectionConfig = payload.collections[poolName]?.config + if (!collectionConfig) { + throw new Error(`[@payloadcms-vectorize/pg] Collection ${poolName} not found`) + } + + const table = getEmbeddingsTable(poolName) + if (!table) { + throw new Error( + `[@payloadcms-vectorize/pg] Embeddings table for knowledge pool "${poolName}" not registered.`, + ) + } + + // Drop ids that can't match the primary-key column type before querying, so a + // malformed id is treated as a miss instead of making Postgres reject the cast + // and throw for the whole batch. + const queryableIds = ids.filter((id) => idMatchesPkType(table.id, id)) + if (queryableIds.length === 0) return result + + const selectObj: Record = { + id: table.id, + } + if (populateEmbedding) { + selectObj.embedding = table.embedding + } + for (const field of collectionConfig.fields ?? []) { + if (typeof field === 'object' && 'name' in field) { + const name = field.name as string + if (name in table) { + selectObj[name] = table[name] + } else if (toSnakeCase(name) in table) { + selectObj[name] = table[toSnakeCase(name)] + } + } + } + + const rows = await drizzle.select(selectObj).from(table).where(inArray(table.id, queryableIds)) + for (const record of mapRowsToRecords(rows, collectionConfig, populateEmbedding)) { + result[record.id] = record + } + return result +} + +const UUID = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i + +function idMatchesPkType(idColumn: { getSQLType?: () => string }, id: string): boolean { + const sqlType = idColumn.getSQLType?.() ?? '' + if (sqlType === 'integer' || sqlType === 'serial' || sqlType === 'bigint' || sqlType === 'bigserial') { + return /^\d+$/.test(id) + } + if (sqlType === 'uuid') { + return UUID.test(id) + } + return true +} + +function mapRowsToRecords( + rows: Record[], + collectionConfig: SanitizedCollectionConfig, + populateEmbedding: boolean, +): Array { + const numberFields = new Set() + for (const field of collectionConfig.fields) { + if (typeof field === 'object' && 'name' in field && field.type === 'number') { + numberFields.add(field.name) + } + } + + return rows.map((row) => { + const rawDocId = row.docId ?? row.doc_id + const rawChunkIndex = row.chunkIndex ?? row.chunk_index + + const record = { + ...row, + id: String(row.id), + sourceCollection: String(row.sourceCollection ?? ''), + docId: String(rawDocId ?? ''), + chunkIndex: + typeof rawChunkIndex === 'number' ? rawChunkIndex : parseInt(String(rawChunkIndex), 10), + chunkText: String(row.chunkText ?? ''), + embeddingVersion: String(row.embeddingVersion ?? ''), + ...(populateEmbedding ? { embedding: parseEmbedding(row.embedding) } : {}), + } as EmbeddingRecord + + for (const fieldName of numberFields) { + const value = record[fieldName] + if (value != null && typeof value !== 'number') { + const parsed = parseFloat(String(value)) + if (!Number.isNaN(parsed)) { + record[fieldName] = parsed + } + } + } + + return record + }) +} + +function parseEmbedding(value: unknown): number[] { + if (Array.isArray(value)) return value as number[] + if (typeof value === 'string') { + return value + .replace(/^\[/, '') + .replace(/\]$/, '') + .split(',') + .filter((s) => s.length > 0) + .map((s) => Number(s)) + } + return [] +} diff --git a/adapters/pg/src/index.ts b/adapters/pg/src/index.ts index ac28c21..22c25c0 100644 --- a/adapters/pg/src/index.ts +++ b/adapters/pg/src/index.ts @@ -9,6 +9,7 @@ import { fileURLToPath } from 'url' import { dirname, resolve } from 'path' import embed from './embed.js' import search from './search.js' +import findByIds from './findByIds.js' export type { KnowledgePoolsConfig as KnowledgePoolConfig } @@ -93,6 +94,7 @@ export const createPostgresVectorIntegration = ( } }, search, + findByIds, storeChunk: async (payload, poolName, data) => { const embeddingArray = Array.isArray(data.embedding) ? data.embedding : Array.from(data.embedding) diff --git a/adapters/pg/src/search.ts b/adapters/pg/src/search.ts index ce31c56..54dcc20 100644 --- a/adapters/pg/src/search.ts +++ b/adapters/pg/src/search.ts @@ -303,10 +303,13 @@ function mapRowsToResults( const result = { ...row, id: String(row.id), - docId: String(rawDocId), + sourceCollection: String(row.sourceCollection ?? ''), + docId: String(rawDocId ?? ''), score: typeof rawScore === 'number' ? rawScore : parseFloat(String(rawScore)), chunkIndex: typeof rawChunkIndex === 'number' ? rawChunkIndex : parseInt(String(rawChunkIndex), 10), + chunkText: String(row.chunkText ?? ''), + embeddingVersion: String(row.embeddingVersion ?? ''), } as VectorSearchResult // Ensure any number fields from the schema are numbers in the result diff --git a/dev/helpers/mockAdapter.ts b/dev/helpers/mockAdapter.ts index dacad6f..0659f1d 100644 --- a/dev/helpers/mockAdapter.ts +++ b/dev/helpers/mockAdapter.ts @@ -1,4 +1,4 @@ -import type { DbAdapter, KnowledgePoolName, KnowledgePoolDynamicConfig, StoreChunkData, VectorSearchResult } from 'payloadcms-vectorize' +import type { DbAdapter, EmbeddingRecord, KnowledgePoolName, KnowledgePoolDynamicConfig, StoreChunkData, VectorSearchResult } from 'payloadcms-vectorize' import { createEmbeddingsCollection } from 'payloadcms-vectorize' import type { CollectionSlug, Payload, BasePayload, Where, Config } from 'payload' @@ -195,6 +195,46 @@ export const createMockAdapter = (options: MockAdapterOptions = {}): DbAdapter = .slice(0, limit) .map(({ _score, ...rest }) => rest) }, + + findByIds: async ( + payload: BasePayload, + poolName: KnowledgePoolName, + ids: string[], + populateEmbedding = false, + ): Promise> => { + const records: Record = {} + for (const id of ids) { + records[id] = undefined + const stored = storage.get(`${poolName}:${id}`) + if (!stored) continue + let doc: Record | null + try { + doc = (await payload.findByID({ + collection: poolName as CollectionSlug, + id: stored.id, + })) as Record | null + } catch (e) { + if (e instanceof Error && e.name === 'NotFound') { + continue + } + throw e + } + if (!doc) continue + const { + id: _id, + createdAt: _createdAt, + updatedAt: _updatedAt, + embedding: _embedding, + ...docFields + } = doc + records[id] = { + id: stored.id, + ...(populateEmbedding ? { embedding: stored.embedding } : {}), + ...docFields, + } as EmbeddingRecord + } + return records + }, } } diff --git a/dev/specs/vectorizedPayload.spec.ts b/dev/specs/vectorizedPayload.spec.ts index 65c40ba..f4c36b4 100644 --- a/dev/specs/vectorizedPayload.spec.ts +++ b/dev/specs/vectorizedPayload.spec.ts @@ -200,6 +200,81 @@ describe('VectorizedPayload', () => { }) }) + describe('findByIds method', () => { + let embeddingId: string + + beforeAll(async () => { + const post = await payload.create({ + collection: 'posts', + data: { title: 'FindByIds seed', content: markdownContent as unknown as any }, + }) + await waitForVectorizationJobs(payload) + const rows = await payload.find({ + collection: 'default' as any, + where: { docId: { equals: String(post.id) } }, + limit: 1, + }) + embeddingId = String(rows.docs[0].id) + }) + + test('payload has findByIds method', () => { + const vectorizedPayload = getVectorizedPayload(payload) + expect(typeof vectorizedPayload!.findByIds).toBe('function') + }) + + test('returns the full EmbeddingRecord including the embedding vector when populateEmbedding is true', async () => { + const vectorizedPayload = getVectorizedPayload(payload)! + const records = await vectorizedPayload.findByIds({ + knowledgePool: 'default', + ids: [embeddingId], + populateEmbedding: true, + }) + expect(Object.keys(records)).toEqual([embeddingId]) + const record = records[embeddingId]! + expect(record.id).toBe(embeddingId) + expect(Array.isArray(record.embedding)).toBe(true) + expect(record.embedding!.length).toBe(DIMS) + expect(typeof record.sourceCollection).toBe('string') + expect(typeof record.chunkText).toBe('string') + }) + + test('omits the embedding vector by default', async () => { + const vectorizedPayload = getVectorizedPayload(payload)! + const records = await vectorizedPayload.findByIds({ + knowledgePool: 'default', + ids: [embeddingId], + }) + expect(Object.keys(records)).toEqual([embeddingId]) + const record = records[embeddingId]! + expect(record.id).toBe(embeddingId) + expect(record.embedding).toBeUndefined() + expect(typeof record.sourceCollection).toBe('string') + expect(typeof record.chunkText).toBe('string') + }) + + test('maps unknown ids to undefined (every requested id is a key)', async () => { + const vectorizedPayload = getVectorizedPayload(payload)! + const records = await vectorizedPayload.findByIds({ + knowledgePool: 'default', + ids: [embeddingId, 'definitely-not-an-id-999999'], + }) + expect(Object.keys(records).sort()).toEqual( + [embeddingId, 'definitely-not-an-id-999999'].sort(), + ) + expect(records[embeddingId]!.id).toBe(embeddingId) + expect(records['definitely-not-an-id-999999']).toBeUndefined() + }) + + test('empty ids returns {}', async () => { + const vectorizedPayload = getVectorizedPayload(payload)! + const records = await vectorizedPayload.findByIds({ + knowledgePool: 'default', + ids: [], + }) + expect(records).toEqual({}) + }) + }) + describe('queueEmbed method', () => { test('payload has queueEmbed method', () => { const vectorizedPayload = getVectorizedPayload(payload) diff --git a/src/index.ts b/src/index.ts index bf3ac4e..7621b9c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -76,6 +76,7 @@ export type { // For adapters VectorSearchResult, + EmbeddingRecord, } from './types.js' export { getVectorizedPayload } from './types.js' @@ -356,6 +357,19 @@ export default (pluginOptions: PayloadcmsVectorizeConfig) => params.limit, params.where, ), + findByIds: (params: { + knowledgePool: KnowledgePoolName + ids: string[] + populateEmbedding?: boolean + }) => { + if (params.ids.length === 0) return Promise.resolve({}) + return pluginOptions.dbAdapter.findByIds( + payload, + params.knowledgePool, + params.ids, + params.populateEmbedding ?? false, + ) + }, queueEmbed: async ( params: | { diff --git a/src/types.ts b/src/types.ts index e54611d..2fb10b1 100644 --- a/src/types.ts +++ b/src/types.ts @@ -57,6 +57,11 @@ export type VectorizedPayload = { _isBulkEmbedEnabled: (knowledgePool: KnowledgePoolName) => boolean getDbAdapterCustom: () => Record | undefined search: (params: VectorSearchQuery) => Promise> + findByIds: (params: { + knowledgePool: KnowledgePoolName + ids: string[] + populateEmbedding?: boolean + }) => Promise> queueEmbed: ( params: | { @@ -322,6 +327,17 @@ export interface VectorSearchResult { [key: string]: any // Extension fields and other dynamic fields } +export interface EmbeddingRecord { + id: string + sourceCollection: string + docId: string + chunkIndex: number + chunkText: string + embeddingVersion: string + embedding?: number[] + [key: string]: any +} + export interface VectorSearchQuery { /** The knowledge pool to search in */ knowledgePool: KnowledgePoolName @@ -430,4 +446,10 @@ export type DbAdapter = { limit?: number, where?: Where, ) => Promise> + findByIds: ( + payload: BasePayload, + poolName: KnowledgePoolName, + ids: string[], + populateEmbedding?: boolean, + ) => Promise> }