diff --git a/dev/specs/searchByEmbedding.spec.ts b/dev/specs/searchByEmbedding.spec.ts new file mode 100644 index 0000000..a3f38ee --- /dev/null +++ b/dev/specs/searchByEmbedding.spec.ts @@ -0,0 +1,276 @@ +import type { Payload } from 'payload' + +import { postgresAdapter } from '@payloadcms/db-postgres' +import { type SerializedEditorState } from '@payloadcms/richtext-lexical/lexical' +import { chunkRichText, chunkText } from 'helpers/chunkers.js' +import { makeDummyEmbedDocs, makeDummyEmbedQuery, testEmbeddingVersion } from 'helpers/embed.js' +import { createMockAdapter } from 'helpers/mockAdapter.js' +import { getPayload } from 'payload' +import payloadcmsVectorize, { + DbAdapter, + getVectorizedPayload, + type VectorizedPayload, +} from 'payloadcms-vectorize' +import { afterAll, beforeAll, describe, expect, test } from 'vitest' +import { buildDummyConfig, DIMS, getInitialMarkdownContent } from './constants.js' +import { + expectResultsOrderedByScore, + expectResultsRespectWhere, + expectValidVectorSearchResults, +} from './helpers/vectorSearchExpectations.js' +import { + BULK_QUEUE_NAMES, + createTestDb, + destroyPayload, + waitForVectorizationJobs, +} from './utils.js' + +const embedFn = makeDummyEmbedQuery(DIMS) + +describe('searchByEmbedding method tests', () => { + let payload: Payload + let vectorizedPayload: VectorizedPayload + let adapter: DbAdapter + let markdownContent: SerializedEditorState + const titleAndQuery = 'My query is a title for searchByEmbedding' + const dbName = 'search_by_embedding_test' + + beforeAll(async () => { + await createTestDb({ dbName }) + adapter = createMockAdapter() + + const config = await buildDummyConfig({ + jobs: { + tasks: [], + autoRun: [ + { + cron: '*/5 * * * * *', // Run every 5 seconds + limit: 10, + }, + ], + }, + collections: [ + { + slug: 'posts', + fields: [ + { name: 'title', type: 'text' }, + { name: 'content', type: 'richText' }, + ], + }, + ], + db: postgresAdapter({ + pool: { + connectionString: `postgresql://postgres:password@localhost:5433/${dbName}`, + }, + }), + plugins: [ + payloadcmsVectorize({ + dbAdapter: adapter, + knowledgePools: { + default: { + collections: { + posts: { + toKnowledgePool: async (doc, payload) => { + const chunks: Array<{ chunk: string }> = [] + // Process title + if (doc.title) { + const titleChunks = chunkText(doc.title) + chunks.push(...titleChunks.map((chunk) => ({ chunk }))) + } + // Process content + if (doc.content) { + const contentChunks = await chunkRichText(doc.content, payload.config) + chunks.push(...contentChunks.map((chunk) => ({ chunk }))) + } + return chunks + }, + }, + }, + embeddingConfig: { + version: testEmbeddingVersion, + queryFn: makeDummyEmbedQuery(DIMS), + realTimeIngestionFn: makeDummyEmbedDocs(DIMS), + }, + }, + }, + bulkQueueNames: BULK_QUEUE_NAMES, + }), + ], + }) + + payload = await getPayload({ + config, + key: `search-by-embedding-test-${Date.now()}`, + cron: true, + }) + + const vp = getVectorizedPayload(payload) + if (!vp) { + throw new Error('Failed to get vectorized payload') + } + vectorizedPayload = vp + + markdownContent = await getInitialMarkdownContent(config) + }) + + afterAll(async () => { + await destroyPayload(payload) + }) + + test('searchByEmbedding with embedding vector returns valid results', async () => { + // Create a post + const post = await payload.create({ + collection: 'posts', + data: { + title: titleAndQuery, + content: markdownContent as unknown as any, + }, + }) + + // Wait for vectorization jobs to complete + await waitForVectorizationJobs(payload) + + // Get the embedding for our query + const queryEmbedding = await embedFn(titleAndQuery) + const embeddingArray = Array.isArray(queryEmbedding) + ? queryEmbedding + : Array.from(queryEmbedding) + + // Search using the embedding directly + const results = await vectorizedPayload.searchByEmbedding({ + knowledgePool: 'default', + embedding: embeddingArray, + }) + + expectValidVectorSearchResults(results, { + checkShape: true, + expectedTitle: { + title: titleAndQuery, + postId: String(post.id), + embeddingVersion: testEmbeddingVersion, + }, + }) + }) + + test('searchByEmbedding results are ordered by score (highest first)', async () => { + // Get the embedding for our query + const queryEmbedding = await embedFn(titleAndQuery) + const embeddingArray = Array.isArray(queryEmbedding) + ? queryEmbedding + : Array.from(queryEmbedding) + + const results = await vectorizedPayload.searchByEmbedding({ + knowledgePool: 'default', + embedding: embeddingArray, + }) + + expectResultsOrderedByScore(results) + }) + + test('searchByEmbedding respects limit parameter', async () => { + // Get the embedding for our query + const queryEmbedding = await embedFn(titleAndQuery) + const embeddingArray = Array.isArray(queryEmbedding) + ? queryEmbedding + : Array.from(queryEmbedding) + + const limit = 2 + const results = await vectorizedPayload.searchByEmbedding({ + knowledgePool: 'default', + embedding: embeddingArray, + limit, + }) + + expect(results.length).toBeLessThanOrEqual(limit) + }) + + test('searchByEmbedding respects where clause', async () => { + const sharedText = 'Shared searchable content for embedding search' + + // Create two posts with same text + const post1 = await payload.create({ + collection: 'posts', + data: { + title: sharedText, + content: null, + }, + }) + + const post2 = await payload.create({ + collection: 'posts', + data: { + title: sharedText, + content: null, + }, + }) + + // Wait for vectorization jobs to complete + await waitForVectorizationJobs(payload) + + // Get the embedding for our query + const queryEmbedding = await embedFn(sharedText) + const embeddingArray = Array.isArray(queryEmbedding) + ? queryEmbedding + : Array.from(queryEmbedding) + + // Search without WHERE - should return both + const resultsAll = await vectorizedPayload.searchByEmbedding({ + knowledgePool: 'default', + embedding: embeddingArray, + }) + + expect(resultsAll.length).toBeGreaterThanOrEqual(2) + + // Search with WHERE clause filtering by docId - should return only one + const resultsFiltered = await vectorizedPayload.searchByEmbedding({ + knowledgePool: 'default', + embedding: embeddingArray, + where: { + docId: { equals: String(post1.id) }, + }, + }) + + expectResultsRespectWhere(resultsFiltered, (r) => r.docId === String(post1.id)) + }) + + test('searchByEmbedding with same embedding as search returns similar results', async () => { + const testQuery = 'test query for comparison' + + // Create a post to search for + const post = await payload.create({ + collection: 'posts', + data: { + title: testQuery, + content: null, + }, + }) + + // Wait for vectorization jobs to complete + await waitForVectorizationJobs(payload) + + // Get results using regular search + const searchResults = await vectorizedPayload.search({ + knowledgePool: 'default', + query: testQuery, + }) + + // Get the embedding and use searchByEmbedding + const queryEmbedding = await embedFn(testQuery) + const embeddingArray = Array.isArray(queryEmbedding) + ? queryEmbedding + : Array.from(queryEmbedding) + + const embeddingSearchResults = await vectorizedPayload.searchByEmbedding({ + knowledgePool: 'default', + embedding: embeddingArray, + }) + + // Both should return results (not necessarily identical due to possible reranking in search) + expect(searchResults.length).toBeGreaterThan(0) + expect(embeddingSearchResults.length).toBeGreaterThan(0) + + // The top result should be the same document chunk for both approaches + // since we're using the same embedding + expect(embeddingSearchResults[0].docId).toBe(searchResults[0].docId) + }) +}) diff --git a/src/endpoints/vectorSearch.ts b/src/endpoints/vectorSearch.ts index 9a0b14b..fd1914f 100644 --- a/src/endpoints/vectorSearch.ts +++ b/src/endpoints/vectorSearch.ts @@ -44,6 +44,18 @@ export const createVectorSearchHandlers = ( const reranked = await rerank.callback(query, candidates) return reranked.slice(0, effectiveLimit) } + + const searchByEmbedding = async ( + payload: BasePayload, + embedding: number[], + knowledgePool: KnowledgePoolName, + limit?: number, + where?: Where, + ) => { + // searchByEmbedding does not support reranking because rerankers need text, not vectors + return adapter.search(payload, embedding, knowledgePool, limit, where) + } + const requestHandler: PayloadHandler = async (req) => { if (!req || !req.json) { return Response.json({ error: 'Request is required' }, { status: 400 }) @@ -77,5 +89,5 @@ export const createVectorSearchHandlers = ( return Response.json({ error: `Internal server error: ${error}` }, { status: 500 }) } } - return { vectorSearch, requestHandler } + return { vectorSearch, searchByEmbedding, requestHandler } } diff --git a/src/index.ts b/src/index.ts index 7621b9c..f716c0d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,6 +8,7 @@ import type { KnowledgePoolDynamicConfig, VectorizedPayload, VectorSearchQuery, + VectorSearchEmbeddingQuery, BulkEmbedResult, RetryFailedBatchResult, DbAdapter, @@ -76,6 +77,7 @@ export type { // For adapters VectorSearchResult, + VectorSearchEmbeddingQuery, EmbeddingRecord, } from './types.js' @@ -357,6 +359,14 @@ export default (pluginOptions: PayloadcmsVectorizeConfig) => params.limit, params.where, ), + searchByEmbedding: (params: VectorSearchEmbeddingQuery) => + vectorSearchHandlers.searchByEmbedding( + payload, + params.embedding, + params.knowledgePool, + params.limit, + params.where, + ), findByIds: (params: { knowledgePool: KnowledgePoolName ids: string[] diff --git a/src/types.ts b/src/types.ts index 2fb10b1..e414e9b 100644 --- a/src/types.ts +++ b/src/types.ts @@ -57,6 +57,7 @@ export type VectorizedPayload = { _isBulkEmbedEnabled: (knowledgePool: KnowledgePoolName) => boolean getDbAdapterCustom: () => Record | undefined search: (params: VectorSearchQuery) => Promise> + searchByEmbedding: (params: VectorSearchEmbeddingQuery) => Promise> findByIds: (params: { knowledgePool: KnowledgePoolName ids: string[] @@ -349,6 +350,17 @@ export interface VectorSearchQuery { limit?: number } +export interface VectorSearchEmbeddingQuery { + /** The knowledge pool to search in */ + knowledgePool: KnowledgePoolName + /** The embedding vector to search with */ + embedding: number[] + /** Optional Payload where clause to filter results. Can rely on embeddings collection fields or extension fields. */ + where?: Where + /** Optional limit for number of results (default: 10) */ + limit?: number +} + // ========================================== // Document types for internal collections // ==========================================