diff --git a/internal/bind/binder.go b/internal/bind/binder.go deleted file mode 100644 index 9e66cfee..00000000 --- a/internal/bind/binder.go +++ /dev/null @@ -1,173 +0,0 @@ -package bind - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "github.com/jsiebens/ionscale/internal/util" - "github.com/klauspost/compress/zstd" - "github.com/labstack/echo/v4" - "io/ioutil" - "sync" - "tailscale.com/smallzstd" - "tailscale.com/types/key" -) - -type Factory func(c echo.Context) (Binder, error) - -type Binder interface { - BindRequest(c echo.Context, v interface{}) error - WriteResponse(c echo.Context, code int, v interface{}) error - Marshal(compress string, v interface{}) ([]byte, error) - Peer() key.MachinePublic -} - -func DefaultBinder(machineKey key.MachinePublic) Factory { - return func(c echo.Context) (Binder, error) { - return &defaultBinder{machineKey: machineKey}, nil - } -} - -func BoxBinder(controlKey key.MachinePrivate) Factory { - return func(c echo.Context) (Binder, error) { - idParam := c.Param("id") - - id, err := util.ParseMachinePublicKey(idParam) - - if err != nil { - return nil, err - } - - return &boxBinder{ - controlKey: controlKey, - machineKey: *id, - }, nil - } -} - -type defaultBinder struct { - machineKey key.MachinePublic -} - -func (d *defaultBinder) BindRequest(c echo.Context, v interface{}) error { - body, err := ioutil.ReadAll(c.Request().Body) - if err != nil { - return err - } - - return json.Unmarshal(body, v) -} - -func (d *defaultBinder) WriteResponse(c echo.Context, code int, v interface{}) error { - marshalled, err := json.Marshal(v) - if err != nil { - return err - } - - c.Response().WriteHeader(code) - _, err = c.Response().Write(marshalled) - - return err -} - -func (d *defaultBinder) Marshal(compress string, v interface{}) ([]byte, error) { - var payload []byte - - marshalled, err := json.Marshal(v) - if err != nil { - return nil, err - } - - if compress == "zstd" { - payload = zstdEncode(marshalled) - } else { - payload = marshalled - } - - data := make([]byte, 4) - binary.LittleEndian.PutUint32(data, uint32(len(payload))) - data = append(data, payload...) - - return data, nil -} - -func (d *defaultBinder) Peer() key.MachinePublic { - return d.machineKey -} - -type boxBinder struct { - controlKey key.MachinePrivate - machineKey key.MachinePublic -} - -func (b *boxBinder) BindRequest(c echo.Context, v interface{}) error { - body, err := ioutil.ReadAll(c.Request().Body) - if err != nil { - return err - } - - decrypted, ok := b.controlKey.OpenFrom(b.machineKey, body) - if !ok { - return fmt.Errorf("unable to decrypt payload") - } - - return json.Unmarshal(decrypted, v) -} - -func (b *boxBinder) WriteResponse(c echo.Context, code int, v interface{}) error { - marshalled, err := json.Marshal(v) - if err != nil { - return err - } - - encrypted := b.controlKey.SealTo(b.machineKey, marshalled) - - c.Response().WriteHeader(code) - _, err = c.Response().Write(encrypted) - - return err -} - -func (b *boxBinder) Marshal(compress string, v interface{}) ([]byte, error) { - var payload []byte - - marshalled, err := json.Marshal(v) - if err != nil { - return nil, err - } - - if compress == "zstd" { - encoded := zstdEncode(marshalled) - payload = b.controlKey.SealTo(b.machineKey, encoded) - } else { - payload = b.controlKey.SealTo(b.machineKey, marshalled) - } - - data := make([]byte, 4) - binary.LittleEndian.PutUint32(data, uint32(len(payload))) - data = append(data, payload...) - - return data, nil -} - -func (b *boxBinder) Peer() key.MachinePublic { - return b.machineKey -} - -func zstdEncode(in []byte) []byte { - encoder := zstdEncoderPool.Get().(*zstd.Encoder) - out := encoder.EncodeAll(in, nil) - encoder.Close() - zstdEncoderPool.Put(encoder) - return out -} - -var zstdEncoderPool = &sync.Pool{ - New: func() any { - encoder, err := smallzstd.NewEncoder(nil, zstd.WithEncoderLevel(zstd.SpeedFastest)) - if err != nil { - panic(err) - } - return encoder - }, -} diff --git a/internal/handlers/dns.go b/internal/handlers/dns.go index 27afdbdd..a46de1f7 100644 --- a/internal/handlers/dns.go +++ b/internal/handlers/dns.go @@ -1,38 +1,31 @@ package handlers import ( - "github.com/jsiebens/ionscale/internal/bind" "github.com/jsiebens/ionscale/internal/dns" "github.com/labstack/echo/v4" "net" "net/http" "strings" "tailscale.com/tailcfg" + "tailscale.com/types/key" "time" ) -func NewDNSHandlers(createBinder bind.Factory, provider dns.Provider) *DNSHandlers { +func NewDNSHandlers(_ key.MachinePublic, provider dns.Provider) *DNSHandlers { return &DNSHandlers{ - createBinder: createBinder, - provider: provider, + provider: provider, } } type DNSHandlers struct { - createBinder bind.Factory - provider dns.Provider + provider dns.Provider } func (h *DNSHandlers) SetDNS(c echo.Context) error { ctx := c.Request().Context() - binder, err := h.createBinder(c) - if err != nil { - return logError(err) - } - req := &tailcfg.SetDNSRequest{} - if err := binder.BindRequest(c, req); err != nil { + if err := c.Bind(req); err != nil { return logError(err) } @@ -58,16 +51,16 @@ func (h *DNSHandlers) SetDNS(c echo.Context) error { txtrecords, _ := net.LookupTXT(req.Name) for _, txt := range txtrecords { if txt == req.Value { - return binder.WriteResponse(c, http.StatusOK, tailcfg.SetDNSResponse{}) + return c.JSON(http.StatusOK, tailcfg.SetDNSResponse{}) } } case <-timeout: - return binder.WriteResponse(c, http.StatusOK, tailcfg.SetDNSResponse{}) + return c.JSON(http.StatusOK, tailcfg.SetDNSResponse{}) case <-notify: return nil } } } - return binder.WriteResponse(c, http.StatusOK, tailcfg.SetDNSResponse{}) + return c.JSON(http.StatusOK, tailcfg.SetDNSResponse{}) } diff --git a/internal/handlers/id_token.go b/internal/handlers/id_token.go index ccca26b4..f500b81c 100644 --- a/internal/handlers/id_token.go +++ b/internal/handlers/id_token.go @@ -4,63 +4,36 @@ import ( "fmt" "github.com/go-jose/go-jose/v3" "github.com/golang-jwt/jwt/v4" - "github.com/jsiebens/ionscale/internal/bind" "github.com/jsiebens/ionscale/internal/config" "github.com/jsiebens/ionscale/internal/domain" "github.com/jsiebens/ionscale/internal/util" "github.com/labstack/echo/v4" "net/http" "tailscale.com/tailcfg" + "tailscale.com/types/key" "time" ) -func NewIDTokenHandlers(createBinder bind.Factory, config *config.Config, repository domain.Repository) *IDTokenHandlers { +func NewIDTokenHandlers(machineKey key.MachinePublic, config *config.Config, repository domain.Repository) *IDTokenHandlers { return &IDTokenHandlers{ - issuer: config.ServerUrl, - jwksUri: config.CreateUrl("/.well-known/jwks"), - createBinder: createBinder, - repository: repository, + machineKey: machineKey, + issuer: config.ServerUrl, + repository: repository, } } -type IDTokenHandlers struct { - issuer string - jwksUri string - createBinder bind.Factory - repository domain.Repository -} - -func (h *IDTokenHandlers) OpenIDConfig(c echo.Context) error { - v := map[string]interface{}{} - - v["issuer"] = h.issuer - v["jwks_uri"] = h.jwksUri - v["subject_types_supported"] = []string{"public"} - v["response_types_supported"] = []string{"id_token"} - v["scopes_supported"] = []string{"openid"} - v["id_token_signing_alg_values_supported"] = []string{"RS256"} - v["claims_supported"] = []string{ - "sub", - "aud", - "exp", - "iat", - "iss", - "jti", - "nbf", +func NewOIDCConfigHandlers(config *config.Config, repository domain.Repository) *OIDCConfigHandlers { + return &OIDCConfigHandlers{ + issuer: config.ServerUrl, + jwksUri: config.CreateUrl("/.well-known/jwks"), + repository: repository, } - - return c.JSON(http.StatusOK, v) } -func (h *IDTokenHandlers) Jwks(c echo.Context) error { - keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context()) - if err != nil { - return logError(err) - } - - pub := jose.JSONWebKey{Key: keySet.Key.Public(), KeyID: keySet.Key.Id, Algorithm: "RS256", Use: "sig"} - set := jose.JSONWebKeySet{Keys: []jose.JSONWebKey{pub}} - return c.JSON(http.StatusOK, set) +type IDTokenHandlers struct { + machineKey key.MachinePublic + issuer string + repository domain.Repository } func (h *IDTokenHandlers) FetchToken(c echo.Context) error { @@ -71,17 +44,12 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error { return logError(err) } - binder, err := h.createBinder(c) - if err != nil { - return logError(err) - } - req := &tailcfg.TokenRequest{} - if err := binder.BindRequest(c, req); err != nil { + if err := c.Bind(req); err != nil { return logError(err) } - machineKey := binder.Peer().String() + machineKey := h.machineKey.String() nodeKey := req.NodeKey.String() var m *domain.Machine @@ -134,7 +102,46 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error { } resp := tailcfg.TokenResponse{IDToken: jwtB64} - return binder.WriteResponse(c, http.StatusOK, resp) + return c.JSON(http.StatusOK, resp) +} + +type OIDCConfigHandlers struct { + issuer string + jwksUri string + repository domain.Repository +} + +func (h *OIDCConfigHandlers) OpenIDConfig(c echo.Context) error { + v := map[string]interface{}{} + + v["issuer"] = h.issuer + v["jwks_uri"] = h.jwksUri + v["subject_types_supported"] = []string{"public"} + v["response_types_supported"] = []string{"id_token"} + v["scopes_supported"] = []string{"openid"} + v["id_token_signing_alg_values_supported"] = []string{"RS256"} + v["claims_supported"] = []string{ + "sub", + "aud", + "exp", + "iat", + "iss", + "jti", + "nbf", + } + + return c.JSON(http.StatusOK, v) +} + +func (h *OIDCConfigHandlers) Jwks(c echo.Context) error { + keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context()) + if err != nil { + return logError(err) + } + + pub := jose.JSONWebKey{Key: keySet.Key.Public(), KeyID: keySet.Key.Id, Algorithm: "RS256", Use: "sig"} + set := jose.JSONWebKeySet{Keys: []jose.JSONWebKey{pub}} + return c.JSON(http.StatusOK, set) } func (h *IDTokenHandlers) names(m *domain.Machine) (string, string, string) { diff --git a/internal/handlers/noise.go b/internal/handlers/noise.go index 2c0548a1..1dceaba3 100644 --- a/internal/handlers/noise.go +++ b/internal/handlers/noise.go @@ -41,3 +41,35 @@ func (h *NoiseHandlers) Upgrade(c echo.Context) error { } return nil } + +type JsonBinder struct { + echo.DefaultBinder +} + +func (b JsonBinder) Bind(i interface{}, c echo.Context) error { + if err := b.BindPathParams(c, i); err != nil { + return err + } + + method := c.Request().Method + if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { + if err := b.BindQueryParams(c, i); err != nil { + return err + } + } + + if c.Request().ContentLength == 0 { + return nil + } + + if err := c.Echo().JSONSerializer.Deserialize(c, i); err != nil { + switch err.(type) { + case *echo.HTTPError: + return err + default: + return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + } + } + + return nil +} diff --git a/internal/handlers/poll_net_map.go b/internal/handlers/poll_net_map.go index 95071e20..294c0cd1 100644 --- a/internal/handlers/poll_net_map.go +++ b/internal/handlers/poll_net_map.go @@ -2,24 +2,29 @@ package handlers import ( "context" - "github.com/jsiebens/ionscale/internal/bind" + "encoding/binary" + "encoding/json" "github.com/jsiebens/ionscale/internal/config" "github.com/jsiebens/ionscale/internal/core" "github.com/jsiebens/ionscale/internal/domain" "github.com/jsiebens/ionscale/internal/mapping" + "github.com/klauspost/compress/zstd" "github.com/labstack/echo/v4" "net/http" + "sync" + "tailscale.com/smallzstd" "tailscale.com/tailcfg" + "tailscale.com/types/key" "time" ) func NewPollNetMapHandler( - createBinder bind.Factory, + machineKey key.MachinePublic, sessionManager core.PollMapSessionManager, repository domain.Repository) *PollNetMapHandler { handler := &PollNetMapHandler{ - createBinder: createBinder, + machineKey: machineKey, sessionManager: sessionManager, repository: repository, } @@ -28,28 +33,24 @@ func NewPollNetMapHandler( } type PollNetMapHandler struct { - createBinder bind.Factory + machineKey key.MachinePublic repository domain.Repository sessionManager core.PollMapSessionManager } func (h *PollNetMapHandler) PollNetMap(c echo.Context) error { ctx := c.Request().Context() - binder, err := h.createBinder(c) - if err != nil { - return logError(err) - } req := &tailcfg.MapRequest{} - if err := binder.BindRequest(c, req); err != nil { + if err := c.Bind(req); err != nil { return logError(err) } - machineKey := binder.Peer().String() + machineKey := h.machineKey.String() nodeKey := req.NodeKey.String() var m *domain.Machine - m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey) + m, err := h.repository.GetMachineByKeys(ctx, machineKey, nodeKey) if err != nil { return logError(err) } @@ -59,13 +60,13 @@ func (h *PollNetMapHandler) PollNetMap(c echo.Context) error { } if req.ReadOnly { - return h.handleReadOnly(c, binder, m, req) + return h.handleReadOnly(c, m, req) } else { - return h.handleUpdate(c, binder, m, req) + return h.handleUpdate(c, m, req) } } -func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *domain.Machine, mapRequest *tailcfg.MapRequest) error { +func (h *PollNetMapHandler) handleUpdate(c echo.Context, m *domain.Machine, mapRequest *tailcfg.MapRequest) error { ctx := c.Request().Context() now := time.Now().UTC() @@ -90,7 +91,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m * mapper := mapping.NewPollNetMapper(mapRequest, m.ID, h.repository, h.sessionManager) - response, err := createMapResponse(mapper, binder, false, mapRequest.Compress) + response, err := h.createMapResponse(mapper, false, mapRequest.Compress) if err != nil { return logError(err) } @@ -101,7 +102,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m * // Listen to connection close notify := c.Request().Context().Done() - keepAliveResponse, err := createKeepAliveResponse(binder, mapRequest) + keepAliveResponse, err := h.createKeepAliveResponse(mapRequest) if err != nil { return logError(err) } @@ -154,7 +155,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m * var payload []byte var payloadErr error - payload, payloadErr = createMapResponse(mapper, binder, true, mapRequest.Compress) + payload, payloadErr = h.createMapResponse(mapper, true, mapRequest.Compress) if payloadErr != nil { return payloadErr @@ -173,7 +174,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m * } } -func (h *PollNetMapHandler) handleReadOnly(c echo.Context, binder bind.Binder, m *domain.Machine, request *tailcfg.MapRequest) error { +func (h *PollNetMapHandler) handleReadOnly(c echo.Context, m *domain.Machine, request *tailcfg.MapRequest) error { ctx := c.Request().Context() m.HostInfo = domain.HostInfo(*request.Hostinfo) @@ -184,7 +185,7 @@ func (h *PollNetMapHandler) handleReadOnly(c echo.Context, binder bind.Binder, m } mapper := mapping.NewPollNetMapper(request, m.ID, h.repository, h.sessionManager) - payload, err := createMapResponse(mapper, binder, false, request.Compress) + payload, err := h.createMapResponse(mapper, false, request.Compress) if err != nil { return logError(err) } @@ -193,18 +194,57 @@ func (h *PollNetMapHandler) handleReadOnly(c echo.Context, binder bind.Binder, m return logError(err) } -func createKeepAliveResponse(binder bind.Binder, request *tailcfg.MapRequest) ([]byte, error) { +func (h *PollNetMapHandler) createKeepAliveResponse(request *tailcfg.MapRequest) ([]byte, error) { mapResponse := &tailcfg.MapResponse{ KeepAlive: true, } - return binder.Marshal(request.Compress, mapResponse) + return h.marshalResponse(request.Compress, mapResponse) } -func createMapResponse(m *mapping.PollNetMapper, binder bind.Binder, delta bool, compress string) ([]byte, error) { +func (h *PollNetMapHandler) createMapResponse(m *mapping.PollNetMapper, delta bool, compress string) ([]byte, error) { response, err := m.CreateMapResponse(context.Background(), delta) if err != nil { return nil, err } - return binder.Marshal(compress, response) + return h.marshalResponse(compress, response) +} + +func (h *PollNetMapHandler) marshalResponse(compress string, v interface{}) ([]byte, error) { + var payload []byte + + marshalled, err := json.Marshal(v) + if err != nil { + return nil, err + } + + if compress == "zstd" { + payload = zstdEncode(marshalled) + } else { + payload = marshalled + } + + data := make([]byte, 4) + binary.LittleEndian.PutUint32(data, uint32(len(payload))) + data = append(data, payload...) + + return data, nil +} + +func zstdEncode(in []byte) []byte { + encoder := zstdEncoderPool.Get().(*zstd.Encoder) + out := encoder.EncodeAll(in, nil) + _ = encoder.Close() + zstdEncoderPool.Put(encoder) + return out +} + +var zstdEncoderPool = &sync.Pool{ + New: func() any { + encoder, err := smallzstd.NewEncoder(nil, zstd.WithEncoderLevel(zstd.SpeedFastest)) + if err != nil { + panic(err) + } + return encoder + }, } diff --git a/internal/handlers/query_feature.go b/internal/handlers/query_feature.go index 46265875..4404add1 100644 --- a/internal/handlers/query_feature.go +++ b/internal/handlers/query_feature.go @@ -2,41 +2,37 @@ package handlers import ( "fmt" - "github.com/jsiebens/ionscale/internal/bind" "github.com/jsiebens/ionscale/internal/dns" "github.com/jsiebens/ionscale/internal/domain" "github.com/labstack/echo/v4" "net/http" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) -func NewQueryFeatureHandlers(createBinder bind.Factory, dnsProvider dns.Provider, repository domain.Repository) *QueryFeatureHandlers { +func NewQueryFeatureHandlers(machineKey key.MachinePublic, dnsProvider dns.Provider, repository domain.Repository) *QueryFeatureHandlers { return &QueryFeatureHandlers{ - createBinder: createBinder, - repository: repository, + machineKey: machineKey, + dnsProvider: dnsProvider, + repository: repository, } } type QueryFeatureHandlers struct { - createBinder bind.Factory - dnsProvider dns.Provider - repository domain.Repository + machineKey key.MachinePublic + dnsProvider dns.Provider + repository domain.Repository } func (h *QueryFeatureHandlers) QueryFeature(c echo.Context) error { ctx := c.Request().Context() - binder, err := h.createBinder(c) - if err != nil { - return logError(err) - } - req := new(tailcfg.QueryFeatureRequest) - if err := binder.BindRequest(c, req); err != nil { + if err := c.Bind(req); err != nil { return logError(err) } - machineKey := binder.Peer().String() + machineKey := h.machineKey.String() nodeKey := req.NodeKey.String() resp := tailcfg.QueryFeatureResponse{} @@ -61,7 +57,7 @@ func (h *QueryFeatureHandlers) QueryFeature(c echo.Context) error { resp.Text = fmt.Sprintf("Unknown feature request '%s'\n", req.Feature) } - return binder.WriteResponse(c, http.StatusOK, resp) + return c.JSON(http.StatusOK, resp) } const serverMessage = `Enabling HTTPS is required to use Serve: diff --git a/internal/handlers/registration.go b/internal/handlers/registration.go index 5b3414fb..7393335a 100644 --- a/internal/handlers/registration.go +++ b/internal/handlers/registration.go @@ -3,7 +3,6 @@ package handlers import ( "context" "github.com/jsiebens/ionscale/internal/addr" - "github.com/jsiebens/ionscale/internal/bind" "github.com/jsiebens/ionscale/internal/config" "github.com/jsiebens/ionscale/internal/core" "github.com/jsiebens/ionscale/internal/domain" @@ -13,17 +12,18 @@ import ( "net/http" "net/netip" "tailscale.com/tailcfg" + "tailscale.com/types/key" "tailscale.com/util/dnsname" "time" ) func NewRegistrationHandlers( - createBinder bind.Factory, + machineKey key.MachinePublic, config *config.Config, sessionManager core.PollMapSessionManager, repository domain.Repository) *RegistrationHandlers { return &RegistrationHandlers{ - createBinder: createBinder, + machineKey: machineKey, sessionManager: sessionManager, repository: repository, config: config, @@ -31,7 +31,7 @@ func NewRegistrationHandlers( } type RegistrationHandlers struct { - createBinder bind.Factory + machineKey key.MachinePublic repository domain.Repository sessionManager core.PollMapSessionManager config *config.Config @@ -40,21 +40,16 @@ type RegistrationHandlers struct { func (h *RegistrationHandlers) Register(c echo.Context) error { ctx := c.Request().Context() - binder, err := h.createBinder(c) - if err != nil { - return logError(err) - } - req := &tailcfg.RegisterRequest{} - if err := binder.BindRequest(c, req); err != nil { + if err := c.Bind(req); err != nil { return logError(err) } - machineKey := binder.Peer().String() + machineKey := h.machineKey.String() nodeKey := req.NodeKey.String() var m *domain.Machine - m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey) + m, err := h.repository.GetMachineByKeys(ctx, machineKey, nodeKey) if err != nil { return logError(err) @@ -63,7 +58,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error { if m != nil { if m.IsExpired() { response := tailcfg.RegisterResponse{NodeKeyExpired: true} - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } if !req.Expiry.IsZero() && req.Expiry.Before(time.Now()) { @@ -82,7 +77,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error { } response := tailcfg.RegisterResponse{NodeKeyExpired: true} - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname) @@ -111,17 +106,17 @@ func (h *RegistrationHandlers) Register(c echo.Context) error { Login: tLogin, } - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } - return h.authenticateMachine(c, binder, machineKey, req) + return h.authenticateMachine(c, machineKey, req) } -func (h *RegistrationHandlers) authenticateMachine(c echo.Context, binder bind.Binder, machineKey string, req *tailcfg.RegisterRequest) error { +func (h *RegistrationHandlers) authenticateMachine(c echo.Context, machineKey string, req *tailcfg.RegisterRequest) error { ctx := c.Request().Context() if req.Followup != "" { - return h.followup(c, binder, req) + return h.followup(c, req) } if req.Auth.AuthKey == "" { @@ -138,17 +133,17 @@ func (h *RegistrationHandlers) authenticateMachine(c echo.Context, binder bind.B err := h.repository.SaveRegistrationRequest(ctx, &request) if err != nil { response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "something went wrong"} - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } response := tailcfg.RegisterResponse{AuthURL: authUrl} - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } else { - return h.authenticateMachineWithAuthKey(c, binder, machineKey, req) + return h.authenticateMachineWithAuthKey(c, machineKey, req) } } -func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, binder bind.Binder, machineKey string, req *tailcfg.RegisterRequest) error { +func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, machineKey string, req *tailcfg.RegisterRequest) error { ctx := c.Request().Context() nodeKey := req.NodeKey.String() @@ -159,7 +154,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi if authKey == nil { response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "invalid auth key"} - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } tailnet := authKey.Tailnet @@ -167,7 +162,7 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi if err := tailnet.ACLPolicy.CheckTagOwners(req.Hostinfo.RequestTags, &user); err != nil { response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: err.Error()} - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } registeredTags := authKey.Tags @@ -254,10 +249,10 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi Login: tLogin, } - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } -func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req *tailcfg.RegisterRequest) error { +func (h *RegistrationHandlers) followup(c echo.Context, req *tailcfg.RegisterRequest) error { // Listen to connection close ctx := c.Request().Context() notify := ctx.Done() @@ -265,7 +260,7 @@ func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req defer func() { tick.Stop() }() - machineKey := binder.Peer().String() + machineKey := h.machineKey.String() for { select { @@ -274,7 +269,7 @@ func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req if err != nil || m == nil { response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "something went wrong"} - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } if m != nil && m.Authenticated { @@ -291,7 +286,7 @@ func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req User: u, Login: l, } - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } if m != nil && len(m.Error) != 0 { @@ -299,7 +294,7 @@ func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req MachineAuthorized: len(m.Error) != 0, Error: m.Error, } - return binder.WriteResponse(c, http.StatusOK, response) + return c.JSON(http.StatusOK, response) } case <-notify: return nil diff --git a/internal/handlers/ssh_action.go b/internal/handlers/ssh_action.go index e502d7fa..7fb9a61c 100644 --- a/internal/handlers/ssh_action.go +++ b/internal/handlers/ssh_action.go @@ -2,28 +2,28 @@ package handlers import ( "fmt" - "github.com/jsiebens/ionscale/internal/bind" "github.com/jsiebens/ionscale/internal/config" "github.com/jsiebens/ionscale/internal/domain" "github.com/jsiebens/ionscale/internal/util" "github.com/labstack/echo/v4" "net/http" "tailscale.com/tailcfg" + "tailscale.com/types/key" "time" ) -func NewSSHActionHandlers(createBinder bind.Factory, config *config.Config, repository domain.Repository) *SSHActionHandlers { +func NewSSHActionHandlers(machineKey key.MachinePublic, config *config.Config, repository domain.Repository) *SSHActionHandlers { return &SSHActionHandlers{ - createBinder: createBinder, - repository: repository, - config: config, + machineKey: machineKey, + repository: repository, + config: config, } } type SSHActionHandlers struct { - createBinder bind.Factory - repository domain.Repository - config *config.Config + machineKey key.MachinePublic + repository domain.Repository + config *config.Config } type sshActionRequestData struct { @@ -35,13 +35,8 @@ type sshActionRequestData struct { func (h *SSHActionHandlers) StartAuth(c echo.Context) error { ctx := c.Request().Context() - binder, err := h.createBinder(c) - if err != nil { - return logError(err) - } - data := new(sshActionRequestData) - if err = c.Bind(data); err != nil { + if err := c.Bind(data); err != nil { return logError(err) } @@ -67,7 +62,7 @@ func (h *SSHActionHandlers) StartAuth(c echo.Context) error { AllowLocalPortForwarding: true, } - return binder.WriteResponse(c, http.StatusOK, resp) + return c.JSON(http.StatusOK, resp) } } } @@ -92,7 +87,7 @@ check: HoldAndDelegate: fmt.Sprintf("https://unused/machine/ssh/action/check/%s", key), } - return binder.WriteResponse(c, http.StatusOK, resp) + return c.JSON(http.StatusOK, resp) } func (h *SSHActionHandlers) CheckAuth(c echo.Context) error { @@ -100,11 +95,6 @@ func (h *SSHActionHandlers) CheckAuth(c echo.Context) error { ctx := c.Request().Context() notify := ctx.Done() - binder, err := h.createBinder(c) - if err != nil { - return logError(err) - } - tick := time.NewTicker(2 * time.Second) defer func() { tick.Stop() }() @@ -117,7 +107,7 @@ func (h *SSHActionHandlers) CheckAuth(c echo.Context) error { m, err := h.repository.GetSSHActionRequest(ctx, key) if err != nil || m == nil { - return binder.WriteResponse(c, http.StatusOK, &tailcfg.SSHAction{Reject: true}) + return c.JSON(http.StatusOK, &tailcfg.SSHAction{Reject: true}) } if m.Action == "accept" { @@ -127,13 +117,13 @@ func (h *SSHActionHandlers) CheckAuth(c echo.Context) error { AllowLocalPortForwarding: true, } _ = h.repository.DeleteSSHActionRequest(ctx, key) - return binder.WriteResponse(c, http.StatusOK, action) + return c.JSON(http.StatusOK, action) } if m.Action == "reject" { action := &tailcfg.SSHAction{Reject: true} _ = h.repository.DeleteSSHActionRequest(ctx, key) - return binder.WriteResponse(c, http.StatusOK, action) + return c.JSON(http.StatusOK, action) } case <-notify: return nil diff --git a/internal/server/server.go b/internal/server/server.go index a4cbd4e9..e126fbad 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/caddyserver/certmagic" "github.com/jsiebens/ionscale/internal/auth" - "github.com/jsiebens/ionscale/internal/bind" "github.com/jsiebens/ionscale/internal/config" "github.com/jsiebens/ionscale/internal/core" "github.com/jsiebens/ionscale/internal/database" @@ -107,16 +106,15 @@ func Start(c *config.Config) error { p.SetMetricsPath(metricsHandler) createPeerHandler := func(machinePublicKey key.MachinePublic) http.Handler { - binder := bind.DefaultBinder(machinePublicKey) - - registrationHandlers := handlers.NewRegistrationHandlers(binder, c, sessionManager, repository) - pollNetMapHandler := handlers.NewPollNetMapHandler(binder, sessionManager, repository) - dnsHandlers := handlers.NewDNSHandlers(binder, dnsProvider) - idTokenHandlers := handlers.NewIDTokenHandlers(binder, c, repository) - sshActionHandlers := handlers.NewSSHActionHandlers(binder, c, repository) - queryFeatureHandlers := handlers.NewQueryFeatureHandlers(binder, dnsProvider, repository) + registrationHandlers := handlers.NewRegistrationHandlers(machinePublicKey, c, sessionManager, repository) + pollNetMapHandler := handlers.NewPollNetMapHandler(machinePublicKey, sessionManager, repository) + dnsHandlers := handlers.NewDNSHandlers(machinePublicKey, dnsProvider) + idTokenHandlers := handlers.NewIDTokenHandlers(machinePublicKey, c, repository) + sshActionHandlers := handlers.NewSSHActionHandlers(machinePublicKey, c, repository) + queryFeatureHandlers := handlers.NewQueryFeatureHandlers(machinePublicKey, dnsProvider, repository) e := echo.New() + e.Binder = handlers.JsonBinder{} e.Use(EchoMetrics(p), EchoLogger(httpLogger), EchoErrorHandler(), EchoRecover()) e.POST("/machine/register", registrationHandlers.Register) e.POST("/machine/map", pollNetMapHandler.PollNetMap) @@ -131,10 +129,8 @@ func Start(c *config.Config) error { } noiseHandlers := handlers.NewNoiseHandlers(serverKey.ControlKey, createPeerHandler) - registrationHandlers := handlers.NewRegistrationHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, sessionManager, repository) - pollNetMapHandler := handlers.NewPollNetMapHandler(bind.BoxBinder(serverKey.LegacyControlKey), sessionManager, repository) - dnsHandlers := handlers.NewDNSHandlers(bind.BoxBinder(serverKey.LegacyControlKey), dnsProvider) - idTokenHandlers := handlers.NewIDTokenHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, repository) + oidcConfigHandlers := handlers.NewOIDCConfigHandlers(c, repository) + authenticationHandlers := handlers.NewAuthenticationHandlers( c, authProvider, @@ -161,22 +157,16 @@ func Start(c *config.Config) error { tlsAppHandler.GET("/version", handlers.Version) tlsAppHandler.GET("/key", handlers.KeyHandler(serverKey)) tlsAppHandler.POST("/ts2021", noiseHandlers.Upgrade) - tlsAppHandler.POST("/machine/:id", registrationHandlers.Register) - tlsAppHandler.POST("/machine/:id/map", pollNetMapHandler.PollNetMap) - tlsAppHandler.POST("/machine/:id/set-dns", dnsHandlers.SetDNS) - tlsAppHandler.GET("/.well-known/jwks", idTokenHandlers.Jwks) - tlsAppHandler.GET("/.well-known/openid-configuration", idTokenHandlers.OpenIDConfig) - - auth := tlsAppHandler.Group("/a") - auth.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ - TokenLookup: "form:_csrf", - })) - auth.GET("/:flow/:key", authenticationHandlers.StartAuth) - auth.POST("/:flow/:key", authenticationHandlers.ProcessAuth) - auth.GET("/callback", authenticationHandlers.Callback) - auth.POST("/callback", authenticationHandlers.EndAuth) - auth.GET("/success", authenticationHandlers.Success) - auth.GET("/error", authenticationHandlers.Error) + tlsAppHandler.GET("/.well-known/jwks", oidcConfigHandlers.Jwks) + tlsAppHandler.GET("/.well-known/openid-configuration", oidcConfigHandlers.OpenIDConfig) + + csrf := middleware.CSRFWithConfig(middleware.CSRFConfig{TokenLookup: "form:_csrf"}) + tlsAppHandler.GET("/a/:flow/:key", authenticationHandlers.StartAuth, csrf) + tlsAppHandler.POST("/a/:flow/:key", authenticationHandlers.ProcessAuth, csrf) + tlsAppHandler.GET("/a/callback", authenticationHandlers.Callback, csrf) + tlsAppHandler.POST("/a/callback", authenticationHandlers.EndAuth, csrf) + tlsAppHandler.GET("/a/success", authenticationHandlers.Success, csrf) + tlsAppHandler.GET("/a/error", authenticationHandlers.Error, csrf) tlsL, err := tlsListener(c) if err != nil {