Skip to content

Improve extensibility of SimpleVectoreStore #1662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,14 @@
import java.nio.file.FileAlreadyExistsException;
import java.nio.file.Files;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.databind.json.JsonMapper;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
Expand All @@ -50,14 +47,16 @@
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;

/**
* SimpleVectorStore is a simple implementation of the VectorStore interface.
*
* Simple, in-memory implementation of the {@link VectorStore} interface.
* <p/>
* It also provides methods to save the current state of the vectors to a file, and to
* load vectors from a file.
*
* <p/>
* For a deeper understanding of the mathematical concepts and computations involved in
* calculating similarity scores among vectors, refer to this
* [resource](https://docs.spring.io/spring-ai/reference/api/vectordbs.html#_understanding_vectors).
Expand All @@ -67,6 +66,8 @@
* @author Mark Pollack
* @author Christian Tzolov
* @author Sebastien Deleuze
* @author John Blum
* @see VectorStore
*/
public class SimpleVectorStore extends AbstractObservationVectorStore {

Expand All @@ -87,54 +88,72 @@ public SimpleVectorStore(EmbeddingModel embeddingModel, ObservationRegistry obse

super(observationRegistry, customObservationConvention);

Objects.requireNonNull(embeddingModel, "EmbeddingModel must not be null");
Assert.notNull(embeddingModel, "EmbeddingModel must not be null");

this.embeddingModel = embeddingModel;
this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
}

@Override
public void doAdd(List<Document> documents) {
for (Document document : documents) {
logger.info("Calling EmbeddingModel for document id = {}", document.getId());
float[] embedding = this.embeddingModel.embed(document);
document.setEmbedding(embedding);
logger.info("Calling EmbeddingModel for Document id = {}", document.getId());
document = embed(document);
this.store.put(document.getId(), document);
}
}

protected Document embed(Document document) {
float[] documentEmbedding = this.embeddingModel.embed(document);
document.setEmbedding(documentEmbedding);
return document;
}

@Override
public Optional<Boolean> doDelete(List<String> idList) {
for (String id : idList) {
this.store.remove(id);
}
idList.forEach(this.store::remove);
return Optional.of(true);
}

@Override
public List<Document> doSimilaritySearch(SearchRequest request) {

if (request.getFilterExpression() != null) {
throw new UnsupportedOperationException(
"The [" + this.getClass() + "] doesn't support metadata filtering!");
"[%s] doesn't support metadata filtering".formatted(getClass().getName()));
}

float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery());
return this.store.values()
.stream()
.map(entry -> new Similarity(entry.getId(),
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
.filter(s -> s.score >= request.getSimilarityThreshold())
.sorted(Comparator.<Similarity>comparingDouble(s -> s.score).reversed())
// @formatter:off
return this.store.values().stream()
.map(document -> computeSimilarity(request, document))
.filter(similarity -> similarity.score >= request.getSimilarityThreshold())
.sorted(Comparator.<Similarity>comparingDouble(similarity -> similarity.score).reversed())
.limit(request.getTopK())
.map(s -> this.store.get(s.key))
.map(similarity -> this.store.get(similarity.key))
.toList();
// @formatter:on
}

protected Similarity computeSimilarity(SearchRequest request, Document document) {

float[] userQueryEmbedding = getUserQueryEmbedding(request);
float[] documentEmbedding = document.getEmbedding();

double score = computeCosineSimilarity(userQueryEmbedding, documentEmbedding);

return new Similarity(document.getId(), score);
}

protected double computeCosineSimilarity(float[] userQueryEmbedding, float[] storedDocumentEmbedding) {
return EmbeddingMath.cosineSimilarity(userQueryEmbedding, storedDocumentEmbedding);
}

/**
* Serialize the vector store content into a file in JSON format.
* @param file the file to save the vector store content
*/
public void save(File file) {
String json = getVectorDbAsJson();

try {
if (!file.exists()) {
logger.info("Creating new vector store file: {}", file);
Expand All @@ -145,28 +164,22 @@ public void save(File file) {
throw new RuntimeException("File already exists: " + file, e);
}
catch (IOException e) {
throw new RuntimeException("Failed to create new file: " + file + ". Reason: " + e.getMessage(), e);
throw new RuntimeException("Failed to create new file: " + file + "; Reason: " + e.getMessage(), e);
}
}
else {
logger.info("Overwriting existing vector store file: {}", file);
}

try (OutputStream stream = new FileOutputStream(file);
Writer writer = new OutputStreamWriter(stream, StandardCharsets.UTF_8)) {
String json = getVectorDbAsJson();
writer.write(json);
writer.flush();
}
}
catch (IOException ex) {
logger.error("IOException occurred while saving vector store file.", ex);
throw new RuntimeException(ex);
}
catch (SecurityException ex) {
logger.error("SecurityException occurred while saving vector store file.", ex);
throw new RuntimeException(ex);
}
catch (NullPointerException ex) {
logger.error("NullPointerException occurred while saving vector store file.", ex);
catch (IOException | NullPointerException | SecurityException ex) {
logger.error("%s occurred while saving vector store file".formatted(ex.getClass().getSimpleName()), ex);
throw new RuntimeException(ex);
}
}
Expand All @@ -176,45 +189,40 @@ public void save(File file) {
* @param file the file to load the vector store content
*/
public void load(File file) {
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {

};
try {
Map<String, Document> deserializedMap = this.objectMapper.readValue(file, typeRef);
this.store = deserializedMap;
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
load(new FileSystemResource(file));
}

/**
* Deserialize the vector store content from a resource in JSON format into memory.
* @param resource the resource to load the vector store content
*/
public void load(Resource resource) {
TypeReference<HashMap<String, Document>> typeRef = new TypeReference<>() {

};
try {
Map<String, Document> deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef);
this.store = deserializedMap;
this.store = this.objectMapper.readValue(resource.getInputStream(), documentMapTypeRef());
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
}

private TypeReference<Map<String, Document>> documentMapTypeRef() {
return new TypeReference<>() {
};
}

private String getVectorDbAsJson() {
ObjectWriter objectWriter = this.objectMapper.writerWithDefaultPrettyPrinter();
String json;

try {
json = objectWriter.writeValueAsString(this.store);
return this.objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(this.store);
}
catch (JsonProcessingException e) {
throw new RuntimeException("Error serializing documentMap to JSON.", e);
throw new RuntimeException("Error serializing Map of Documents to JSON", e);
}
return json;
}

private float[] getUserQueryEmbedding(SearchRequest request) {
return getUserQueryEmbedding(request.getQuery());
}

private float[] getUserQueryEmbedding(String query) {
Expand All @@ -232,9 +240,9 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str

public static class Similarity {

private String key;
private final String key;

private double score;
private final double score;

public Similarity(String key, double score) {
this.key = key;
Expand All @@ -243,16 +251,18 @@ public Similarity(String key, double score) {

}

public final class EmbeddingMath {
public static final class EmbeddingMath {

private EmbeddingMath() {
throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
}

public static double cosineSimilarity(float[] vectorX, float[] vectorY) {

if (vectorX == null || vectorY == null) {
throw new RuntimeException("Vectors must not be null");
throw new IllegalArgumentException("Vectors must not be null");
}

if (vectorX.length != vectorY.length) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}
Expand All @@ -268,20 +278,22 @@ public static double cosineSimilarity(float[] vectorX, float[] vectorY) {
return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
}

public static float dotProduct(float[] vectorX, float[] vectorY) {
private static float dotProduct(float[] vectorX, float[] vectorY) {

if (vectorX.length != vectorY.length) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}

float result = 0;
for (int i = 0; i < vectorX.length; ++i) {
result += vectorX[i] * vectorY[i];

for (int index = 0; index < vectorX.length; ++index) {
result += vectorX[index] * vectorY[index];
}

return result;
}

public static float norm(float[] vector) {
private static float norm(float[] vector) {
return dotProduct(vector, vector);
}

Expand Down