1
1
package client
2
2
3
+ import io.mockk.coEvery
4
+ import io.mockk.spyk
3
5
import io.modelcontextprotocol.kotlin.sdk.ClientCapabilities
4
6
import io.modelcontextprotocol.kotlin.sdk.CreateMessageRequest
5
7
import io.modelcontextprotocol.kotlin.sdk.CreateMessageResult
6
8
import io.modelcontextprotocol.kotlin.sdk.EmptyJsonObject
7
9
import io.modelcontextprotocol.kotlin.sdk.Implementation
8
- import io.mockk.coEvery
9
- import io.mockk.spyk
10
10
import io.modelcontextprotocol.kotlin.sdk.InMemoryTransport
11
11
import io.modelcontextprotocol.kotlin.sdk.InitializeRequest
12
12
import io.modelcontextprotocol.kotlin.sdk.InitializeResult
@@ -23,10 +23,17 @@ import io.modelcontextprotocol.kotlin.sdk.LoggingLevel
23
23
import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification
24
24
import io.modelcontextprotocol.kotlin.sdk.Method
25
25
import io.modelcontextprotocol.kotlin.sdk.Role
26
+ import io.modelcontextprotocol.kotlin.sdk.Root
27
+ import io.modelcontextprotocol.kotlin.sdk.RootsListChangedNotification
26
28
import io.modelcontextprotocol.kotlin.sdk.SUPPORTED_PROTOCOL_VERSIONS
27
29
import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities
28
30
import io.modelcontextprotocol.kotlin.sdk.TextContent
29
31
import io.modelcontextprotocol.kotlin.sdk.Tool
32
+ import io.modelcontextprotocol.kotlin.sdk.client.Client
33
+ import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions
34
+ import io.modelcontextprotocol.kotlin.sdk.server.Server
35
+ import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
36
+ import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
30
37
import kotlinx.coroutines.CompletableDeferred
31
38
import kotlinx.coroutines.TimeoutCancellationException
32
39
import kotlinx.coroutines.cancel
@@ -35,13 +42,9 @@ import kotlinx.coroutines.launch
35
42
import kotlinx.coroutines.test.runTest
36
43
import kotlinx.serialization.json.buildJsonObject
37
44
import kotlinx.serialization.json.put
38
- import io.modelcontextprotocol.kotlin.sdk.client.Client
39
- import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions
40
45
import org.junit.jupiter.api.Test
41
- import io.modelcontextprotocol.kotlin.sdk.server.Server
42
- import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
43
- import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
44
46
import org.junit.jupiter.api.assertInstanceOf
47
+ import org.junit.jupiter.api.assertThrows
45
48
import kotlin.coroutines.cancellation.CancellationException
46
49
import kotlin.test.assertEquals
47
50
import kotlin.test.assertFailsWith
@@ -628,4 +631,168 @@ class ClientTest {
628
631
assertEquals(null , receivedAsResponse.error)
629
632
}
630
633
634
+ @Test
635
+ fun `listRoots returns list of roots` () = runTest {
636
+ val client = Client (
637
+ Implementation (name = " test client" , version = " 1.0" ),
638
+ ClientOptions (
639
+ capabilities = ClientCapabilities (
640
+ roots = ClientCapabilities .Roots (null )
641
+ )
642
+ )
643
+ )
644
+
645
+ val clientRoots = listOf (
646
+ Root (uri = " file:///test-root" , name = " testRoot" )
647
+ )
648
+
649
+ client.addRoots(clientRoots)
650
+
651
+ val (clientTransport, serverTransport) = InMemoryTransport .createLinkedPair()
652
+
653
+ val server = Server (
654
+ serverInfo = Implementation (name = " test server" , version = " 1.0" ),
655
+ options = ServerOptions (
656
+ capabilities = ServerCapabilities ()
657
+ )
658
+ )
659
+
660
+ listOf (
661
+ launch { client.connect(clientTransport) },
662
+ launch { server.connect(serverTransport) }
663
+ ).joinAll()
664
+
665
+ val clientCapabilities = server.clientCapabilities
666
+ assertEquals(ClientCapabilities .Roots (null ), clientCapabilities?.roots)
667
+
668
+ val listRootsResult = server.listRoots()
669
+
670
+ assertEquals(listRootsResult.roots, clientRoots)
671
+ }
672
+
673
+ @Test
674
+ fun `addRoot should throw when roots capability is not supported` () = runTest {
675
+ val client = Client (
676
+ Implementation (name = " test client" , version = " 1.0" ),
677
+ ClientOptions (
678
+ capabilities = ClientCapabilities ()
679
+ )
680
+ )
681
+
682
+ // Verify that adding a root throws an exception
683
+ val exception = assertThrows<IllegalStateException > {
684
+ client.addRoot(uri = " file:///test-root1" , name = " testRoot1" )
685
+ }
686
+ assertEquals(" Client does not support roots capability." , exception.message)
687
+ }
688
+
689
+ @Test
690
+ fun `removeRoot should throw when roots capability is not supported` () = runTest {
691
+ val client = Client (
692
+ Implementation (name = " test client" , version = " 1.0" ),
693
+ ClientOptions (
694
+ capabilities = ClientCapabilities ()
695
+ )
696
+ )
697
+
698
+ // Verify that removing a root throws an exception
699
+ val exception = assertThrows<IllegalStateException > {
700
+ client.removeRoot(uri = " file:///test-root1" )
701
+ }
702
+ assertEquals(" Client does not support roots capability." , exception.message)
703
+ }
704
+
705
+ @Test
706
+ fun `removeRoot should remove a root` () = runTest {
707
+ val client = Client (
708
+ Implementation (name = " test client" , version = " 1.0" ),
709
+ ClientOptions (
710
+ capabilities = ClientCapabilities (
711
+ roots = ClientCapabilities .Roots (null )
712
+ )
713
+ )
714
+ )
715
+
716
+ // Add some roots
717
+ client.addRoots(
718
+ listOf (
719
+ Root (uri = " file:///test-root1" , name = " testRoot1" ),
720
+ Root (uri = " file:///test-root2" , name = " testRoot2" ),
721
+ )
722
+ )
723
+
724
+ // Remove a root
725
+ val result = client.removeRoot(" file:///test-root1" )
726
+
727
+ // Verify the root was removed
728
+ assertTrue(result, " Root should be removed successfully" )
729
+ }
730
+
731
+ @Test
732
+ fun `removeRoots should remove multiple roots` () = runTest {
733
+ val client = Client (
734
+ Implementation (name = " test client" , version = " 1.0" ),
735
+ ClientOptions (
736
+ capabilities = ClientCapabilities (
737
+ roots = ClientCapabilities .Roots (null )
738
+ )
739
+ )
740
+ )
741
+
742
+ // Add some roots
743
+ client.addRoots(
744
+ listOf (
745
+ Root (uri = " file:///test-root1" , name = " testRoot1" ),
746
+ Root (uri = " file:///test-root2" , name = " testRoot2" ),
747
+ )
748
+ )
749
+
750
+ // Remove multiple roots
751
+ val result = client.removeRoots(
752
+ listOf (" file:///test-root1" , " file:///test-root2" )
753
+ )
754
+
755
+ // Verify the root was removed
756
+ assertEquals(2 , result, " Both roots should be removed" )
757
+ }
758
+
759
+ @Test
760
+ fun `sendRootsListChanged should notify server` () = runTest {
761
+ val client = Client (
762
+ Implementation (name = " test client" , version = " 1.0" ),
763
+ ClientOptions (
764
+ capabilities = ClientCapabilities (
765
+ roots = ClientCapabilities .Roots (listChanged = true )
766
+ )
767
+ )
768
+ )
769
+
770
+ val (clientTransport, serverTransport) = InMemoryTransport .createLinkedPair()
771
+
772
+ val server = Server (
773
+ serverInfo = Implementation (name = " test server" , version = " 1.0" ),
774
+ options = ServerOptions (
775
+ capabilities = ServerCapabilities ()
776
+ )
777
+ )
778
+
779
+ // Track notifications
780
+ var rootListChangedNotificationReceived = false
781
+ server.setNotificationHandler<RootsListChangedNotification >(Method .Defined .NotificationsRootsListChanged ) {
782
+ rootListChangedNotificationReceived = true
783
+ CompletableDeferred (Unit )
784
+ }
785
+
786
+ listOf (
787
+ launch { client.connect(clientTransport) },
788
+ launch { server.connect(serverTransport) }
789
+ ).joinAll()
790
+
791
+ client.sendRootsListChanged()
792
+
793
+ assertTrue(
794
+ rootListChangedNotificationReceived,
795
+ " Notification should be sent when sendRootsListChanged is called"
796
+ )
797
+ }
631
798
}
0 commit comments