diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f70936c7e..1e130db9d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,8 +63,24 @@ jobs: with: java-version: ${{ matrix.java-version }} distribution: 'adopt' + - name: Install python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + - name: Install Python dependencies for MCP tests + run: | + cd python + uv sync --all-extras - name: Run Java Tests - run: tools/ut.sh -j + run: | + # Add Python venv to PATH so Java tests can find MCP dependencies + export PATH="${{ github.workspace }}/python/.venv/bin:$PATH" + export PYTHONPATH="${{ github.workspace }}/python/.venv/lib/python3.11/site-packages:$PYTHONPATH" + tools/ut.sh -j python_tests: name: ut-python [${{ matrix.os }}] [python-${{ matrix.python-version}}] diff --git a/api/pom.xml b/api/pom.xml index a540361f6..9f5203316 100644 --- a/api/pom.xml +++ b/api/pom.xml @@ -65,4 +65,26 @@ under the License. + + + java-11 + + [11,17) + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + org/apache/flink/agents/api/annotation/MCPServer.java + + + + + + + + \ No newline at end of file diff --git a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java index 37a5d5327..506b2848b 100644 --- a/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java +++ b/api/src/main/java/org/apache/flink/agents/api/agents/ReActAgent.java @@ -91,7 +91,7 @@ public ReActAgent( "Output schema must be RowTypeInfo or Pojo class."); } Prompt schemaPrompt = - new Prompt( + Prompt.fromText( String.format( "The final response should be json format, and match the schema %s", jsonSchema)); diff --git a/api/src/main/java/org/apache/flink/agents/api/annotation/MCPServer.java b/api/src/main/java/org/apache/flink/agents/api/annotation/MCPServer.java new file mode 100644 index 000000000..f427c96c3 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/annotation/MCPServer.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to mark a method as an MCP server resource that should be managed by the agent plan. + * + *

Methods annotated with @MCPServer will be scanned during agent plan creation. The agent plan + * will automatically: + * + *

+ * + *

Example usage: + * + *

{@code
+ * public class MyAgent extends Agent {
+ *     @MCPServer
+ *     public static MCPServer myMcpServer() {
+ *         return MCPServer.builder("http://localhost:8000/mcp")
+ *             .timeout(Duration.ofSeconds(30))
+ *             .build();
+ *     }
+ *
+ *     @ChatModelSetup
+ *     public static ChatModel chatModel() {
+ *         return new ChatModel.Builder()
+ *             .prompt("greeting")  // MCP prompt from server
+ *             .tools(List.of("calculator"))  // MCP tool from server
+ *             .build();
+ *     }
+ * }
+ * }
+ * + *

This is the Java equivalent of Python's {@code @mcp_server} decorator. + * + * @see org.apache.flink.agents.integrations.mcp.MCPServer + * @see org.apache.flink.agents.integrations.mcp.MCPTool + * @see org.apache.flink.agents.integrations.mcp.MCPPrompt + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface MCPServer {} diff --git a/api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java b/api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java index bb0be8bb8..28651c286 100644 --- a/api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java +++ b/api/src/main/java/org/apache/flink/agents/api/prompt/Prompt.java @@ -25,6 +25,7 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonSubTypes; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeInfo; import java.util.ArrayList; @@ -33,66 +34,85 @@ import java.util.Map; import java.util.Objects; import java.util.function.Function; -import java.util.regex.Pattern; import java.util.stream.Collectors; /** - * Prompt for a language model + * Abstract base class for prompts. * - *

The template can be either a string or a sequence of ChatMessage objects. + *

This is the base class for all prompt implementations in Flink Agents. Subclasses must + * implement the formatting methods to generate text or messages from templates. + * + *

Common implementations: + * + *

+ * + *

Example usage: + * + *

{@code
+ * // String template
+ * Prompt prompt = Prompt.fromText("Hello {name}!");
+ * String result = prompt.formatString(Map.of("name", "World"));
+ * // result: "Hello World!"
+ *
+ * // Message template
+ * List messages = List.of(
+ *     new ChatMessage(MessageRole.SYSTEM, "You are {role}"),
+ *     new ChatMessage(MessageRole.USER, "Help me with {task}")
+ * );
+ * Prompt prompt2 = Prompt.fromMessages(messages);
+ * List result2 = prompt2.formatMessages(
+ *     MessageRole.SYSTEM,
+ *     Map.of("role", "an assistant", "task", "coding")
+ * );
+ * }
+ * + * @see org.apache.flink.agents.integrations.mcp.MCPPrompt */ -public class Prompt extends SerializableResource { - private static final String FIELD_TEMPLATE = "template"; +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") +@JsonSubTypes({@JsonSubTypes.Type(value = Prompt.LocalPrompt.class, name = "LocalPrompt")}) +public abstract class Prompt extends SerializableResource { - @JsonProperty(FIELD_TEMPLATE) - private final PromptTemplate template; - - @JsonCreator - private Prompt(@JsonProperty(FIELD_TEMPLATE) PromptTemplate promptTemplate) { - this.template = promptTemplate; - } - - public Prompt(String template) { - this.template = PromptTemplate.fromString(template); + /** + * Create a prompt from a text string template. + * + * @param text The text template with placeholders like {variable} + * @return A Prompt instance + */ + public static Prompt fromText(String text) { + return new LocalPrompt(text); } - public Prompt(List template) { - this.template = PromptTemplate.fromMessages(template); + /** + * Create a prompt from a sequence of chat messages. + * + * @param messages The list of chat messages forming the prompt template + * @return A Prompt instance + */ + public static Prompt fromMessages(List messages) { + return new LocalPrompt(messages); } - public String formatString(Map kwargs) { - return template.match( - // Handle string template - content -> format(content, kwargs), - // Handle messages template - messages -> { - List formattedMessages = new ArrayList<>(); - for (ChatMessage message : messages) { - String formattedContent = format(message.getContent(), kwargs); - String formatted = message.getRole().getValue() + ": " + formattedContent; - formattedMessages.add(formatted); - } - return String.join("\n", formattedMessages); - }); - } + /** + * Generate a text string from the prompt template with additional arguments. + * + * @param kwargs Key-value pairs to substitute in the template + * @return The formatted prompt as a string + */ + public abstract String formatString(Map kwargs); - public List formatMessages(MessageRole defaultRole, Map kwargs) { - return template.match( - // Handle string template - content -> - new ArrayList<>( - Collections.singletonList( - new ChatMessage(defaultRole, format(content, kwargs)))), - // Handle messages template - messages -> - messages.stream() - .map( - message -> - new ChatMessage( - message.getRole(), - format(message.getContent(), kwargs))) - .collect(Collectors.toList())); - } + /** + * Generate a list of ChatMessage from the prompt template with additional arguments. + * + * @param defaultRole The default message role (usually SYSTEM) + * @param kwargs Key-value pairs to substitute in the template + * @return List of formatted chat messages + */ + public abstract List formatMessages( + MessageRole defaultRole, Map kwargs); @JsonIgnore @Override @@ -100,98 +120,165 @@ public ResourceType getResourceType() { return ResourceType.PROMPT; } - private static final Pattern BRACE_PATTERN = Pattern.compile("\\{([^}]+)\\}"); + /** + * Local prompt implementation for language models. + * + *

This prompt implementation uses a local template that can be either a string or a sequence + * of ChatMessage objects. The template supports placeholder substitution using {variable} + * syntax. + * + *

While this class is public, users should prefer using {@link Prompt#fromText(String)} or + * {@link Prompt#fromMessages(List)} factory methods to create prompt instances instead of + * directly instantiating this class. + */ + public static class LocalPrompt extends Prompt { + private static final String FIELD_TEMPLATE = "template"; - /** Format template string with keyword arguments */ - private static String format(String template, Map kwargs) { - if (template == null) { - return ""; - } + @JsonProperty(FIELD_TEMPLATE) + private final transient PromptTemplate template; - String result = template; - for (Map.Entry entry : kwargs.entrySet()) { - String placeholder = "{" + entry.getKey() + "}"; - String value = entry.getValue() != null ? entry.getValue() : ""; - result = result.replace(placeholder, value); + @JsonCreator + private LocalPrompt(@JsonProperty(FIELD_TEMPLATE) PromptTemplate promptTemplate) { + this.template = promptTemplate; } - return result; - } - @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) - private abstract static class PromptTemplate { - public static PromptTemplate fromString(String content) { - return new StringTemplate(content); + public LocalPrompt(String template) { + this.template = PromptTemplate.fromString(template); } - public static PromptTemplate fromMessages(List messages) { - return new MessagesTemplate(messages); + public LocalPrompt(List template) { + this.template = PromptTemplate.fromMessages(template); } - /** - * Pattern matching method for type-safe operations. This replaces instanceof checks and - * casting. - */ - public abstract T match( - Function onString, Function, T> onMessages); - } + @Override + public String formatString(Map kwargs) { + return template.match( + // Handle string template + content -> format(content, kwargs), + // Handle messages template + messages -> { + List formattedMessages = new ArrayList<>(); + for (ChatMessage message : messages) { + String formattedContent = format(message.getContent(), kwargs); + String formatted = + message.getRole().getValue() + ": " + formattedContent; + formattedMessages.add(formatted); + } + return String.join("\n", formattedMessages); + }); + } - /** String template implementation. */ - private static class StringTemplate extends PromptTemplate { - private static final String FIELD_CONTENT = "content"; + @Override + public List formatMessages( + MessageRole defaultRole, Map kwargs) { + return template.match( + // Handle string template + content -> + new ArrayList<>( + Collections.singletonList( + new ChatMessage(defaultRole, format(content, kwargs)))), + // Handle messages template + messages -> + messages.stream() + .map( + message -> + new ChatMessage( + message.getRole(), + format(message.getContent(), kwargs))) + .collect(Collectors.toList())); + } - @JsonProperty(FIELD_CONTENT) - private final String content; + /** Format template string with keyword arguments */ + private static String format(String template, Map kwargs) { + if (template == null) { + return ""; + } - @JsonCreator - public StringTemplate(@JsonProperty(FIELD_CONTENT) String content) { - this.content = Objects.requireNonNull(content, "content cannot be null"); + String result = template; + for (Map.Entry entry : kwargs.entrySet()) { + String placeholder = "{" + entry.getKey() + "}"; + String value = entry.getValue() != null ? entry.getValue() : ""; + result = result.replace(placeholder, value); + } + return result; } - public String getContent() { - return content; - } + @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS) + private abstract static class PromptTemplate { + public static PromptTemplate fromString(String content) { + return new StringTemplate(content); + } - @Override - public T match( - Function onString, Function, T> onMessages) { - return onString.apply(content); - } + public static PromptTemplate fromMessages(List messages) { + return new MessagesTemplate(messages); + } - @Override - public String toString() { - return "StringTemplate{content='" + content + "'}"; + /** + * Pattern matching method for type-safe operations. This replaces instanceof checks and + * casting. + */ + public abstract T match( + Function onString, Function, T> onMessages); } - } - /** Messages template implementation. */ - private static class MessagesTemplate extends PromptTemplate { - private static final String FIELD_MESSAGES = "messages"; + /** String template implementation. */ + private static class StringTemplate extends PromptTemplate { + private static final String FIELD_CONTENT = "content"; - @JsonProperty(FIELD_MESSAGES) - private final List messages; + @JsonProperty(FIELD_CONTENT) + private final String content; - @JsonCreator - public MessagesTemplate(@JsonProperty(FIELD_MESSAGES) List messages) { - Objects.requireNonNull(messages, "messages cannot be null"); - if (messages.isEmpty()) { - throw new IllegalArgumentException("Messages cannot be empty"); + @JsonCreator + public StringTemplate(@JsonProperty(FIELD_CONTENT) String content) { + this.content = Objects.requireNonNull(content, "content cannot be null"); } - this.messages = new ArrayList<>(messages); - } - public List getMessages() { - return new ArrayList<>(messages); - } + public String getContent() { + return content; + } - @Override - public T match( - Function onString, Function, T> onMessages) { - return onMessages.apply(new ArrayList<>(messages)); + @Override + public T match( + Function onString, Function, T> onMessages) { + return onString.apply(content); + } + + @Override + public String toString() { + return "StringTemplate{content='" + content + "'}"; + } } - @Override - public String toString() { - return "MessagesTemplate{messages=" + messages.size() + " items}"; + /** Messages template implementation. */ + private static class MessagesTemplate extends PromptTemplate { + private static final String FIELD_MESSAGES = "messages"; + + @JsonProperty(FIELD_MESSAGES) + private final List messages; + + @JsonCreator + public MessagesTemplate(@JsonProperty(FIELD_MESSAGES) List messages) { + Objects.requireNonNull(messages, "messages cannot be null"); + if (messages.isEmpty()) { + throw new IllegalArgumentException("Messages cannot be empty"); + } + this.messages = new ArrayList<>(messages); + } + + public List getMessages() { + return new ArrayList<>(messages); + } + + @Override + public T match( + Function onString, Function, T> onMessages) { + return onMessages.apply(new ArrayList<>(messages)); + } + + @Override + public String toString() { + return "MessagesTemplate{messages=" + messages.size() + " items}"; + } } } } diff --git a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java index 9a11503d4..65c83152a 100644 --- a/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/chat/model/BaseChatModelTest.java @@ -94,14 +94,14 @@ void setUp() { null); // Create simple prompt - simplePrompt = new Prompt("You are a helpful assistant. User says: {user_input}"); + simplePrompt = Prompt.fromText("You are a helpful assistant. User says: {user_input}"); // Create conversation prompt List conversationTemplate = Arrays.asList( new ChatMessage(MessageRole.SYSTEM, "You are a helpful AI assistant."), new ChatMessage(MessageRole.USER, "{user_message}")); - conversationPrompt = new Prompt(conversationTemplate); + conversationPrompt = Prompt.fromMessages(conversationTemplate); } @Test @@ -118,7 +118,7 @@ void testBasicChat() { // Format the prompt with variables Prompt formattedPrompt = - new Prompt(simplePrompt.formatMessages(MessageRole.SYSTEM, variables)); + Prompt.fromMessages(simplePrompt.formatMessages(MessageRole.SYSTEM, variables)); ChatMessage response = chatModel.chat(formattedPrompt.formatMessages(MessageRole.USER, new HashMap<>())); @@ -135,7 +135,8 @@ void testChatWithConversationPrompt() { variables.put("user_message", "What's the weather like?"); Prompt formattedPrompt = - new Prompt(conversationPrompt.formatMessages(MessageRole.SYSTEM, variables)); + Prompt.fromMessages( + conversationPrompt.formatMessages(MessageRole.SYSTEM, variables)); ChatMessage response = chatModel.chat(formattedPrompt.formatMessages(MessageRole.USER, new HashMap<>())); @@ -148,7 +149,7 @@ void testChatWithConversationPrompt() { @Test @DisplayName("Test chat with empty prompt") void testChatWithEmptyPrompt() { - Prompt emptyPrompt = new Prompt(""); + Prompt emptyPrompt = Prompt.fromText(""); ChatMessage response = chatModel.chat(emptyPrompt.formatMessages(MessageRole.USER, new HashMap<>())); @@ -169,7 +170,7 @@ void testChatWithMultipleUserMessages() { new ChatMessage( MessageRole.USER, "Second message - this should be the response")); - Prompt multiPrompt = new Prompt(multipleMessages); + Prompt multiPrompt = Prompt.fromMessages(multipleMessages); ChatMessage response = chatModel.chat(multiPrompt.formatMessages(MessageRole.USER, new HashMap<>())); @@ -187,7 +188,7 @@ void testChatModelConfiguration() { variables.put("user_input", "Test message"); Prompt formattedPrompt = - new Prompt(simplePrompt.formatMessages(MessageRole.SYSTEM, variables)); + Prompt.fromMessages(simplePrompt.formatMessages(MessageRole.SYSTEM, variables)); ChatMessage response = chatModel.chat(formattedPrompt.formatMessages(MessageRole.USER, new HashMap<>())); @@ -199,7 +200,7 @@ void testChatModelConfiguration() { @DisplayName("Test chat with system-only prompt") void testChatWithSystemOnlyPrompt() { Prompt systemOnlyPrompt = - new Prompt( + Prompt.fromMessages( Arrays.asList( new ChatMessage(MessageRole.SYSTEM, "System instruction only"))); @@ -218,7 +219,7 @@ void testChatResponseFormat() { variables.put("user_input", "Format test"); Prompt formattedPrompt = - new Prompt(simplePrompt.formatMessages(MessageRole.SYSTEM, variables)); + Prompt.fromMessages(simplePrompt.formatMessages(MessageRole.SYSTEM, variables)); ChatMessage response = chatModel.chat(formattedPrompt.formatMessages(MessageRole.USER, new HashMap<>())); @@ -243,7 +244,7 @@ void testChatWithLongInput() { variables.put("user_input", longInput.toString()); Prompt formattedPrompt = - new Prompt(simplePrompt.formatMessages(MessageRole.SYSTEM, variables)); + Prompt.fromMessages(simplePrompt.formatMessages(MessageRole.SYSTEM, variables)); ChatMessage response = chatModel.chat(formattedPrompt.formatMessages(MessageRole.USER, new HashMap<>())); diff --git a/api/src/test/java/org/apache/flink/agents/api/prompt/PromptTest.java b/api/src/test/java/org/apache/flink/agents/api/prompt/PromptTest.java index 454c496f4..ddd1852ec 100644 --- a/api/src/test/java/org/apache/flink/agents/api/prompt/PromptTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/prompt/PromptTest.java @@ -53,7 +53,7 @@ void setUp() { String textTemplate = "You are a product review analyzer, please generate a score and the dislike reasons " + "(if any) for the review. The product {product_id} is {description}, and user review is '{review}'."; - textPrompt = new Prompt(textTemplate); + textPrompt = Prompt.fromText(textTemplate); // Create message-based prompt template List messageTemplate = @@ -65,7 +65,7 @@ void setUp() { new ChatMessage( MessageRole.USER, "The product {product_id} is {description}, and user review is '{review}'.")); - messagesPrompt = new Prompt(messageTemplate); + messagesPrompt = Prompt.fromMessages(messageTemplate); // Set up test variables variables = new HashMap<>(); @@ -170,7 +170,7 @@ void testPromptResourceType() { @Test @DisplayName("Test empty prompt") void testEmptyPrompt() { - Prompt emptyPrompt = new Prompt(""); + Prompt emptyPrompt = Prompt.fromText(""); String result = emptyPrompt.formatString(new HashMap<>()); assertEquals("", result); @@ -183,7 +183,7 @@ void testEmptyPrompt() { @DisplayName("Test prompt with special characters") void testPromptWithSpecialCharacters() { String specialTemplate = "Handle special chars: {text} with symbols like @#$%^&*()"; - Prompt specialPrompt = new Prompt(specialTemplate); + Prompt specialPrompt = Prompt.fromText(specialTemplate); Map specialVars = new HashMap<>(); specialVars.put("text", "Hello & Welcome!"); @@ -197,7 +197,7 @@ void testPromptWithSpecialCharacters() { @DisplayName("Test prompt with nested braces") void testPromptWithNestedBraces() { String nestedTemplate = "JSON example: {{\"key\": \"{value}\"}}"; - Prompt nestedPrompt = new Prompt(nestedTemplate); + Prompt nestedPrompt = Prompt.fromText(nestedTemplate); Map nestedVars = new HashMap<>(); nestedVars.put("value", "test"); @@ -220,7 +220,7 @@ void testComplexConversationPrompt() { "I'd be happy to help with {task}. Let me know what specifically you need."), new ChatMessage(MessageRole.USER, "{user_request}")); - Prompt conversationPrompt = new Prompt(conversationTemplate); + Prompt conversationPrompt = Prompt.fromMessages(conversationTemplate); Map conversationVars = new HashMap<>(); conversationVars.put("assistant_type", "an AI assistant"); diff --git a/docs/content/docs/development/prompts.md b/docs/content/docs/development/prompts.md index 17ffe3cc9..591fabf8e 100644 --- a/docs/content/docs/development/prompts.md +++ b/docs/content/docs/development/prompts.md @@ -103,7 +103,7 @@ String PRODUCT_SUGGESTION_PROMPT_STR = + "{input}"; -Prompt productSuggestionPrompt = new Prompt(PRODUCT_SUGGESTION_PROMPT_STR); +Prompt productSuggestionPrompt = Prompt.fromText(PRODUCT_SUGGESTION_PROMPT_STR); ``` {{< /tab >}} @@ -158,7 +158,7 @@ review_analysis_prompt = Prompt.from_messages( {{< tab "Java" >}} ```java Prompt reviewAnalysisPrompt = - new Prompt( + new Prompt.fromMessages( Arrays.asList( new ChatMessage( MessageRole.SYSTEM, diff --git a/docs/content/docs/development/react_agent.md b/docs/content/docs/development/react_agent.md index a22054f32..a8de5072d 100644 --- a/docs/content/docs/development/react_agent.md +++ b/docs/content/docs/development/react_agent.md @@ -139,7 +139,7 @@ String systemPromptString = + "..."; // Prompt for review analysis react agent. -Prompt myPrompt = new Prompt( +Prompt myPrompt = Prompt.fromMessages( Arrays.asList( new ChatMessage(MessageRole.SYSTEM, systemPromptString), new ChatMessage( diff --git a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java index 10905fc54..510c451d3 100644 --- a/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java +++ b/e2e-test/flink-agents-end-to-end-tests-integration/src/test/java/org/apache/flink/agents/integration/test/ReActAgentTest.java @@ -166,7 +166,7 @@ private static Agent getAgent() { .build(); Prompt prompt = - new Prompt( + Prompt.fromMessages( List.of( new ChatMessage( MessageRole.SYSTEM, diff --git a/examples/src/main/java/org/apache/flink/agents/examples/agents/CustomTypesAndResources.java b/examples/src/main/java/org/apache/flink/agents/examples/agents/CustomTypesAndResources.java index f18b45392..fd3b8aefe 100644 --- a/examples/src/main/java/org/apache/flink/agents/examples/agents/CustomTypesAndResources.java +++ b/examples/src/main/java/org/apache/flink/agents/examples/agents/CustomTypesAndResources.java @@ -56,14 +56,14 @@ public class CustomTypesAndResources { + "no need to disclose whether the tool was used."; public static final Prompt REVIEW_ANALYSIS_PROMPT = - new Prompt( + Prompt.fromMessages( Arrays.asList( new ChatMessage(MessageRole.SYSTEM, REVIEW_ANALYSIS_SYSTEM_PROMPT_STR), new ChatMessage(MessageRole.USER, "\"input\":\n" + "{input}"))); // Prompt for review analysis react agent public static final Prompt REVIEW_ANALYSIS_REACT_PROMPT = - new Prompt( + Prompt.fromMessages( Arrays.asList( new ChatMessage(MessageRole.SYSTEM, REVIEW_ANALYSIS_SYSTEM_PROMPT_STR), new ChatMessage( @@ -91,7 +91,7 @@ public class CustomTypesAndResources { + "{input}"; public static final Prompt PRODUCT_SUGGESTION_PROMPT = - new Prompt(PRODUCT_SUGGESTION_PROMPT_STR); + Prompt.fromText(PRODUCT_SUGGESTION_PROMPT_STR); /** * Tool for notifying the shipping manager when product received a negative review due to diff --git a/integrations/mcp/pom.xml b/integrations/mcp/pom.xml new file mode 100644 index 000000000..040cf78c8 --- /dev/null +++ b/integrations/mcp/pom.xml @@ -0,0 +1,95 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations + 0.2-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-mcp + Flink Agents : Integrations: MCP + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + org.apache.flink + flink-shaded-jackson + ${flink.shaded.jackson.version}-${flink.shaded.version} + + + + + + java-17 + + [17,) + + + + io.modelcontextprotocol.sdk + mcp + 0.16.0 + + + + + + java-11 + + [11,17) + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + org/apache/flink/agents/integrations/mcp/**/*.java + + + org/apache/flink/agents/integrations/mcp/**/*.java + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/mcp/**/*Test.java + + + + + + + + + diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPContentExtractor.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPContentExtractor.java new file mode 100644 index 000000000..50e62428f --- /dev/null +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPContentExtractor.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.HashMap; +import java.util.Map; + +/** + * Utility class for extracting and normalizing MCP content items. + * + *

MCP servers can return various content types (text, images, embedded resources). This utility + * converts them to Java-friendly objects. + */ +public class MCPContentExtractor { + + /** + * Extract and normalize a single MCP content item. + * + * @param contentItem A content item from MCP (TextContent, ImageContent, etc.) + * @return Normalized content as String or Map + */ + public static Object extractContentItem(Object contentItem) { + if (contentItem instanceof McpSchema.TextContent) { + return extractTextContent((McpSchema.TextContent) contentItem); + } else if (contentItem instanceof McpSchema.ImageContent) { + return extractImageContent((McpSchema.ImageContent) contentItem); + } else if (contentItem instanceof McpSchema.EmbeddedResource) { + return extractEmbeddedResource((McpSchema.EmbeddedResource) contentItem); + } else { + // Handle unknown content types as string + return contentItem != null ? contentItem.toString() : ""; + } + } + + /** + * Extract text content from MCP TextContent. + * + * @param textContent The text content + * @return The text as a string + */ + private static String extractTextContent(McpSchema.TextContent textContent) { + return textContent.text(); + } + + /** + * Extract image content from MCP ImageContent. + * + * @param imageContent The image content + * @return A map with image details + */ + private static Map extractImageContent(McpSchema.ImageContent imageContent) { + Map result = new HashMap<>(); + result.put("type", "image"); + result.put("data", imageContent.data()); + result.put("mimeType", imageContent.mimeType()); + return result; + } + + /** + * Extract embedded resource from MCP EmbeddedResource. + * + * @param embeddedResource The embedded resource + * @return A map with resource details + */ + private static Map extractEmbeddedResource( + McpSchema.EmbeddedResource embeddedResource) { + Map result = new HashMap<>(); + result.put("type", "resource"); + + var resource = embeddedResource.resource(); + if (resource instanceof McpSchema.TextResourceContents) { + McpSchema.TextResourceContents textResource = (McpSchema.TextResourceContents) resource; + result.put("uri", textResource.uri()); + result.put("text", textResource.text()); + } else if (resource instanceof McpSchema.BlobResourceContents) { + McpSchema.BlobResourceContents blobResource = (McpSchema.BlobResourceContents) resource; + result.put("uri", blobResource.uri()); + result.put("blob", blobResource.blob()); + } + + return result; + } + + /** Private constructor to prevent instantiation. */ + private MCPContentExtractor() { + throw new UnsupportedOperationException("Utility class"); + } +} diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPPrompt.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPPrompt.java new file mode 100644 index 000000000..a456be486 --- /dev/null +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPPrompt.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * MCP prompt definition that extends the base Prompt class. + * + *

This represents a prompt managed by an MCP server. Unlike static prompts, MCP prompts are + * fetched dynamically from the server and can accept arguments. + */ +public class MCPPrompt extends Prompt { + + private static final String FIELD_NAME = "name"; + private static final String FIELD_DESCRIPTION = "description"; + private static final String FIELD_ARGUMENTS = "arguments"; + private static final String FIELD_MCP_SERVER = "mcpServer"; + + @JsonProperty(FIELD_NAME) + private final String name; + + @JsonProperty(FIELD_DESCRIPTION) + private final String description; + + @JsonProperty(FIELD_ARGUMENTS) + private final Map promptArguments; + + @JsonProperty(FIELD_MCP_SERVER) + private final MCPServer mcpServer; + + /** Represents an argument that can be passed to an MCP prompt. */ + public static class PromptArgument { + @JsonProperty("name") + private final String name; + + @JsonProperty("description") + private final String description; + + @JsonProperty("required") + private final boolean required; + + @JsonCreator + public PromptArgument( + @JsonProperty("name") String name, + @JsonProperty("description") String description, + @JsonProperty("required") boolean required) { + this.name = Objects.requireNonNull(name, "name cannot be null"); + this.description = description; + this.required = required; + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public boolean isRequired() { + return required; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PromptArgument that = (PromptArgument) o; + return required == that.required + && Objects.equals(name, that.name) + && Objects.equals(description, that.description); + } + + @Override + public int hashCode() { + return Objects.hash(name, description, required); + } + } + + /** + * Create a new MCPPrompt. + * + * @param name The prompt name + * @param description The prompt description + * @param promptArguments Map of argument names to argument definitions + * @param mcpServer The MCP server reference + */ + @JsonCreator + public MCPPrompt( + @JsonProperty(FIELD_NAME) String name, + @JsonProperty(FIELD_DESCRIPTION) String description, + @JsonProperty(FIELD_ARGUMENTS) Map promptArguments, + @JsonProperty(FIELD_MCP_SERVER) MCPServer mcpServer) { + this.name = Objects.requireNonNull(name, "name cannot be null"); + this.description = description; + this.promptArguments = + promptArguments != null ? new HashMap<>(promptArguments) : new HashMap<>(); + this.mcpServer = Objects.requireNonNull(mcpServer, "mcpServer cannot be null"); + } + + public String getName() { + return name; + } + + public String getDescription() { + return description; + } + + public Map getPromptArguments() { + return new HashMap<>(promptArguments); + } + + @JsonIgnore + public MCPServer getMcpServer() { + return mcpServer; + } + + /** + * Format the prompt as a string with the given arguments. Overrides the base Prompt class to + * fetch prompts from the MCP server. + * + * @param arguments Arguments to pass to the prompt (String keys and values) + * @return The formatted prompt as a string + */ + @Override + public String formatString(Map arguments) { + List messages = formatMessages(MessageRole.SYSTEM, arguments); + return messages.stream() + .map(msg -> msg.getRole().getValue() + ": " + msg.getContent()) + .collect(Collectors.joining("\n")); + } + + /** + * Format the prompt as a list of chat messages with the given arguments. Overrides the base + * Prompt class to fetch prompts from the MCP server. + * + * @param defaultRole The default role for messages (usually SYSTEM) + * @param kwargs Arguments to pass to the prompt (String keys and values) + * @return List of formatted chat messages + */ + @Override + public List formatMessages(MessageRole defaultRole, Map kwargs) { + Map objectArgs = new HashMap<>(kwargs); + return formatMessages(objectArgs); + } + + /** + * Format the prompt as a list of chat messages with the given arguments. + * + * @param arguments Arguments to pass to the prompt (Object values) + * @return List of formatted chat messages + */ + private List formatMessages(Map arguments) { + return mcpServer.getPrompt(name, validateAndPrepareArguments(arguments)); + } + + /** + * Validate that all required arguments are present and prepare the arguments map. + * + * @param arguments The provided arguments + * @return A validated map of arguments + * @throws IllegalArgumentException if required arguments are missing + */ + private Map validateAndPrepareArguments(Map arguments) { + Map result = new HashMap<>(); + + for (PromptArgument arg : promptArguments.values()) { + if (arg.isRequired()) { + if (arguments == null || !arguments.containsKey(arg.getName())) { + throw new IllegalArgumentException( + "Missing required argument: " + arg.getName()); + } + result.put(arg.getName(), arguments.get(arg.getName())); + } else if (arguments != null && arguments.containsKey(arg.getName())) { + result.put(arg.getName(), arguments.get(arg.getName())); + } + } + + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MCPPrompt mcpPrompt = (MCPPrompt) o; + return Objects.equals(name, mcpPrompt.name) + && Objects.equals(description, mcpPrompt.description) + && Objects.equals(promptArguments, mcpPrompt.promptArguments) + && Objects.equals(mcpServer, mcpPrompt.mcpServer); + } + + @Override + public int hashCode() { + return Objects.hash(name, description, promptArguments, mcpServer); + } + + @Override + public String toString() { + return String.format("MCPPrompt{name='%s', server='%s'}", name, mcpServer.getEndpoint()); + } +} diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java new file mode 100644 index 000000000..a116c4bdc --- /dev/null +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.resource.SerializableResource; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.integrations.mcp.auth.ApiKeyAuth; +import org.apache.flink.agents.integrations.mcp.auth.Auth; +import org.apache.flink.agents.integrations.mcp.auth.BasicAuth; +import org.apache.flink.agents.integrations.mcp.auth.BearerTokenAuth; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; + +import java.net.URI; +import java.net.http.HttpRequest; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Resource representing an MCP server and exposing its tools/prompts. + * + *

This is a logical container for MCP tools and prompts; it is not directly invokable. It uses + * the official MCP Java SDK to communicate with MCP servers via HTTP/SSE. + * + *

Authentication is supported through the {@link Auth} interface with multiple implementations: + * + *

    + *
  • {@link BearerTokenAuth} - For OAuth 2.0 and JWT tokens + *
  • {@link BasicAuth} - For username/password authentication + *
  • {@link ApiKeyAuth} - For API key authentication via custom headers + *
+ * + *

Example with OAuth authentication: + * + *

{@code
+ * MCPServer server = MCPServer.builder("https://api.example.com/mcp")
+ *     .auth(new BearerTokenAuth("your-oauth-token"))
+ *     .timeout(Duration.ofSeconds(30))
+ *     .build();
+ *
+ * List tools = server.listTools();
+ * server.close();
+ * }
+ * + *

Reference: MCP Java Client + */ +public class MCPServer extends SerializableResource { + + private static final String FIELD_ENDPOINT = "endpoint"; + private static final String FIELD_HEADERS = "headers"; + private static final String FIELD_TIMEOUT_SECONDS = "timeoutSeconds"; + private static final String FIELD_AUTH = "auth"; + + @JsonProperty(FIELD_ENDPOINT) + private final String endpoint; + + @JsonProperty(FIELD_HEADERS) + private final Map headers; + + @JsonProperty(FIELD_TIMEOUT_SECONDS) + private final long timeoutSeconds; + + @JsonProperty(FIELD_AUTH) + private final Auth auth; + + @JsonIgnore private transient McpSyncClient client; + + /** Builder for MCPServer with fluent API. */ + public static class Builder { + private String endpoint; + private final Map headers = new HashMap<>(); + private long timeoutSeconds = 30; + private Auth auth = null; + + public Builder endpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + public Builder header(String key, String value) { + this.headers.put(key, value); + return this; + } + + public Builder headers(Map headers) { + this.headers.putAll(headers); + return this; + } + + public Builder timeout(Duration timeout) { + this.timeoutSeconds = timeout.getSeconds(); + return this; + } + + public Builder auth(Auth auth) { + this.auth = auth; + return this; + } + + public MCPServer build() { + return new MCPServer(endpoint, headers, timeoutSeconds, auth); + } + } + + /** + * Creates a new MCPServer instance. + * + * @param endpoint The HTTP endpoint of the MCP server + */ + public MCPServer(String endpoint) { + this(endpoint, new HashMap<>(), 30, null); + } + + @JsonCreator + public MCPServer( + @JsonProperty(FIELD_ENDPOINT) String endpoint, + @JsonProperty(FIELD_HEADERS) Map headers, + @JsonProperty(FIELD_TIMEOUT_SECONDS) long timeoutSeconds, + @JsonProperty(FIELD_AUTH) Auth auth) { + this.endpoint = Objects.requireNonNull(endpoint, "endpoint cannot be null"); + this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); + this.timeoutSeconds = timeoutSeconds; + this.auth = auth; + } + + public static Builder builder(String endpoint) { + return new Builder().endpoint(endpoint); + } + + @Override + @JsonIgnore + public ResourceType getResourceType() { + return ResourceType.MCP_SERVER; + } + + public String getEndpoint() { + return endpoint; + } + + public Map getHeaders() { + return new HashMap<>(headers); + } + + public long getTimeoutSeconds() { + return timeoutSeconds; + } + + public Auth getAuth() { + return auth; + } + + /** + * Get or create a synchronized MCP client. + * + * @return The MCP sync client + */ + @JsonIgnore + private synchronized McpSyncClient getClient() { + if (client == null) { + client = createClient(); + } + return client; + } + + /** + * Create a new MCP client with the configured transport. + * + * @return A new MCP sync client + */ + private McpSyncClient createClient() { + validateHttpUrl(); + + var requestBuilder = HttpRequest.newBuilder().timeout(Duration.ofSeconds(timeoutSeconds)); + + // Add custom headers + headers.forEach(requestBuilder::header); + + // Apply authentication if configured + if (auth != null) { + auth.applyAuth(requestBuilder); + } + + // Create transport based on type + var transport = + HttpClientStreamableHttpTransport.builder(endpoint) + .requestBuilder(requestBuilder) + .build(); + + // Build and initialize the client + var mcpClient = + McpClient.sync(transport) + .requestTimeout(Duration.ofSeconds(timeoutSeconds)) + .build(); + + mcpClient.initialize(); + return mcpClient; + } + + /** Validate that the endpoint is a valid HTTP URL. */ + private void validateHttpUrl() { + try { + URI uri = URI.create(endpoint); + String scheme = uri.getScheme(); + if (scheme == null || (!scheme.equals("http") && !scheme.equals("https"))) { + throw new IllegalArgumentException( + "Invalid HTTP endpoint: " + endpoint + ". Scheme must be http or https"); + } + if (uri.getHost() == null || uri.getHost().isEmpty()) { + throw new IllegalArgumentException( + "Invalid HTTP endpoint: " + endpoint + ". Host cannot be empty"); + } + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid HTTP endpoint: " + endpoint, e); + } + } + + /** + * List available tools from the MCP server. + * + * @return List of MCPTool instances + */ + public List listTools() { + McpSyncClient mcpClient = getClient(); + McpSchema.ListToolsResult toolsResult = mcpClient.listTools(); + + List tools = new ArrayList<>(); + for (McpSchema.Tool toolData : toolsResult.tools()) { + ToolMetadata metadata = + new ToolMetadata( + toolData.name(), + toolData.description() != null ? toolData.description() : "", + serializeInputSchema(toolData.inputSchema())); + + MCPTool tool = new MCPTool(metadata, this); + tools.add(tool); + } + + return tools; + } + + /** + * Get a specific tool by name. + * + * @param name The tool name + * @return The MCPTool instance + * @throws IllegalArgumentException if tool not found + */ + public MCPTool getTool(String name) { + List tools = listTools(); + return tools.stream() + .filter(tool -> tool.getName().equals(name)) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + "Tool '" + + name + + "' not found on MCP server at " + + endpoint)); + } + + /** + * Get tool metadata by name. + * + * @param name The tool name + * @return The ToolMetadata + */ + public ToolMetadata getToolMetadata(String name) { + return getTool(name).getMetadata(); + } + + /** + * Call a tool on the MCP server. + * + * @param toolName The name of the tool to call + * @param arguments The arguments to pass to the tool + * @return The result as a list of content items + */ + public List callTool(String toolName, Map arguments) { + McpSyncClient mcpClient = getClient(); + McpSchema.CallToolRequest request = + new McpSchema.CallToolRequest( + toolName, arguments != null ? arguments : new HashMap<>()); + McpSchema.CallToolResult result = mcpClient.callTool(request); + + List content = new ArrayList<>(); + for (var item : result.content()) { + content.add(MCPContentExtractor.extractContentItem(item)); + } + + return content; + } + + /** + * List available prompts from the MCP server. + * + * @return List of MCPPrompt instances + */ + public List listPrompts() { + McpSyncClient mcpClient = getClient(); + McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts(); + + List prompts = new ArrayList<>(); + for (McpSchema.Prompt promptData : promptsResult.prompts()) { + Map argumentsMap = new HashMap<>(); + if (promptData.arguments() != null) { + for (var arg : promptData.arguments()) { + argumentsMap.put( + arg.name(), + new MCPPrompt.PromptArgument( + arg.name(), arg.description(), arg.required())); + } + } + + MCPPrompt prompt = + new MCPPrompt(promptData.name(), promptData.description(), argumentsMap, this); + prompts.add(prompt); + } + + return prompts; + } + + /** + * Get a prompt by name with optional arguments. + * + * @param name The prompt name + * @param arguments Optional arguments for the prompt + * @return List of chat messages + */ + public List getPrompt(String name, Map arguments) { + McpSyncClient mcpClient = getClient(); + McpSchema.GetPromptRequest request = + new McpSchema.GetPromptRequest( + name, arguments != null ? arguments : new HashMap<>()); + McpSchema.GetPromptResult result = mcpClient.getPrompt(request); + + List chatMessages = new ArrayList<>(); + for (var message : result.messages()) { + if (message.content() instanceof McpSchema.TextContent) { + var textContent = (McpSchema.TextContent) message.content(); + MessageRole role = MessageRole.valueOf(message.role().name().toUpperCase()); + chatMessages.add(new ChatMessage(role, textContent.text())); + } + } + + return chatMessages; + } + + /** Close the MCP client and clean up resources. */ + public void close() { + if (client != null) { + try { + client.closeGracefully(); + } catch (Exception e) { + // Ignore exceptions during cleanup + } finally { + client = null; + } + } + } + + /** Serialize input schema to JSON string. */ + private String serializeInputSchema(Object inputSchema) { + if (inputSchema == null) { + return "{\"type\":\"object\",\"properties\":{}}"; + } + try { + return new ObjectMapper().writeValueAsString(inputSchema); + } catch (Exception e) { + return "{\"type\":\"object\",\"properties\":{}}"; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MCPServer that = (MCPServer) o; + return timeoutSeconds == that.timeoutSeconds + && Objects.equals(endpoint, that.endpoint) + && Objects.equals(headers, that.headers) + && Objects.equals(auth, that.auth); + } + + @Override + public int hashCode() { + return Objects.hash(endpoint, headers, timeoutSeconds, auth); + } + + @Override + public String toString() { + return String.format("MCPServer{endpoint='%s'}", endpoint); + } +} diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPTool.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPTool.java new file mode 100644 index 000000000..1b30eb56e --- /dev/null +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPTool.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; +import org.apache.flink.agents.api.tools.ToolType; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * MCP tool definition that can be called directly. + * + *

This represents a single tool from an MCP server. It extends the base Tool class and delegates + * actual execution to the MCP server. + */ +public class MCPTool extends Tool { + + private static final String FIELD_MCP_SERVER = "mcpServer"; + + @JsonProperty(FIELD_MCP_SERVER) + private final MCPServer mcpServer; + + /** + * Create a new MCPTool. + * + * @param metadata The tool metadata + * @param mcpServer The MCP server reference + */ + @JsonCreator + public MCPTool( + @JsonProperty("metadata") ToolMetadata metadata, + @JsonProperty(FIELD_MCP_SERVER) MCPServer mcpServer) { + super(metadata); + this.mcpServer = Objects.requireNonNull(mcpServer, "mcpServer cannot be null"); + } + + @Override + @JsonIgnore + public ToolType getToolType() { + return ToolType.MCP; + } + + /** + * Call the MCP tool with the given parameters. + * + * @param parameters The tool parameters + * @return The tool response + */ + @Override + public ToolResponse call(ToolParameters parameters) { + long startTime = System.currentTimeMillis(); + + try { + Map arguments = new HashMap<>(); + for (String paramName : parameters.getParameterNames()) { + arguments.put(paramName, parameters.getParameter(paramName)); + } + + List result = mcpServer.callTool(metadata.getName(), arguments); + + long executionTime = System.currentTimeMillis() - startTime; + + // Return the result (could be text, images, or other content) + return ToolResponse.success(result, executionTime, metadata.getName()); + + } catch (Exception e) { + long executionTime = System.currentTimeMillis() - startTime; + String errorMessage = + "Error calling MCP tool '" + metadata.getName() + "': " + e.getMessage(); + return ToolResponse.error(errorMessage, executionTime, metadata.getName()); + } + } + + /** + * Get the MCP server associated with this tool. + * + * @return The MCP server + */ + @JsonIgnore + public MCPServer getMcpServer() { + return mcpServer; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MCPTool mcpTool = (MCPTool) o; + return Objects.equals(metadata, mcpTool.metadata) + && Objects.equals(mcpServer, mcpTool.mcpServer); + } + + @Override + public int hashCode() { + return Objects.hash(metadata, mcpServer); + } + + @Override + public String toString() { + return String.format( + "MCPTool{name='%s', server='%s'}", metadata.getName(), mcpServer.getEndpoint()); + } +} diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/ApiKeyAuth.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/ApiKeyAuth.java new file mode 100644 index 000000000..ae36f8018 --- /dev/null +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/ApiKeyAuth.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp.auth; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import java.net.http.HttpRequest; +import java.util.Objects; + +/** + * API key authentication for HTTP requests. + * + *

This authentication method adds a custom header with an API key to requests. Common header + * names include "X-API-Key", "Api-Key", or any custom header. + * + *

Example usage: + * + *

{@code
+ * MCPServer server = MCPServer.builder("https://api.example.com/mcp")
+ *     .auth(new ApiKeyAuth("X-API-Key", "your-api-key"))
+ *     .build();
+ * }
+ */ +public class ApiKeyAuth implements Auth { + + private static final String AUTH_TYPE = "api_key"; + private static final String FIELD_HEADER_NAME = "headerName"; + private static final String FIELD_API_KEY = "apiKey"; + + @JsonProperty(FIELD_HEADER_NAME) + private final String headerName; + + @JsonProperty(FIELD_API_KEY) + private final String apiKey; + + /** + * Create a new API key authentication. + * + * @param headerName The name of the header to use (e.g., "X-API-Key") + * @param apiKey The API key value + */ + @JsonCreator + public ApiKeyAuth( + @JsonProperty(FIELD_HEADER_NAME) String headerName, + @JsonProperty(FIELD_API_KEY) String apiKey) { + this.headerName = Objects.requireNonNull(headerName, "headerName cannot be null"); + this.apiKey = Objects.requireNonNull(apiKey, "apiKey cannot be null"); + } + + @Override + public void applyAuth(HttpRequest.Builder requestBuilder) { + requestBuilder.header(headerName, apiKey); + } + + @Override + public String getAuthType() { + return AUTH_TYPE; + } + + public String getHeaderName() { + return headerName; + } + + public String getApiKey() { + return apiKey; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ApiKeyAuth that = (ApiKeyAuth) o; + return Objects.equals(headerName, that.headerName) && Objects.equals(apiKey, that.apiKey); + } + + @Override + public int hashCode() { + return Objects.hash(headerName, apiKey); + } + + @Override + public String toString() { + return "ApiKeyAuth{headerName='" + headerName + "', apiKey=***}"; + } +} diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/Auth.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/Auth.java new file mode 100644 index 000000000..3f3789bfd --- /dev/null +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/Auth.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp.auth; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonSubTypes; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeInfo; + +import java.io.Serializable; +import java.net.http.HttpRequest; + +/** + * Base interface for authentication mechanisms used with MCP servers. + * + *

Implementations of this interface define how to apply authentication to HTTP requests made to + * MCP servers. + * + * @see BearerTokenAuth + * @see BasicAuth + * @see ApiKeyAuth + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "authType") +@JsonSubTypes({ + @JsonSubTypes.Type(value = BearerTokenAuth.class, name = "bearer"), + @JsonSubTypes.Type(value = BasicAuth.class, name = "basic"), + @JsonSubTypes.Type(value = ApiKeyAuth.class, name = "api_key") +}) +public interface Auth extends Serializable { + + /** + * Apply authentication to an HTTP request builder. + * + * @param requestBuilder The HTTP request builder to apply authentication to + */ + void applyAuth(HttpRequest.Builder requestBuilder); + + /** + * Get the type of authentication for serialization purposes. + * + * @return The authentication type identifier + */ + @org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore + String getAuthType(); +} diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/BasicAuth.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/BasicAuth.java new file mode 100644 index 000000000..b8c3f3915 --- /dev/null +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/BasicAuth.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp.auth; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import java.net.http.HttpRequest; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Objects; + +/** + * Basic authentication (username/password) for HTTP requests. + * + *

This authentication method adds an "Authorization: Basic {credentials}" header to requests, + * where credentials is the base64-encoded string "username:password". + * + *

Example usage: + * + *

{@code
+ * MCPServer server = MCPServer.builder("https://api.example.com/mcp")
+ *     .auth(new BasicAuth("username", "password"))
+ *     .build();
+ * }
+ */ +public class BasicAuth implements Auth { + + private static final String AUTH_TYPE = "basic"; + private static final String FIELD_USERNAME = "username"; + private static final String FIELD_PASSWORD = "password"; + + @JsonProperty(FIELD_USERNAME) + private final String username; + + @JsonProperty(FIELD_PASSWORD) + private final String password; + + /** + * Create a new Basic authentication. + * + * @param username The username for authentication + * @param password The password for authentication + */ + @JsonCreator + public BasicAuth( + @JsonProperty(FIELD_USERNAME) String username, + @JsonProperty(FIELD_PASSWORD) String password) { + this.username = Objects.requireNonNull(username, "username cannot be null"); + this.password = Objects.requireNonNull(password, "password cannot be null"); + } + + @Override + public void applyAuth(HttpRequest.Builder requestBuilder) { + String credentials = username + ":" + password; + String encoded = + Base64.getEncoder().encodeToString(credentials.getBytes(StandardCharsets.UTF_8)); + requestBuilder.header("Authorization", "Basic " + encoded); + } + + @Override + public String getAuthType() { + return AUTH_TYPE; + } + + public String getUsername() { + return username; + } + + public String getPassword() { + return password; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BasicAuth basicAuth = (BasicAuth) o; + return Objects.equals(username, basicAuth.username) + && Objects.equals(password, basicAuth.password); + } + + @Override + public int hashCode() { + return Objects.hash(username, password); + } + + @Override + public String toString() { + return "BasicAuth{username='" + username + "', password=***}"; + } +} diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/BearerTokenAuth.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/BearerTokenAuth.java new file mode 100644 index 000000000..ab0e0948b --- /dev/null +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/auth/BearerTokenAuth.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp.auth; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; + +import java.net.http.HttpRequest; +import java.util.Objects; + +/** + * Bearer token authentication for OAuth 2.0 and similar token-based authentication schemes. + * + *

This authentication method adds an "Authorization: Bearer {token}" header to requests. + * + *

Example usage: + * + *

{@code
+ * MCPServer server = MCPServer.builder("https://api.example.com/mcp")
+ *     .auth(new BearerTokenAuth("your-oauth-token"))
+ *     .build();
+ * }
+ */ +public class BearerTokenAuth implements Auth { + + private static final String AUTH_TYPE = "bearer"; + private static final String FIELD_TOKEN = "token"; + + @JsonProperty(FIELD_TOKEN) + private final String token; + + /** + * Create a new Bearer token authentication. + * + * @param token The bearer token to use for authentication + */ + @JsonCreator + public BearerTokenAuth(@JsonProperty(FIELD_TOKEN) String token) { + this.token = Objects.requireNonNull(token, "token cannot be null"); + } + + @Override + public void applyAuth(HttpRequest.Builder requestBuilder) { + requestBuilder.header("Authorization", "Bearer " + token); + } + + @Override + public String getAuthType() { + return AUTH_TYPE; + } + + public String getToken() { + return token; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BearerTokenAuth that = (BearerTokenAuth) o; + return Objects.equals(token, that.token); + } + + @Override + public int hashCode() { + return Objects.hash(token); + } + + @Override + public String toString() { + return "BearerTokenAuth{token=***}"; + } +} diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPPromptTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPPromptTest.java new file mode 100644 index 000000000..599e4ed0e --- /dev/null +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPPromptTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES; +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link MCPPrompt}. */ +class MCPPromptTest { + + private static final String DEFAULT_ENDPOINT = "http://localhost:8000/mcp"; + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Create MCPPrompt with required arguments") + void testCreationWithRequiredArgs() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + Map args = new HashMap<>(); + args.put("name", new MCPPrompt.PromptArgument("name", "User name", true)); + + MCPPrompt prompt = new MCPPrompt("greeting", "Greeting prompt", args, server); + + assertThat(prompt.getName()).isEqualTo("greeting"); + assertThat(prompt.getDescription()).isEqualTo("Greeting prompt"); + assertThat(prompt.getPromptArguments()).hasSize(1); + assertThat(prompt.getMcpServer()).isEqualTo(server); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Create MCPPrompt with optional arguments") + void testCreationWithOptionalArgs() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + Map args = new HashMap<>(); + args.put("city", new MCPPrompt.PromptArgument("city", "City name", true)); + args.put("units", new MCPPrompt.PromptArgument("units", "Temperature units", false)); + + MCPPrompt prompt = new MCPPrompt("weather", "Weather prompt", args, server); + + assertThat(prompt.getPromptArguments()).hasSize(2); + assertThat(prompt.getPromptArguments().get("city").isRequired()).isTrue(); + assertThat(prompt.getPromptArguments().get("units").isRequired()).isFalse(); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Validate argument handling - required vs optional") + void testArgumentValidation() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + Map args = new HashMap<>(); + args.put("required", new MCPPrompt.PromptArgument("required", "Required", true)); + args.put("optional", new MCPPrompt.PromptArgument("optional", "Optional", false)); + + MCPPrompt prompt = new MCPPrompt("test", "Test prompt", args, server); + + // Verify the prompt was created with correct arguments + assertThat(prompt.getPromptArguments()).hasSize(2); + assertThat(prompt.getPromptArguments().get("required").isRequired()).isTrue(); + assertThat(prompt.getPromptArguments().get("optional").isRequired()).isFalse(); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test PromptArgument creation and getters") + void testPromptArgument() { + MCPPrompt.PromptArgument arg = new MCPPrompt.PromptArgument("city", "City name", true); + + assertThat(arg.getName()).isEqualTo("city"); + assertThat(arg.getDescription()).isEqualTo("City name"); + assertThat(arg.isRequired()).isTrue(); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test PromptArgument equals and hashCode") + void testPromptArgumentEquals() { + MCPPrompt.PromptArgument arg1 = new MCPPrompt.PromptArgument("name", "Name", true); + MCPPrompt.PromptArgument arg2 = new MCPPrompt.PromptArgument("name", "Name", true); + MCPPrompt.PromptArgument arg3 = new MCPPrompt.PromptArgument("name", "Different", true); + + assertThat(arg1).isEqualTo(arg2).hasSameHashCodeAs(arg2).isNotEqualTo(arg3); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test MCPPrompt equals and hashCode") + void testEquals() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + Map args = new HashMap<>(); + args.put("arg1", new MCPPrompt.PromptArgument("arg1", "Arg 1", true)); + + MCPPrompt prompt1 = new MCPPrompt("test", "Test", args, server); + MCPPrompt prompt2 = new MCPPrompt("test", "Test", args, server); + MCPPrompt prompt3 = new MCPPrompt("other", "Other", args, server); + + assertThat(prompt1).isEqualTo(prompt2).hasSameHashCodeAs(prompt2).isNotEqualTo(prompt3); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test toString") + void testToString() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + MCPPrompt prompt = new MCPPrompt("greeting", "Greeting prompt", new HashMap<>(), server); + + String str = prompt.toString(); + assertThat(str).contains("MCPPrompt").contains("greeting").contains(DEFAULT_ENDPOINT); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("JSON serialization and deserialization") + void testJsonSerialization() throws Exception { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + Map args = new HashMap<>(); + args.put("topic", new MCPPrompt.PromptArgument("topic", "Topic name", true)); + args.put("style", new MCPPrompt.PromptArgument("style", "Writing style", false)); + + MCPPrompt original = new MCPPrompt("essay", "Essay prompt", args, server); + + ObjectMapper mapper = new ObjectMapper(); + // Configure to ignore unknown properties during deserialization + mapper.configure(FAIL_ON_UNKNOWN_PROPERTIES, false); + + String json = mapper.writeValueAsString(original); + + MCPPrompt deserialized = mapper.readValue(json, MCPPrompt.class); + + assertThat(deserialized.getName()).isEqualTo(original.getName()); + assertThat(deserialized.getDescription()).isEqualTo(original.getDescription()); + assertThat(deserialized.getPromptArguments()).hasSize(original.getPromptArguments().size()); + assertThat(deserialized.getMcpServer()).isEqualTo(original.getMcpServer()); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Arguments map is immutable from outside") + void testArgumentsImmutability() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + Map args = new HashMap<>(); + args.put("arg1", new MCPPrompt.PromptArgument("arg1", "Arg 1", true)); + + MCPPrompt prompt = new MCPPrompt("test", "Test", args, server); + + // Modify original map + args.put("arg2", new MCPPrompt.PromptArgument("arg2", "Arg 2", false)); + + // Prompt should not be affected + assertThat(prompt.getPromptArguments()).hasSize(1); + assertThat(prompt.getPromptArguments()).doesNotContainKey("arg2"); + + // Modify returned map + Map returnedArgs = prompt.getPromptArguments(); + returnedArgs.put("arg3", new MCPPrompt.PromptArgument("arg3", "Arg 3", false)); + + // Prompt should not be affected + assertThat(prompt.getPromptArguments()).hasSize(1); + assertThat(prompt.getPromptArguments()).doesNotContainKey("arg3"); + } +} diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPSerializationTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPSerializationTest.java new file mode 100644 index 000000000..1849476cc --- /dev/null +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPSerializationTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES; +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for MCP classes serialization. + * + *

This test ensures that MCP objects (MCPServer, MCPTool, MCPPrompt) can be properly serialized + * and deserialized by Flink's serialization framework. It verifies that all fields are preserved + */ +class MCPSerializationTest { + + private static final String DEFAULT_ENDPOINT = "http://localhost:8000/mcp"; + + /** + * Create an ObjectMapper configured to ignore unknown properties during deserialization. This + * is needed because base classes may have getters that are serialized. + */ + private ObjectMapper createObjectMapper() { + ObjectMapper mapper = new ObjectMapper(); + mapper.configure(FAIL_ON_UNKNOWN_PROPERTIES, false); + return mapper; + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test MCPServer JSON serialization and deserialization") + void testMCPServerJsonSerialization() throws Exception { + MCPServer original = + MCPServer.builder(DEFAULT_ENDPOINT) + .timeout(Duration.ofSeconds(30)) + .header("X-Custom", "value") + .build(); + + ObjectMapper mapper = createObjectMapper(); + String json = mapper.writeValueAsString(original); + MCPServer deserialized = mapper.readValue(json, MCPServer.class); + + assertThat(deserialized.getEndpoint()).isEqualTo(original.getEndpoint()); + assertThat(deserialized.getTimeoutSeconds()).isEqualTo(original.getTimeoutSeconds()); + assertThat(deserialized.getHeaders()).isEqualTo(original.getHeaders()); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test MCPTool JSON serialization and deserialization") + void testMCPToolJsonSerialization() throws Exception { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + ToolMetadata metadata = + new ToolMetadata("add", "Add numbers", "{\"type\":\"object\",\"properties\":{}}"); + MCPTool original = new MCPTool(metadata, server); + + ObjectMapper mapper = createObjectMapper(); + String json = mapper.writeValueAsString(original); + MCPTool deserialized = mapper.readValue(json, MCPTool.class); + + assertThat(deserialized.getName()).isEqualTo(original.getName()); + assertThat(deserialized.getMetadata()).isEqualTo(original.getMetadata()); + assertThat(deserialized.getMcpServer()).isEqualTo(original.getMcpServer()); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test MCPPrompt JSON serialization and deserialization") + void testMCPPromptJsonSerialization() throws Exception { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + Map args = new HashMap<>(); + args.put("name", new MCPPrompt.PromptArgument("name", "User name", true)); + + MCPPrompt original = new MCPPrompt("greeting", "Greet user", args, server); + + ObjectMapper mapper = createObjectMapper(); + String json = mapper.writeValueAsString(original); + MCPPrompt deserialized = mapper.readValue(json, MCPPrompt.class); + + assertThat(deserialized.getName()).isEqualTo(original.getName()); + assertThat(deserialized.getDescription()).isEqualTo(original.getDescription()); + assertThat(deserialized.getPromptArguments()).hasSameSizeAs(original.getPromptArguments()); + assertThat(deserialized.getMcpServer()).isEqualTo(original.getMcpServer()); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test HashMap serialization in MCP objects") + void testHashMapSerialization() throws Exception { + // This specifically tests that the HashMap instances in MCP objects + // don't cause Kryo serialization issues (like Arrays$ArrayList did) + + Map headers = new HashMap<>(); + headers.put("Header1", "Value1"); + headers.put("Header2", "Value2"); + headers.put("Header3", "Value3"); + + MCPServer server = MCPServer.builder(DEFAULT_ENDPOINT).headers(headers).build(); + + Map args = new HashMap<>(); + args.put("arg1", new MCPPrompt.PromptArgument("arg1", "Argument 1", true)); + args.put("arg2", new MCPPrompt.PromptArgument("arg2", "Argument 2", false)); + args.put("arg3", new MCPPrompt.PromptArgument("arg3", "Argument 3", true)); + + MCPPrompt prompt = new MCPPrompt("test", "Test prompt", args, server); + + // Serialize + ObjectMapper mapper = createObjectMapper(); + String json = mapper.writeValueAsString(prompt); + + // Deserialize + MCPPrompt deserialized = mapper.readValue(json, MCPPrompt.class); + + // Verify HashMaps are properly serialized + assertThat(deserialized.getMcpServer().getHeaders()).hasSize(3); + assertThat(deserialized.getPromptArguments()).hasSize(3); + assertThat(deserialized.getMcpServer().getHeaders()).containsEntry("Header1", "Value1"); + assertThat(deserialized.getPromptArguments()).containsKey("arg1"); + } +} diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java new file mode 100644 index 000000000..24456033a --- /dev/null +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.integrations.mcp.auth.ApiKeyAuth; +import org.apache.flink.agents.integrations.mcp.auth.BasicAuth; +import org.apache.flink.agents.integrations.mcp.auth.BearerTokenAuth; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link MCPServer}. */ +class MCPServerTest { + + private static final String DEFAULT_ENDPOINT = "http://localhost:8000/mcp"; + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Create MCPServer with builder") + void testBuilderCreation() { + MCPServer server = + MCPServer.builder(DEFAULT_ENDPOINT) + .header("X-Custom-Header", "value") + .timeout(Duration.ofSeconds(30)) + .auth(new BearerTokenAuth("test-token")) + .build(); + + assertThat(server.getEndpoint()).isEqualTo("http://localhost:8000/mcp"); + assertThat(server.getHeaders()).containsEntry("X-Custom-Header", "value"); + assertThat(server.getTimeoutSeconds()).isEqualTo(30); + assertThat(server.getAuth()).isInstanceOf(BearerTokenAuth.class); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Create MCPServer with simple constructor") + void testSimpleConstructor() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + + assertThat(server.getEndpoint()).isEqualTo(DEFAULT_ENDPOINT); + assertThat(server.getHeaders()).isEmpty(); + assertThat(server.getTimeoutSeconds()).isEqualTo(30); + assertThat(server.getAuth()).isNull(); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Builder with multiple headers") + void testBuilderWithMultipleHeaders() { + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer token"); + headers.put("X-API-Key", "key123"); + + MCPServer server = MCPServer.builder(DEFAULT_ENDPOINT).headers(headers).build(); + + assertThat(server.getHeaders()).hasSize(2); + assertThat(server.getHeaders()).containsEntry("Authorization", "Bearer token"); + assertThat(server.getHeaders()).containsEntry("X-API-Key", "key123"); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test different authentication types") + void testAuthenticationTypes() { + // Bearer token auth + MCPServer bearerServer = + MCPServer.builder(DEFAULT_ENDPOINT).auth(new BearerTokenAuth("my-token")).build(); + assertThat(bearerServer.getAuth()).isInstanceOf(BearerTokenAuth.class); + + // Basic auth + MCPServer basicServer = + MCPServer.builder(DEFAULT_ENDPOINT).auth(new BasicAuth("user", "pass")).build(); + assertThat(basicServer.getAuth()).isInstanceOf(BasicAuth.class); + + // API key auth + MCPServer apiKeyServer = + MCPServer.builder(DEFAULT_ENDPOINT) + .auth(new ApiKeyAuth("X-API-Key", "secret")) + .build(); + assertThat(apiKeyServer.getAuth()).isInstanceOf(ApiKeyAuth.class); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Validate HTTP endpoint") + void testEndpointValidation() { + // Valid endpoints + new MCPServer(DEFAULT_ENDPOINT); + new MCPServer("https://api.example.com/mcp"); + + // Null endpoint + assertThatThrownBy(() -> new MCPServer(null)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("endpoint cannot be null"); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test resource type") + void testResourceType() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + assertThat(server.getResourceType()).isEqualTo(ResourceType.MCP_SERVER); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test equals and hashCode") + void testEqualsAndHashCode() { + MCPServer server1 = + MCPServer.builder(DEFAULT_ENDPOINT).timeout(Duration.ofSeconds(30)).build(); + + MCPServer server2 = + MCPServer.builder(DEFAULT_ENDPOINT).timeout(Duration.ofSeconds(30)).build(); + + MCPServer server3 = + MCPServer.builder(DEFAULT_ENDPOINT).timeout(Duration.ofSeconds(60)).build(); + + assertThat(server1).hasSameHashCodeAs(server2).isEqualTo(server2).isNotEqualTo(server3); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test toString") + void testToString() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + assertThat(server.toString()).contains("MCPServer"); + assertThat(server.toString()).contains(DEFAULT_ENDPOINT); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("JSON serialization and deserialization") + void testJsonSerialization() throws Exception { + MCPServer original = + MCPServer.builder(DEFAULT_ENDPOINT) + .header("X-Custom", "value") + .timeout(Duration.ofSeconds(45)) + .auth(new BearerTokenAuth("test-token")) + .build(); + + ObjectMapper mapper = new ObjectMapper(); + // Configure to ignore unknown properties during deserialization + mapper.configure(FAIL_ON_UNKNOWN_PROPERTIES, false); + + String json = mapper.writeValueAsString(original); + + MCPServer deserialized = mapper.readValue(json, MCPServer.class); + + assertThat(deserialized.getEndpoint()).isEqualTo(original.getEndpoint()); + assertThat(deserialized.getHeaders()).isEqualTo(original.getHeaders()); + assertThat(deserialized.getTimeoutSeconds()).isEqualTo(original.getTimeoutSeconds()); + assertThat(deserialized.getAuth()).isEqualTo(original.getAuth()); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("JSON serialization with different auth types") + void testJsonSerializationWithAuth() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + + // Bearer token auth + MCPServer bearerServer = + MCPServer.builder(DEFAULT_ENDPOINT).auth(new BearerTokenAuth("token")).build(); + String bearerJson = mapper.writeValueAsString(bearerServer); + MCPServer bearerDeserialized = mapper.readValue(bearerJson, MCPServer.class); + assertThat(bearerDeserialized.getAuth()).isInstanceOf(BearerTokenAuth.class); + + // Basic auth + MCPServer basicServer = + MCPServer.builder(DEFAULT_ENDPOINT).auth(new BasicAuth("user", "pass")).build(); + String basicJson = mapper.writeValueAsString(basicServer); + MCPServer basicDeserialized = mapper.readValue(basicJson, MCPServer.class); + assertThat(basicDeserialized.getAuth()).isInstanceOf(BasicAuth.class); + + // API key auth + MCPServer apiKeyServer = + MCPServer.builder(DEFAULT_ENDPOINT) + .auth(new ApiKeyAuth("X-API-Key", "secret")) + .build(); + String apiKeyJson = mapper.writeValueAsString(apiKeyServer); + MCPServer apiKeyDeserialized = mapper.readValue(apiKeyJson, MCPServer.class); + assertThat(apiKeyDeserialized.getAuth()).isInstanceOf(ApiKeyAuth.class); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Headers are immutable from outside") + void testHeadersImmutability() { + Map headers = new HashMap<>(); + headers.put("X-Test", "value"); + + MCPServer server = MCPServer.builder(DEFAULT_ENDPOINT).headers(headers).build(); + + // Modify original map + headers.put("X-New", "new-value"); + + // Server should not be affected + assertThat(server.getHeaders()).hasSize(1); + assertThat(server.getHeaders()).doesNotContainKey("X-New"); + + // Modify returned map + Map returnedHeaders = server.getHeaders(); + returnedHeaders.put("X-Another", "another-value"); + + // Server should not be affected + assertThat(server.getHeaders()).hasSize(1); + assertThat(server.getHeaders()).doesNotContainKey("X-Another"); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Close server gracefully") + void testClose() { + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + // Should not throw any exception + server.close(); + server.close(); // Calling twice should be safe + } +} diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPToolTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPToolTest.java new file mode 100644 index 000000000..10e5c1261 --- /dev/null +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPToolTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolType; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +import static org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES; +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link MCPTool}. */ +class MCPToolTest { + + private static final String DEFAULT_ENDPOINT = "http://localhost:8000/mcp"; + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Create MCPTool with metadata and server") + void testCreation() { + ToolMetadata metadata = new ToolMetadata("add", "Add two numbers", "{\"type\":\"object\"}"); + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + + MCPTool tool = new MCPTool(metadata, server); + + assertThat(tool.getName()).isEqualTo("add"); + assertThat(tool.getMetadata()).isEqualTo(metadata); + assertThat(tool.getMcpServer()).isEqualTo(server); + assertThat(tool.getToolType()).isEqualTo(ToolType.MCP); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test tool metadata access") + void testToolMetadataAccess() { + ToolMetadata metadata = + new ToolMetadata( + "calculator", + "Calculate mathematical expressions", + "{\"type\":\"object\",\"properties\":{\"expression\":{\"type\":\"string\"}}}"); + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + + MCPTool tool = new MCPTool(metadata, server); + + assertThat(tool.getMetadata()).isEqualTo(metadata); + assertThat(tool.getMetadata().getName()).isEqualTo("calculator"); + assertThat(tool.getMetadata().getDescription()) + .isEqualTo("Calculate mathematical expressions"); + assertThat(tool.getMetadata().getInputSchema()).contains("expression"); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test tool name and description getters") + void testToolGetters() { + ToolMetadata metadata = + new ToolMetadata("multiply", "Multiply two numbers", "{\"type\":\"object\"}"); + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + + MCPTool tool = new MCPTool(metadata, server); + + assertThat(tool.getName()).isEqualTo("multiply"); + assertThat(tool.getDescription()).isEqualTo("Multiply two numbers"); + assertThat(tool.getMcpServer()).isEqualTo(server); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test equals and hashCode") + void testEqualsAndHashCode() { + ToolMetadata metadata1 = new ToolMetadata("tool1", "Description", "{\"type\":\"object\"}"); + MCPServer server1 = new MCPServer(DEFAULT_ENDPOINT); + + MCPTool tool1 = new MCPTool(metadata1, server1); + MCPTool tool2 = new MCPTool(metadata1, server1); + + ToolMetadata metadata2 = new ToolMetadata("tool2", "Description", "{\"type\":\"object\"}"); + MCPTool tool3 = new MCPTool(metadata2, server1); + + assertThat(tool1).isEqualTo(tool2).hasSameHashCodeAs(tool2).isNotEqualTo(tool3); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test toString") + void testToString() { + ToolMetadata metadata = new ToolMetadata("add", "Add numbers", "{\"type\":\"object\"}"); + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + MCPTool tool = new MCPTool(metadata, server); + + String str = tool.toString(); + assertThat(str).contains("MCPTool").contains("add").contains(DEFAULT_ENDPOINT); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("JSON serialization and deserialization") + void testJsonSerialization() throws Exception { + ToolMetadata metadata = + new ToolMetadata("multiply", "Multiply two numbers", "{\"type\":\"object\"}"); + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + MCPTool original = new MCPTool(metadata, server); + + ObjectMapper mapper = new ObjectMapper(); + // Configure to ignore unknown properties during deserialization + mapper.configure(FAIL_ON_UNKNOWN_PROPERTIES, false); + + String json = mapper.writeValueAsString(original); + + MCPTool deserialized = mapper.readValue(json, MCPTool.class); + + assertThat(deserialized.getName()).isEqualTo(original.getName()); + assertThat(deserialized.getMetadata()).isEqualTo(original.getMetadata()); + assertThat(deserialized.getMcpServer()).isEqualTo(original.getMcpServer()); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test tool with complex input schema") + void testToolWithComplexSchema() { + String complexSchema = + "{\"type\":\"object\",\"properties\":{" + + "\"param1\":{\"type\":\"string\",\"description\":\"First parameter\"}," + + "\"param2\":{\"type\":\"number\",\"description\":\"Second parameter\"}," + + "\"param3\":{\"type\":\"boolean\",\"description\":\"Third parameter\"}" + + "},\"required\":[\"param1\",\"param2\"]}"; + + ToolMetadata metadata = + new ToolMetadata("complexTool", "Tool with complex schema", complexSchema); + MCPServer server = new MCPServer(DEFAULT_ENDPOINT); + + MCPTool tool = new MCPTool(metadata, server); + + assertThat(tool.getMetadata().getInputSchema()).contains("param1"); + assertThat(tool.getMetadata().getInputSchema()).contains("param2"); + assertThat(tool.getMetadata().getInputSchema()).contains("param3"); + assertThat(tool.getMetadata().getInputSchema()).contains("required"); + } +} diff --git a/integrations/pom.xml b/integrations/pom.xml index 75a5d6695..e8f6176e5 100644 --- a/integrations/pom.xml +++ b/integrations/pom.xml @@ -41,6 +41,7 @@ under the License. chat-models embedding-models vector-stores + mcp diff --git a/plan/pom.xml b/plan/pom.xml index 917a3585e..6089e16a0 100644 --- a/plan/pom.xml +++ b/plan/pom.xml @@ -40,6 +40,12 @@ under the License. flink-agents-api ${project.version} + + org.apache.flink + flink-agents-integrations-mcp + ${project.version} + test + org.apache.flink flink-shaded-jackson @@ -57,8 +63,45 @@ under the License. ${assertj.version} test + + org.testcontainers + testcontainers + 2.0.2 + test + + + + java-11 + + [11,17) + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + org/apache/flink/agents/plan/AgentPlanDeclareMCPServerTest.java + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/AgentPlanDeclareMCPServerTest.java + + + + + + + + diff --git a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java index 88ec9bdce..56adbeecc 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/AgentPlan.java @@ -59,6 +59,10 @@ import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import static org.apache.flink.agents.api.resource.ResourceType.MCP_SERVER; +import static org.apache.flink.agents.api.resource.ResourceType.PROMPT; +import static org.apache.flink.agents.api.resource.ResourceType.TOOL; + /** Agent plan compiled from user defined agent. */ @JsonSerialize(using = AgentPlanJsonSerializer.class) @JsonDeserialize(using = AgentPlanJsonDeserializer.class) @@ -311,12 +315,52 @@ private void extractTool(Method method) throws Exception { FunctionTool tool = new FunctionTool(metadata, javaFunction); JavaSerializableResourceProvider provider = - JavaSerializableResourceProvider.createResourceProvider( - name, ResourceType.TOOL, tool); + JavaSerializableResourceProvider.createResourceProvider(name, TOOL, tool); addResourceProvider(provider); } + private void extractMCPServer(Method method) throws Exception { + // Use reflection to handle MCP classes to support Java 11 without MCP + String name = method.getName(); + Object mcpServer = method.invoke(null); + + addResourceProvider( + JavaSerializableResourceProvider.createResourceProvider( + name, MCP_SERVER, (SerializableResource) mcpServer)); + + // Call listTools() via reflection + Method listToolsMethod = mcpServer.getClass().getMethod("listTools"); + @SuppressWarnings("unchecked") + Iterable tools = + (Iterable) listToolsMethod.invoke(mcpServer); + + for (SerializableResource tool : tools) { + Method getNameMethod = tool.getClass().getMethod("getName"); + String toolName = (String) getNameMethod.invoke(tool); + addResourceProvider( + JavaSerializableResourceProvider.createResourceProvider(toolName, TOOL, tool)); + } + + // Call listPrompts() via reflection + Method listPromptsMethod = mcpServer.getClass().getMethod("listPrompts"); + @SuppressWarnings("unchecked") + Iterable prompts = + (Iterable) listPromptsMethod.invoke(mcpServer); + + for (SerializableResource prompt : prompts) { + Method getNameMethod = prompt.getClass().getMethod("getName"); + String promptName = (String) getNameMethod.invoke(prompt); + addResourceProvider( + JavaSerializableResourceProvider.createResourceProvider( + promptName, PROMPT, prompt)); + } + + // Call close() via reflection + Method closeMethod = mcpServer.getClass().getMethod("close"); + closeMethod.invoke(mcpServer); + } + private void extractResourceProvidersFromAgent(Agent agent) throws Exception { Class agentClass = agent.getClass(); @@ -339,8 +383,7 @@ private void extractResourceProvidersFromAgent(Agent agent) throws Exception { if (fieldValue instanceof Resource) { Resource resource = (Resource) fieldValue; ResourceProvider provider = - createResourceProvider( - resourceName, ResourceType.TOOL, resource, agentClass); + createResourceProvider(resourceName, TOOL, resource, agentClass); addResourceProvider(provider); } } catch (IllegalAccessException e) { @@ -382,7 +425,7 @@ private void extractResourceProvidersFromAgent(Agent agent) throws Exception { JavaSerializableResourceProvider provider = JavaSerializableResourceProvider.createResourceProvider( - promptName, ResourceType.PROMPT, prompt); + promptName, PROMPT, prompt); addResourceProvider(provider); } else if (method.isAnnotationPresent(ChatModelSetup.class)) { @@ -395,6 +438,19 @@ private void extractResourceProvidersFromAgent(Agent agent) throws Exception { extractResource(ResourceType.EMBEDDING_MODEL_CONNECTION, method); } else if (method.isAnnotationPresent(VectorStore.class)) { extractResource(ResourceType.VECTOR_STORE, method); + } else if (Modifier.isStatic(method.getModifiers())) { + // Check for MCPServer annotation using reflection to support Java 11 without MCP + try { + Class mcpServerAnnotation = + Class.forName("org.apache.flink.agents.api.annotation.MCPServer"); + if (method.isAnnotationPresent( + (Class) + mcpServerAnnotation)) { + extractMCPServer(method); + } + } catch (ClassNotFoundException e) { + // MCP annotation not available (Java 11 build), skip MCP processing + } } } @@ -415,17 +471,15 @@ private void extractResourceProvidersFromAgent(Agent agent) throws Exception { } addResourceProvider(provider); } - } else if (type == ResourceType.PROMPT) { + } else if (type == PROMPT) { for (Map.Entry kv : entry.getValue().entrySet()) { JavaSerializableResourceProvider provider = JavaSerializableResourceProvider.createResourceProvider( - kv.getKey(), - ResourceType.PROMPT, - (SerializableResource) kv.getValue()); + kv.getKey(), PROMPT, (SerializableResource) kv.getValue()); addResourceProvider(provider); } - } else if (type == ResourceType.TOOL) { + } else if (type == TOOL) { for (Map.Entry kv : entry.getValue().entrySet()) { extractTool( ((org.apache.flink.agents.api.tools.FunctionTool) kv.getValue()) diff --git a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonPrompt.java b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonPrompt.java index fee43f4bd..2e7965377 100644 --- a/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonPrompt.java +++ b/plan/src/main/java/org/apache/flink/agents/plan/resource/python/PythonPrompt.java @@ -26,10 +26,10 @@ import java.util.Map; /** - * PythonPrompt is a subclass of Prompt that provides a method to parse a Python prompt from a - * serialized map. + * PythonPrompt is a subclass of Prompt.LocalPrompt that provides a method to parse a Python prompt + * from a serialized map. */ -public class PythonPrompt extends Prompt { +public class PythonPrompt extends Prompt.LocalPrompt { public PythonPrompt(String template) { super(template); } diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java index c8c22b9e6..20c996a78 100644 --- a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareChatModelTest.java @@ -111,7 +111,7 @@ void retrieveAndChat() throws Exception { agentPlan.getResource("testChatModel", ResourceType.CHAT_MODEL); assertNotNull(model); - Prompt prompt = new Prompt("Hello world"); + Prompt prompt = Prompt.fromText("Hello world"); ChatMessage reply = model.chat(prompt.formatMessages(MessageRole.USER, new HashMap<>())); assertEquals(MessageRole.ASSISTANT, reply.getRole()); @@ -128,7 +128,7 @@ void jsonRoundTrip() throws Exception { BaseChatModelSetup model = (BaseChatModelSetup) restored.getResource("testChatModel", ResourceType.CHAT_MODEL); ChatMessage reply = - model.chat(new Prompt("Hi").formatMessages(MessageRole.USER, new HashMap<>())); + model.chat(Prompt.fromText("Hi").formatMessages(MessageRole.USER, new HashMap<>())); assertEquals("ok:Hi", reply.getContent()); } diff --git a/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareMCPServerTest.java b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareMCPServerTest.java new file mode 100644 index 000000000..cb8ae9a69 --- /dev/null +++ b/plan/src/test/java/org/apache/flink/agents/plan/AgentPlanDeclareMCPServerTest.java @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.plan; + +import org.apache.flink.agents.api.Agent; +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.annotation.Action; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.api.prompt.Prompt; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.integrations.mcp.MCPPrompt; +import org.apache.flink.agents.integrations.mcp.MCPServer; +import org.apache.flink.agents.integrations.mcp.MCPTool; +import org.apache.flink.agents.plan.resourceprovider.ResourceProvider; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.*; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +import java.io.File; +import java.net.HttpURLConnection; +import java.net.URL; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * Tests for MCP server integration with AgentPlan. + * + *

This test verifies that MCP servers, tools, and prompts are properly discovered and registered + * in the agent plan, following the pattern from {@link AgentPlanDeclareToolMethodTest}. + * + *

Uses the Python MCP server from python/flink_agents/api/tests/mcp/mcp_server.py. + */ +class AgentPlanDeclareMCPServerTest { + + private static Process pythonMcpServerProcess; + private static final String MCP_SERVER_SCRIPT = + "python/flink_agents/api/tests/mcp/mcp_server.py"; + private static final String MCP_ENDPOINT = "http://127.0.0.1:8000/mcp"; + + private AgentPlan agentPlan; + + /** Test agent with MCP server annotation. */ + static class TestMCPAgent extends Agent { + + @org.apache.flink.agents.api.annotation.MCPServer + public static MCPServer testMcpServer() { + return MCPServer.builder(MCP_ENDPOINT).timeout(Duration.ofSeconds(30)).build(); + } + + @Action(listenEvents = {InputEvent.class}) + public void process(Event event, RunnerContext ctx) { + // no-op + } + } + + @BeforeAll + static void beforeAll() throws Exception { + // Get the project root directory + File projectRoot = new File(System.getProperty("user.dir")).getParentFile(); + + // Try to find Python executable (prefer venv if available) + String pythonExecutable = findPythonExecutable(projectRoot); + + // Check if Python 3 is available + boolean pythonAvailable = false; + try { + Process pythonCheck = new ProcessBuilder(pythonExecutable, "--version").start(); + pythonCheck.waitFor(5, TimeUnit.SECONDS); + pythonAvailable = pythonCheck.exitValue() == 0; + } catch (Exception e) { + System.err.println("Python3 not available: " + e.getMessage()); + } + + assumeTrue( + pythonAvailable, + "python3 is not available or not in PATH. Skipping MCP server tests."); + + File mcpServerScript = new File(projectRoot, MCP_SERVER_SCRIPT); + + assumeTrue( + mcpServerScript.exists(), + "MCP server script not found at: " + mcpServerScript.getAbsolutePath()); + + // Start Python MCP server process + ProcessBuilder pb = + new ProcessBuilder(pythonExecutable, mcpServerScript.getAbsolutePath()) + .redirectErrorStream(true); + pythonMcpServerProcess = pb.start(); + + // Wait for server to be ready with health check + boolean serverReady = false; + int maxRetries = 30; // 30 seconds max + for (int i = 0; i < maxRetries; i++) { + if (isServerReady(MCP_ENDPOINT)) { + serverReady = true; + break; + } + Thread.sleep(1000); + } + + if (!serverReady && pythonMcpServerProcess != null) { + pythonMcpServerProcess.destroy(); + } + + assumeTrue( + serverReady, + "MCP server did not start within 30 seconds. " + + "Check that Python dependencies (mcp, dotenv) are installed."); + } + + /** + * Find the Python executable. Prefers venv python if available, otherwise uses system python3. + * + * @param projectRoot The project root directory + * @return Path to python executable + */ + private static String findPythonExecutable(File projectRoot) { + // Try to find venv python first (used in CI and when building locally) + File venvPython = new File(projectRoot, "python/.venv/bin/python3"); + if (venvPython.exists() && venvPython.canExecute()) { + return venvPython.getAbsolutePath(); + } + + // Fallback to system python3 + return "python3"; + } + + /** + * Check if the MCP server is ready by attempting to connect to the endpoint. + * + * @param endpoint The MCP server endpoint + * @return true if server is ready, false otherwise + */ + private static boolean isServerReady(String endpoint) { + try { + URL url = new URL(endpoint); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.setRequestMethod("GET"); + connection.setConnectTimeout(1000); + connection.setReadTimeout(1000); + int responseCode = connection.getResponseCode(); + // MCP server might return 404 or other codes, we just want to know it's responding + return responseCode > 0; + } catch (Exception e) { + // Server not ready yet + return false; + } + } + + @BeforeEach + void setup() throws Exception { + agentPlan = new AgentPlan(new TestMCPAgent()); + } + + @AfterAll + static void afterAll() { + if (pythonMcpServerProcess != null) { + pythonMcpServerProcess.destroy(); + try { + pythonMcpServerProcess.waitFor(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Discover @MCPServer method and register MCP server") + void discoverMCPServer() { + Map> providers = + agentPlan.getResourceProviders(); + assertTrue(providers.containsKey(ResourceType.MCP_SERVER)); + Map mcpServerProviders = providers.get(ResourceType.MCP_SERVER); + assertTrue(mcpServerProviders.containsKey("testMcpServer")); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Discover and register tools from MCP server") + void discoverToolsFromMCPServer() { + Map> providers = + agentPlan.getResourceProviders(); + assertTrue(providers.containsKey(ResourceType.TOOL)); + + Map toolProviders = providers.get(ResourceType.TOOL); + assertTrue(toolProviders.containsKey("add"), "add tool should be discovered"); + assertEquals(1, toolProviders.size(), "Should have exactly 1 tool from Python server"); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Discover and register prompts from MCP server") + void discoverPromptsFromMCPServer() { + Map> providers = + agentPlan.getResourceProviders(); + assertTrue(providers.containsKey(ResourceType.PROMPT)); + + Map promptProviders = providers.get(ResourceType.PROMPT); + assertTrue(promptProviders.containsKey("ask_sum"), "ask_sum prompt should be discovered"); + assertEquals(1, promptProviders.size(), "Should have exactly 1 prompt from Python server"); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Retrieve MCP tool from AgentPlan - add tool") + void retrieveMCPToolAdd() throws Exception { + Tool tool = (Tool) agentPlan.getResource("add", ResourceType.TOOL); + assertNotNull(tool); + assertInstanceOf(MCPTool.class, tool); + + MCPTool mcpTool = (MCPTool) tool; + assertEquals("add", mcpTool.getName()); + // Verify description starts with expected text + assertTrue( + mcpTool.getMetadata() + .getDescription() + .startsWith("Get the detailed information of a specified IP address."), + "Description should start with expected text"); + // Verify input schema contains expected parameters + String schema = mcpTool.getMetadata().getInputSchema(); + assertTrue(schema.contains("a"), "Schema should contain parameter 'a'"); + assertTrue(schema.contains("b"), "Schema should contain parameter 'b'"); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Retrieve MCP prompt from AgentPlan - ask_sum") + void retrieveMCPPromptAskSum() throws Exception { + Prompt prompt = (Prompt) agentPlan.getResource("ask_sum", ResourceType.PROMPT); + assertNotNull(prompt); + assertInstanceOf(MCPPrompt.class, prompt); + + MCPPrompt mcpPrompt = (MCPPrompt) prompt; + assertEquals("ask_sum", mcpPrompt.getName()); + assertEquals("Prompt of add tool.", mcpPrompt.getDescription()); + // ask_sum prompt should have 'a' and 'b' as arguments + Map args = mcpPrompt.getPromptArguments(); + assertTrue(args.containsKey("a"), "Should have 'a' argument"); + assertTrue(args.containsKey("b"), "Should have 'b' argument"); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("AgentPlan JSON serialization with MCP resources") + void testAgentPlanJsonSerializableWithMCP() throws Exception { + ObjectMapper mapper = new ObjectMapper(); + String json = mapper.writeValueAsString(agentPlan); + + // Verify JSON contains MCP resources + assertTrue(json.contains("add"), "JSON should contain add tool"); + assertTrue(json.contains("ask_sum"), "JSON should contain ask_sum prompt"); + assertTrue(json.contains("mcp_server"), "JSON should contain mcp_server type"); + + // Verify serialization works without errors + assertNotNull(json); + assertFalse(json.isEmpty()); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test MCP server is closed after discovery") + void testMCPServerClosedAfterDiscovery() throws Exception { + // The MCPServer.close() should be called after listTools() and listPrompts() + // We verify this indirectly by checking that the plan was created successfully + assertNotNull(agentPlan); + assertTrue(agentPlan.getResourceProviders().containsKey(ResourceType.MCP_SERVER)); + assertTrue(agentPlan.getResourceProviders().containsKey(ResourceType.TOOL)); + assertTrue(agentPlan.getResourceProviders().containsKey(ResourceType.PROMPT)); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Test metadata from MCP tool - add") + void testMCPToolMetadata() throws Exception { + Tool tool = (Tool) agentPlan.getResource("add", ResourceType.TOOL); + ToolMetadata metadata = tool.getMetadata(); + + assertEquals("add", metadata.getName()); + // Verify description starts with expected text (full docstring includes Args/Returns) + assertTrue( + metadata.getDescription() + .startsWith("Get the detailed information of a specified IP address."), + "Description should start with expected text"); + assertNotNull(metadata.getInputSchema()); + + String schema = metadata.getInputSchema(); + // Verify the tool has expected parameters + assertTrue(schema.contains("a"), "Schema should contain 'a' parameter"); + assertTrue(schema.contains("b"), "Schema should contain 'b' parameter"); + } +} diff --git a/pom.xml b/pom.xml index baa903dc3..813e0e57d 100644 --- a/pom.xml +++ b/pom.xml @@ -36,7 +36,7 @@ under the License. UTF-8 - 11 + 17 ${target.java.version} ${target.java.version} 2.27.1 @@ -87,6 +87,30 @@ under the License. + + java-11 + + [11,17) + + + 11 + 11 + 11 + + + + + java-17 + + [17,) + + + 17 + 17 + 17 + + + release @@ -238,7 +262,7 @@ under the License. ${spotless.skip} - 1.7 + 1.15.0 diff --git a/python/flink_agents/api/tests/mcp/mcp_server.py b/python/flink_agents/api/tests/mcp/mcp_server.py index 4cba1edbf..2e295d0b7 100644 --- a/python/flink_agents/api/tests/mcp/mcp_server.py +++ b/python/flink_agents/api/tests/mcp/mcp_server.py @@ -16,10 +16,14 @@ # limitations under the License. ################################################################################# -import dotenv -from mcp.server.fastmcp import FastMCP +try: + import dotenv + dotenv.load_dotenv() +except ImportError: + # dotenv is optional for this test server + pass -dotenv.load_dotenv() +from mcp.server.fastmcp import FastMCP # Create MCP server mcp = FastMCP("BasicServer") diff --git a/runtime/pom.xml b/runtime/pom.xml index 4a210c507..651161aa1 100644 --- a/runtime/pom.xml +++ b/runtime/pom.xml @@ -133,5 +133,4 @@ under the License. ${log4j2.version} - \ No newline at end of file