From 657eeb615638cb5b1687d64e68e6393f7dbf0714 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Fri, 21 Feb 2025 15:35:12 -0700 Subject: [PATCH 1/2] feat(redis): Add Redis-based semantic caching and chat memory implementations Add comprehensive Redis-backed features to enhance Spring AI: * Add semantic caching for chat responses: - SemanticCache interface and Redis implementation using vector similarity - SemanticCacheAdvisor for intercepting and caching chat responses - Uses vector search to cache and retrieve responses based on query similarity - Support for TTL-based cache expiration - Improves response times and reduces API costs for similar questions * Add Redis-based chat memory implementation: - RedisChatMemory using RedisJSON + RediSearch for conversation storage - Configurable RedisChatMemoryConfig with builder pattern support - Message TTL, ordering, multi-conversation and batch operations - Efficient conversation history retrieval using RediSearch indexes * Add integration tests: - Comprehensive test coverage using TestContainers - Tests for semantic caching features and chat memory operations - Integration test for RedisVectorStore with VectorStoreChatMemoryAdvisor - Verify chat completion augmentation with vector store content The Redis implementations enable efficient storage and retrieval of chat responses and conversation history, with semantic search capabilities and configurable persistence options. Signed-off-by: Brian Sam-Bodden --- .../ROOT/pages/api/vectordbs/redis.adoc | 65 ++++ vector-stores/spring-ai-redis-store/pom.xml | 7 + .../cache/semantic/SemanticCacheAdvisor.java | 188 ++++++++++ .../ai/chat/memory/redis/RedisChatMemory.java | 228 +++++++++++ .../memory/redis/RedisChatMemoryConfig.java | 158 ++++++++ .../cache/semantic/DefaultSemanticCache.java | 354 ++++++++++++++++++ .../redis/cache/semantic/SemanticCache.java | 91 +++++ .../semantic/SemanticCacheAdvisorIT.java | 226 +++++++++++ .../chat/memory/redis/RedisChatMemoryIT.java | 227 +++++++++++ ...disVectorStoreWithChatMemoryAdvisorIT.java | 133 +++++++ 10 files changed, 1677 insertions(+) create mode 100644 vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisor.java create mode 100644 vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemory.java create mode 100644 vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryConfig.java create mode 100644 vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/DefaultSemanticCache.java create mode 100644 vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/SemanticCache.java create mode 100644 vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java create mode 100644 vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryIT.java create mode 100644 vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc index 99782a0c5f1..59b205c127c 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc @@ -9,6 +9,8 @@ link:https://redis.io/docs/interact/search-and-query/[Redis Search and Query] ex * Store vectors and the associated metadata within hashes or JSON documents * Retrieve vectors * Perform vector searches +* Cache chat responses based on semantic similarity +* Store and query conversation history == Prerequisites @@ -167,6 +169,69 @@ is converted into the proprietary Redis filter format: @country:{UK | NL} @year:[2020 inf] ---- +=== Semantic Cache Usage + +The semantic cache provides vector similarity-based caching for chat responses implemented as an advisor: + +[source,java] +---- +// Create semantic cache +SemanticCache semanticCache = DefaultSemanticCache.builder() + .embeddingModel(embeddingModel) + .jedisClient(jedisClient) + .similarityThreshold(0.95) // Optional: defaults to 0.95 + .build(); + +// Create cache advisor +SemanticCacheAdvisor cacheAdvisor = SemanticCacheAdvisor.builder() + .cache(semanticCache) + .build(); + +// Use with chat client +ChatResponse response = ChatClient.builder(chatModel) + .build() + .prompt("What is the capital of France?") + .advisors(cacheAdvisor) + .call() + .chatResponse(); + +// Manually interact with cache +semanticCache.set("query", chatResponse); +semanticCache.set("query", chatResponse, Duration.ofHours(1)); // With TTL +Optional cached = semanticCache.get("similar query"); +---- + +=== Chat Memory Usage + +RedisChatMemory provides persistent storage for conversation history: + +[source,java] +---- +// Create chat memory +RedisChatMemory chatMemory = RedisChatMemory.builder() + .jedisClient(jedisClient) + .timeToLive(Duration.ofHours(24)) // Optional: message TTL + .indexName("custom-memory-index") // Optional + .keyPrefix("custom-prefix") // Optional + .build(); + +// Add messages +chatMemory.add("conversation-1", new UserMessage("Hello")); +chatMemory.add("conversation-1", new AssistantMessage("Hi there!")); + +// Add multiple messages +chatMemory.add("conversation-1", List.of( + new UserMessage("How are you?"), + new AssistantMessage("I'm doing well!") +)); + +// Retrieve messages +List messages = chatMemory.get("conversation-1", 10); // Last 10 messages + +// Clear conversation +chatMemory.clear("conversation-1"); +---- + == Manual Configuration Instead of using the Spring Boot auto-configuration, you can manually configure the Redis vector store. For this you need to add the `spring-ai-redis-store` to your project: diff --git a/vector-stores/spring-ai-redis-store/pom.xml b/vector-stores/spring-ai-redis-store/pom.xml index 5b7576df8b6..dafc9f25215 100644 --- a/vector-stores/spring-ai-redis-store/pom.xml +++ b/vector-stores/spring-ai-redis-store/pom.xml @@ -101,6 +101,13 @@ test + + org.springframework.ai + spring-ai-openai + ${project.parent.version} + test + + diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisor.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisor.java new file mode 100644 index 00000000000..3f9efb5972b --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisor.java @@ -0,0 +1,188 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.cache.semantic; + +import org.springframework.ai.chat.client.advisor.api.*; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.vectorstore.redis.cache.semantic.SemanticCache; +import reactor.core.publisher.Flux; + +import java.util.Optional; + +/** + * An advisor implementation that provides semantic caching capabilities for chat + * responses. This advisor intercepts chat requests and checks for semantically similar + * cached responses before allowing the request to proceed to the model. + * + *

+ * This advisor implements both {@link CallAroundAdvisor} for synchronous operations and + * {@link StreamAroundAdvisor} for reactive streaming operations. + *

+ * + *

+ * Key features: + *

    + *
  • Semantic similarity based caching of responses
  • + *
  • Support for both synchronous and streaming chat operations
  • + *
  • Configurable execution order in the advisor chain
  • + *
+ * + * @author Brian Sam-Bodden + */ +public class SemanticCacheAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + /** The underlying semantic cache implementation */ + private final SemanticCache cache; + + /** The order of this advisor in the chain */ + private final int order; + + /** + * Creates a new semantic cache advisor with default order. + * @param cache The semantic cache implementation to use + */ + public SemanticCacheAdvisor(SemanticCache cache) { + this(cache, Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); + } + + /** + * Creates a new semantic cache advisor with specified order. + * @param cache The semantic cache implementation to use + * @param order The order of this advisor in the chain + */ + public SemanticCacheAdvisor(SemanticCache cache, int order) { + this.cache = cache; + this.order = order; + } + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return this.order; + } + + /** + * Handles synchronous chat requests by checking the cache before proceeding. If a + * semantically similar response is found in the cache, it is returned immediately. + * Otherwise, the request proceeds through the chain and the response is cached. + * @param request The chat request to process + * @param chain The advisor chain to continue processing if needed + * @return The response, either from cache or from the model + */ + @Override + public AdvisedResponse aroundCall(AdvisedRequest request, CallAroundAdvisorChain chain) { + // Check cache first + Optional cached = cache.get(request.userText()); + + if (cached.isPresent()) { + return new AdvisedResponse(cached.get(), request.adviseContext()); + } + + // Cache miss - call the model + AdvisedResponse response = chain.nextAroundCall(request); + + // Cache the response + if (response.response() != null) { + cache.set(request.userText(), response.response()); + } + + return response; + } + + /** + * Handles streaming chat requests by checking the cache before proceeding. If a + * semantically similar response is found in the cache, it is returned as a single + * item flux. Otherwise, the request proceeds through the chain and the final response + * is cached. + * @param request The chat request to process + * @param chain The advisor chain to continue processing if needed + * @return A Flux of responses, either from cache or from the model + */ + @Override + public Flux aroundStream(AdvisedRequest request, StreamAroundAdvisorChain chain) { + // Check cache first + Optional cached = cache.get(request.userText()); + + if (cached.isPresent()) { + return Flux.just(new AdvisedResponse(cached.get(), request.adviseContext())); + } + + // Cache miss - stream from model + return chain.nextAroundStream(request).collectList().flatMapMany(responses -> { + // Cache the final aggregated response + if (!responses.isEmpty()) { + AdvisedResponse last = responses.get(responses.size() - 1); + if (last.response() != null) { + cache.set(request.userText(), last.response()); + } + } + return Flux.fromIterable(responses); + }); + } + + /** + * Creates a new builder for constructing SemanticCacheAdvisor instances. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder class for creating SemanticCacheAdvisor instances. Provides a fluent API + * for configuration. + */ + public static class Builder { + + private SemanticCache cache; + + private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER; + + /** + * Sets the semantic cache implementation. + * @param cache The cache implementation to use + * @return This builder instance + */ + public Builder cache(SemanticCache cache) { + this.cache = cache; + return this; + } + + /** + * Sets the advisor order. + * @param order The order value for this advisor + * @return This builder instance + */ + public Builder order(int order) { + this.order = order; + return this; + } + + /** + * Builds and returns a new SemanticCacheAdvisor instance. + * @return A new SemanticCacheAdvisor configured with this builder's settings + */ + public SemanticCacheAdvisor build() { + return new SemanticCacheAdvisor(cache, order); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemory.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemory.java new file mode 100644 index 00000000000..a0fc4e3418e --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemory.java @@ -0,0 +1,228 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.redis; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.util.Assert; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.Pipeline; +import redis.clients.jedis.json.Path2; +import redis.clients.jedis.search.*; +import redis.clients.jedis.search.schemafields.NumericField; +import redis.clients.jedis.search.schemafields.SchemaField; +import redis.clients.jedis.search.schemafields.TagField; +import redis.clients.jedis.search.schemafields.TextField; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Redis implementation of {@link ChatMemory} using Redis Stack (RedisJSON + RediSearch). + * Stores chat messages as JSON documents and uses RediSearch for querying. + * + * @author Brian Sam-Bodden + */ +public final class RedisChatMemory implements ChatMemory { + + private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class); + + private static final Gson gson = new Gson(); + + private static final Path2 ROOT_PATH = Path2.of("$"); + + private final RedisChatMemoryConfig config; + + private final JedisPooled jedis; + + public RedisChatMemory(RedisChatMemoryConfig config) { + Assert.notNull(config, "Config must not be null"); + this.config = config; + this.jedis = config.getJedisClient(); + + if (config.isInitializeSchema()) { + initializeSchema(); + } + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void add(String conversationId, List messages) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.notNull(messages, "Messages must not be null"); + + final AtomicLong timestampSequence = new AtomicLong(Instant.now().toEpochMilli()); + try (Pipeline pipeline = jedis.pipelined()) { + for (Message message : messages) { + String key = createKey(conversationId, timestampSequence.getAndIncrement()); + String json = gson.toJson(createMessageDocument(conversationId, message)); + pipeline.jsonSet(key, ROOT_PATH, json); + + if (config.getTimeToLiveSeconds() != -1) { + pipeline.expire(key, config.getTimeToLiveSeconds()); + } + } + pipeline.sync(); + } + } + + @Override + public void add(String conversationId, Message message) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.notNull(message, "Message must not be null"); + + String key = createKey(conversationId, Instant.now().toEpochMilli()); + String json = gson.toJson(createMessageDocument(conversationId, message)); + + jedis.jsonSet(key, ROOT_PATH, json); + if (config.getTimeToLiveSeconds() != -1) { + jedis.expire(key, config.getTimeToLiveSeconds()); + } + } + + @Override + public List get(String conversationId, int lastN) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.isTrue(lastN > 0, "LastN must be greater than 0"); + + String queryStr = String.format("@conversation_id:{%s}", RediSearchUtil.escape(conversationId)); + Query query = new Query(queryStr).setSortBy("timestamp", true).limit(0, lastN); + + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + + List messages = new ArrayList<>(); + result.getDocuments().forEach(doc -> { + if (doc.get("$") != null) { + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); + String type = json.get("type").getAsString(); + String content = json.get("content").getAsString(); + + if (MessageType.ASSISTANT.toString().equals(type)) { + messages.add(new AssistantMessage(content)); + } + else if (MessageType.USER.toString().equals(type)) { + messages.add(new UserMessage(content)); + } + } + }); + + return messages; + } + + @Override + public void clear(String conversationId) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + + String queryStr = String.format("@conversation_id:{%s}", RediSearchUtil.escape(conversationId)); + Query query = new Query(queryStr); + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + + try (Pipeline pipeline = jedis.pipelined()) { + result.getDocuments().forEach(doc -> pipeline.del(doc.getId())); + pipeline.sync(); + } + } + + private void initializeSchema() { + try { + if (!jedis.ftList().contains(config.getIndexName())) { + List schemaFields = new ArrayList<>(); + schemaFields.add(new TextField("$.content").as("content")); + schemaFields.add(new TextField("$.type").as("type")); + schemaFields.add(new TagField("$.conversation_id").as("conversation_id")); + schemaFields.add(new NumericField("$.timestamp").as("timestamp")); + + String response = jedis.ftCreate(config.getIndexName(), + FTCreateParams.createParams().on(IndexDataType.JSON).prefix(config.getKeyPrefix()), + schemaFields.toArray(new SchemaField[0])); + + if (!response.equals("OK")) { + throw new IllegalStateException("Failed to create index: " + response); + } + } + } + catch (Exception e) { + logger.error("Failed to initialize Redis schema", e); + throw new IllegalStateException("Could not initialize Redis schema", e); + } + } + + private String createKey(String conversationId, long timestamp) { + return String.format("%s%s:%d", config.getKeyPrefix(), escapeKey(conversationId), timestamp); + } + + private Map createMessageDocument(String conversationId, Message message) { + return Map.of("type", message.getMessageType().toString(), "content", message.getText(), "conversation_id", + conversationId, "timestamp", Instant.now().toEpochMilli()); + } + + private String escapeKey(String key) { + return key.replace(":", "\\:"); + } + + /** + * Builder for RedisChatMemory configuration. + */ + public static class Builder { + + private final RedisChatMemoryConfig.Builder configBuilder = RedisChatMemoryConfig.builder(); + + public Builder jedisClient(JedisPooled jedisClient) { + configBuilder.jedisClient(jedisClient); + return this; + } + + public Builder timeToLive(Duration ttl) { + configBuilder.timeToLive(ttl); + return this; + } + + public Builder indexName(String indexName) { + configBuilder.indexName(indexName); + return this; + } + + public Builder keyPrefix(String keyPrefix) { + configBuilder.keyPrefix(keyPrefix); + return this; + } + + public Builder initializeSchema(boolean initialize) { + configBuilder.initializeSchema(initialize); + return this; + } + + public RedisChatMemory build() { + return new RedisChatMemory(configBuilder.build()); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryConfig.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryConfig.java new file mode 100644 index 00000000000..fe4323d5418 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryConfig.java @@ -0,0 +1,158 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.redis; + +import java.time.Duration; + +import redis.clients.jedis.JedisPooled; + +import org.springframework.util.Assert; + +/** + * Configuration class for RedisChatMemory. + * + * @author Brian Sam-Bodden + */ +public class RedisChatMemoryConfig { + + public static final String DEFAULT_INDEX_NAME = "chat-memory-idx"; + + public static final String DEFAULT_KEY_PREFIX = "chat-memory:"; + + private final JedisPooled jedisClient; + + private final String indexName; + + private final String keyPrefix; + + private final Integer timeToLiveSeconds; + + private final boolean initializeSchema; + + private RedisChatMemoryConfig(Builder builder) { + Assert.notNull(builder.jedisClient, "JedisPooled client must not be null"); + Assert.hasText(builder.indexName, "Index name must not be empty"); + Assert.hasText(builder.keyPrefix, "Key prefix must not be empty"); + + this.jedisClient = builder.jedisClient; + this.indexName = builder.indexName; + this.keyPrefix = builder.keyPrefix; + this.timeToLiveSeconds = builder.timeToLiveSeconds; + this.initializeSchema = builder.initializeSchema; + } + + public static Builder builder() { + return new Builder(); + } + + public JedisPooled getJedisClient() { + return jedisClient; + } + + public String getIndexName() { + return indexName; + } + + public String getKeyPrefix() { + return keyPrefix; + } + + public Integer getTimeToLiveSeconds() { + return timeToLiveSeconds; + } + + public boolean isInitializeSchema() { + return initializeSchema; + } + + /** + * Builder for RedisChatMemoryConfig. + */ + public static class Builder { + + private JedisPooled jedisClient; + + private String indexName = DEFAULT_INDEX_NAME; + + private String keyPrefix = DEFAULT_KEY_PREFIX; + + private Integer timeToLiveSeconds = -1; + + private boolean initializeSchema = true; + + /** + * Sets the Redis client. + * @param jedisClient the Redis client to use + * @return the builder instance + */ + public Builder jedisClient(JedisPooled jedisClient) { + this.jedisClient = jedisClient; + return this; + } + + /** + * Sets the index name. + * @param indexName the index name to use + * @return the builder instance + */ + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + + /** + * Sets the key prefix. + * @param keyPrefix the key prefix to use + * @return the builder instance + */ + public Builder keyPrefix(String keyPrefix) { + this.keyPrefix = keyPrefix; + return this; + } + + /** + * Sets the time-to-live duration. + * @param ttl the time-to-live duration + * @return the builder instance + */ + public Builder timeToLive(Duration ttl) { + if (ttl != null) { + this.timeToLiveSeconds = (int) ttl.toSeconds(); + } + return this; + } + + /** + * Sets whether to initialize the schema. + * @param initialize true to initialize schema, false otherwise + * @return the builder instance + */ + public Builder initializeSchema(boolean initialize) { + this.initializeSchema = initialize; + return this; + } + + /** + * Builds a new RedisChatMemoryConfig instance. + * @return the new configuration instance + */ + public RedisChatMemoryConfig build() { + return new RedisChatMemoryConfig(this); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/DefaultSemanticCache.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/DefaultSemanticCache.java new file mode 100644 index 00000000000..1309cb6dab5 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/DefaultSemanticCache.java @@ -0,0 +1,354 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore.redis.cache.semantic; + +import com.google.gson.*; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.redis.RedisVectorStore; +import org.springframework.ai.vectorstore.redis.RedisVectorStore.MetadataField; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.Pipeline; +import redis.clients.jedis.search.Query; +import redis.clients.jedis.search.SearchResult; + +import java.lang.reflect.Type; +import java.time.Duration; +import java.util.*; + +/** + * Default implementation of SemanticCache using Redis as the backing store. This + * implementation uses vector similarity search to find cached responses for semantically + * similar queries. + * + * @author Brian Sam-Bodden + */ +public class DefaultSemanticCache implements SemanticCache { + + // Default configuration constants + private static final String DEFAULT_INDEX_NAME = "semantic-cache-index"; + + private static final String DEFAULT_PREFIX = "semantic-cache:"; + + private static final Integer DEFAULT_BATCH_SIZE = 100; + + private static final double DEFAULT_SIMILARITY_THRESHOLD = 0.95; + + // Core components + private final VectorStore vectorStore; + + private final EmbeddingModel embeddingModel; + + private final double similarityThreshold; + + private final Gson gson; + + private final String prefix; + + private final String indexName; + + /** + * Private constructor enforcing builder pattern usage. + */ + private DefaultSemanticCache(VectorStore vectorStore, EmbeddingModel embeddingModel, double similarityThreshold, + String indexName, String prefix) { + this.vectorStore = vectorStore; + this.embeddingModel = embeddingModel; + this.similarityThreshold = similarityThreshold; + this.prefix = prefix; + this.indexName = indexName; + this.gson = createGson(); + } + + /** + * Creates a customized Gson instance with type adapters for special types. + */ + private Gson createGson() { + return new GsonBuilder() // + .registerTypeAdapter(Duration.class, new DurationAdapter()) // + .registerTypeAdapter(ChatResponse.class, new ChatResponseAdapter()) // + .create(); + } + + @Override + public VectorStore getStore() { + return this.vectorStore; + } + + @Override + public void set(String query, ChatResponse response) { + // Convert response to JSON for storage + String responseJson = gson.toJson(response); + String responseText = response.getResult().getOutput().getText(); + + // Create metadata map for the document + Map metadata = new HashMap<>(); + metadata.put("response", responseJson); + metadata.put("response_text", responseText); + + // Create document with query as text (for embedding) and response in metadata + Document document = Document.builder().text(query).metadata(metadata).build(); + + // Check for and remove any existing similar documents + List existing = vectorStore.similaritySearch( + SearchRequest.builder().query(query).topK(1).similarityThreshold(similarityThreshold).build()); + + // If similar document exists, delete it first + if (!existing.isEmpty()) { + vectorStore.delete(List.of(existing.get(0).getId())); + } + + // Add new document to vector store + vectorStore.add(List.of(document)); + } + + @Override + public void set(String query, ChatResponse response, Duration ttl) { + // Generate a unique ID for the document + String docId = UUID.randomUUID().toString(); + + // Convert response to JSON + String responseJson = gson.toJson(response); + String responseText = response.getResult().getOutput().getText(); + + // Create metadata + Map metadata = new HashMap<>(); + metadata.put("response", responseJson); + metadata.put("response_text", responseText); + + // Create document with generated ID + Document document = Document.builder().id(docId).text(query).metadata(metadata).build(); + + // Remove any existing similar documents + List existing = vectorStore.similaritySearch( + SearchRequest.builder().query(query).topK(1).similarityThreshold(similarityThreshold).build()); + + // If similar document exists, delete it first + if (!existing.isEmpty()) { + vectorStore.delete(List.of(existing.get(0).getId())); + } + + // Add document to vector store + vectorStore.add(List.of(document)); + + // Get access to Redis client and set TTL + if (vectorStore instanceof RedisVectorStore redisStore) { + String key = prefix + docId; + redisStore.getJedis().expire(key, ttl.getSeconds()); + } + } + + @Override + public Optional get(String query) { + // Search for similar documents + List similar = vectorStore.similaritySearch( + SearchRequest.builder().query(query).topK(1).similarityThreshold(similarityThreshold).build()); + + if (similar.isEmpty()) { + return Optional.empty(); + } + + Document mostSimilar = similar.get(0); + + // Get stored response JSON from metadata + String responseJson = (String) mostSimilar.getMetadata().get("response"); + if (responseJson == null) { + return Optional.empty(); + } + + // Attempt to parse stored response + try { + ChatResponse response = gson.fromJson(responseJson, ChatResponse.class); + return Optional.of(response); + } + catch (JsonParseException e) { + return Optional.empty(); + } + } + + @Override + public void clear() { + Optional nativeClient = vectorStore.getNativeClient(); + if (nativeClient.isPresent()) { + JedisPooled jedis = nativeClient.get(); + + // Delete documents in batches to avoid memory issues + boolean moreRecords = true; + while (moreRecords) { + Query query = new Query("*"); + query.limit(0, DEFAULT_BATCH_SIZE); // Reasonable batch size + query.setNoContent(); + + SearchResult searchResult = jedis.ftSearch(this.indexName, query); + + if (searchResult.getTotalResults() > 0) { + try (Pipeline pipeline = jedis.pipelined()) { + for (redis.clients.jedis.search.Document doc : searchResult.getDocuments()) { + pipeline.jsonDel(doc.getId()); + } + pipeline.syncAndReturnAll(); + } + } + else { + moreRecords = false; + } + } + } + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating DefaultSemanticCache instances. + */ + public static class Builder { + + private VectorStore vectorStore; + + private EmbeddingModel embeddingModel; + + private double similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD; + + private String indexName = DEFAULT_INDEX_NAME; + + private String prefix = DEFAULT_PREFIX; + + private JedisPooled jedisClient; + + // Builder methods with validation + public Builder vectorStore(VectorStore vectorStore) { + this.vectorStore = vectorStore; + return this; + } + + public Builder embeddingModel(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + return this; + } + + public Builder similarityThreshold(double threshold) { + this.similarityThreshold = threshold; + return this; + } + + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + + public Builder prefix(String prefix) { + this.prefix = prefix; + return this; + } + + public Builder jedisClient(JedisPooled jedisClient) { + this.jedisClient = jedisClient; + return this; + } + + public DefaultSemanticCache build() { + if (vectorStore == null) { + if (jedisClient == null) { + throw new IllegalStateException("Either vectorStore or jedisClient must be provided"); + } + if (embeddingModel == null) { + throw new IllegalStateException("EmbeddingModel must be provided"); + } + vectorStore = RedisVectorStore.builder(jedisClient, embeddingModel) + .indexName(indexName) + .prefix(prefix) + .metadataFields( // + MetadataField.text("response"), // + MetadataField.text("response_text"), // + MetadataField.numeric("ttl")) // + .initializeSchema(true) + .build(); + if (vectorStore instanceof RedisVectorStore redisStore) { + redisStore.afterPropertiesSet(); + } + } + return new DefaultSemanticCache(vectorStore, embeddingModel, similarityThreshold, indexName, prefix); + } + + } + + /** + * Type adapter for serializing/deserializing Duration objects. + */ + private static class DurationAdapter implements JsonSerializer, JsonDeserializer { + + @Override + public JsonElement serialize(Duration duration, Type type, JsonSerializationContext context) { + return new JsonPrimitive(duration.toSeconds()); + } + + @Override + public Duration deserialize(JsonElement json, Type type, JsonDeserializationContext context) + throws JsonParseException { + return Duration.ofSeconds(json.getAsLong()); + } + + } + + /** + * Type adapter for serializing/deserializing ChatResponse objects. + */ + private static class ChatResponseAdapter implements JsonSerializer, JsonDeserializer { + + @Override + public JsonElement serialize(ChatResponse response, Type type, JsonSerializationContext context) { + JsonObject jsonObject = new JsonObject(); + + // Handle generations + JsonArray generations = new JsonArray(); + for (Generation generation : response.getResults()) { + JsonObject generationObj = new JsonObject(); + Message output = (Message) generation.getOutput(); + generationObj.addProperty("text", output.getText()); + generations.add(generationObj); + } + jsonObject.add("generations", generations); + + return jsonObject; + } + + @Override + public ChatResponse deserialize(JsonElement json, Type type, JsonDeserializationContext context) + throws JsonParseException { + JsonObject jsonObject = json.getAsJsonObject(); + + List generations = new ArrayList<>(); + JsonArray generationsArray = jsonObject.getAsJsonArray("generations"); + for (JsonElement element : generationsArray) { + JsonObject generationObj = element.getAsJsonObject(); + String text = generationObj.get("text").getAsString(); + generations.add(new Generation(new AssistantMessage(text))); + } + + return ChatResponse.builder().generations(generations).build(); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/SemanticCache.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/SemanticCache.java new file mode 100644 index 00000000000..d678107a9a7 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/SemanticCache.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore.redis.cache.semantic; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.vectorstore.VectorStore; + +import java.time.Duration; +import java.util.Optional; + +/** + * Interface defining operations for a semantic cache implementation that stores and + * retrieves chat responses based on semantic similarity of queries. This cache uses + * vector embeddings to determine similarity between queries. + * + *

+ * The semantic cache provides functionality to: + *

    + *
  • Store chat responses with their associated queries
  • + *
  • Retrieve responses for semantically similar queries
  • + *
  • Support time-based expiration of cached entries
  • + *
  • Clear the entire cache
  • + *
+ * + *

+ * Implementations should ensure thread-safety and proper resource management. + * + * @author Brian Sam-Bodden + */ +public interface SemanticCache { + + /** + * Stores a query and its corresponding chat response in the cache. Implementations + * should handle vector embedding of the query and proper storage of both the query + * embedding and response. + * @param query The original query text to be cached + * @param response The chat response associated with the query + */ + void set(String query, ChatResponse response); + + /** + * Stores a query and response in the cache with a specified time-to-live duration. + * After the TTL expires, the entry should be automatically removed from the cache. + * @param query The original query text to be cached + * @param response The chat response associated with the query + * @param ttl The duration after which the cache entry should expire + */ + void set(String query, ChatResponse response, Duration ttl); + + /** + * Retrieves a cached response for a semantically similar query. The implementation + * should: + *

    + *
  • Convert the input query to a vector embedding
  • + *
  • Search for similar query embeddings in the cache
  • + *
  • Return the response associated with the most similar query if it meets the + * similarity threshold
  • + *
+ * @param query The query to find similar responses for + * @return Optional containing the most similar cached response if found and meets + * similarity threshold, empty Optional otherwise + */ + Optional get(String query); + + /** + * Removes all entries from the cache. This operation should be atomic and + * thread-safe. + */ + void clear(); + + /** + * Returns the underlying vector store used by this cache implementation. This allows + * access to lower-level vector operations if needed. + * @return The VectorStore instance used by this cache + */ + VectorStore getStore(); + +} diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java new file mode 100644 index 00000000000..138e7eb7856 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java @@ -0,0 +1,226 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.cache.semantic; + +import com.redis.testcontainers.RedisStackContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.cache.semantic.SemanticCacheAdvisorIT.TestApplication; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.vectorstore.redis.cache.semantic.DefaultSemanticCache; +import org.springframework.ai.vectorstore.redis.cache.semantic.SemanticCache; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.time.Duration; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test the Redis-based advisor that provides semantic caching capabilities for chat + * responses + * + * @author Brian Sam-Bodden + */ +@Testcontainers +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class SemanticCacheAdvisorIT { + + @Container + static RedisStackContainer redisContainer = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class)) + .withUserConfiguration(TestApplication.class) + .withPropertyValues("spring.data.redis.url=" + redisContainer.getRedisURI()); + + @Autowired + OpenAiChatModel openAiChatModel; + + @Autowired + SemanticCache semanticCache; + + @AfterEach + void tearDown() { + semanticCache.clear(); + } + + @Test + void semanticCacheTest() { + this.contextRunner.run(context -> { + String question = "What is the capital of France?"; + String expectedResponse = "Paris is the capital of France."; + + // First, simulate a cached response + semanticCache.set(question, createMockResponse(expectedResponse)); + + // Create advisor + SemanticCacheAdvisor cacheAdvisor = SemanticCacheAdvisor.builder().cache(semanticCache).build(); + + // Test with a semantically similar question + String similarQuestion = "Tell me which city is France's capital?"; + ChatResponse chatResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(similarQuestion) + .advisors(cacheAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + String response = chatResponse.getResult().getOutput().getText(); + assertThat(response).containsIgnoringCase("Paris"); + + // Test cache miss with a different question + String differentQuestion = "What is the population of Tokyo?"; + ChatResponse newResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(differentQuestion) + .advisors(cacheAdvisor) + .call() + .chatResponse(); + + assertThat(newResponse).isNotNull(); + String newResponseText = newResponse.getResult().getOutput().getText(); + assertThat(newResponseText).doesNotContain(expectedResponse); + + // Verify the new response was cached + ChatResponse cachedNewResponse = semanticCache.get(differentQuestion).orElseThrow(); + assertThat(cachedNewResponse.getResult().getOutput().getText()) + .isEqualTo(newResponse.getResult().getOutput().getText()); + }); + } + + @Test + void semanticCacheTTLTest() throws InterruptedException { + this.contextRunner.run(context -> { + String question = "What is the capital of France?"; + String expectedResponse = "Paris is the capital of France."; + + // Set with short TTL + semanticCache.set(question, createMockResponse(expectedResponse), Duration.ofSeconds(2)); + + // Create advisor + SemanticCacheAdvisor cacheAdvisor = SemanticCacheAdvisor.builder().cache(semanticCache).build(); + + // Verify key exists + Optional nativeClient = semanticCache.getStore().getNativeClient(); + assertThat(nativeClient).isPresent(); + JedisPooled jedis = nativeClient.get(); + + Set keys = jedis.keys("semantic-cache:*"); + assertThat(keys).hasSize(1); + String key = keys.iterator().next(); + + // Verify TTL is set + Long ttl = jedis.ttl(key); + assertThat(ttl).isGreaterThan(0); + assertThat(ttl).isLessThanOrEqualTo(2); + + // Test cache hit before expiry + String similarQuestion = "Tell me which city is France's capital?"; + ChatResponse chatResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(similarQuestion) + .advisors(cacheAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).containsIgnoringCase("Paris"); + + // Wait for TTL to expire + Thread.sleep(2100); + + // Verify key is gone + assertThat(jedis.exists(key)).isFalse(); + + // Should get a cache miss and new response + ChatResponse newResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(similarQuestion) + .advisors(cacheAdvisor) + .call() + .chatResponse(); + + assertThat(newResponse).isNotNull(); + assertThat(newResponse.getResult().getOutput().getText()).containsIgnoringCase("Paris"); + // Original cached response should be gone, this should be a fresh response + }); + } + + private ChatResponse createMockResponse(String text) { + return ChatResponse.builder().generations(List.of(new Generation(new AssistantMessage(text)))).build(); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + public static class TestApplication { + + @Bean + public SemanticCache semanticCache(EmbeddingModel embeddingModel, + JedisConnectionFactory jedisConnectionFactory) { + JedisPooled jedisPooled = new JedisPooled(Objects.requireNonNull(jedisConnectionFactory.getPoolConfig()), + jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()); + + return DefaultSemanticCache.builder().embeddingModel(embeddingModel).jedisClient(jedisPooled).build(); + } + + @Bean(name = "openAiEmbeddingModel") + public EmbeddingModel embeddingModel() { + return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + } + + @Bean(name = "openAiChatModel") + public OpenAiChatModel openAiChatModel() { + var openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); + var openAiChatOptions = OpenAiChatOptions.builder() + .model("gpt-3.5-turbo") + .temperature(0.4) + .maxTokens(200) + .build(); + return new OpenAiChatModel(openAiApi, openAiChatOptions); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryIT.java new file mode 100644 index 00000000000..dfc9f0c1af8 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryIT.java @@ -0,0 +1,227 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.redis; + +import com.redis.testcontainers.RedisStackContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.time.Duration; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for RedisChatMemory using Redis Stack TestContainer. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryIT { + + @Container + static RedisStackContainer redisContainer = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private RedisChatMemory chatMemory; + + private JedisPooled jedisClient; + + @BeforeEach + void setUp() { + jedisClient = new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()); + chatMemory = RedisChatMemory.builder() + .jedisClient(jedisClient) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + + chatMemory.clear("test-conversation"); + } + + @AfterEach + void tearDown() { + if (jedisClient != null) { + jedisClient.close(); + } + } + + @Test + void shouldStoreAndRetrieveMessages() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Hello")); + chatMemory.add(conversationId, new AssistantMessage("Hi there!")); + chatMemory.add(conversationId, new UserMessage("How are you?")); + + // Retrieve messages + List messages = chatMemory.get(conversationId, 10); + + assertThat(messages).hasSize(3); + assertThat(messages.get(0).getText()).isEqualTo("Hello"); + assertThat(messages.get(1).getText()).isEqualTo("Hi there!"); + assertThat(messages.get(2).getText()).isEqualTo("How are you?"); + }); + } + + @Test + void shouldRespectMessageLimit() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Message 1")); + chatMemory.add(conversationId, new AssistantMessage("Message 2")); + chatMemory.add(conversationId, new UserMessage("Message 3")); + + // Retrieve limited messages + List messages = chatMemory.get(conversationId, 2); + + assertThat(messages).hasSize(2); + }); + } + + @Test + void shouldClearConversation() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Hello")); + chatMemory.add(conversationId, new AssistantMessage("Hi")); + + // Clear conversation + chatMemory.clear(conversationId); + + // Verify messages are cleared + List messages = chatMemory.get(conversationId, 10); + assertThat(messages).isEmpty(); + }); + } + + @Test + void shouldHandleBatchMessageAddition() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + List messageBatch = List.of(new UserMessage("Message 1"), // + new AssistantMessage("Response 1"), // + new UserMessage("Message 2"), // + new AssistantMessage("Response 2") // + ); + + // Add batch of messages + chatMemory.add(conversationId, messageBatch); + + // Verify all messages were stored + List retrievedMessages = chatMemory.get(conversationId, 10); + assertThat(retrievedMessages).hasSize(4); + }); + } + + @Test + void shouldHandleTimeToLive() throws InterruptedException { + this.contextRunner.run(context -> { + RedisChatMemory shortTtlMemory = RedisChatMemory.builder() + .jedisClient(jedisClient) + .indexName("test-ttl-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .timeToLive(Duration.ofSeconds(2)) + .keyPrefix("short-lived:") + .build(); + + String conversationId = "test-conversation"; + shortTtlMemory.add(conversationId, new UserMessage("This should expire")); + + // Verify message exists + assertThat(shortTtlMemory.get(conversationId, 1)).hasSize(1); + + // Wait for TTL to expire + Thread.sleep(2000); + + // Verify message is gone + assertThat(shortTtlMemory.get(conversationId, 1)).isEmpty(); + }); + } + + @Test + void shouldMaintainMessageOrder() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + // Add messages with minimal delay to test timestamp ordering + chatMemory.add(conversationId, new UserMessage("First")); + Thread.sleep(10); + chatMemory.add(conversationId, new AssistantMessage("Second")); + Thread.sleep(10); + chatMemory.add(conversationId, new UserMessage("Third")); + + List messages = chatMemory.get(conversationId, 10); + assertThat(messages).hasSize(3); + assertThat(messages.get(0).getText()).isEqualTo("First"); + assertThat(messages.get(1).getText()).isEqualTo("Second"); + assertThat(messages.get(2).getText()).isEqualTo("Third"); + }); + } + + @Test + void shouldHandleMultipleConversations() { + this.contextRunner.run(context -> { + String conv1 = "conversation-1"; + String conv2 = "conversation-2"; + + chatMemory.add(conv1, new UserMessage("Conv1 Message")); + chatMemory.add(conv2, new UserMessage("Conv2 Message")); + + List conv1Messages = chatMemory.get(conv1, 10); + List conv2Messages = chatMemory.get(conv2, 10); + + assertThat(conv1Messages).hasSize(1); + assertThat(conv2Messages).hasSize(1); + assertThat(conv1Messages.get(0).getText()).isEqualTo("Conv1 Message"); + assertThat(conv2Messages.get(0).getText()).isEqualTo("Conv2 Message"); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + RedisChatMemory chatMemory() { + return RedisChatMemory.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .timeToLive(Duration.ofMinutes(5)) + .build(); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java new file mode 100644 index 00000000000..34f57a7b96f --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java @@ -0,0 +1,133 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.redis; + +import com.redis.testcontainers.RedisStackContainer; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.redis.RedisVectorStore.MetadataField; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Integration tests for RedisVectorStore using Redis Stack TestContainer. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisVectorStoreWithChatMemoryAdvisorIT { + + @Container + static RedisStackContainer redisContainer = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); + + float[] embed = { 0.003961659F, -0.0073295482F, 0.02663665F }; + + @Test + @DisplayName("Advised chat should have similar messages from vector store") + void advisedChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { + // Mock chat model + ChatModel chatModel = chatModelAlwaysReturnsTheSameReply(); + // Mock embedding model + EmbeddingModel embeddingModel = embeddingModelShouldAlwaysReturnFakedEmbed(); + + // Create Redis store with dimensions matching our fake embeddings + RedisVectorStore store = RedisVectorStore + .builder(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()), embeddingModel) + .metadataFields(MetadataField.tag("conversationId"), MetadataField.tag("messageType")) + .initializeSchema(true) + .build(); + + store.afterPropertiesSet(); + + // Initialize store with test data + store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", "default")), + new Document("Tell me a bad joke", Map.of("conversationId", "default", "messageType", "USER")))); + + // Run chat with advisor + ChatClient.builder(chatModel) + .build() + .prompt() + .user("joke") + .advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) + .call() + .chatResponse(); + + verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel); + } + + private static ChatModel chatModelAlwaysReturnsTheSameReply() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Prompt.class); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""" + Why don't scientists trust atoms? + Because they make up everything! + """)))); + given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse); + return chatModel; + } + + private EmbeddingModel embeddingModelShouldAlwaysReturnFakedEmbed() { + EmbeddingModel embeddingModel = mock(EmbeddingModel.class); + Mockito.doAnswer(invocationOnMock -> List.of(this.embed, this.embed)) + .when(embeddingModel) + .embed(any(), any(), any()); + given(embeddingModel.embed(any(String.class))).willReturn(this.embed); + given(embeddingModel.dimensions()).willReturn(3); // Explicit dimensions matching + // embed array + return embeddingModel; + } + + private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatModel chatModel) { + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + verify(chatModel).call(promptCaptor.capture()); + assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(promptCaptor.getValue().getInstructions().get(0).getText()).isEqualTo(""" + + Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers. + + --------------------- + LONG_TERM_MEMORY: + Tell me a good joke + Tell me a bad joke + --------------------- + """); + } + +} From 00a4e2ef486542b19cc402a43e133c34f1a9b768 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Sat, 3 May 2025 16:40:42 -0400 Subject: [PATCH 2/2] rebase --- vector-stores/spring-ai-redis-store/pom.xml | 18 +++++++++++++++++- .../semantic/SemanticCacheAdvisorIT.java | 19 +++++++++++++++---- ...disVectorStoreWithChatMemoryAdvisorIT.java | 2 +- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/vector-stores/spring-ai-redis-store/pom.xml b/vector-stores/spring-ai-redis-store/pom.xml index dafc9f25215..d708cff8d72 100644 --- a/vector-stores/spring-ai-redis-store/pom.xml +++ b/vector-stores/spring-ai-redis-store/pom.xml @@ -55,6 +55,21 @@ spring-data-redis + + + org.springframework.ai + spring-ai-client-chat + ${project.version} + + + + org.springframework.ai + spring-ai-advisors-vector-store + ${project.version} + test + + + redis.clients jedis @@ -108,6 +123,7 @@ test - + + diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java index 138e7eb7856..1b35576b5b4 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java @@ -17,6 +17,8 @@ package org.springframework.ai.chat.cache.semantic; import com.redis.testcontainers.RedisStackContainer; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistry; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; @@ -26,6 +28,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiEmbeddingModel; @@ -42,6 +45,8 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; +import org.springframework.retry.support.RetryTemplate; + import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import redis.clients.jedis.JedisPooled; @@ -207,18 +212,24 @@ public SemanticCache semanticCache(EmbeddingModel embeddingModel, @Bean(name = "openAiEmbeddingModel") public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); + } + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); } @Bean(name = "openAiChatModel") - public OpenAiChatModel openAiChatModel() { - var openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); + public OpenAiChatModel openAiChatModel(ObservationRegistry observationRegistry) { + var openAiApi = OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build(); var openAiChatOptions = OpenAiChatOptions.builder() .model("gpt-3.5-turbo") .temperature(0.4) .maxTokens(200) .build(); - return new OpenAiChatModel(openAiApi, openAiChatOptions); + return new OpenAiChatModel(openAiApi, openAiChatOptions, ToolCallingManager.builder().build(), + RetryTemplate.defaultInstance(), observationRegistry); } } diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java index 34f57a7b96f..61f259e3388 100644 --- a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java @@ -22,7 +22,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Mockito; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.vectorstore.VectorStoreChatMemoryAdvisor; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.model.ChatModel;