Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 276 additions & 0 deletions dev/specs/searchByEmbedding.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
import type { Payload } from 'payload'

import { postgresAdapter } from '@payloadcms/db-postgres'
import { type SerializedEditorState } from '@payloadcms/richtext-lexical/lexical'
import { chunkRichText, chunkText } from 'helpers/chunkers.js'
import { makeDummyEmbedDocs, makeDummyEmbedQuery, testEmbeddingVersion } from 'helpers/embed.js'
import { createMockAdapter } from 'helpers/mockAdapter.js'
import { getPayload } from 'payload'
import payloadcmsVectorize, {
DbAdapter,
getVectorizedPayload,
type VectorizedPayload,
} from 'payloadcms-vectorize'
import { afterAll, beforeAll, describe, expect, test } from 'vitest'
import { buildDummyConfig, DIMS, getInitialMarkdownContent } from './constants.js'
import {
expectResultsOrderedByScore,
expectResultsRespectWhere,
expectValidVectorSearchResults,
} from './helpers/vectorSearchExpectations.js'
import {
BULK_QUEUE_NAMES,
createTestDb,
destroyPayload,
waitForVectorizationJobs,
} from './utils.js'

const embedFn = makeDummyEmbedQuery(DIMS)

describe('searchByEmbedding method tests', () => {
let payload: Payload
let vectorizedPayload: VectorizedPayload
let adapter: DbAdapter
let markdownContent: SerializedEditorState
const titleAndQuery = 'My query is a title for searchByEmbedding'
const dbName = 'search_by_embedding_test'

beforeAll(async () => {
await createTestDb({ dbName })
adapter = createMockAdapter()

const config = await buildDummyConfig({
jobs: {
tasks: [],
autoRun: [
{
cron: '*/5 * * * * *', // Run every 5 seconds
limit: 10,
},
],
},
collections: [
{
slug: 'posts',
fields: [
{ name: 'title', type: 'text' },
{ name: 'content', type: 'richText' },
],
},
],
db: postgresAdapter({
pool: {
connectionString: `postgresql://postgres:password@localhost:5433/${dbName}`,
},
}),
plugins: [
payloadcmsVectorize({
dbAdapter: adapter,
knowledgePools: {
default: {
collections: {
posts: {
toKnowledgePool: async (doc, payload) => {
const chunks: Array<{ chunk: string }> = []
// Process title
if (doc.title) {
const titleChunks = chunkText(doc.title)
chunks.push(...titleChunks.map((chunk) => ({ chunk })))
}
// Process content
if (doc.content) {
const contentChunks = await chunkRichText(doc.content, payload.config)
chunks.push(...contentChunks.map((chunk) => ({ chunk })))
}
return chunks
},
},
},
embeddingConfig: {
version: testEmbeddingVersion,
queryFn: makeDummyEmbedQuery(DIMS),
realTimeIngestionFn: makeDummyEmbedDocs(DIMS),
},
},
},
bulkQueueNames: BULK_QUEUE_NAMES,
}),
],
})

payload = await getPayload({
config,
key: `search-by-embedding-test-${Date.now()}`,
cron: true,
})

const vp = getVectorizedPayload(payload)
if (!vp) {
throw new Error('Failed to get vectorized payload')
}
vectorizedPayload = vp

markdownContent = await getInitialMarkdownContent(config)
})

afterAll(async () => {
await destroyPayload(payload)
})

test('searchByEmbedding with embedding vector returns valid results', async () => {
// Create a post
const post = await payload.create({
collection: 'posts',
data: {
title: titleAndQuery,
content: markdownContent as unknown as any,
},
})

// Wait for vectorization jobs to complete
await waitForVectorizationJobs(payload)

// Get the embedding for our query
const queryEmbedding = await embedFn(titleAndQuery)
const embeddingArray = Array.isArray(queryEmbedding)
? queryEmbedding
: Array.from(queryEmbedding)

// Search using the embedding directly
const results = await vectorizedPayload.searchByEmbedding({
knowledgePool: 'default',
embedding: embeddingArray,
})

expectValidVectorSearchResults(results, {
checkShape: true,
expectedTitle: {
title: titleAndQuery,
postId: String(post.id),
embeddingVersion: testEmbeddingVersion,
},
})
})

test('searchByEmbedding results are ordered by score (highest first)', async () => {
// Get the embedding for our query
const queryEmbedding = await embedFn(titleAndQuery)
const embeddingArray = Array.isArray(queryEmbedding)
? queryEmbedding
: Array.from(queryEmbedding)

const results = await vectorizedPayload.searchByEmbedding({
knowledgePool: 'default',
embedding: embeddingArray,
})

expectResultsOrderedByScore(results)
})

test('searchByEmbedding respects limit parameter', async () => {
// Get the embedding for our query
const queryEmbedding = await embedFn(titleAndQuery)
const embeddingArray = Array.isArray(queryEmbedding)
? queryEmbedding
: Array.from(queryEmbedding)

const limit = 2
const results = await vectorizedPayload.searchByEmbedding({
knowledgePool: 'default',
embedding: embeddingArray,
limit,
})

expect(results.length).toBeLessThanOrEqual(limit)
})

test('searchByEmbedding respects where clause', async () => {
const sharedText = 'Shared searchable content for embedding search'

// Create two posts with same text
const post1 = await payload.create({
collection: 'posts',
data: {
title: sharedText,
content: null,
},
})

const post2 = await payload.create({
collection: 'posts',
data: {
title: sharedText,
content: null,
},
})

// Wait for vectorization jobs to complete
await waitForVectorizationJobs(payload)

// Get the embedding for our query
const queryEmbedding = await embedFn(sharedText)
const embeddingArray = Array.isArray(queryEmbedding)
? queryEmbedding
: Array.from(queryEmbedding)

// Search without WHERE - should return both
const resultsAll = await vectorizedPayload.searchByEmbedding({
knowledgePool: 'default',
embedding: embeddingArray,
})

expect(resultsAll.length).toBeGreaterThanOrEqual(2)

// Search with WHERE clause filtering by docId - should return only one
const resultsFiltered = await vectorizedPayload.searchByEmbedding({
knowledgePool: 'default',
embedding: embeddingArray,
where: {
docId: { equals: String(post1.id) },
},
})

expectResultsRespectWhere(resultsFiltered, (r) => r.docId === String(post1.id))
})

test('searchByEmbedding with same embedding as search returns similar results', async () => {
const testQuery = 'test query for comparison'

// Create a post to search for
const post = await payload.create({
collection: 'posts',
data: {
title: testQuery,
content: null,
},
})

// Wait for vectorization jobs to complete
await waitForVectorizationJobs(payload)

// Get results using regular search
const searchResults = await vectorizedPayload.search({
knowledgePool: 'default',
query: testQuery,
})

// Get the embedding and use searchByEmbedding
const queryEmbedding = await embedFn(testQuery)
const embeddingArray = Array.isArray(queryEmbedding)
? queryEmbedding
: Array.from(queryEmbedding)

const embeddingSearchResults = await vectorizedPayload.searchByEmbedding({
knowledgePool: 'default',
embedding: embeddingArray,
})

// Both should return results (not necessarily identical due to possible reranking in search)
expect(searchResults.length).toBeGreaterThan(0)
expect(embeddingSearchResults.length).toBeGreaterThan(0)

// The top result should be the same document chunk for both approaches
// since we're using the same embedding
expect(embeddingSearchResults[0].docId).toBe(searchResults[0].docId)
})
})
14 changes: 13 additions & 1 deletion src/endpoints/vectorSearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ export const createVectorSearchHandlers = (
const reranked = await rerank.callback(query, candidates)
return reranked.slice(0, effectiveLimit)
}

const searchByEmbedding = async (
payload: BasePayload,
embedding: number[],
knowledgePool: KnowledgePoolName,
limit?: number,
where?: Where,
) => {
// searchByEmbedding does not support reranking because rerankers need text, not vectors
return adapter.search(payload, embedding, knowledgePool, limit, where)
}

const requestHandler: PayloadHandler = async (req) => {
if (!req || !req.json) {
return Response.json({ error: 'Request is required' }, { status: 400 })
Expand Down Expand Up @@ -77,5 +89,5 @@ export const createVectorSearchHandlers = (
return Response.json({ error: `Internal server error: ${error}` }, { status: 500 })
}
}
return { vectorSearch, requestHandler }
return { vectorSearch, searchByEmbedding, requestHandler }
}
10 changes: 10 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import type {
KnowledgePoolDynamicConfig,
VectorizedPayload,
VectorSearchQuery,
VectorSearchEmbeddingQuery,
BulkEmbedResult,
RetryFailedBatchResult,
DbAdapter,
Expand Down Expand Up @@ -76,6 +77,7 @@ export type {

// For adapters
VectorSearchResult,
VectorSearchEmbeddingQuery,
EmbeddingRecord,
} from './types.js'

Expand Down Expand Up @@ -357,6 +359,14 @@ export default (pluginOptions: PayloadcmsVectorizeConfig) =>
params.limit,
params.where,
),
searchByEmbedding: (params: VectorSearchEmbeddingQuery) =>
vectorSearchHandlers.searchByEmbedding(
payload,
params.embedding,
params.knowledgePool,
params.limit,
params.where,
),
findByIds: (params: {
knowledgePool: KnowledgePoolName
ids: string[]
Expand Down
12 changes: 12 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ export type VectorizedPayload = {
_isBulkEmbedEnabled: (knowledgePool: KnowledgePoolName) => boolean
getDbAdapterCustom: () => Record<string, any> | undefined
search: (params: VectorSearchQuery) => Promise<Array<VectorSearchResult>>
searchByEmbedding: (params: VectorSearchEmbeddingQuery) => Promise<Array<VectorSearchResult>>
findByIds: (params: {
knowledgePool: KnowledgePoolName
ids: string[]
Expand Down Expand Up @@ -349,6 +350,17 @@ export interface VectorSearchQuery {
limit?: number
}

export interface VectorSearchEmbeddingQuery {
/** The knowledge pool to search in */
knowledgePool: KnowledgePoolName
/** The embedding vector to search with */
embedding: number[]
/** Optional Payload where clause to filter results. Can rely on embeddings collection fields or extension fields. */
where?: Where
/** Optional limit for number of results (default: 10) */
limit?: number
}

// ==========================================
// Document types for internal collections
// ==========================================
Expand Down
Loading