Skip to content

Commit

Permalink
AiServiceRegisteredEvent (#89)
Browse files Browse the repository at this point in the history
Closes langchain4j/langchain4j#2112

Publish a `AiServiceRegisteredEvent` Spring Event after registering the
`AiService` bean in `AiServicesAutoConfig`.
This event contains the `AiService` class and its corresponding tools
description information.

Once a user implements the even listener to listen for this event, they
can receive the event during the Spring Boot startup phase and handle
their business logic as needed.
  • Loading branch information
catofdestruction authored Dec 20, 2024
1 parent 3bdf8f2 commit a4280ba
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dev.langchain4j.service.spring;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.agent.tool.ToolSpecifications;
import dev.langchain4j.exception.IllegalConfigurationException;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
Expand All @@ -9,19 +11,21 @@
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.beans.factory.support.ManagedList;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.annotation.Bean;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import java.util.*;

import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
Expand All @@ -31,7 +35,16 @@
import static dev.langchain4j.service.spring.AiServiceWiringMode.EXPLICIT;
import static java.util.Arrays.asList;

public class AiServicesAutoConfig {
public class AiServicesAutoConfig implements ApplicationEventPublisherAware {

private static final Logger log = LoggerFactory.getLogger(AiServicesAutoConfig.class);

private ApplicationEventPublisher eventPublisher;

@Override
public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) {
this.eventPublisher = eventPublisher;
}

@Bean
BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Expand All @@ -46,7 +59,8 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class);
String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class);

Set<String> tools = new HashSet<>();
Set<String> toolBeanNames = new HashSet<>();
List<ToolSpecification> toolSpecifications = new ArrayList<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
try {
String beanClassName = beanFactory.getBeanDefinition(beanName).getBeanClassName();
Expand All @@ -56,7 +70,13 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
Class<?> beanClass = Class.forName(beanClassName);
for (Method beanMethod : beanClass.getDeclaredMethods()) {
if (beanMethod.isAnnotationPresent(Tool.class)) {
tools.add(beanName);
toolBeanNames.add(beanName);
try {
toolSpecifications.add(ToolSpecifications.toolSpecificationFrom(beanMethod));
} catch (Exception e) {
log.warn("Cannot convert %s.%s method annotated with @Tool into ToolSpecification"
.formatted(beanClass.getName(), beanMethod.getName()), e);
}
}
}
} catch (Exception e) {
Expand Down Expand Up @@ -148,14 +168,18 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
propertyValues.add("tools", toManagedList(tools));
propertyValues.add("tools", toManagedList(toolBeanNames));
} else {
throw illegalArgument("Unknown wiring mode: " + aiServiceAnnotation.wiringMode());
}

BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;
registry.removeBeanDefinition(aiService);
registry.registerBeanDefinition(lowercaseFirstLetter(aiService), aiServiceBeanDefinition);

if (eventPublisher != null) {
eventPublisher.publishEvent(new AiServiceRegisteredEvent(this, aiServiceClass, toolSpecifications));
}
}
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dev.langchain4j.service.spring.event;

import dev.langchain4j.agent.tool.ToolSpecification;
import org.springframework.context.ApplicationEvent;

import java.util.List;

import static dev.langchain4j.internal.Utils.copyIfNotNull;

public class AiServiceRegisteredEvent extends ApplicationEvent {

private final Class<?> aiServiceClass;
private final List<ToolSpecification> toolSpecifications;

public AiServiceRegisteredEvent(Object source, Class<?> aiServiceClass, List<ToolSpecification> toolSpecifications) {
super(source);
this.aiServiceClass = aiServiceClass;
this.toolSpecifications = copyIfNotNull(toolSpecifications);
}

public Class<?> aiServiceClass() {
return aiServiceClass;
}

public List<ToolSpecification> toolSpecifications() {
return toolSpecifications;
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ApplicationListener;

import java.util.List;

@SpringBootApplication
class AiServiceWithToolsApplication {
class AiServiceWithToolsApplication implements ApplicationListener<AiServiceRegisteredEvent> {

public static void main(String[] args) {
SpringApplication.run(AiServiceWithToolsApplication.class, args);
}

@Override
public void onApplicationEvent(AiServiceRegisteredEvent event) {
Class<?> aiServiceClass = event.aiServiceClass();
List<ToolSpecification> toolSpecifications = event.toolSpecifications();
for (int i = 0; i < toolSpecifications.size(); i++) {
System.out.printf("[%s]: [Tool-%s]: %s%n", aiServiceClass.getSimpleName(), i + 1, toolSpecifications.get(i));
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package dev.langchain4j.service.spring.mode.automatic.withTools;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.service.spring.AiServicesAutoConfig;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import dev.langchain4j.service.spring.mode.automatic.withTools.aop.ToolObserverAspect;
import dev.langchain4j.service.spring.mode.automatic.withTools.listener.AiServiceRegisteredEventListener;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import java.util.List;

import static dev.langchain4j.service.spring.mode.ApiKeys.OPENAI_API_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY;
import static dev.langchain4j.service.spring.mode.automatic.withTools.AopEnhancedTools.TOOL_OBSERVER_KEY_NAME_DESCRIPTION;
Expand All @@ -16,6 +21,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

class AiServicesAutoConfigIT {
Expand Down Expand Up @@ -69,6 +75,32 @@ void should_create_AI_service_with_tool_that_is_package_private_method_in_packag
});
}

@Test
void should_receive_ai_service_registered_event() {
contextRunner
.withUserConfiguration(AiServiceWithToolsApplication.class)
.run(context -> {

// given
AiServiceRegisteredEventListener listener = context.getBean(AiServiceRegisteredEventListener.class);

// then should receive AiServiceRegisteredEvent
assertTrue(listener.isEventReceived());
assertEquals(1, listener.getReceivedEvents().size());

AiServiceRegisteredEvent event = listener.getReceivedEvents().stream().findFirst().orElse(null);
assertNotNull(event);
assertEquals(AiServiceWithTools.class, event.aiServiceClass());
assertEquals(4, event.toolSpecifications().size());

List<String> tools = event.toolSpecifications().stream().map(ToolSpecification::name).toList();
assertTrue(tools.contains("getCurrentDate"));
assertTrue(tools.contains("getCurrentTime"));
assertTrue(tools.contains("getToolObserverPackageName"));
assertTrue(tools.contains("getToolObserverKey"));
});
}

@Test
void should_create_AI_service_with_tool_which_is_enhanced_by_spring_aop() {
contextRunner
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.listener;

import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationListener;

import java.util.ArrayList;
import java.util.List;

public class AbstractApplicationListener<E extends ApplicationEvent> implements ApplicationListener<E> {
private final List<E> receivedEvents = new ArrayList<>();

@Override
public void onApplicationEvent(E event) {
receivedEvents.add(event);
}

public List<E> getReceivedEvents() {
return receivedEvents;
}

public boolean isEventReceived() {
return !receivedEvents.isEmpty();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dev.langchain4j.service.spring.mode.automatic.withTools.listener;

import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import org.springframework.stereotype.Component;

@Component
public class AiServiceRegisteredEventListener extends AbstractApplicationListener<AiServiceRegisteredEvent> {
}

0 comments on commit a4280ba

Please sign in to comment.