Búsqueda de embeddings rápida usando Python

Búsqueda de embeddings rápida usando Python
Photo by Robin Pierre / Unsplash

Estos días he estado enfrascado en una problemática muy interesante. Tengo una base de datos en Postgres en la que guardo muchos embeddings.

A la hora de hacer una búsqueda, estaba descargando los embeddings y haciendo una búsqueda por similaridad de coseno, pero con el volumen de datos que tenemos la búsqueda estaba resultando muy lenta (30-40seg en algunos casos). El objetivo estaba en mantenerlo por debajo de 500ms para que la experiencia de usuario fuera satisfactoria, por lo que estamos hablando de un incremento de velocidad de 1000!

Investigando por internet, me encontré con Faiss, una librería de Facebook para búsqueda eficiente y rápida de embeddings. La instalación fue un poco compleja, pero los resultados han sido impresionantes!

Aquí podemos ver un ejemplo en el que la búsqueda por coseno tardaba ~20s, y tras la optimización con Faiss se reduce a ~0.4s!

Cómo hacer una búsqueda de embeddings con Faiss

En primer lugar, tras seguir los pasos de instalación de la librería, inicializo un índice de Faiss:

dimension = 768 # dimensionalidad de los embeddings
index = faiss.IndexFlatL2(dimension)

Una vez el índice está inicializado, hay que entrenarlo con los datos. En el caso de usar IndexFlatL2, no es necesario el entrenamiento así que podemos añadir directamente los datos:

embeddings = get_embeddings()
index.add(embeddings)

Una vez añadidos los embeddings, solamente hay que llamar al método search:

number_of_results = 10
query = 'bitcoin'
query_embeddings = convert_to_embeddings(query)
matches, indexes = index.search(query_embeddings, number_of_results)

Tras la búsqueda, en indexes tendremos el índice del elemento de embeddings que dio resultado. En mi caso, esto era un poco problemático porque lo que yo quería era el ID de mi base de datos al que se refería el embedding. Esto lo solventé creando un array de IDs que se corresponde con los embeddings, de forma que si accedo a los índices de este array, obtengo directamente los IDs de mi base de datos que están relacionados con la búsqueda del usuario.