diff --git a/Makefile b/Makefile index 04c7402a..eea9730a 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ -VERSION_DEV := 0.3.5-dev$(shell date +%Y%m%d%H%M) -VERSION := 0.3.4 +VERSION_DEV := 0.3.6-dev$(shell date +%Y%m%d%H%M) +VERSION := 0.3.5 hash_resource: @shasum -a 256 resource/jetkvm_native | cut -d ' ' -f 1 > resource/jetkvm_native.sha256 diff --git a/internal/jsonrpc/router.go b/internal/jsonrpc/router.go new file mode 100644 index 00000000..0534432c --- /dev/null +++ b/internal/jsonrpc/router.go @@ -0,0 +1,300 @@ +package jsonrpc + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log" + "reflect" + "sync" + "sync/atomic" + "time" +) + +type JSONRPCRouter struct { + writer io.Writer + + handlers map[string]*RPCHandler + nextId atomic.Int64 + + responseChannelsMutex sync.Mutex + responseChannels map[int64]chan JSONRPCResponse +} + +func NewJSONRPCRouter(writer io.Writer, handlers map[string]*RPCHandler) *JSONRPCRouter { + return &JSONRPCRouter{ + writer: writer, + handlers: handlers, + + responseChannels: make(map[int64]chan JSONRPCResponse), + } +} + +func (s *JSONRPCRouter) Request(method string, params map[string]interface{}, result interface{}) *JSONRPCResponseError { + id := s.nextId.Add(1) + request := JSONRPCRequest{ + JSONRPC: "2.0", + Method: method, + Params: params, + ID: id, + } + requestBytes, err := json.Marshal(request) + if err != nil { + return &JSONRPCResponseError{ + Code: -32700, + Message: "Parse error", + Data: err, + } + } + + // log.Printf("Sending RPC request: Method=%s, Params=%v, ID=%d", method, params, id) + + responseChan := make(chan JSONRPCResponse, 1) + s.responseChannelsMutex.Lock() + s.responseChannels[id] = responseChan + s.responseChannelsMutex.Unlock() + defer func() { + s.responseChannelsMutex.Lock() + delete(s.responseChannels, id) + s.responseChannelsMutex.Unlock() + }() + + _, err = s.writer.Write(requestBytes) + if err != nil { + return &JSONRPCResponseError{ + Code: -32603, + Message: "Internal error", + Data: err, + } + } + + timeout := time.After(5 * time.Second) + select { + case response := <-responseChan: + if response.Error != nil { + return response.Error + } + + rawResult, err := json.Marshal(response.Result) + if err != nil { + return &JSONRPCResponseError{ + Code: -32603, + Message: "Internal error", + Data: err, + } + } + + if err := json.Unmarshal(rawResult, result); err != nil { + return &JSONRPCResponseError{ + Code: -32603, + Message: "Internal error", + Data: err, + } + } + + return nil + case <-timeout: + return &JSONRPCResponseError{ + Code: -32603, + Message: "Internal error", + Data: "timeout waiting for response", + } + } +} + +type JSONRPCMessage struct { + Method *string `json:"method,omitempty"` + ID *int64 `json:"id,omitempty"` +} + +func (s *JSONRPCRouter) HandleMessage(data []byte) error { + // Data will either be a JSONRPCRequest or JSONRPCResponse object + // We need to determine which one it is + var raw JSONRPCMessage + err := json.Unmarshal(data, &raw) + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: &JSONRPCResponseError{ + Code: -32700, + Message: "Parse error", + }, + ID: 0, + } + return s.writeResponse(errorResponse) + } + + if raw.Method == nil && raw.ID != nil { + var resp JSONRPCResponse + if err := json.Unmarshal(data, &resp); err != nil { + fmt.Println("error unmarshalling response", err) + return err + } + + s.responseChannelsMutex.Lock() + responseChan, ok := s.responseChannels[*raw.ID] + s.responseChannelsMutex.Unlock() + if ok { + responseChan <- resp + } else { + log.Println("No response channel found for ID", resp.ID) + } + return nil + } + + var request JSONRPCRequest + err = json.Unmarshal(data, &request) + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: &JSONRPCResponseError{ + Code: -32700, + Message: "Parse error", + }, + ID: 0, + } + return s.writeResponse(errorResponse) + } + + //log.Printf("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID) + handler, ok := s.handlers[request.Method] + if !ok { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: &JSONRPCResponseError{ + Code: -32601, + Message: "Method not found", + }, + ID: request.ID, + } + return s.writeResponse(errorResponse) + } + + result, err := callRPCHandler(handler, request.Params) + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: "2.0", + Error: &JSONRPCResponseError{ + Code: -32603, + Message: "Internal error", + Data: err.Error(), + }, + ID: request.ID, + } + return s.writeResponse(errorResponse) + } + + response := JSONRPCResponse{ + JSONRPC: "2.0", + Result: result, + ID: request.ID, + } + return s.writeResponse(response) +} + +func (s *JSONRPCRouter) writeResponse(response JSONRPCResponse) error { + responseBytes, err := json.Marshal(response) + if err != nil { + return err + } + _, err = s.writer.Write(responseBytes) + return err +} + +func callRPCHandler(handler *RPCHandler, params map[string]interface{}) (interface{}, error) { + handlerValue := reflect.ValueOf(handler.Func) + handlerType := handlerValue.Type() + + if handlerType.Kind() != reflect.Func { + return nil, errors.New("handler is not a function") + } + + numParams := handlerType.NumIn() + args := make([]reflect.Value, numParams) + // Get the parameter names from the RPCHandler + paramNames := handler.Params + + if len(paramNames) != numParams { + return nil, errors.New("mismatch between handler parameters and defined parameter names") + } + + for i := 0; i < numParams; i++ { + paramType := handlerType.In(i) + paramName := paramNames[i] + paramValue, ok := params[paramName] + if !ok { + return nil, errors.New("missing parameter: " + paramName) + } + + convertedValue := reflect.ValueOf(paramValue) + if !convertedValue.Type().ConvertibleTo(paramType) { + if paramType.Kind() == reflect.Slice && (convertedValue.Kind() == reflect.Slice || convertedValue.Kind() == reflect.Array) { + newSlice := reflect.MakeSlice(paramType, convertedValue.Len(), convertedValue.Len()) + for j := 0; j < convertedValue.Len(); j++ { + elemValue := convertedValue.Index(j) + if elemValue.Kind() == reflect.Interface { + elemValue = elemValue.Elem() + } + if !elemValue.Type().ConvertibleTo(paramType.Elem()) { + // Handle float64 to uint8 conversion + if elemValue.Kind() == reflect.Float64 && paramType.Elem().Kind() == reflect.Uint8 { + intValue := int(elemValue.Float()) + if intValue < 0 || intValue > 255 { + return nil, fmt.Errorf("value out of range for uint8: %v", intValue) + } + newSlice.Index(j).SetUint(uint64(intValue)) + } else { + fromType := elemValue.Type() + toType := paramType.Elem() + return nil, fmt.Errorf("invalid element type in slice for parameter %s: from %v to %v", paramName, fromType, toType) + } + } else { + newSlice.Index(j).Set(elemValue.Convert(paramType.Elem())) + } + } + args[i] = newSlice + } else if paramType.Kind() == reflect.Struct && convertedValue.Kind() == reflect.Map { + jsonData, err := json.Marshal(convertedValue.Interface()) + if err != nil { + return nil, fmt.Errorf("failed to marshal map to JSON: %v", err) + } + + newStruct := reflect.New(paramType).Interface() + if err := json.Unmarshal(jsonData, newStruct); err != nil { + return nil, fmt.Errorf("failed to unmarshal JSON into struct: %v", err) + } + args[i] = reflect.ValueOf(newStruct).Elem() + } else { + return nil, fmt.Errorf("invalid parameter type for: %s", paramName) + } + } else { + args[i] = convertedValue.Convert(paramType) + } + } + + results := handlerValue.Call(args) + + if len(results) == 0 { + return nil, nil + } + + if len(results) == 1 { + if results[0].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { + if !results[0].IsNil() { + return nil, results[0].Interface().(error) + } + return nil, nil + } + return results[0].Interface(), nil + } + + if len(results) == 2 && results[1].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { + if !results[1].IsNil() { + return nil, results[1].Interface().(error) + } + return results[0].Interface(), nil + } + + return nil, errors.New("unexpected return values from handler") +} diff --git a/internal/jsonrpc/types.go b/internal/jsonrpc/types.go new file mode 100644 index 00000000..ac4f956c --- /dev/null +++ b/internal/jsonrpc/types.go @@ -0,0 +1,32 @@ +package jsonrpc + +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params map[string]interface{} `json:"params,omitempty"` + ID interface{} `json:"id,omitempty"` +} + +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result,omitempty"` + Error *JSONRPCResponseError `json:"error,omitempty"` + ID interface{} `json:"id"` +} + +type JSONRPCResponseError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +type JSONRPCEvent struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + +type RPCHandler struct { + Func interface{} + Params []string +} diff --git a/internal/plugin/database.go b/internal/plugin/database.go new file mode 100644 index 00000000..6e669dce --- /dev/null +++ b/internal/plugin/database.go @@ -0,0 +1,92 @@ +package plugin + +import ( + "encoding/json" + "fmt" + "os" + "path" + "sync" +) + +const databaseFile = pluginsFolder + "/plugins.json" + +type PluginDatabase struct { + // Map with the plugin name as the key + Plugins map[string]*PluginInstall `json:"plugins"` + + saveMutex sync.Mutex +} + +var pluginDatabase = PluginDatabase{} + +func (d *PluginDatabase) Load() error { + file, err := os.Open(databaseFile) + if os.IsNotExist(err) { + d.Plugins = make(map[string]*PluginInstall) + return nil + } + if err != nil { + return fmt.Errorf("failed to open plugin database: %v", err) + } + defer file.Close() + + if err := json.NewDecoder(file).Decode(d); err != nil { + return fmt.Errorf("failed to decode plugin database: %v", err) + } + + return nil +} + +func (d *PluginDatabase) Save() error { + d.saveMutex.Lock() + defer d.saveMutex.Unlock() + + file, err := os.Create(databaseFile + ".tmp") + if err != nil { + return fmt.Errorf("failed to create plugin database tmp: %v", err) + } + defer file.Close() + + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + if err := encoder.Encode(d); err != nil { + return fmt.Errorf("failed to encode plugin database: %v", err) + } + + if err := os.Rename(databaseFile+".tmp", databaseFile); err != nil { + return fmt.Errorf("failed to move plugin database to active file: %v", err) + } + + return nil +} + +// Find all extract directories that are not referenced in the Plugins map and remove them +func (d *PluginDatabase) CleanupExtractDirectories() error { + extractDirectories, err := os.ReadDir(pluginsExtractsFolder) + if err != nil { + return fmt.Errorf("failed to read extract directories: %v", err) + } + + for _, extractDir := range extractDirectories { + found := false + for _, pluginInstall := range d.Plugins { + for _, extractedFolder := range pluginInstall.ExtractedVersions { + if extractDir.Name() == extractedFolder { + found = true + break + } + } + if found { + break + } + } + + if !found { + if err := os.RemoveAll(path.Join(pluginsExtractsFolder, extractDir.Name())); err != nil { + return fmt.Errorf("failed to remove extract directory: %v", err) + } + } + } + + return nil +} diff --git a/internal/plugin/extract.go b/internal/plugin/extract.go new file mode 100644 index 00000000..9fd8bb80 --- /dev/null +++ b/internal/plugin/extract.go @@ -0,0 +1,95 @@ +package plugin + +import ( + "archive/tar" + "compress/gzip" + "fmt" + "io" + "os" + "path" + "path/filepath" + "strings" + + "github.com/google/uuid" +) + +const pluginsExtractsFolder = pluginsFolder + "/extracts" + +func init() { + _ = os.MkdirAll(pluginsExtractsFolder, 0755) +} + +func extractPlugin(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", fmt.Errorf("failed to open file for extraction: %v", err) + } + defer file.Close() + + var reader io.Reader = file + // TODO: there's probably a better way of doing this without relying on the file extension + if strings.HasSuffix(filePath, ".gz") { + gzipReader, err := gzip.NewReader(file) + if err != nil { + return "", fmt.Errorf("failed to create gzip reader: %v", err) + } + defer gzipReader.Close() + reader = gzipReader + } + + destinationFolder := path.Join(pluginsExtractsFolder, uuid.New().String()) + if err := os.MkdirAll(destinationFolder, 0755); err != nil { + return "", fmt.Errorf("failed to create extracts folder: %v", err) + } + + if err := extractTarball(reader, destinationFolder); err != nil { + if err := os.RemoveAll(destinationFolder); err != nil { + return "", fmt.Errorf("failed to remove failed extraction folder: %v", err) + } + + return "", fmt.Errorf("failed to extract tarball: %v", err) + } + + return destinationFolder, nil +} + +func extractTarball(reader io.Reader, destinationFolder string) error { + tarReader := tar.NewReader(reader) + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar header: %v", err) + } + + // Prevent path traversal attacks + targetPath := filepath.Join(destinationFolder, header.Name) + if !strings.HasPrefix(targetPath, filepath.Clean(destinationFolder)+string(os.PathSeparator)) { + return fmt.Errorf("tar file contains illegal path: %s", header.Name) + } + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + case tar.TypeReg: + file, err := os.OpenFile(targetPath, os.O_CREATE|os.O_WRONLY, os.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("failed to create file: %v", err) + } + defer file.Close() + + if _, err := io.Copy(file, tarReader); err != nil { + return fmt.Errorf("failed to extract file: %v", err) + } + default: + return fmt.Errorf("unsupported tar entry type: %v", header.Typeflag) + } + } + + return nil +} diff --git a/internal/plugin/install.go b/internal/plugin/install.go new file mode 100644 index 00000000..dedf3291 --- /dev/null +++ b/internal/plugin/install.go @@ -0,0 +1,165 @@ +package plugin + +import ( + "fmt" + "log" + "os" + "os/exec" + "path" + "syscall" +) + +type PluginInstall struct { + Enabled bool `json:"enabled"` + + // Current active version of the plugin + Version string `json:"version"` + + // Map of a plugin version to the extracted directory + ExtractedVersions map[string]string `json:"extracted_versions"` + + manifest *PluginManifest + runningVersion string + processManager *ProcessManager + rpcServer *PluginRpcServer +} + +func (p *PluginInstall) GetManifest() (*PluginManifest, error) { + if p.manifest != nil { + return p.manifest, nil + } + + manifest, err := readManifest(p.GetExtractedFolder()) + if err != nil { + return nil, err + } + + p.manifest = manifest + return manifest, nil +} + +func (p *PluginInstall) GetExtractedFolder() string { + return p.ExtractedVersions[p.Version] +} + +func (p *PluginInstall) GetStatus() (*PluginStatus, error) { + manifest, err := p.GetManifest() + if err != nil { + return nil, fmt.Errorf("failed to get plugin manifest: %v", err) + } + + status := PluginStatus{ + PluginManifest: *manifest, + Enabled: p.Enabled, + } + + // If the rpc server is connected and the plugin is reporting status, use that + if p.rpcServer != nil && + p.rpcServer.status.Status != "disconnected" && + p.rpcServer.status.Status != "unknown" { + status.Status = p.rpcServer.status.Status + status.Message = p.rpcServer.status.Message + + if status.Status == "error" { + status.Message = p.rpcServer.status.Message + } + } else { + status.Status = "stopped" + if p.processManager != nil { + status.Status = "running" + if p.processManager.LastError != nil { + status.Status = "error" + status.Message = p.processManager.LastError.Error() + } + } + log.Printf("Status from process manager: %v", status.Status) + } + + return &status, nil +} + +func (p *PluginInstall) ReconcileSubprocess() error { + manifest, err := p.GetManifest() + if err != nil { + return fmt.Errorf("failed to get plugin manifest: %v", err) + } + + versionRunning := p.runningVersion + + versionShouldBeRunning := p.Version + if !p.Enabled { + versionShouldBeRunning = "" + } + + log.Printf("Reconciling plugin %s running %v, should be running %v", manifest.Name, versionRunning, versionShouldBeRunning) + + if versionRunning == versionShouldBeRunning { + log.Printf("Plugin %s is already running version %s", manifest.Name, versionRunning) + return nil + } + + if p.processManager != nil { + log.Printf("Stopping plugin %s running version %s", manifest.Name, versionRunning) + p.processManager.Disable() + p.processManager = nil + p.runningVersion = "" + err = p.rpcServer.Stop() + if err != nil { + return fmt.Errorf("failed to stop rpc server: %v", err) + } + } + + if versionShouldBeRunning == "" { + return nil + } + + workingDir := path.Join(pluginsFolder, "working_dirs", p.manifest.Name) + err = os.MkdirAll(workingDir, 0755) + if err != nil { + return fmt.Errorf("failed to create working directory: %v", err) + } + + p.rpcServer = NewPluginRpcServer(p, workingDir) + err = p.rpcServer.Start() + if err != nil { + return fmt.Errorf("failed to start rpc server: %v", err) + } + + p.processManager = NewProcessManager(func() *exec.Cmd { + cmd := exec.Command(manifest.BinaryPath) + cmd.Dir = p.GetExtractedFolder() + cmd.Env = append(cmd.Env, + "JETKVM_PLUGIN_SOCK="+p.rpcServer.SocketPath(), + "JETKVM_PLUGIN_WORKING_DIR="+workingDir, + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + // Ensure that the process is killed when the parent dies + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Pdeathsig: syscall.SIGKILL, + } + return cmd + }) + p.processManager.StartMonitor() + p.processManager.Enable() + p.runningVersion = p.Version + + // Clear out manifest so the new version gets pulled next time + p.manifest = nil + + log.Printf("Started plugin %s version %s", manifest.Name, p.Version) + return nil +} + +func (p *PluginInstall) Shutdown() { + if p.processManager != nil { + p.processManager.Disable() + p.processManager = nil + p.runningVersion = "" + } + + if p.rpcServer != nil { + p.rpcServer.Stop() + } +} diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go new file mode 100644 index 00000000..e72acdfb --- /dev/null +++ b/internal/plugin/plugin.go @@ -0,0 +1,257 @@ +package plugin + +import ( + "encoding/json" + "fmt" + "kvm/internal/storage" + "os" + "path" + + "github.com/google/uuid" +) + +const pluginsFolder = "/userdata/jetkvm/plugins" +const pluginsUploadFolder = pluginsFolder + "/uploads" + +func init() { + _ = os.MkdirAll(pluginsUploadFolder, 0755) + + if err := pluginDatabase.Load(); err != nil { + fmt.Printf("failed to load plugin database: %v\n", err) + } +} + +// Starts all plugins that need to be started +func ReconcilePlugins() { + for _, install := range pluginDatabase.Plugins { + err := install.ReconcileSubprocess() + if err != nil { + fmt.Printf("failed to reconcile subprocess for plugin: %v\n", err) + } + } +} + +func GracefullyShutdownPlugins() { + for _, install := range pluginDatabase.Plugins { + install.Shutdown() + } +} + +func RpcPluginStartUpload(filename string, size int64) (*storage.StorageFileUpload, error) { + sanitizedFilename, err := storage.SanitizeFilename(filename) + if err != nil { + return nil, err + } + + filePath := path.Join(pluginsUploadFolder, sanitizedFilename) + uploadPath := filePath + ".incomplete" + + if _, err := os.Stat(filePath); err == nil { + return nil, fmt.Errorf("file already exists: %s", sanitizedFilename) + } + + var alreadyUploadedBytes int64 = 0 + if stat, err := os.Stat(uploadPath); err == nil { + alreadyUploadedBytes = stat.Size() + } + + uploadId := "plugin_" + uuid.New().String() + file, err := os.OpenFile(uploadPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, fmt.Errorf("failed to open file for upload: %v", err) + } + + storage.AddPendingUpload(uploadId, storage.PendingUpload{ + File: file, + Size: size, + AlreadyUploadedBytes: alreadyUploadedBytes, + }) + + return &storage.StorageFileUpload{ + AlreadyUploadedBytes: alreadyUploadedBytes, + DataChannel: uploadId, + }, nil +} + +func RpcPluginExtract(filename string) (*PluginManifest, error) { + sanitizedFilename, err := storage.SanitizeFilename(filename) + if err != nil { + return nil, err + } + + filePath := path.Join(pluginsUploadFolder, sanitizedFilename) + extractFolder, err := extractPlugin(filePath) + if err != nil { + return nil, err + } + + if err := os.Remove(filePath); err != nil { + return nil, fmt.Errorf("failed to delete uploaded file: %v", err) + } + + manifest, err := readManifest(extractFolder) + if err != nil { + return nil, err + } + + // Get existing PluginInstall + install, ok := pluginDatabase.Plugins[manifest.Name] + if !ok { + install = &PluginInstall{ + Enabled: false, + Version: manifest.Version, + ExtractedVersions: make(map[string]string), + } + } + + _, ok = install.ExtractedVersions[manifest.Version] + if ok { + return nil, fmt.Errorf("this version has already been uploaded: %s", manifest.Version) + } + + install.ExtractedVersions[manifest.Version] = extractFolder + pluginDatabase.Plugins[manifest.Name] = install + + if err := pluginDatabase.Save(); err != nil { + return nil, fmt.Errorf("failed to save plugin database: %v", err) + } + + return manifest, nil +} + +func RpcPluginInstall(name string, version string) error { + pluginInstall, ok := pluginDatabase.Plugins[name] + if !ok { + return fmt.Errorf("plugin not found: %s", name) + } + + if pluginInstall.Version == version && pluginInstall.Enabled { + fmt.Printf("Plugin %s is already installed with version %s\n", name, version) + return nil + } + + _, ok = pluginInstall.ExtractedVersions[version] + if !ok { + return fmt.Errorf("plugin version not found: %s", version) + } + + pluginInstall.Version = version + pluginInstall.Enabled = true + pluginDatabase.Plugins[name] = pluginInstall + + if err := pluginDatabase.Save(); err != nil { + return fmt.Errorf("failed to save plugin database: %v", err) + } + + err := pluginInstall.ReconcileSubprocess() + if err != nil { + return fmt.Errorf("failed to start plugin %s: %v", name, err) + } + + // TODO: Determine if the old extract should be removed + + return nil +} + +func RpcPluginList() ([]PluginStatus, error) { + plugins := make([]PluginStatus, 0, len(pluginDatabase.Plugins)) + for pluginName, plugin := range pluginDatabase.Plugins { + status, err := plugin.GetStatus() + if err != nil { + return nil, fmt.Errorf("failed to get plugin status for %s: %v", pluginName, err) + } + plugins = append(plugins, *status) + } + return plugins, nil +} + +func RpcPluginUpdateConfig(name string, enabled bool) (*PluginStatus, error) { + pluginInstall, ok := pluginDatabase.Plugins[name] + if !ok { + return nil, fmt.Errorf("plugin not found: %s", name) + } + + pluginInstall.Enabled = enabled + pluginDatabase.Plugins[name] = pluginInstall + + if err := pluginDatabase.Save(); err != nil { + return nil, fmt.Errorf("failed to save plugin database: %v", err) + } + + err := pluginInstall.ReconcileSubprocess() + if err != nil { + return nil, fmt.Errorf("failed to stop plugin %s: %v", name, err) + } + + status, err := pluginInstall.GetStatus() + if err != nil { + return nil, fmt.Errorf("failed to get plugin status for %s: %v", name, err) + } + return status, nil +} + +func RpcPluginUninstall(name string) error { + pluginInstall, ok := pluginDatabase.Plugins[name] + if !ok { + return fmt.Errorf("plugin not found: %s", name) + } + + pluginInstall.Enabled = false + + err := pluginInstall.ReconcileSubprocess() + if err != nil { + return fmt.Errorf("failed to stop plugin %s: %v", name, err) + } + + delete(pluginDatabase.Plugins, name) + if err := pluginDatabase.Save(); err != nil { + return fmt.Errorf("failed to save plugin database: %v", err) + } + + err = pluginDatabase.CleanupExtractDirectories() + if err != nil { + return fmt.Errorf("failed to cleanup extract directories: %v", err) + } + + return nil +} + +func readManifest(extractFolder string) (*PluginManifest, error) { + manifestPath := path.Join(extractFolder, "manifest.json") + manifestFile, err := os.Open(manifestPath) + if err != nil { + return nil, fmt.Errorf("failed to open manifest file: %v", err) + } + defer manifestFile.Close() + + manifest := PluginManifest{} + if err := json.NewDecoder(manifestFile).Decode(&manifest); err != nil { + return nil, fmt.Errorf("failed to read manifest file: %v", err) + } + + if err := validateManifest(&manifest); err != nil { + return nil, fmt.Errorf("invalid manifest file: %v", err) + } + + return &manifest, nil +} + +func validateManifest(manifest *PluginManifest) error { + if manifest.ManifestVersion != "1" { + return fmt.Errorf("unsupported manifest version: %s", manifest.ManifestVersion) + } + + if manifest.Name == "" { + return fmt.Errorf("missing plugin name") + } + + if manifest.Version == "" { + return fmt.Errorf("missing plugin version") + } + + if manifest.Homepage == "" { + return fmt.Errorf("missing plugin homepage") + } + + return nil +} diff --git a/internal/plugin/process_manager.go b/internal/plugin/process_manager.go new file mode 100644 index 00000000..9d647d88 --- /dev/null +++ b/internal/plugin/process_manager.go @@ -0,0 +1,119 @@ +package plugin + +import ( + "fmt" + "log" + "os/exec" + "syscall" + "time" +) + +// TODO: this can probably be defaulted to this, but overwritten on a per-plugin basis +const ( + gracefulShutdownDelay = 30 * time.Second + maxRestartBackoff = 30 * time.Second +) + +type ProcessManager struct { + cmdGen func() *exec.Cmd + cmd *exec.Cmd + enabled bool + backoff time.Duration + shutdown chan struct{} + restartCh chan struct{} + LastError error +} + +func NewProcessManager(commandGenerator func() *exec.Cmd) *ProcessManager { + return &ProcessManager{ + cmdGen: commandGenerator, + enabled: true, + backoff: 250 * time.Millisecond, + shutdown: make(chan struct{}), + restartCh: make(chan struct{}, 1), + } +} + +func (pm *ProcessManager) StartMonitor() { + go pm.monitor() +} + +func (pm *ProcessManager) monitor() { + for { + select { + case <-pm.shutdown: + pm.terminate() + return + case <-pm.restartCh: + if pm.enabled { + go pm.runProcess() + } + } + } +} + +func (pm *ProcessManager) runProcess() { + pm.LastError = nil + pm.cmd = pm.cmdGen() + log.Printf("Starting process: %v", pm.cmd) + err := pm.cmd.Start() + if err != nil { + log.Printf("Failed to start process: %v", err) + pm.LastError = fmt.Errorf("failed to start process: %w", err) + pm.scheduleRestart() + return + } + + err = pm.cmd.Wait() + if err != nil { + log.Printf("Process exited: %v", err) + pm.LastError = fmt.Errorf("process exited with error: %w", err) + pm.scheduleRestart() + } +} + +func (pm *ProcessManager) scheduleRestart() { + if pm.enabled { + log.Printf("Restarting process in %v...", pm.backoff) + time.Sleep(pm.backoff) + pm.backoff *= 2 // Exponential backoff + if pm.backoff > maxRestartBackoff { + pm.backoff = maxRestartBackoff + } + pm.restartCh <- struct{}{} + } +} + +func (pm *ProcessManager) terminate() { + if pm.cmd.Process != nil { + log.Printf("Sending SIGTERM...") + pm.cmd.Process.Signal(syscall.SIGTERM) + select { + case <-time.After(gracefulShutdownDelay): + log.Printf("Forcing process termination...") + pm.cmd.Process.Kill() + case <-pm.waitForExit(): + log.Printf("Process exited gracefully.") + } + } +} + +func (pm *ProcessManager) waitForExit() <-chan struct{} { + done := make(chan struct{}) + go func() { + pm.cmd.Wait() + close(done) + }() + return done +} + +func (pm *ProcessManager) Enable() { + pm.enabled = true + pm.restartCh <- struct{}{} +} + +func (pm *ProcessManager) Disable() { + pm.enabled = false + close(pm.shutdown) + pm.cmd.Wait() +} diff --git a/internal/plugin/rpc.go b/internal/plugin/rpc.go new file mode 100644 index 00000000..dacb1d89 --- /dev/null +++ b/internal/plugin/rpc.go @@ -0,0 +1,174 @@ +package plugin + +import ( + "context" + "errors" + "fmt" + "kvm/internal/jsonrpc" + "log" + "net" + "os" + "path" + "slices" + "time" +) + +type PluginRpcStatus struct { + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +var ( + PluginRpcStatusDisconnected = PluginRpcStatus{"disconnected", ""} + PluginRpcStatusUnknown = PluginRpcStatus{"unknown", ""} + PluginRpcStatusLoading = PluginRpcStatus{"loading", ""} + PluginRpcStatusPendingConfiguration = PluginRpcStatus{"pending-configuration", ""} + PluginRpcStatusRunning = PluginRpcStatus{"running", ""} + PluginRpcStatusError = PluginRpcStatus{"error", ""} +) + +type PluginRpcSupportedMethods struct { + SupportedRpcMethods []string `json:"supported_rpc_methods"` +} + +type PluginRpcServer struct { + install *PluginInstall + workingDir string + + listener net.Listener + status PluginRpcStatus +} + +func NewPluginRpcServer(install *PluginInstall, workingDir string) *PluginRpcServer { + return &PluginRpcServer{ + install: install, + workingDir: workingDir, + status: PluginRpcStatusDisconnected, + } +} + +func (s *PluginRpcServer) Start() error { + socketPath := s.SocketPath() + _ = os.Remove(socketPath) + listener, err := net.Listen("unix", socketPath) + if err != nil { + return fmt.Errorf("failed to listen on socket: %v", err) + } + s.listener = listener + + s.status = PluginRpcStatusDisconnected + go func() { + for { + conn, err := listener.Accept() + if err != nil { + // If the error indicates the listener is closed, break out + if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" { + log.Println("Listener closed, exiting accept loop.") + return + } + + log.Printf("Failed to accept connection: %v", err) + continue + } + log.Printf("Accepted plugin rpc connection from %v", conn.RemoteAddr()) + + go s.handleConnection(conn) + } + }() + + return nil +} + +func (s *PluginRpcServer) Stop() error { + if s.listener != nil { + s.status = PluginRpcStatusDisconnected + return s.listener.Close() + } + return nil +} + +func (s *PluginRpcServer) Status() PluginRpcStatus { + return s.status +} + +func (s *PluginRpcServer) SocketPath() string { + return path.Join(s.workingDir, "plugin.sock") +} + +func (s *PluginRpcServer) handleConnection(conn net.Conn) { + rpcserver := jsonrpc.NewJSONRPCRouter(conn, map[string]*jsonrpc.RPCHandler{}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go s.handleRpcStatus(ctx, rpcserver) + + // Read from the conn and write into rpcserver.HandleMessage + buf := make([]byte, 65*1024) + for { + // TODO: if read 65k bytes, then likey there is more data to read... figure out how to handle this + n, err := conn.Read(buf) + if err != nil { + if errors.Is(err, net.ErrClosed) { + s.status = PluginRpcStatusDisconnected + } else { + log.Printf("Failed to read message: %v", err) + s.status = PluginRpcStatusError + s.status.Message = fmt.Errorf("failed to read message: %v", err).Error() + } + break + } + + err = rpcserver.HandleMessage(buf[:n]) + if err != nil { + log.Printf("Failed to handle message: %v", err) + s.status = PluginRpcStatusError + s.status.Message = fmt.Errorf("failed to handle message: %v", err).Error() + continue + } + } +} + +func (s *PluginRpcServer) handleRpcStatus(ctx context.Context, rpcserver *jsonrpc.JSONRPCRouter) { + s.status = PluginRpcStatusUnknown + + log.Printf("Plugin rpc server started. Getting supported methods...") + var supportedMethodsResponse PluginRpcSupportedMethods + err := rpcserver.Request("getPluginSupportedMethods", nil, &supportedMethodsResponse) + if err != nil { + log.Printf("Failed to get supported methods: %v", err) + s.status = PluginRpcStatusError + s.status.Message = fmt.Errorf("error getting supported methods: %v", err.Message).Error() + } + + log.Printf("Plugin has supported methods: %v", supportedMethodsResponse.SupportedRpcMethods) + + if !slices.Contains(supportedMethodsResponse.SupportedRpcMethods, "getPluginStatus") { + log.Printf("Plugin does not support getPluginStatus method") + return + } + + ticker := time.NewTicker(1 * time.Second) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + var statusResponse PluginRpcStatus + err := rpcserver.Request("getPluginStatus", nil, &statusResponse) + if err != nil { + log.Printf("Failed to get status: %v", err) + if err, ok := err.Data.(error); ok && errors.Is(err, net.ErrClosed) { + s.status = PluginRpcStatusDisconnected + break + } + + s.status = PluginRpcStatusError + s.status.Message = fmt.Errorf("error getting status: %v", err).Error() + continue + } + + s.status = statusResponse + } + } +} diff --git a/internal/plugin/type.go b/internal/plugin/type.go new file mode 100644 index 00000000..de1001a0 --- /dev/null +++ b/internal/plugin/type.go @@ -0,0 +1,18 @@ +package plugin + +type PluginManifest struct { + ManifestVersion string `json:"manifest_version"` + Name string `json:"name"` + Version string `json:"version"` + Description string `json:"description,omitempty"` + Homepage string `json:"homepage"` + BinaryPath string `json:"bin"` + SystemMinVersion string `json:"system_min_version,omitempty"` +} + +type PluginStatus struct { + PluginManifest + Enabled bool `json:"enabled"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} diff --git a/internal/storage/type.go b/internal/storage/type.go new file mode 100644 index 00000000..ba7a1232 --- /dev/null +++ b/internal/storage/type.go @@ -0,0 +1,6 @@ +package storage + +type StorageFileUpload struct { + AlreadyUploadedBytes int64 `json:"alreadyUploadedBytes"` + DataChannel string `json:"dataChannel"` +} diff --git a/internal/storage/uploads.go b/internal/storage/uploads.go new file mode 100644 index 00000000..48fdaf7a --- /dev/null +++ b/internal/storage/uploads.go @@ -0,0 +1,34 @@ +package storage + +import ( + "os" + "sync" +) + +type PendingUpload struct { + File *os.File + Size int64 + AlreadyUploadedBytes int64 +} + +var pendingUploads = make(map[string]PendingUpload) +var pendingUploadsMutex sync.Mutex + +func GetPendingUpload(uploadId string) (PendingUpload, bool) { + pendingUploadsMutex.Lock() + defer pendingUploadsMutex.Unlock() + upload, ok := pendingUploads[uploadId] + return upload, ok +} + +func AddPendingUpload(uploadId string, upload PendingUpload) { + pendingUploadsMutex.Lock() + defer pendingUploadsMutex.Unlock() + pendingUploads[uploadId] = upload +} + +func DeletePendingUpload(uploadId string) { + pendingUploadsMutex.Lock() + defer pendingUploadsMutex.Unlock() + delete(pendingUploads, uploadId) +} diff --git a/internal/storage/utils.go b/internal/storage/utils.go new file mode 100644 index 00000000..e622fc23 --- /dev/null +++ b/internal/storage/utils.go @@ -0,0 +1,19 @@ +package storage + +import ( + "errors" + "path/filepath" + "strings" +) + +func SanitizeFilename(filename string) (string, error) { + cleanPath := filepath.Clean(filename) + if filepath.IsAbs(cleanPath) || strings.Contains(cleanPath, "..") { + return "", errors.New("invalid filename") + } + sanitized := filepath.Base(cleanPath) + if sanitized == "." || sanitized == string(filepath.Separator) { + return "", errors.New("invalid filename") + } + return sanitized, nil +} diff --git a/jsonrpc.go b/jsonrpc.go index 2ce5f189..a80439d6 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -5,50 +5,45 @@ import ( "encoding/json" "errors" "fmt" + "kvm/internal/jsonrpc" + "kvm/internal/plugin" "log" "os" "os/exec" "path/filepath" - "reflect" "github.com/pion/webrtc/v4" ) -type JSONRPCRequest struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params map[string]interface{} `json:"params,omitempty"` - ID interface{} `json:"id,omitempty"` +type DataChannelWriter struct { + dataChannel *webrtc.DataChannel } -type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - Result interface{} `json:"result,omitempty"` - Error interface{} `json:"error,omitempty"` - ID interface{} `json:"id"` -} - -type JSONRPCEvent struct { - JSONRPC string `json:"jsonrpc"` - Method string `json:"method"` - Params interface{} `json:"params,omitempty"` +func NewDataChannelWriter(dataChannel *webrtc.DataChannel) *DataChannelWriter { + return &DataChannelWriter{ + dataChannel: dataChannel, + } } -func writeJSONRPCResponse(response JSONRPCResponse, session *Session) { - responseBytes, err := json.Marshal(response) - if err != nil { - log.Println("Error marshalling JSONRPC response:", err) - return - } - err = session.RPCChannel.SendText(string(responseBytes)) +func (w *DataChannelWriter) Write(data []byte) (int, error) { + err := w.dataChannel.SendText(string(data)) if err != nil { log.Println("Error sending JSONRPC response:", err) - return + return 0, err } + return len(data), nil } +func NewDataChannelJsonRpcRouter(dataChannel *webrtc.DataChannel) *jsonrpc.JSONRPCRouter { + return jsonrpc.NewJSONRPCRouter( + NewDataChannelWriter(dataChannel), + rpcHandlers, + ) +} + +// TODO: embed this into the session's rpc server func writeJSONRPCEvent(event string, params interface{}, session *Session) { - request := JSONRPCEvent{ + request := jsonrpc.JSONRPCEvent{ JSONRPC: "2.0", Method: event, Params: params, @@ -69,60 +64,6 @@ func writeJSONRPCEvent(event string, params interface{}, session *Session) { } } -func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { - var request JSONRPCRequest - err := json.Unmarshal(message.Data, &request) - if err != nil { - errorResponse := JSONRPCResponse{ - JSONRPC: "2.0", - Error: map[string]interface{}{ - "code": -32700, - "message": "Parse error", - }, - ID: 0, - } - writeJSONRPCResponse(errorResponse, session) - return - } - - //log.Printf("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID) - handler, ok := rpcHandlers[request.Method] - if !ok { - errorResponse := JSONRPCResponse{ - JSONRPC: "2.0", - Error: map[string]interface{}{ - "code": -32601, - "message": "Method not found", - }, - ID: request.ID, - } - writeJSONRPCResponse(errorResponse, session) - return - } - - result, err := callRPCHandler(handler, request.Params) - if err != nil { - errorResponse := JSONRPCResponse{ - JSONRPC: "2.0", - Error: map[string]interface{}{ - "code": -32603, - "message": "Internal error", - "data": err.Error(), - }, - ID: request.ID, - } - writeJSONRPCResponse(errorResponse, session) - return - } - - response := JSONRPCResponse{ - JSONRPC: "2.0", - Result: result, - ID: request.ID, - } - writeJSONRPCResponse(response, session) -} - func rpcPing() (string, error) { return "pong", nil } @@ -315,108 +256,6 @@ func rpcSetSSHKeyState(sshKey string) error { return nil } -func callRPCHandler(handler RPCHandler, params map[string]interface{}) (interface{}, error) { - handlerValue := reflect.ValueOf(handler.Func) - handlerType := handlerValue.Type() - - if handlerType.Kind() != reflect.Func { - return nil, errors.New("handler is not a function") - } - - numParams := handlerType.NumIn() - args := make([]reflect.Value, numParams) - // Get the parameter names from the RPCHandler - paramNames := handler.Params - - if len(paramNames) != numParams { - return nil, errors.New("mismatch between handler parameters and defined parameter names") - } - - for i := 0; i < numParams; i++ { - paramType := handlerType.In(i) - paramName := paramNames[i] - paramValue, ok := params[paramName] - if !ok { - return nil, errors.New("missing parameter: " + paramName) - } - - convertedValue := reflect.ValueOf(paramValue) - if !convertedValue.Type().ConvertibleTo(paramType) { - if paramType.Kind() == reflect.Slice && (convertedValue.Kind() == reflect.Slice || convertedValue.Kind() == reflect.Array) { - newSlice := reflect.MakeSlice(paramType, convertedValue.Len(), convertedValue.Len()) - for j := 0; j < convertedValue.Len(); j++ { - elemValue := convertedValue.Index(j) - if elemValue.Kind() == reflect.Interface { - elemValue = elemValue.Elem() - } - if !elemValue.Type().ConvertibleTo(paramType.Elem()) { - // Handle float64 to uint8 conversion - if elemValue.Kind() == reflect.Float64 && paramType.Elem().Kind() == reflect.Uint8 { - intValue := int(elemValue.Float()) - if intValue < 0 || intValue > 255 { - return nil, fmt.Errorf("value out of range for uint8: %v", intValue) - } - newSlice.Index(j).SetUint(uint64(intValue)) - } else { - fromType := elemValue.Type() - toType := paramType.Elem() - return nil, fmt.Errorf("invalid element type in slice for parameter %s: from %v to %v", paramName, fromType, toType) - } - } else { - newSlice.Index(j).Set(elemValue.Convert(paramType.Elem())) - } - } - args[i] = newSlice - } else if paramType.Kind() == reflect.Struct && convertedValue.Kind() == reflect.Map { - jsonData, err := json.Marshal(convertedValue.Interface()) - if err != nil { - return nil, fmt.Errorf("failed to marshal map to JSON: %v", err) - } - - newStruct := reflect.New(paramType).Interface() - if err := json.Unmarshal(jsonData, newStruct); err != nil { - return nil, fmt.Errorf("failed to unmarshal JSON into struct: %v", err) - } - args[i] = reflect.ValueOf(newStruct).Elem() - } else { - return nil, fmt.Errorf("invalid parameter type for: %s", paramName) - } - } else { - args[i] = convertedValue.Convert(paramType) - } - } - - results := handlerValue.Call(args) - - if len(results) == 0 { - return nil, nil - } - - if len(results) == 1 { - if results[0].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { - if !results[0].IsNil() { - return nil, results[0].Interface().(error) - } - return nil, nil - } - return results[0].Interface(), nil - } - - if len(results) == 2 && results[1].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { - if !results[1].IsNil() { - return nil, results[1].Interface().(error) - } - return results[0].Interface(), nil - } - - return nil, errors.New("unexpected return values from handler") -} - -type RPCHandler struct { - Func interface{} - Params []string -} - func rpcSetMassStorageMode(mode string) (string, error) { log.Printf("[jsonrpc.go:rpcSetMassStorageMode] Setting mass storage mode to: %s", mode) var cdrom bool @@ -508,7 +347,7 @@ func rpcResetConfig() error { } // TODO: replace this crap with code generator -var rpcHandlers = map[string]RPCHandler{ +var rpcHandlers = map[string]*jsonrpc.RPCHandler{ "ping": {Func: rpcPing}, "getDeviceID": {Func: rpcGetDeviceID}, "deregisterDevice": {Func: rpcDeregisterDevice}, @@ -554,4 +393,10 @@ var rpcHandlers = map[string]RPCHandler{ "getWakeOnLanDevices": {Func: rpcGetWakeOnLanDevices}, "setWakeOnLanDevices": {Func: rpcSetWakeOnLanDevices, Params: []string{"params"}}, "resetConfig": {Func: rpcResetConfig}, + "pluginStartUpload": {Func: plugin.RpcPluginStartUpload, Params: []string{"filename", "size"}}, + "pluginExtract": {Func: plugin.RpcPluginExtract, Params: []string{"filename"}}, + "pluginInstall": {Func: plugin.RpcPluginInstall, Params: []string{"name", "version"}}, + "pluginList": {Func: plugin.RpcPluginList}, + "pluginUpdateConfig": {Func: plugin.RpcPluginUpdateConfig, Params: []string{"name", "enabled"}}, + "pluginUninstall": {Func: plugin.RpcPluginUninstall, Params: []string{"name"}}, } diff --git a/main.go b/main.go index 7ff771f5..ce9d1fb4 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package kvm import ( "context" + "kvm/internal/plugin" "log" "net/http" "os" @@ -66,15 +67,20 @@ func Main() { }() //go RunFuseServer() go RunWebServer() + go plugin.ReconcilePlugins() + // If the cloud token isn't set, the client won't be started by default. // However, if the user adopts the device via the web interface, handleCloudRegister will start the client. if config.CloudToken != "" { go RunWebsocketClient() } + sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs log.Println("JetKVM Shutting Down") + + plugin.GracefullyShutdownPlugins() //if fuseServer != nil { // err := setMassStorageImage(" ") // if err != nil { diff --git a/ui/package.json b/ui/package.json index 592a300b..9a7fae57 100644 --- a/ui/package.json +++ b/ui/package.json @@ -10,6 +10,7 @@ "dev": "vite dev --mode=development", "build": "npm run build:prod", "build:device": "tsc && vite build --mode=device --emptyOutDir", + "dev:device": "vite dev --mode=device", "build:prod": "tsc && vite build --mode=production", "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0" }, diff --git a/ui/src/components/MountMediaDialog.tsx b/ui/src/components/MountMediaDialog.tsx index 4aca608f..6f7c96b6 100644 --- a/ui/src/components/MountMediaDialog.tsx +++ b/ui/src/components/MountMediaDialog.tsx @@ -1516,7 +1516,7 @@ function PreUploadedImageItem({ ); } -function ViewHeader({ title, description }: { title: string; description: string }) { +export function ViewHeader({ title, description }: { title: string; description: string }) { return (
+ {plugin.status} +
+{error}
} + {plugin?.message && ( + <> ++ Plugin message: +
++ Plugin configuration coming soon +
+ + + +{plugin.name}
+ ++ Supported formats: TAR, TAR.GZ +
++ {formatters.bytes(uploadedFileSize || 0)} +
++ {formatters.truncateMiddle(uploadedFileName, 40)} has been + uploaded +
+{fileError}
} +{manifest.description}
++ Version: {manifest.version} +
+ ++ An error occurred while attempting to extract the plugin. Please ensure the plugin is valid and try again. +
+{errorMessage}
+