Skip to content

Commit a464ea3

Browse files
authored
feat(client): add roots addition/removal API and listRoots handler (#118)
1 parent d79569d commit a464ea3

File tree

3 files changed

+294
-7
lines changed

3 files changed

+294
-7
lines changed

api/kotlin-sdk.api

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2688,6 +2688,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/WithMeta$Companion {
26882688
public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextprotocol/kotlin/sdk/shared/Protocol {
26892689
public fun <init> (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;)V
26902690
public synthetic fun <init> (Lio/modelcontextprotocol/kotlin/sdk/Implementation;Lio/modelcontextprotocol/kotlin/sdk/client/ClientOptions;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
2691+
public final fun addRoot (Ljava/lang/String;Ljava/lang/String;)V
2692+
public final fun addRoots (Ljava/util/List;)V
26912693
protected final fun assertCapability (Ljava/lang/String;Ljava/lang/String;)V
26922694
protected fun assertCapabilityForMethod (Lio/modelcontextprotocol/kotlin/sdk/Method;)V
26932695
protected fun assertNotificationCapability (Lio/modelcontextprotocol/kotlin/sdk/Method;)V
@@ -2715,6 +2717,8 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp
27152717
public static synthetic fun ping$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
27162718
public final fun readResource (Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
27172719
public static synthetic fun readResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/ReadResourceRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
2720+
public final fun removeRoot (Ljava/lang/String;)Z
2721+
public final fun removeRoots (Ljava/util/List;)I
27182722
public final fun sendRootsListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
27192723
public final fun setLoggingLevel (Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
27202724
public static synthetic fun setLoggingLevel$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3+
import io.github.oshai.kotlinlogging.KotlinLogging
34
import io.modelcontextprotocol.kotlin.sdk.CallToolRequest
45
import io.modelcontextprotocol.kotlin.sdk.CallToolResult
56
import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase
@@ -21,6 +22,7 @@ import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest
2122
import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult
2223
import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest
2324
import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult
25+
import io.modelcontextprotocol.kotlin.sdk.ListRootsResult
2426
import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest
2527
import io.modelcontextprotocol.kotlin.sdk.ListToolsResult
2628
import io.modelcontextprotocol.kotlin.sdk.LoggingLevel
@@ -29,6 +31,7 @@ import io.modelcontextprotocol.kotlin.sdk.Method
2931
import io.modelcontextprotocol.kotlin.sdk.PingRequest
3032
import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest
3133
import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult
34+
import io.modelcontextprotocol.kotlin.sdk.Root
3235
import io.modelcontextprotocol.kotlin.sdk.RootsListChangedNotification
3336
import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS
3437
import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities
@@ -44,6 +47,8 @@ import kotlinx.serialization.json.JsonObject
4447
import kotlinx.serialization.json.JsonPrimitive
4548
import kotlin.coroutines.cancellation.CancellationException
4649

50+
private val logger = KotlinLogging.logger {}
51+
4752
/**
4853
* Options for configuring the MCP client.
4954
*
@@ -89,6 +94,19 @@ public open class Client(
8994

9095
private val capabilities: ClientCapabilities = options.capabilities
9196

97+
private val roots = mutableMapOf<String, Root>()
98+
99+
init {
100+
logger.debug { "Initializing MCP client with capabilities: $capabilities" }
101+
102+
// Internal handlers for roots
103+
if (capabilities.roots != null) {
104+
setRequestHandler<ListToolsRequest>(Method.Defined.RootsList) { _, _ ->
105+
handleListRoots()
106+
}
107+
}
108+
}
109+
92110
protected fun assertCapability(capability: String, method: String) {
93111
val caps = serverCapabilities
94112
val hasCapability = when (capability) {
@@ -449,6 +467,97 @@ public open class Client(
449467
return request<ListToolsResult>(request, options)
450468
}
451469

470+
/**
471+
* Registers a single root.
472+
*
473+
* @param uri The URI of the root.
474+
* @param name A human-readable name for the root.
475+
* @throws IllegalStateException If the client does not support roots.
476+
*/
477+
public fun addRoot(
478+
uri: String,
479+
name: String,
480+
) {
481+
if (capabilities.roots == null) {
482+
logger.error { "Failed to add root '$name': Client does not support roots capability" }
483+
throw IllegalStateException("Client does not support roots capability.")
484+
}
485+
logger.info { "Adding root: $name ($uri)" }
486+
roots[uri] = Root(uri, name)
487+
}
488+
489+
/**
490+
* Registers multiple roots at once.
491+
*
492+
* @param rootsToAdd A list of [Root] objects to register.
493+
* @throws IllegalStateException If the client does not support roots.
494+
*/
495+
public fun addRoots(rootsToAdd: List<Root>) {
496+
if (capabilities.roots == null) {
497+
logger.error { "Failed to add roots: Client does not support roots capability" }
498+
throw IllegalStateException("Client does not support roots capability.")
499+
}
500+
logger.info { "Adding ${rootsToAdd.size} roots" }
501+
for (r in rootsToAdd) {
502+
logger.info { "Adding root: ${r.name} (${r.uri})" }
503+
roots[r.uri] = r
504+
}
505+
}
506+
507+
/**
508+
* Removes a single root by URI.
509+
*
510+
* @param uri The URI of the root to remove.
511+
* @return True if the root was removed, false if it wasn't found.
512+
* @throws IllegalStateException If the client does not support roots.
513+
*/
514+
public fun removeRoot(uri: String): Boolean {
515+
if (capabilities.roots == null) {
516+
logger.error { "Failed to remove root '$uri': Client does not support roots capability" }
517+
throw IllegalStateException("Client does not support roots capability.")
518+
}
519+
logger.info { "Removing root: $uri" }
520+
val removed = roots.remove(uri) != null
521+
logger.debug {
522+
if (removed) {
523+
"Root removed: $uri"
524+
} else {
525+
"Root not found: $uri"
526+
}
527+
}
528+
return removed
529+
}
530+
531+
/**
532+
* Removes multiple roots at once.
533+
*
534+
* @param uris A list of root URIs to remove.
535+
* @return The number of roots that were successfully removed.
536+
* @throws IllegalStateException If the client does not support roots.
537+
*/
538+
public fun removeRoots(uris: List<String>): Int {
539+
if (capabilities.roots == null) {
540+
logger.error { "Failed to remove roots: Client does not support roots capability" }
541+
throw IllegalStateException("Client does not support roots capability.")
542+
}
543+
logger.info { "Removing ${uris.size} roots" }
544+
var removedCount = 0
545+
for (uri in uris) {
546+
logger.debug { "Removing root: $uri" }
547+
if (roots.remove(uri) != null) {
548+
removedCount++
549+
}
550+
}
551+
logger.info {
552+
if (removedCount > 0) {
553+
"Removed $removedCount roots"
554+
} else {
555+
"No roots were removed"
556+
}
557+
}
558+
return removedCount
559+
}
560+
452561
/**
453562
* Notifies the server that the list of roots has changed.
454563
* Typically used if the client is managing some form of hierarchical structure.
@@ -458,4 +567,11 @@ public open class Client(
458567
public suspend fun sendRootsListChanged() {
459568
notification(RootsListChangedNotification())
460569
}
570+
571+
// --- Internal Handlers ---
572+
573+
private suspend fun handleListRoots(): ListRootsResult {
574+
val rootList = roots.values.toList()
575+
return ListRootsResult(rootList)
576+
}
461577
}

src/jvmTest/kotlin/client/ClientTest.kt

Lines changed: 174 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package client
22

3+
import io.mockk.coEvery
4+
import io.mockk.spyk
35
import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities
46
import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest
57
import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult
68
import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject
79
import io.modelcontextprotocol.kotlin.sdk.Implementation
8-
import io.mockk.coEvery
9-
import io.mockk.spyk
1010
import io.modelcontextprotocol.kotlin.sdk.InMemoryTransport
1111
import io.modelcontextprotocol.kotlin.sdk.InitializeRequest
1212
import io.modelcontextprotocol.kotlin.sdk.InitializeResult
@@ -23,10 +23,17 @@ import io.modelcontextprotocol.kotlin.sdk.LoggingLevel
2323
import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification
2424
import io.modelcontextprotocol.kotlin.sdk.Method
2525
import io.modelcontextprotocol.kotlin.sdk.Role
26+
import io.modelcontextprotocol.kotlin.sdk.Root
27+
import io.modelcontextprotocol.kotlin.sdk.RootsListChangedNotification
2628
import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS
2729
import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities
2830
import io.modelcontextprotocol.kotlin.sdk.TextContent
2931
import io.modelcontextprotocol.kotlin.sdk.Tool
32+
import io.modelcontextprotocol.kotlin.sdk.client.Client
33+
import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions
34+
import io.modelcontextprotocol.kotlin.sdk.server.Server
35+
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
36+
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
3037
import kotlinx.coroutines.CompletableDeferred
3138
import kotlinx.coroutines.TimeoutCancellationException
3239
import kotlinx.coroutines.cancel
@@ -35,13 +42,9 @@ import kotlinx.coroutines.launch
3542
import kotlinx.coroutines.test.runTest
3643
import kotlinx.serialization.json.buildJsonObject
3744
import kotlinx.serialization.json.put
38-
import io.modelcontextprotocol.kotlin.sdk.client.Client
39-
import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions
4045
import org.junit.jupiter.api.Test
41-
import io.modelcontextprotocol.kotlin.sdk.server.Server
42-
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
43-
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
4446
import org.junit.jupiter.api.assertInstanceOf
47+
import org.junit.jupiter.api.assertThrows
4548
import kotlin.coroutines.cancellation.CancellationException
4649
import kotlin.test.assertEquals
4750
import kotlin.test.assertFailsWith
@@ -628,4 +631,168 @@ class ClientTest {
628631
assertEquals(null, receivedAsResponse.error)
629632
}
630633

634+
@Test
635+
fun `listRoots returns list of roots`() = runTest {
636+
val client = Client(
637+
Implementation(name = "test client", version = "1.0"),
638+
ClientOptions(
639+
capabilities = ClientCapabilities(
640+
roots = ClientCapabilities.Roots(null)
641+
)
642+
)
643+
)
644+
645+
val clientRoots = listOf(
646+
Root(uri = "file:///test-root", name = "testRoot")
647+
)
648+
649+
client.addRoots(clientRoots)
650+
651+
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
652+
653+
val server = Server(
654+
serverInfo = Implementation(name = "test server", version = "1.0"),
655+
options = ServerOptions(
656+
capabilities = ServerCapabilities()
657+
)
658+
)
659+
660+
listOf(
661+
launch { client.connect(clientTransport) },
662+
launch { server.connect(serverTransport) }
663+
).joinAll()
664+
665+
val clientCapabilities = server.clientCapabilities
666+
assertEquals(ClientCapabilities.Roots(null), clientCapabilities?.roots)
667+
668+
val listRootsResult = server.listRoots()
669+
670+
assertEquals(listRootsResult.roots, clientRoots)
671+
}
672+
673+
@Test
674+
fun `addRoot should throw when roots capability is not supported`() = runTest {
675+
val client = Client(
676+
Implementation(name = "test client", version = "1.0"),
677+
ClientOptions(
678+
capabilities = ClientCapabilities()
679+
)
680+
)
681+
682+
// Verify that adding a root throws an exception
683+
val exception = assertThrows<IllegalStateException> {
684+
client.addRoot(uri = "file:///test-root1", name = "testRoot1")
685+
}
686+
assertEquals("Client does not support roots capability.", exception.message)
687+
}
688+
689+
@Test
690+
fun `removeRoot should throw when roots capability is not supported`() = runTest {
691+
val client = Client(
692+
Implementation(name = "test client", version = "1.0"),
693+
ClientOptions(
694+
capabilities = ClientCapabilities()
695+
)
696+
)
697+
698+
// Verify that removing a root throws an exception
699+
val exception = assertThrows<IllegalStateException> {
700+
client.removeRoot(uri = "file:///test-root1")
701+
}
702+
assertEquals("Client does not support roots capability.", exception.message)
703+
}
704+
705+
@Test
706+
fun `removeRoot should remove a root`() = runTest {
707+
val client = Client(
708+
Implementation(name = "test client", version = "1.0"),
709+
ClientOptions(
710+
capabilities = ClientCapabilities(
711+
roots = ClientCapabilities.Roots(null)
712+
)
713+
)
714+
)
715+
716+
// Add some roots
717+
client.addRoots(
718+
listOf(
719+
Root(uri = "file:///test-root1", name = "testRoot1"),
720+
Root(uri = "file:///test-root2", name = "testRoot2"),
721+
)
722+
)
723+
724+
// Remove a root
725+
val result = client.removeRoot("file:///test-root1")
726+
727+
// Verify the root was removed
728+
assertTrue(result, "Root should be removed successfully")
729+
}
730+
731+
@Test
732+
fun `removeRoots should remove multiple roots`() = runTest {
733+
val client = Client(
734+
Implementation(name = "test client", version = "1.0"),
735+
ClientOptions(
736+
capabilities = ClientCapabilities(
737+
roots = ClientCapabilities.Roots(null)
738+
)
739+
)
740+
)
741+
742+
// Add some roots
743+
client.addRoots(
744+
listOf(
745+
Root(uri = "file:///test-root1", name = "testRoot1"),
746+
Root(uri = "file:///test-root2", name = "testRoot2"),
747+
)
748+
)
749+
750+
// Remove multiple roots
751+
val result = client.removeRoots(
752+
listOf("file:///test-root1", "file:///test-root2")
753+
)
754+
755+
// Verify the root was removed
756+
assertEquals(2, result, "Both roots should be removed")
757+
}
758+
759+
@Test
760+
fun `sendRootsListChanged should notify server`() = runTest {
761+
val client = Client(
762+
Implementation(name = "test client", version = "1.0"),
763+
ClientOptions(
764+
capabilities = ClientCapabilities(
765+
roots = ClientCapabilities.Roots(listChanged = true)
766+
)
767+
)
768+
)
769+
770+
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
771+
772+
val server = Server(
773+
serverInfo = Implementation(name = "test server", version = "1.0"),
774+
options = ServerOptions(
775+
capabilities = ServerCapabilities()
776+
)
777+
)
778+
779+
// Track notifications
780+
var rootListChangedNotificationReceived = false
781+
server.setNotificationHandler<RootsListChangedNotification>(Method.Defined.NotificationsRootsListChanged) {
782+
rootListChangedNotificationReceived = true
783+
CompletableDeferred(Unit)
784+
}
785+
786+
listOf(
787+
launch { client.connect(clientTransport) },
788+
launch { server.connect(serverTransport) }
789+
).joinAll()
790+
791+
client.sendRootsListChanged()
792+
793+
assertTrue(
794+
rootListChangedNotificationReceived,
795+
"Notification should be sent when sendRootsListChanged is called"
796+
)
797+
}
631798
}

0 commit comments

Comments
 (0)