diff --git a/internal/jsonrpc/router.go b/internal/jsonrpc/router.go new file mode 100644 index 00000000..616bd475 --- /dev/null +++ b/internal/jsonrpc/router.go @@ -0,0 +1,179 @@ +package jsonrpc + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "reflect" +) + +type JSONRPCRouter struct { + writer io.Writer + + handlers map[string]*RPCHandler +} + +func NewJSONRPCRouter(writer io.Writer, handlers map[string]*RPCHandler) *JSONRPCRouter { + return &JSONRPCRouter{ + writer: writer, + handlers: handlers, + } +} + +func (s *JSONRPCRouter) HandleMessage(data []byte) error { + var request JSONRPCRequest + err := json.Unmarshal(data, &request) + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: map[string]interface{}{ + "code": -32700, + "message": "Parse error", + }, + ID: 0, + } + return s.writeResponse(errorResponse) + } + + //log.Printf("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID) + handler, ok := s.handlers[request.Method] + if !ok { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: map[string]interface{}{ + "code": -32601, + "message": "Method not found", + }, + ID: request.ID, + } + return s.writeResponse(errorResponse) + } + + result, err := callRPCHandler(handler, request.Params) + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: map[string]interface{}{ + "code": -32603, + "message": "Internal error", + "data": err.Error(), + }, + ID: request.ID, + } + return s.writeResponse(errorResponse) + } + + response := JSONRPCResponse{ + JSONRPC: "2.0", + Result: result, + ID: request.ID, + } + return s.writeResponse(response) +} + +func (s *JSONRPCRouter) writeResponse(response JSONRPCResponse) error { + responseBytes, err := json.Marshal(response) + if err != nil { + return err + } + _, err = s.writer.Write(responseBytes) + return err +} + +func callRPCHandler(handler *RPCHandler, params map[string]interface{}) (interface{}, error) { + handlerValue := reflect.ValueOf(handler.Func) + handlerType := handlerValue.Type() + + if handlerType.Kind() != reflect.Func { + return nil, errors.New("handler is not a function") + } + + numParams := handlerType.NumIn() + args := make([]reflect.Value, numParams) + // Get the parameter names from the RPCHandler + paramNames := handler.Params + + if len(paramNames) != numParams { + return nil, errors.New("mismatch between handler parameters and defined parameter names") + } + + for i := 0; i < numParams; i++ { + paramType := handlerType.In(i) + paramName := paramNames[i] + paramValue, ok := params[paramName] + if !ok { + return nil, errors.New("missing parameter: " + paramName) + } + + convertedValue := reflect.ValueOf(paramValue) + if !convertedValue.Type().ConvertibleTo(paramType) { + if paramType.Kind() == reflect.Slice && (convertedValue.Kind() == reflect.Slice || convertedValue.Kind() == reflect.Array) { + newSlice := reflect.MakeSlice(paramType, convertedValue.Len(), convertedValue.Len()) + for j := 0; j < convertedValue.Len(); j++ { + elemValue := convertedValue.Index(j) + if elemValue.Kind() == reflect.Interface { + elemValue = elemValue.Elem() + } + if !elemValue.Type().ConvertibleTo(paramType.Elem()) { + // Handle float64 to uint8 conversion + if elemValue.Kind() == reflect.Float64 && paramType.Elem().Kind() == reflect.Uint8 { + intValue := int(elemValue.Float()) + if intValue < 0 || intValue > 255 { + return nil, fmt.Errorf("value out of range for uint8: %v", intValue) + } + newSlice.Index(j).SetUint(uint64(intValue)) + } else { + fromType := elemValue.Type() + toType := paramType.Elem() + return nil, fmt.Errorf("invalid element type in slice for parameter %s: from %v to %v", paramName, fromType, toType) + } + } else { + newSlice.Index(j).Set(elemValue.Convert(paramType.Elem())) + } + } + args[i] = newSlice + } else if paramType.Kind() == reflect.Struct && convertedValue.Kind() == reflect.Map { + jsonData, err := json.Marshal(convertedValue.Interface()) + if err != nil { + return nil, fmt.Errorf("failed to marshal map to JSON: %v", err) + } + + newStruct := reflect.New(paramType).Interface() + if err := json.Unmarshal(jsonData, newStruct); err != nil { + return nil, fmt.Errorf("failed to unmarshal JSON into struct: %v", err) + } + args[i] = reflect.ValueOf(newStruct).Elem() + } else { + return nil, fmt.Errorf("invalid parameter type for: %s", paramName) + } + } else { + args[i] = convertedValue.Convert(paramType) + } + } + + results := handlerValue.Call(args) + + if len(results) == 0 { + return nil, nil + } + + if len(results) == 1 { + if results[0].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { + if !results[0].IsNil() { + return nil, results[0].Interface().(error) + } + return nil, nil + } + return results[0].Interface(), nil + } + + if len(results) == 2 && results[1].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { + if !results[1].IsNil() { + return nil, results[1].Interface().(error) + } + return results[0].Interface(), nil + } + + return nil, errors.New("unexpected return values from handler") +} diff --git a/internal/jsonrpc/types.go b/internal/jsonrpc/types.go new file mode 100644 index 00000000..30f8a2ca --- /dev/null +++ b/internal/jsonrpc/types.go @@ -0,0 +1,26 @@ +package jsonrpc + +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params map[string]interface{} `json:"params,omitempty"` + ID interface{} `json:"id,omitempty"` +} + +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result,omitempty"` + Error interface{} `json:"error,omitempty"` + ID interface{} `json:"id"` +} + +type JSONRPCEvent struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + +type RPCHandler struct { + Func interface{} + Params []string +} diff --git a/jsonrpc.go b/jsonrpc.go index 2ce5f189..708a9445 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -5,50 +5,44 @@ import ( "encoding/json" "errors" "fmt" + "kvm/internal/jsonrpc" "log" "os" "os/exec" "path/filepath" - "reflect" "github.com/pion/webrtc/v4" ) -type JSONRPCRequest struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params map[string]interface{} `json:"params,omitempty"` - ID interface{} `json:"id,omitempty"` +type DataChannelWriter struct { + dataChannel *webrtc.DataChannel } -type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - Result interface{} `json:"result,omitempty"` - Error interface{} `json:"error,omitempty"` - ID interface{} `json:"id"` -} - -type JSONRPCEvent struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params interface{} `json:"params,omitempty"` +func NewDataChannelWriter(dataChannel *webrtc.DataChannel) *DataChannelWriter { + return &DataChannelWriter{ + dataChannel: dataChannel, + } } -func writeJSONRPCResponse(response JSONRPCResponse, session *Session) { - responseBytes, err := json.Marshal(response) - if err != nil { - log.Println("Error marshalling JSONRPC response:", err) - return - } - err = session.RPCChannel.SendText(string(responseBytes)) +func (w *DataChannelWriter) Write(data []byte) (int, error) { + err := w.dataChannel.SendText(string(data)) if err != nil { log.Println("Error sending JSONRPC response:", err) - return + return 0, err } + return len(data), nil } +func NewDataChannelJsonRpcRouter(dataChannel *webrtc.DataChannel) *jsonrpc.JSONRPCRouter { + return jsonrpc.NewJSONRPCRouter( + NewDataChannelWriter(dataChannel), + rpcHandlers, + ) +} + +// TODO: embed this into the session's rpc server func writeJSONRPCEvent(event string, params interface{}, session *Session) { - request := JSONRPCEvent{ + request := jsonrpc.JSONRPCEvent{ JSONRPC: "2.0", Method: event, Params: params, @@ -69,60 +63,6 @@ func writeJSONRPCEvent(event string, params interface{}, session *Session) { } } -func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { - var request JSONRPCRequest - err := json.Unmarshal(message.Data, &request) - if err != nil { - errorResponse := JSONRPCResponse{ - JSONRPC: "2.0", - Error: map[string]interface{}{ - "code": -32700, - "message": "Parse error", - }, - ID: 0, - } - writeJSONRPCResponse(errorResponse, session) - return - } - - //log.Printf("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID) - handler, ok := rpcHandlers[request.Method] - if !ok { - errorResponse := JSONRPCResponse{ - JSONRPC: "2.0", - Error: map[string]interface{}{ - "code": -32601, - "message": "Method not found", - }, - ID: request.ID, - } - writeJSONRPCResponse(errorResponse, session) - return - } - - result, err := callRPCHandler(handler, request.Params) - if err != nil { - errorResponse := JSONRPCResponse{ - JSONRPC: "2.0", - Error: map[string]interface{}{ - "code": -32603, - "message": "Internal error", - "data": err.Error(), - }, - ID: request.ID, - } - writeJSONRPCResponse(errorResponse, session) - return - } - - response := JSONRPCResponse{ - JSONRPC: "2.0", - Result: result, - ID: request.ID, - } - writeJSONRPCResponse(response, session) -} - func rpcPing() (string, error) { return "pong", nil } @@ -315,108 +255,6 @@ func rpcSetSSHKeyState(sshKey string) error { return nil } -func callRPCHandler(handler RPCHandler, params map[string]interface{}) (interface{}, error) { - handlerValue := reflect.ValueOf(handler.Func) - handlerType := handlerValue.Type() - - if handlerType.Kind() != reflect.Func { - return nil, errors.New("handler is not a function") - } - - numParams := handlerType.NumIn() - args := make([]reflect.Value, numParams) - // Get the parameter names from the RPCHandler - paramNames := handler.Params - - if len(paramNames) != numParams { - return nil, errors.New("mismatch between handler parameters and defined parameter names") - } - - for i := 0; i < numParams; i++ { - paramType := handlerType.In(i) - paramName := paramNames[i] - paramValue, ok := params[paramName] - if !ok { - return nil, errors.New("missing parameter: " + paramName) - } - - convertedValue := reflect.ValueOf(paramValue) - if !convertedValue.Type().ConvertibleTo(paramType) { - if paramType.Kind() == reflect.Slice && (convertedValue.Kind() == reflect.Slice || convertedValue.Kind() == reflect.Array) { - newSlice := reflect.MakeSlice(paramType, convertedValue.Len(), convertedValue.Len()) - for j := 0; j < convertedValue.Len(); j++ { - elemValue := convertedValue.Index(j) - if elemValue.Kind() == reflect.Interface { - elemValue = elemValue.Elem() - } - if !elemValue.Type().ConvertibleTo(paramType.Elem()) { - // Handle float64 to uint8 conversion - if elemValue.Kind() == reflect.Float64 && paramType.Elem().Kind() == reflect.Uint8 { - intValue := int(elemValue.Float()) - if intValue < 0 || intValue > 255 { - return nil, fmt.Errorf("value out of range for uint8: %v", intValue) - } - newSlice.Index(j).SetUint(uint64(intValue)) - } else { - fromType := elemValue.Type() - toType := paramType.Elem() - return nil, fmt.Errorf("invalid element type in slice for parameter %s: from %v to %v", paramName, fromType, toType) - } - } else { - newSlice.Index(j).Set(elemValue.Convert(paramType.Elem())) - } - } - args[i] = newSlice - } else if paramType.Kind() == reflect.Struct && convertedValue.Kind() == reflect.Map { - jsonData, err := json.Marshal(convertedValue.Interface()) - if err != nil { - return nil, fmt.Errorf("failed to marshal map to JSON: %v", err) - } - - newStruct := reflect.New(paramType).Interface() - if err := json.Unmarshal(jsonData, newStruct); err != nil { - return nil, fmt.Errorf("failed to unmarshal JSON into struct: %v", err) - } - args[i] = reflect.ValueOf(newStruct).Elem() - } else { - return nil, fmt.Errorf("invalid parameter type for: %s", paramName) - } - } else { - args[i] = convertedValue.Convert(paramType) - } - } - - results := handlerValue.Call(args) - - if len(results) == 0 { - return nil, nil - } - - if len(results) == 1 { - if results[0].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { - if !results[0].IsNil() { - return nil, results[0].Interface().(error) - } - return nil, nil - } - return results[0].Interface(), nil - } - - if len(results) == 2 && results[1].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { - if !results[1].IsNil() { - return nil, results[1].Interface().(error) - } - return results[0].Interface(), nil - } - - return nil, errors.New("unexpected return values from handler") -} - -type RPCHandler struct { - Func interface{} - Params []string -} - func rpcSetMassStorageMode(mode string) (string, error) { log.Printf("[jsonrpc.go:rpcSetMassStorageMode] Setting mass storage mode to: %s", mode) var cdrom bool @@ -508,7 +346,7 @@ func rpcResetConfig() error { } // TODO: replace this crap with code generator -var rpcHandlers = map[string]RPCHandler{ +var rpcHandlers = map[string]*jsonrpc.RPCHandler{ "ping": {Func: rpcPing}, "getDeviceID": {Func: rpcGetDeviceID}, "deregisterDevice": {Func: rpcDeregisterDevice}, diff --git a/webrtc.go b/webrtc.go index 20ffb99c..49a7b414 100644 --- a/webrtc.go +++ b/webrtc.go @@ -75,8 +75,9 @@ func newSession() (*Session, error) { switch d.Label() { case "rpc": session.RPCChannel = d + rpcServer := NewDataChannelJsonRpcRouter(d) d.OnMessage(func(msg webrtc.DataChannelMessage) { - go onRPCMessage(msg, session) + go rpcServer.HandleMessage(msg.Data) }) triggerOTAStateUpdate() triggerVideoStateUpdate()