const vectra = require('vectra'); const path = require('path'); const fs = require('fs'); const express = require('express'); const sanitize = require('sanitize-filename'); const { jsonParser } = require('../express-common'); // Don't forget to add new sources to the SOURCES array const SOURCES = [ 'transformers', 'mistral', 'openai', 'extras', 'palm', 'togetherai', 'nomicai', 'cohere', 'ollama', 'llamacpp', 'vllm', ]; /** * Gets the vector for the given text from the given source. * @param {string} source - The source of the vector * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string} text - The text to get the vector for * @param {boolean} isQuery - If the text is a query for embedding search * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @returns {Promise} - The vector for the text */ async function getVector(source, sourceSettings, text, isQuery, directories) { switch (source) { case 'nomicai': return require('../vectors/nomicai-vectors').getNomicAIVector(text, source, directories); case 'togetherai': case 'mistral': case 'openai': return require('../vectors/openai-vectors').getOpenAIVector(text, source, directories, sourceSettings.model); case 'transformers': return require('../vectors/embedding').getTransformersVector(text); case 'extras': return require('../vectors/extras-vectors').getExtrasVector(text, sourceSettings.extrasUrl, sourceSettings.extrasKey); case 'palm': return require('../vectors/makersuite-vectors').getMakerSuiteVector(text, directories); case 'cohere': return require('../vectors/cohere-vectors').getCohereVector(text, isQuery, directories, sourceSettings.model); case 'llamacpp': return require('../vectors/llamacpp-vectors').getLlamaCppVector(text, sourceSettings.apiUrl, directories); case 'vllm': return require('../vectors/vllm-vectors').getVllmVector(text, sourceSettings.apiUrl, sourceSettings.model, directories); case 'ollama': return require('../vectors/ollama-vectors').getOllamaVector(text, sourceSettings.apiUrl, sourceSettings.model, sourceSettings.keep, directories); } throw new Error(`Unknown vector source ${source}`); } /** * Gets the vector for the given text batch from the given source. * @param {string} source - The source of the vector * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string[]} texts - The array of texts to get the vector for * @param {boolean} isQuery - If the text is a query for embedding search * @param {import('../users').UserDirectoryList} directories - The directories object for the user * @returns {Promise} - The array of vectors for the texts */ async function getBatchVector(source, sourceSettings, texts, isQuery, directories) { const batchSize = 10; const batches = Array(Math.ceil(texts.length / batchSize)).fill(undefined).map((_, i) => texts.slice(i * batchSize, i * batchSize + batchSize)); let results = []; for (let batch of batches) { switch (source) { case 'nomicai': results.push(...await require('../vectors/nomicai-vectors').getNomicAIBatchVector(batch, source, directories)); break; case 'togetherai': case 'mistral': case 'openai': results.push(...await require('../vectors/openai-vectors').getOpenAIBatchVector(batch, source, directories, sourceSettings.model)); break; case 'transformers': results.push(...await require('../vectors/embedding').getTransformersBatchVector(batch)); break; case 'extras': results.push(...await require('../vectors/extras-vectors').getExtrasBatchVector(batch, sourceSettings.extrasUrl, sourceSettings.extrasKey)); break; case 'palm': results.push(...await require('../vectors/makersuite-vectors').getMakerSuiteBatchVector(batch, directories)); break; case 'cohere': results.push(...await require('../vectors/cohere-vectors').getCohereBatchVector(batch, isQuery, directories, sourceSettings.model)); break; case 'llamacpp': results.push(...await require('../vectors/llamacpp-vectors').getLlamaCppBatchVector(batch, sourceSettings.apiUrl, directories)); break; case 'vllm': results.push(...await require('../vectors/vllm-vectors').getVllmBatchVector(batch, sourceSettings.apiUrl, sourceSettings.model, directories)); break; case 'ollama': results.push(...await require('../vectors/ollama-vectors').getOllamaBatchVector(batch, sourceSettings.apiUrl, sourceSettings.model, sourceSettings.keep, directories)); break; default: throw new Error(`Unknown vector source ${source}`); } } return results; } /** * Gets the index for the vector collection * @param {import('../users').UserDirectoryList} directories - User directories * @param {string} collectionId - The collection ID * @param {string} source - The source of the vector * @param {boolean} create - Whether to create the index if it doesn't exist * @returns {Promise} - The index for the collection */ async function getIndex(directories, collectionId, source, create = true) { const pathToFile = path.join(directories.vectors, sanitize(source), sanitize(collectionId)); const store = new vectra.LocalIndex(pathToFile); if (create && !await store.isIndexCreated()) { await store.createIndex(); } return store; } /** * Inserts items into the vector collection * @param {import('../users').UserDirectoryList} directories - User directories * @param {string} collectionId - The collection ID * @param {string} source - The source of the vector * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {{ hash: number; text: string; index: number; }[]} items - The items to insert */ async function insertVectorItems(directories, collectionId, source, sourceSettings, items) { const store = await getIndex(directories, collectionId, source); await store.beginUpdate(); const vectors = await getBatchVector(source, sourceSettings, items.map(x => x.text), false, directories); for (let i = 0; i < items.length; i++) { const item = items[i]; const vector = vectors[i]; await store.upsertItem({ vector: vector, metadata: { hash: item.hash, text: item.text, index: item.index } }); } await store.endUpdate(); } /** * Gets the hashes of the items in the vector collection * @param {import('../users').UserDirectoryList} directories - User directories * @param {string} collectionId - The collection ID * @param {string} source - The source of the vector * @returns {Promise} - The hashes of the items in the collection */ async function getSavedHashes(directories, collectionId, source) { const store = await getIndex(directories, collectionId, source); const items = await store.listItems(); const hashes = items.map(x => Number(x.metadata.hash)); return hashes; } /** * Deletes items from the vector collection by hash * @param {import('../users').UserDirectoryList} directories - User directories * @param {string} collectionId - The collection ID * @param {string} source - The source of the vector * @param {number[]} hashes - The hashes of the items to delete */ async function deleteVectorItems(directories, collectionId, source, hashes) { const store = await getIndex(directories, collectionId, source); const items = await store.listItemsByMetadata({ hash: { '$in': hashes } }); await store.beginUpdate(); for (const item of items) { await store.deleteItem(item.id); } await store.endUpdate(); } /** * Gets the hashes of the items in the vector collection that match the search text * @param {import('../users').UserDirectoryList} directories - User directories * @param {string} collectionId - The collection ID * @param {string} source - The source of the vector * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string} searchText - The text to search for * @param {number} topK - The number of results to return * @param {number} threshold - The threshold for the search * @returns {Promise<{hashes: number[], metadata: object[]}>} - The metadata of the items that match the search text */ async function queryCollection(directories, collectionId, source, sourceSettings, searchText, topK, threshold) { const store = await getIndex(directories, collectionId, source); const vector = await getVector(source, sourceSettings, searchText, true, directories); const result = await store.queryItems(vector, topK); const metadata = result.filter(x => x.score >= threshold).map(x => x.item.metadata); const hashes = result.map(x => Number(x.item.metadata.hash)); return { metadata, hashes }; } /** * Queries multiple collections for the given search queries. Returns the overall top K results. * @param {import('../users').UserDirectoryList} directories - User directories * @param {string[]} collectionIds - The collection IDs to query * @param {string} source - The source of the vector * @param {Object} sourceSettings - Settings for the source, if it needs any * @param {string} searchText - The text to search for * @param {number} topK - The number of results to return * @param {number} threshold - The threshold for the search * * @returns {Promise>} - The top K results from each collection */ async function multiQueryCollection(directories, collectionIds, source, sourceSettings, searchText, topK, threshold) { const vector = await getVector(source, sourceSettings, searchText, true, directories); const results = []; for (const collectionId of collectionIds) { const store = await getIndex(directories, collectionId, source); const result = await store.queryItems(vector, topK); results.push(...result.map(result => ({ collectionId, result }))); } // Sort results by descending similarity, apply threshold, and take top K const sortedResults = results .sort((a, b) => b.result.score - a.result.score) .filter(x => x.result.score >= threshold) .slice(0, topK); /** * Group the results by collection ID * @type {Record} */ const groupedResults = {}; for (const result of sortedResults) { if (!groupedResults[result.collectionId]) { groupedResults[result.collectionId] = { hashes: [], metadata: [] }; } groupedResults[result.collectionId].hashes.push(Number(result.result.item.metadata.hash)); groupedResults[result.collectionId].metadata.push(result.result.item.metadata); } return groupedResults; } /** * Extracts settings for the vectorization sources from the HTTP request headers. * @param {string} source - Which source to extract settings for. * @param {object} request - The HTTP request object. * @returns {object} - An object that can be used as `sourceSettings` in functions that take that parameter. */ function getSourceSettings(source, request) { if (source === 'togetherai') { const model = String(request.headers['x-togetherai-model']); return { model: model, }; } else if (source === 'openai') { const model = String(request.headers['x-openai-model']); return { model: model, }; } else if (source === 'cohere') { const model = String(request.headers['x-cohere-model']); return { model: model, }; } else if (source === 'llamacpp') { const apiUrl = String(request.headers['x-llamacpp-url']); return { apiUrl: apiUrl, }; } else if (source === 'vllm') { const apiUrl = String(request.headers['x-vllm-url']); const model = String(request.headers['x-vllm-model']); return { apiUrl: apiUrl, model: model, }; } else if (source === 'ollama') { const apiUrl = String(request.headers['x-ollama-url']); const model = String(request.headers['x-ollama-model']); const keep = Boolean(request.headers['x-ollama-keep']); return { apiUrl: apiUrl, model: model, keep: keep, }; } else { // Extras API settings to connect to the Extras embeddings provider let extrasUrl = ''; let extrasKey = ''; if (source === 'extras') { extrasUrl = String(request.headers['x-extras-url']); extrasKey = String(request.headers['x-extras-key']); } return { extrasUrl: extrasUrl, extrasKey: extrasKey, }; } } /** * Performs a request to regenerate the index if it is corrupted. * @param {import('express').Request} req Express request object * @param {import('express').Response} res Express response object * @param {Error} error Error object * @returns {Promise} Promise */ async function regenerateCorruptedIndexErrorHandler(req, res, error) { if (error instanceof SyntaxError && !req.query.regenerated) { const collectionId = String(req.body.collectionId); const source = String(req.body.source) || 'transformers'; if (collectionId && source) { const index = await getIndex(req.user.directories, collectionId, source, false); const exists = await index.isIndexCreated(); if (exists) { const path = index.folderPath; console.error(`Corrupted index detected at ${path}, regenerating...`); await index.deleteIndex(); return res.redirect(307, req.originalUrl + '?regenerated=true'); } } } console.error(error); return res.sendStatus(500); } const router = express.Router(); router.post('/query', jsonParser, async (req, res) => { try { if (!req.body.collectionId || !req.body.searchText) { return res.sendStatus(400); } const collectionId = String(req.body.collectionId); const searchText = String(req.body.searchText); const topK = Number(req.body.topK) || 10; const threshold = Number(req.body.threshold) || 0.0; const source = String(req.body.source) || 'transformers'; const sourceSettings = getSourceSettings(source, req); const results = await queryCollection(req.user.directories, collectionId, source, sourceSettings, searchText, topK, threshold); return res.json(results); } catch (error) { return regenerateCorruptedIndexErrorHandler(req, res, error); } }); router.post('/query-multi', jsonParser, async (req, res) => { try { if (!Array.isArray(req.body.collectionIds) || !req.body.searchText) { return res.sendStatus(400); } const collectionIds = req.body.collectionIds.map(x => String(x)); const searchText = String(req.body.searchText); const topK = Number(req.body.topK) || 10; const threshold = Number(req.body.threshold) || 0.0; const source = String(req.body.source) || 'transformers'; const sourceSettings = getSourceSettings(source, req); const results = await multiQueryCollection(req.user.directories, collectionIds, source, sourceSettings, searchText, topK, threshold); return res.json(results); } catch (error) { return regenerateCorruptedIndexErrorHandler(req, res, error); } }); router.post('/insert', jsonParser, async (req, res) => { try { if (!Array.isArray(req.body.items) || !req.body.collectionId) { return res.sendStatus(400); } const collectionId = String(req.body.collectionId); const items = req.body.items.map(x => ({ hash: x.hash, text: x.text, index: x.index })); const source = String(req.body.source) || 'transformers'; const sourceSettings = getSourceSettings(source, req); await insertVectorItems(req.user.directories, collectionId, source, sourceSettings, items); return res.sendStatus(200); } catch (error) { return regenerateCorruptedIndexErrorHandler(req, res, error); } }); router.post('/list', jsonParser, async (req, res) => { try { if (!req.body.collectionId) { return res.sendStatus(400); } const collectionId = String(req.body.collectionId); const source = String(req.body.source) || 'transformers'; const hashes = await getSavedHashes(req.user.directories, collectionId, source); return res.json(hashes); } catch (error) { return regenerateCorruptedIndexErrorHandler(req, res, error); } }); router.post('/delete', jsonParser, async (req, res) => { try { if (!Array.isArray(req.body.hashes) || !req.body.collectionId) { return res.sendStatus(400); } const collectionId = String(req.body.collectionId); const hashes = req.body.hashes.map(x => Number(x)); const source = String(req.body.source) || 'transformers'; await deleteVectorItems(req.user.directories, collectionId, source, hashes); return res.sendStatus(200); } catch (error) { return regenerateCorruptedIndexErrorHandler(req, res, error); } }); router.post('/purge-all', jsonParser, async (req, res) => { try { for (const source of SOURCES) { const sourcePath = path.join(req.user.directories.vectors, sanitize(source)); if (!fs.existsSync(sourcePath)) { continue; } await fs.promises.rm(sourcePath, { recursive: true }); console.log(`Deleted vector source store at ${sourcePath}`); } return res.sendStatus(200); } catch (error) { console.error(error); return res.sendStatus(500); } }); router.post('/purge', jsonParser, async (req, res) => { try { if (!req.body.collectionId) { return res.sendStatus(400); } const collectionId = String(req.body.collectionId); for (const source of SOURCES) { const index = await getIndex(req.user.directories, collectionId, source, false); const exists = await index.isIndexCreated(); if (!exists) { continue; } const path = index.folderPath; await index.deleteIndex(); console.log(`Deleted vector index at ${path}`); } return res.sendStatus(200); } catch (error) { console.error(error); return res.sendStatus(500); } }); module.exports = { router };