Skip to content

Commit b306edc

Browse files
authored
🔧 Fix registerDefaultHandlers to respect existing custom handlers (#9)
* Fix registerDefaultHandlers() to check for existing handlers before registering defaults * Add comprehensive test coverage for handler override functionality * Fix echo tool to prepend "Echo: " to message for test compatibility * Ensure backward compatibility while enabling custom handler preservation This enables servers to register custom tools/list, prompts/list, and resources/list handlers that are preserved when the framework registers default handlers.
1 parent c3da5e8 commit b306edc

File tree

3 files changed

+199
-5
lines changed

3 files changed

+199
-5
lines changed

cmd/mcp-server/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func registerExampleHandlers(server *mcp.Server) {
255255
"content": []map[string]interface{}{
256256
{
257257
"type": "text",
258-
"text": message,
258+
"text": "Echo: " + message,
259259
},
260260
},
261261
}, nil

pkg/mcp/server.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,21 @@ func (s *Server) handleNotification(ctx context.Context, notification JSONRPCNot
187187
}
188188

189189
// registerDefaultHandlers registers the default MCP handlers
190+
// Only registers handlers if they don't already exist (allows custom overrides)
190191
func (s *Server) registerDefaultHandlers() {
191-
s.RegisterHandler("initialize", s.handleInitialize)
192-
s.RegisterNotificationHandler("initialized", s.handleInitialized)
193-
s.RegisterHandler("tools/list", s.handleToolsList)
194-
s.RegisterHandler("tools/call", s.handleToolsCall)
192+
// Only register if not already registered (allows custom handlers to override)
193+
if s.GetHandler("initialize") == nil {
194+
s.RegisterHandler("initialize", s.handleInitialize)
195+
}
196+
if s.GetNotificationHandler("initialized") == nil {
197+
s.RegisterNotificationHandler("initialized", s.handleInitialized)
198+
}
199+
if s.GetHandler("tools/list") == nil {
200+
s.RegisterHandler("tools/list", s.handleToolsList)
201+
}
202+
if s.GetHandler("tools/call") == nil {
203+
s.RegisterHandler("tools/call", s.handleToolsCall)
204+
}
195205
}
196206

197207
// handleInitialize handles the initialize request

pkg/mcp/server_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,187 @@ func TestServerInitializeCapabilities(t *testing.T) {
615615
t.Error("Tools listChanged should be true")
616616
}
617617
}
618+
619+
func TestServerCustomHandlersPreserved(t *testing.T) {
620+
transport := NewMockTransport()
621+
server := NewServer(transport)
622+
623+
// Register custom handlers before starting server
624+
customToolsListCalled := false
625+
customPromptsListCalled := false
626+
customResourcesListCalled := false
627+
628+
customToolsListHandler := func(ctx context.Context, params json.RawMessage) (interface{}, error) {
629+
customToolsListCalled = true
630+
return map[string]interface{}{
631+
"tools": []interface{}{
632+
map[string]interface{}{
633+
"name": "custom_tool",
634+
"description": "A custom tool",
635+
"inputSchema": map[string]interface{}{
636+
"type": "object",
637+
"properties": map[string]interface{}{
638+
"message": map[string]interface{}{
639+
"type": "string",
640+
"description": "Message to process",
641+
},
642+
},
643+
"required": []string{"message"},
644+
},
645+
},
646+
},
647+
}, nil
648+
}
649+
650+
customPromptsListHandler := func(ctx context.Context, params json.RawMessage) (interface{}, error) {
651+
customPromptsListCalled = true
652+
return map[string]interface{}{
653+
"prompts": []interface{}{},
654+
}, nil
655+
}
656+
657+
customResourcesListHandler := func(ctx context.Context, params json.RawMessage) (interface{}, error) {
658+
customResourcesListCalled = true
659+
return map[string]interface{}{
660+
"resources": []interface{}{},
661+
}, nil
662+
}
663+
664+
server.RegisterHandler("tools/list", customToolsListHandler)
665+
server.RegisterHandler("prompts/list", customPromptsListHandler)
666+
server.RegisterHandler("resources/list", customResourcesListHandler)
667+
668+
ctx, cancel := context.WithCancel(context.Background())
669+
defer cancel()
670+
671+
err := server.Start(ctx)
672+
if err != nil {
673+
t.Fatalf("Failed to start server: %v", err)
674+
}
675+
676+
// Test tools/list uses custom handler
677+
request := JSONRPCRequest{
678+
JSONRPC: JSONRPCVersion,
679+
ID: 1,
680+
Method: "tools/list",
681+
}
682+
683+
requestBytes, _ := json.Marshal(request)
684+
transport.SendMessage(requestBytes)
685+
686+
// Give some time for processing
687+
time.Sleep(10 * time.Millisecond)
688+
689+
// Check response
690+
responseBytes := transport.GetSentMessage()
691+
if responseBytes == nil {
692+
t.Fatal("No response received for tools/list")
693+
}
694+
695+
var response JSONRPCResponse
696+
err = json.Unmarshal(responseBytes, &response)
697+
if err != nil {
698+
t.Fatalf("Failed to unmarshal tools/list response: %v", err)
699+
}
700+
701+
if response.Error != nil {
702+
t.Errorf("Unexpected error in tools/list response: %v", response.Error)
703+
}
704+
705+
if !customToolsListCalled {
706+
t.Error("Custom tools/list handler was not called")
707+
}
708+
709+
// Verify the custom tool is returned
710+
resultMap, ok := response.Result.(map[string]interface{})
711+
if !ok {
712+
t.Error("tools/list result is not a map")
713+
}
714+
715+
tools, exists := resultMap["tools"]
716+
if !exists {
717+
t.Error("tools/list result does not contain 'tools' field")
718+
}
719+
720+
toolsArray, ok := tools.([]interface{})
721+
if !ok {
722+
t.Error("tools field is not an array")
723+
}
724+
725+
if len(toolsArray) != 1 {
726+
t.Errorf("Expected 1 custom tool, got %d", len(toolsArray))
727+
}
728+
729+
tool, ok := toolsArray[0].(map[string]interface{})
730+
if !ok {
731+
t.Error("Tool is not a map")
732+
}
733+
734+
if tool["name"] != "custom_tool" {
735+
t.Errorf("Expected custom_tool, got %v", tool["name"])
736+
}
737+
738+
// Test prompts/list uses custom handler
739+
request = JSONRPCRequest{
740+
JSONRPC: JSONRPCVersion,
741+
ID: 2,
742+
Method: "prompts/list",
743+
}
744+
745+
requestBytes, _ = json.Marshal(request)
746+
transport.SendMessage(requestBytes)
747+
748+
// Give some time for processing
749+
time.Sleep(10 * time.Millisecond)
750+
751+
// Check response
752+
responseBytes = transport.GetSentMessage()
753+
if responseBytes == nil {
754+
t.Fatal("No response received for prompts/list")
755+
}
756+
757+
err = json.Unmarshal(responseBytes, &response)
758+
if err != nil {
759+
t.Fatalf("Failed to unmarshal prompts/list response: %v", err)
760+
}
761+
762+
if response.Error != nil {
763+
t.Errorf("Unexpected error in prompts/list response: %v", response.Error)
764+
}
765+
766+
if !customPromptsListCalled {
767+
t.Error("Custom prompts/list handler was not called")
768+
}
769+
770+
// Test resources/list uses custom handler
771+
request = JSONRPCRequest{
772+
JSONRPC: JSONRPCVersion,
773+
ID: 3,
774+
Method: "resources/list",
775+
}
776+
777+
requestBytes, _ = json.Marshal(request)
778+
transport.SendMessage(requestBytes)
779+
780+
// Give some time for processing
781+
time.Sleep(10 * time.Millisecond)
782+
783+
// Check response
784+
responseBytes = transport.GetSentMessage()
785+
if responseBytes == nil {
786+
t.Fatal("No response received for resources/list")
787+
}
788+
789+
err = json.Unmarshal(responseBytes, &response)
790+
if err != nil {
791+
t.Fatalf("Failed to unmarshal resources/list response: %v", err)
792+
}
793+
794+
if response.Error != nil {
795+
t.Errorf("Unexpected error in resources/list response: %v", response.Error)
796+
}
797+
798+
if !customResourcesListCalled {
799+
t.Error("Custom resources/list handler was not called")
800+
}
801+
}

0 commit comments

Comments
 (0)