diff --git a/api/kotlin-sdk.api b/api/kotlin-sdk.api index 9a6ef950..ec254cd9 100644 --- a/api/kotlin-sdk.api +++ b/api/kotlin-sdk.api @@ -2688,6 +2688,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/WithMeta$Companion { public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextprotocol/kotlin/sdk/shared/Protocol { public fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;)V public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun addRoot (Ljava/lang/String;Ljava/lang/String;)V + public final fun addRoots (Ljava/util/List;)V protected final fun assertCapability (Ljava/lang/String;Ljava/lang/String;)V protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V @@ -2715,6 +2717,8 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp public static synthetic fun ping$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public final fun readResource (Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun readResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + public final fun removeRoot (Ljava/lang/String;)Z + public final fun removeRoots (Ljava/util/List;)I public final fun sendRootsListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun setLoggingLevel (Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun setLoggingLevel$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 1b70daca..71b61ba7 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -1,5 +1,6 @@ package io.modelcontextprotocol.kotlin.sdk.client +import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.CallToolRequest import io.modelcontextprotocol.kotlin.sdk.CallToolResult import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase @@ -21,6 +22,7 @@ import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult +import io.modelcontextprotocol.kotlin.sdk.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest import io.modelcontextprotocol.kotlin.sdk.ListToolsResult import io.modelcontextprotocol.kotlin.sdk.LoggingLevel @@ -29,6 +31,7 @@ import io.modelcontextprotocol.kotlin.sdk.Method import io.modelcontextprotocol.kotlin.sdk.PingRequest import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.Root import io.modelcontextprotocol.kotlin.sdk.RootsListChangedNotification import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities @@ -44,6 +47,8 @@ import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlin.coroutines.cancellation.CancellationException +private val logger = KotlinLogging.logger {} + /** * Options for configuring the MCP client. * @@ -89,6 +94,19 @@ public open class Client( private val capabilities: ClientCapabilities = options.capabilities + private val roots = mutableMapOf() + + init { + logger.debug { "Initializing MCP client with capabilities: $capabilities" } + + // Internal handlers for roots + if (capabilities.roots != null) { + setRequestHandler(Method.Defined.RootsList) { _, _ -> + handleListRoots() + } + } + } + protected fun assertCapability(capability: String, method: String) { val caps = serverCapabilities val hasCapability = when (capability) { @@ -449,6 +467,97 @@ public open class Client( return request(request, options) } + /** + * Registers a single root. + * + * @param uri The URI of the root. + * @param name A human-readable name for the root. + * @throws IllegalStateException If the client does not support roots. + */ + public fun addRoot( + uri: String, + name: String, + ) { + if (capabilities.roots == null) { + logger.error { "Failed to add root '$name': Client does not support roots capability" } + throw IllegalStateException("Client does not support roots capability.") + } + logger.info { "Adding root: $name ($uri)" } + roots[uri] = Root(uri, name) + } + + /** + * Registers multiple roots at once. + * + * @param rootsToAdd A list of [Root] objects to register. + * @throws IllegalStateException If the client does not support roots. + */ + public fun addRoots(rootsToAdd: List) { + if (capabilities.roots == null) { + logger.error { "Failed to add roots: Client does not support roots capability" } + throw IllegalStateException("Client does not support roots capability.") + } + logger.info { "Adding ${rootsToAdd.size} roots" } + for (r in rootsToAdd) { + logger.info { "Adding root: ${r.name} (${r.uri})" } + roots[r.uri] = r + } + } + + /** + * Removes a single root by URI. + * + * @param uri The URI of the root to remove. + * @return True if the root was removed, false if it wasn't found. + * @throws IllegalStateException If the client does not support roots. + */ + public fun removeRoot(uri: String): Boolean { + if (capabilities.roots == null) { + logger.error { "Failed to remove root '$uri': Client does not support roots capability" } + throw IllegalStateException("Client does not support roots capability.") + } + logger.info { "Removing root: $uri" } + val removed = roots.remove(uri) != null + logger.debug { + if (removed) { + "Root removed: $uri" + } else { + "Root not found: $uri" + } + } + return removed + } + + /** + * Removes multiple roots at once. + * + * @param uris A list of root URIs to remove. + * @return The number of roots that were successfully removed. + * @throws IllegalStateException If the client does not support roots. + */ + public fun removeRoots(uris: List): Int { + if (capabilities.roots == null) { + logger.error { "Failed to remove roots: Client does not support roots capability" } + throw IllegalStateException("Client does not support roots capability.") + } + logger.info { "Removing ${uris.size} roots" } + var removedCount = 0 + for (uri in uris) { + logger.debug { "Removing root: $uri" } + if (roots.remove(uri) != null) { + removedCount++ + } + } + logger.info { + if (removedCount > 0) { + "Removed $removedCount roots" + } else { + "No roots were removed" + } + } + return removedCount + } + /** * Notifies the server that the list of roots has changed. * Typically used if the client is managing some form of hierarchical structure. @@ -458,4 +567,11 @@ public open class Client( public suspend fun sendRootsListChanged() { notification(RootsListChangedNotification()) } + + // --- Internal Handlers --- + + private suspend fun handleListRoots(): ListRootsResult { + val rootList = roots.values.toList() + return ListRootsResult(rootList) + } } diff --git a/src/jvmTest/kotlin/client/ClientTest.kt b/src/jvmTest/kotlin/client/ClientTest.kt index d63f27ed..c1251baa 100644 --- a/src/jvmTest/kotlin/client/ClientTest.kt +++ b/src/jvmTest/kotlin/client/ClientTest.kt @@ -1,12 +1,12 @@ package client +import io.mockk.coEvery +import io.mockk.spyk import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.mockk.coEvery -import io.mockk.spyk import io.modelcontextprotocol.kotlin.sdk.InMemoryTransport import io.modelcontextprotocol.kotlin.sdk.InitializeRequest import io.modelcontextprotocol.kotlin.sdk.InitializeResult @@ -23,10 +23,17 @@ import io.modelcontextprotocol.kotlin.sdk.LoggingLevel import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification import io.modelcontextprotocol.kotlin.sdk.Method import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.Root +import io.modelcontextprotocol.kotlin.sdk.RootsListChangedNotification import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.TextContent import io.modelcontextprotocol.kotlin.sdk.Tool +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.cancel @@ -35,13 +42,9 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put -import io.modelcontextprotocol.kotlin.sdk.client.Client -import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions import org.junit.jupiter.api.Test -import io.modelcontextprotocol.kotlin.sdk.server.Server -import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions -import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import org.junit.jupiter.api.assertInstanceOf +import org.junit.jupiter.api.assertThrows import kotlin.coroutines.cancellation.CancellationException import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -628,4 +631,168 @@ class ClientTest { assertEquals(null, receivedAsResponse.error) } + @Test + fun `listRoots returns list of roots`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(null) + ) + ) + ) + + val clientRoots = listOf( + Root(uri = "file:///test-root", name = "testRoot") + ) + + client.addRoots(clientRoots) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val server = Server( + serverInfo = Implementation(name = "test server", version = "1.0"), + options = ServerOptions( + capabilities = ServerCapabilities() + ) + ) + + listOf( + launch { client.connect(clientTransport) }, + launch { server.connect(serverTransport) } + ).joinAll() + + val clientCapabilities = server.clientCapabilities + assertEquals(ClientCapabilities.Roots(null), clientCapabilities?.roots) + + val listRootsResult = server.listRoots() + + assertEquals(listRootsResult.roots, clientRoots) + } + + @Test + fun `addRoot should throw when roots capability is not supported`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities() + ) + ) + + // Verify that adding a root throws an exception + val exception = assertThrows { + client.addRoot(uri = "file:///test-root1", name = "testRoot1") + } + assertEquals("Client does not support roots capability.", exception.message) + } + + @Test + fun `removeRoot should throw when roots capability is not supported`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities() + ) + ) + + // Verify that removing a root throws an exception + val exception = assertThrows { + client.removeRoot(uri = "file:///test-root1") + } + assertEquals("Client does not support roots capability.", exception.message) + } + + @Test + fun `removeRoot should remove a root`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(null) + ) + ) + ) + + // Add some roots + client.addRoots( + listOf( + Root(uri = "file:///test-root1", name = "testRoot1"), + Root(uri = "file:///test-root2", name = "testRoot2"), + ) + ) + + // Remove a root + val result = client.removeRoot("file:///test-root1") + + // Verify the root was removed + assertTrue(result, "Root should be removed successfully") + } + + @Test + fun `removeRoots should remove multiple roots`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(null) + ) + ) + ) + + // Add some roots + client.addRoots( + listOf( + Root(uri = "file:///test-root1", name = "testRoot1"), + Root(uri = "file:///test-root2", name = "testRoot2"), + ) + ) + + // Remove multiple roots + val result = client.removeRoots( + listOf("file:///test-root1", "file:///test-root2") + ) + + // Verify the root was removed + assertEquals(2, result, "Both roots should be removed") + } + + @Test + fun `sendRootsListChanged should notify server`() = runTest { + val client = Client( + Implementation(name = "test client", version = "1.0"), + ClientOptions( + capabilities = ClientCapabilities( + roots = ClientCapabilities.Roots(listChanged = true) + ) + ) + ) + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val server = Server( + serverInfo = Implementation(name = "test server", version = "1.0"), + options = ServerOptions( + capabilities = ServerCapabilities() + ) + ) + + // Track notifications + var rootListChangedNotificationReceived = false + server.setNotificationHandler(Method.Defined.NotificationsRootsListChanged) { + rootListChangedNotificationReceived = true + CompletableDeferred(Unit) + } + + listOf( + launch { client.connect(clientTransport) }, + launch { server.connect(serverTransport) } + ).joinAll() + + client.sendRootsListChanged() + + assertTrue( + rootListChangedNotificationReceived, + "Notification should be sent when sendRootsListChanged is called" + ) + } }