Skip to content

Commit

Permalink
core: improve feature methods params validation
Browse files Browse the repository at this point in the history
- also introduce RequestUri
  • Loading branch information
mkouba committed Jan 21, 2025
1 parent 7bfc57f commit b9dd528
Show file tree
Hide file tree
Showing 20 changed files with 131 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.quarkiverse.mcp.server.PromptMessage;
import io.quarkiverse.mcp.server.PromptResponse;
import io.quarkiverse.mcp.server.RequestId;
import io.quarkiverse.mcp.server.RequestUri;
import io.quarkiverse.mcp.server.Resource;
import io.quarkiverse.mcp.server.ResourceContent;
import io.quarkiverse.mcp.server.ResourceContents;
Expand Down Expand Up @@ -54,6 +55,7 @@ class DotNames {
static final DotName MCP_CONNECTION = DotName.createSimple(McpConnection.class);
static final DotName MCP_LOG = DotName.createSimple(McpLog.class);
static final DotName REQUEST_ID = DotName.createSimple(RequestId.class);
static final DotName REQUEST_URI = DotName.createSimple(RequestUri.class);
static final DotName CONTENT = DotName.createSimple(Content.class);
static final DotName TEXT_CONTENT = DotName.createSimple(TextContent.class);
static final DotName IMAGE_CONTENT = DotName.createSimple(ImageContent.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import io.quarkiverse.mcp.server.runtime.ResourceManager;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateCompleteManager;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateManager;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateManager.VariableMatcher;
import io.quarkiverse.mcp.server.runtime.ResultMappers;
import io.quarkiverse.mcp.server.runtime.ToolManager;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
Expand Down Expand Up @@ -125,7 +126,7 @@ void collectFeatureMethods(BeanDiscoveryFinishedBuildItem beanDiscovery, Invoker
AnnotationInstance featureAnnotation = getFeatureAnnotation(method);
if (featureAnnotation != null) {
Feature feature = getFeature(featureAnnotation);
validateFeatureMethod(method, feature);
validateFeatureMethod(method, feature, featureAnnotation);
String name;
if (feature == PROMPT_COMPLETE || feature == RESOURCE_TEMPLATE_COMPLETE) {
name = featureAnnotation.value().asString();
Expand Down Expand Up @@ -364,9 +365,11 @@ void registerForReflection(List<FeatureMethodBuildItem> featureMethods,
for (FeatureMethodBuildItem m : featureMethods) {
for (org.jboss.jandex.Type paramType : m.getMethod().parameterTypes()) {
if (paramType.kind() == Kind.PRIMITIVE
|| paramType.name().equals(DotNames.STRING)
|| paramType.name().equals(DotNames.MCP_CONNECTION)
|| paramType.name().equals(DotNames.MCP_LOG)
|| paramType.name().equals(DotNames.REQUEST_ID)) {
|| paramType.name().equals(DotNames.REQUEST_ID)
|| paramType.name().equals(DotNames.REQUEST_URI)) {
continue;
}
reflectiveHierarchies.produce(ReflectiveHierarchyBuildItem.builder(paramType).build());
Expand All @@ -380,7 +383,7 @@ void registerForReflection(List<FeatureMethodBuildItem> featureMethods,
reflectiveHierarchies.produce(ReflectiveHierarchyBuildItem.builder(Map.class).build());
}

private void validateFeatureMethod(MethodInfo method, Feature feature) {
private void validateFeatureMethod(MethodInfo method, Feature feature, AnnotationInstance featureAnnotation) {
if (Modifier.isStatic(method.flags())) {
throw new IllegalStateException("MCP feature method must not be static: " + method);
}
Expand All @@ -392,7 +395,7 @@ private void validateFeatureMethod(MethodInfo method, Feature feature) {
case PROMPT_COMPLETE -> validatePromptCompleteMethod(method);
case TOOL -> validateToolMethod(method);
case RESOURCE -> validateResourceMethod(method);
case RESOURCE_TEMPLATE -> validateResourceTemplateMethod(method);
case RESOURCE_TEMPLATE -> validateResourceTemplateMethod(method, featureAnnotation);
case RESOURCE_TEMPLATE_COMPLETE -> validateResourceTemplateCompleteMethod(method);
default -> throw new IllegalArgumentException("Unsupported feature: " + feature);
}
Expand All @@ -413,11 +416,10 @@ private void validatePromptMethod(MethodInfo method) {
throw new IllegalStateException("Unsupported Prompt method return type: " + method);
}

List<MethodParameterInfo> arguments = method.parameters().stream()
.filter(p -> providerFrom(p.type()) == Provider.PARAMS).toList();
for (MethodParameterInfo arg : arguments) {
if (!arg.type().name().equals(DotNames.STRING)) {
throw new IllegalStateException("Prompt method must only consume String arguments: " + method);
List<MethodParameterInfo> parameters = parameters(method);
for (MethodParameterInfo param : parameters) {
if (!param.type().name().equals(DotNames.STRING)) {
throw new IllegalStateException("Prompt method must only consume String parameters: " + method);
}
}
}
Expand All @@ -437,9 +439,8 @@ private void validatePromptCompleteMethod(MethodInfo method) {
throw new IllegalStateException("Unsupported Prompt complete method return type: " + method);
}

List<MethodParameterInfo> arguments = method.parameters().stream()
.filter(p -> providerFrom(p.type()) == Provider.PARAMS).toList();
if (arguments.size() != 1 || !arguments.get(0).type().name().equals(DotNames.STRING)) {
List<MethodParameterInfo> parameters = parameters(method);
if (parameters.size() != 1 || !parameters.get(0).type().name().equals(DotNames.STRING)) {
throw new IllegalStateException("Prompt complete must consume exactly one String argument: " + method);
}
}
Expand All @@ -456,9 +457,8 @@ private void validateResourceTemplateCompleteMethod(MethodInfo method) {
throw new IllegalStateException("Unsupported Resource template complete method return type: " + method);
}

List<MethodParameterInfo> arguments = method.parameters().stream()
.filter(p -> providerFrom(p.type()) == Provider.PARAMS).toList();
if (arguments.size() != 1 || !arguments.get(0).type().name().equals(DotNames.STRING)) {
List<MethodParameterInfo> parameters = parameters(method);
if (parameters.size() != 1 || !parameters.get(0).type().name().equals(DotNames.STRING)) {
throw new IllegalStateException("Resource template complete must consume exactly one String argument: " + method);
}
}
Expand Down Expand Up @@ -496,17 +496,19 @@ private void validateResourceMethod(MethodInfo method) {
if (!RESOURCE_TYPES.contains(type)) {
throw new IllegalStateException("Unsupported Resource method return type: " + method);
}
if (method.parametersCount() > 1
|| (method.parametersCount() == 1
&& !method.parameterName(0).equals("uri")
&& !method.parameterType(0).name().equals(DotNames.STRING))) {

List<MethodParameterInfo> parameters = parameters(method);
if (parameters.size() > 1
|| parameters.size() == 1
&& !parameters.get(0).name().equals("uri")
&& !parameters.get(0).type().name().equals(DotNames.STRING)) {
throw new IllegalStateException(
"Resource method may accept zero paramateres or a single parameter of name 'uri' and type String: "
+ method);
}
}

private void validateResourceTemplateMethod(MethodInfo method) {
private void validateResourceTemplateMethod(MethodInfo method, AnnotationInstance featureAnnotation) {
org.jboss.jandex.Type type = method.returnType();
if (DotNames.UNI.equals(type.name()) && type.kind() == Kind.PARAMETERIZED_TYPE) {
type = type.asParameterizedType().arguments().get(0);
Expand All @@ -517,7 +519,28 @@ private void validateResourceTemplateMethod(MethodInfo method) {
if (!RESOURCE_TYPES.contains(type)) {
throw new IllegalStateException("Unsupported Resource template method return type: " + method);
}
// TODO validate params

AnnotationValue uriTemplateValue = featureAnnotation.value("uriTemplate");
if (uriTemplateValue == null) {
throw new IllegalStateException("URI template not found");
}
VariableMatcher variableMatcher = ResourceTemplateManager.createMatcherFromUriTemplate(uriTemplateValue.asString());

List<MethodParameterInfo> parameters = parameters(method);
for (MethodParameterInfo param : parameters) {
if (!param.type().name().equals(DotNames.STRING)) {
throw new IllegalStateException("Resource template method must only consume String parameters: " + method);
}
if (!variableMatcher.variables().contains(param.name())) {
throw new IllegalStateException(
"Parameter [" + param.name() + "] does not match an URI template variable: " + method);
}
}
}

private List<MethodParameterInfo> parameters(MethodInfo method) {
return method.parameters().stream()
.filter(p -> providerFrom(p.type()) == Provider.PARAMS).toList();
}

private boolean hasFeatureMethod(BeanInfo bean) {
Expand Down Expand Up @@ -595,6 +618,8 @@ private FeatureArgument.Provider providerFrom(org.jboss.jandex.Type type) {
return FeatureArgument.Provider.REQUEST_ID;
} else if (type.name().equals(DotNames.MCP_LOG)) {
return FeatureArgument.Provider.MCP_LOG;
} else if (type.name().equals(DotNames.REQUEST_URI)) {
return FeatureArgument.Provider.REQUEST_URI;
} else {
return FeatureArgument.Provider.PARAMS;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.quarkiverse.mcp.server;

/**
* Resource and resource template methods may accept the requested URI.
*
* @see Resource
* @see ResourceTemplate
*/
public record RequestUri(String value) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import io.quarkiverse.mcp.server.McpConnection;

public record ArgumentProviders(Map<String, Object> args, McpConnection connection, Object requestId, Responder responder) {
public record ArgumentProviders(Map<String, Object> args, McpConnection connection, Object requestId, String uri,
Responder responder) {

Object getArg(String name) {
return args != null ? args.get(name) : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void complete(Object id, JsonObject ref, JsonObject argument, Responder responde
String key = referenceName + "_" + argumentName;

ArgumentProviders argProviders = new ArgumentProviders(
Map.of(argumentName, argument.getString("value")), connection, id, responder);
Map.of(argumentName, argument.getString("value")), connection, id, null, responder);

try {
Future<CompletionResponse> fu = execute(key, argProviders);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public boolean isParam() {
public enum Provider {
PARAMS,
REQUEST_ID,
REQUEST_URI,
MCP_CONNECTION,
MCP_LOG
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;

import io.quarkiverse.mcp.server.RequestId;
import io.quarkiverse.mcp.server.RequestUri;
import io.quarkiverse.mcp.server.runtime.FeatureArgument.Provider;
import io.quarkus.arc.Arc;
import io.quarkus.arc.ManagedContext;
Expand Down Expand Up @@ -79,6 +80,8 @@ protected Object[] prepareArguments(FeatureMetadata<?> metadata, ArgumentProvide
ret[idx] = argProviders.connection();
} else if (arg.provider() == Provider.REQUEST_ID) {
ret[idx] = new RequestId(argProviders.requestId());
} else if (arg.provider() == Provider.REQUEST_URI) {
ret[idx] = new RequestUri(argProviders.uri());
} else if (arg.provider() == Provider.MCP_LOG) {
ret[idx] = logs.computeIfAbsent(logKey(metadata),
key -> new McpLogImpl(argProviders.connection()::logLevel, metadata.info().declaringClassName(), key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void promptsGet(JsonObject message, Responder responder, McpConnection connectio
LOG.debugf("Get prompt %s [id: %s]", promptName, id);

ArgumentProviders argProviders = new ArgumentProviders(params.getJsonObject("arguments").getMap(), connection, id,
responder);
null, responder);

try {
Future<PromptResponse> fu = manager.execute(promptName, argProviders);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ protected Object[] prepareArguments(FeatureMetadata<?> metadata, ArgumentProvide
if (metadata.feature() == Feature.RESOURCE_TEMPLATE) {
// Use variable matching to extract method arguments
Map<String, Object> matchedVariables = resourceTemplateManager.getVariableMatcher(metadata.info().name())
.matchVariables(argProviders.args().get("uri").toString());
matchedVariables.putIfAbsent("uri", argProviders.args().get("uri"));
.matchVariables(argProviders.uri());
argProviders = new ArgumentProviders(
matchedVariables, argProviders.connection(), argProviders.requestId(), argProviders.responder());
matchedVariables, argProviders.connection(), argProviders.requestId(), argProviders.uri(),
argProviders.responder());
}
return super.prepareArguments(metadata, argProviders);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void resourcesRead(JsonObject message, Responder responder, McpConnection connec
}
LOG.debugf("Read resource %s [id: %s]", resourceUri, id);

ArgumentProviders argProviders = new ArgumentProviders(Map.of("uri", resourceUri), connection, id, responder);
ArgumentProviders argProviders = new ArgumentProviders(Map.of(), connection, id, resourceUri, responder);

try {
Future<ResourceResponse> fu = manager.execute(resourceUri, argProviders);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ protected McpException notFound(String id) {
return new McpException("Invalid resource uri: " + id, JsonRPC.RESOURCE_NOT_FOUND);
}

static VariableMatcher createMatcherFromUriTemplate(String uriTemplate) {
public static VariableMatcher createMatcherFromUriTemplate(String uriTemplate) {
// Find variables
List<String> variables = new ArrayList<>();
Matcher m = Pattern.compile("\\{(\\w+)\\}").matcher(uriTemplate);
Expand All @@ -71,7 +71,7 @@ static VariableMatcher createMatcherFromUriTemplate(String uriTemplate) {
record ResourceTemplateMetadata(VariableMatcher variableMatcher, FeatureMetadata<ResourceResponse> metadata) {
}

record VariableMatcher(Pattern pattern, List<String> variables) {
public record VariableMatcher(Pattern pattern, List<String> variables) {

boolean matches(String uri) {
return pattern.matcher(uri).matches();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void toolsCall(JsonObject message, Responder responder, McpConnection connection
LOG.debugf("Call tool %s [id: %s]", toolName, id);

ArgumentProviders argProviders = new ArgumentProviders(params.getJsonObject("arguments").getMap(), connection, id,
responder);
null, responder);

try {
Future<ToolResponse> fu = manager.execute(toolName, argProviders);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ ifndef::add-copy-button-to-env-var[]
Environment variable: `+++QUARKUS_CLIENT_LOGGING_DEFAULT_LEVEL+++`
endif::add-copy-button-to-env-var[]
--
a|`alert`, `critical`, `debug`, `emergency`, `error`, `info`, `notice`, `warning`
a|`debug`, `info`, `notice`, `warning`, `error`, `critical`, `alert`, `emergency`
|`info`

a|icon:lock[title=Fixed at build time] [[quarkus-mcp-server-core_quarkus-auto-ping-interval]] [.property-path]##link:#quarkus-mcp-server-core_quarkus-auto-ping-interval[`quarkus.auto-ping-interval`]##
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import io.quarkiverse.mcp.server.CompleteArg;
import io.quarkiverse.mcp.server.CompleteResourceTemplate;
import io.quarkiverse.mcp.server.RequestUri;
import io.quarkiverse.mcp.server.ResourceTemplate;
import io.quarkiverse.mcp.server.TextResourceContents;
import io.quarkus.logging.Log;
Expand All @@ -13,8 +14,8 @@ public class MyResourceTemplates {
static final List<String> NAMES = List.of("Martin", "Lu", "Jachym", "Vojtik", "Onda");

@ResourceTemplate(uriTemplate = "file:///{foo}/{bar}")
TextResourceContents foo_template(String foo, String bar, String uri) {
return TextResourceContents.create(uri, foo + ":" + bar);
TextResourceContents foo_template(String foo, String bar, RequestUri uri) {
return TextResourceContents.create(uri.value(), foo + ":" + bar);
}

@CompleteResourceTemplate("foo_template")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.util.List;

import io.quarkiverse.mcp.server.RequestUri;
import io.quarkiverse.mcp.server.Resource;
import io.quarkiverse.mcp.server.ResourceResponse;
import io.quarkiverse.mcp.server.TextResourceContents;
Expand All @@ -14,27 +15,27 @@
public class MyResources {

@Resource(uri = "file:///project/alpha")
ResourceResponse alpha(String uri) {
ResourceResponse alpha(RequestUri uri) {
checkExecutionModel(true);
checkDuplicatedContext();
checkRequestContext();
return new ResourceResponse(List.of(new TextResourceContents(uri, "1", null)));
return new ResourceResponse(List.of(new TextResourceContents(uri.value(), "1", null)));
}

@Resource(uri = "file:///project/uni_alpha")
Uni<ResourceResponse> uni_alpha(String uri) {
Uni<ResourceResponse> uni_alpha(RequestUri uri) {
checkExecutionModel(false);
checkDuplicatedContext();
checkRequestContext();
return Uni.createFrom().item(new ResourceResponse(List.of(new TextResourceContents(uri, "2", null))));
return Uni.createFrom().item(new ResourceResponse(List.of(new TextResourceContents(uri.value(), "2", null))));
}

@Resource(uri = "file:///project/bravo")
TextResourceContents bravo(String uri) {
TextResourceContents bravo(RequestUri uri) {
checkExecutionModel(true);
checkDuplicatedContext();
checkRequestContext();
return new TextResourceContents(uri, "3", null);
return new TextResourceContents(uri.value(), "3", null);
}

@Resource(uri = "file:///project/uni_bravo")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkiverse.mcp.server.RequestUri;
import io.quarkiverse.mcp.server.Resource;
import io.quarkiverse.mcp.server.ResourceResponse;
import io.quarkiverse.mcp.server.runtime.JsonRPC;
Expand Down Expand Up @@ -36,7 +37,7 @@ public void testError() throws URISyntaxException {
public static class MyResources {

@Resource(uri = "file:///project/alpha")
ResourceResponse alpha(String uri) {
ResourceResponse alpha(RequestUri uri) {
throw new NullPointerException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.util.List;

import io.quarkiverse.mcp.server.RequestUri;
import io.quarkiverse.mcp.server.ResourceResponse;
import io.quarkiverse.mcp.server.ResourceTemplate;
import io.quarkiverse.mcp.server.TextResourceContents;
Expand All @@ -22,11 +23,11 @@ ResourceResponse alpha(String path) {
}

@ResourceTemplate(uriTemplate = "file:///{foo}/{bar}")
TextResourceContents bravo(String foo, String bar, String uri) {
TextResourceContents bravo(String foo, String bar, RequestUri uri) {
checkExecutionModel(true);
checkDuplicatedContext();
checkRequestContext();
return TextResourceContents.create(uri, foo + ":" + bar);
return TextResourceContents.create(uri.value(), foo + ":" + bar);
}

}
Loading

0 comments on commit b9dd528

Please sign in to comment.