Skip to content

Commit

Permalink
feat: remove support for non-noise clients
Browse files Browse the repository at this point in the history
jsiebens committed Jan 10, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent b083e26 commit cbcbd61
Showing 9 changed files with 228 additions and 358 deletions.
173 changes: 0 additions & 173 deletions internal/bind/binder.go

This file was deleted.

23 changes: 8 additions & 15 deletions internal/handlers/dns.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,31 @@
package handlers

import (
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/dns"
"github.com/labstack/echo/v4"
"net"
"net/http"
"strings"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"time"
)

func NewDNSHandlers(createBinder bind.Factory, provider dns.Provider) *DNSHandlers {
func NewDNSHandlers(_ key.MachinePublic, provider dns.Provider) *DNSHandlers {
return &DNSHandlers{
createBinder: createBinder,
provider: provider,
provider: provider,
}
}

type DNSHandlers struct {
createBinder bind.Factory
provider dns.Provider
provider dns.Provider
}

func (h *DNSHandlers) SetDNS(c echo.Context) error {
ctx := c.Request().Context()

binder, err := h.createBinder(c)
if err != nil {
return logError(err)
}

req := &tailcfg.SetDNSRequest{}
if err := binder.BindRequest(c, req); err != nil {
if err := c.Bind(req); err != nil {
return logError(err)
}

@@ -58,16 +51,16 @@ func (h *DNSHandlers) SetDNS(c echo.Context) error {
txtrecords, _ := net.LookupTXT(req.Name)
for _, txt := range txtrecords {
if txt == req.Value {
return binder.WriteResponse(c, http.StatusOK, tailcfg.SetDNSResponse{})
return c.JSON(http.StatusOK, tailcfg.SetDNSResponse{})
}
}
case <-timeout:
return binder.WriteResponse(c, http.StatusOK, tailcfg.SetDNSResponse{})
return c.JSON(http.StatusOK, tailcfg.SetDNSResponse{})
case <-notify:
return nil
}
}
}

return binder.WriteResponse(c, http.StatusOK, tailcfg.SetDNSResponse{})
return c.JSON(http.StatusOK, tailcfg.SetDNSResponse{})
}
105 changes: 56 additions & 49 deletions internal/handlers/id_token.go
Original file line number Diff line number Diff line change
@@ -4,63 +4,36 @@ import (
"fmt"
"github.com/go-jose/go-jose/v3"
"github.com/golang-jwt/jwt/v4"
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo/v4"
"net/http"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"time"
)

func NewIDTokenHandlers(createBinder bind.Factory, config *config.Config, repository domain.Repository) *IDTokenHandlers {
func NewIDTokenHandlers(machineKey key.MachinePublic, config *config.Config, repository domain.Repository) *IDTokenHandlers {
return &IDTokenHandlers{
issuer: config.ServerUrl,
jwksUri: config.CreateUrl("/.well-known/jwks"),
createBinder: createBinder,
repository: repository,
machineKey: machineKey,
issuer: config.ServerUrl,
repository: repository,
}
}

type IDTokenHandlers struct {
issuer string
jwksUri string
createBinder bind.Factory
repository domain.Repository
}

func (h *IDTokenHandlers) OpenIDConfig(c echo.Context) error {
v := map[string]interface{}{}

v["issuer"] = h.issuer
v["jwks_uri"] = h.jwksUri
v["subject_types_supported"] = []string{"public"}
v["response_types_supported"] = []string{"id_token"}
v["scopes_supported"] = []string{"openid"}
v["id_token_signing_alg_values_supported"] = []string{"RS256"}
v["claims_supported"] = []string{
"sub",
"aud",
"exp",
"iat",
"iss",
"jti",
"nbf",
func NewOIDCConfigHandlers(config *config.Config, repository domain.Repository) *OIDCConfigHandlers {
return &OIDCConfigHandlers{
issuer: config.ServerUrl,
jwksUri: config.CreateUrl("/.well-known/jwks"),
repository: repository,
}

return c.JSON(http.StatusOK, v)
}

func (h *IDTokenHandlers) Jwks(c echo.Context) error {
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
if err != nil {
return logError(err)
}

pub := jose.JSONWebKey{Key: keySet.Key.Public(), KeyID: keySet.Key.Id, Algorithm: "RS256", Use: "sig"}
set := jose.JSONWebKeySet{Keys: []jose.JSONWebKey{pub}}
return c.JSON(http.StatusOK, set)
type IDTokenHandlers struct {
machineKey key.MachinePublic
issuer string
repository domain.Repository
}

func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
@@ -71,17 +44,12 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
return logError(err)
}

binder, err := h.createBinder(c)
if err != nil {
return logError(err)
}

req := &tailcfg.TokenRequest{}
if err := binder.BindRequest(c, req); err != nil {
if err := c.Bind(req); err != nil {
return logError(err)
}

machineKey := binder.Peer().String()
machineKey := h.machineKey.String()
nodeKey := req.NodeKey.String()

var m *domain.Machine
@@ -134,7 +102,46 @@ func (h *IDTokenHandlers) FetchToken(c echo.Context) error {
}

resp := tailcfg.TokenResponse{IDToken: jwtB64}
return binder.WriteResponse(c, http.StatusOK, resp)
return c.JSON(http.StatusOK, resp)
}

type OIDCConfigHandlers struct {
issuer string
jwksUri string
repository domain.Repository
}

func (h *OIDCConfigHandlers) OpenIDConfig(c echo.Context) error {
v := map[string]interface{}{}

v["issuer"] = h.issuer
v["jwks_uri"] = h.jwksUri
v["subject_types_supported"] = []string{"public"}
v["response_types_supported"] = []string{"id_token"}
v["scopes_supported"] = []string{"openid"}
v["id_token_signing_alg_values_supported"] = []string{"RS256"}
v["claims_supported"] = []string{
"sub",
"aud",
"exp",
"iat",
"iss",
"jti",
"nbf",
}

return c.JSON(http.StatusOK, v)
}

func (h *OIDCConfigHandlers) Jwks(c echo.Context) error {
keySet, err := h.repository.GetJSONWebKeySet(c.Request().Context())
if err != nil {
return logError(err)
}

pub := jose.JSONWebKey{Key: keySet.Key.Public(), KeyID: keySet.Key.Id, Algorithm: "RS256", Use: "sig"}
set := jose.JSONWebKeySet{Keys: []jose.JSONWebKey{pub}}
return c.JSON(http.StatusOK, set)
}

func (h *IDTokenHandlers) names(m *domain.Machine) (string, string, string) {
32 changes: 32 additions & 0 deletions internal/handlers/noise.go
Original file line number Diff line number Diff line change
@@ -41,3 +41,35 @@ func (h *NoiseHandlers) Upgrade(c echo.Context) error {
}
return nil
}

type JsonBinder struct {
echo.DefaultBinder
}

func (b JsonBinder) Bind(i interface{}, c echo.Context) error {
if err := b.BindPathParams(c, i); err != nil {
return err
}

method := c.Request().Method
if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead {
if err := b.BindQueryParams(c, i); err != nil {
return err
}
}

if c.Request().ContentLength == 0 {
return nil
}

if err := c.Echo().JSONSerializer.Deserialize(c, i); err != nil {
switch err.(type) {
case *echo.HTTPError:
return err
default:
return echo.NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err)
}
}

return nil
}
86 changes: 63 additions & 23 deletions internal/handlers/poll_net_map.go
Original file line number Diff line number Diff line change
@@ -2,24 +2,29 @@ package handlers

import (
"context"
"github.com/jsiebens/ionscale/internal/bind"
"encoding/binary"
"encoding/json"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/core"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/mapping"
"github.com/klauspost/compress/zstd"
"github.com/labstack/echo/v4"
"net/http"
"sync"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"time"
)

func NewPollNetMapHandler(
createBinder bind.Factory,
machineKey key.MachinePublic,
sessionManager core.PollMapSessionManager,
repository domain.Repository) *PollNetMapHandler {

handler := &PollNetMapHandler{
createBinder: createBinder,
machineKey: machineKey,
sessionManager: sessionManager,
repository: repository,
}
@@ -28,28 +33,24 @@ func NewPollNetMapHandler(
}

type PollNetMapHandler struct {
createBinder bind.Factory
machineKey key.MachinePublic
repository domain.Repository
sessionManager core.PollMapSessionManager
}

func (h *PollNetMapHandler) PollNetMap(c echo.Context) error {
ctx := c.Request().Context()
binder, err := h.createBinder(c)
if err != nil {
return logError(err)
}

req := &tailcfg.MapRequest{}
if err := binder.BindRequest(c, req); err != nil {
if err := c.Bind(req); err != nil {
return logError(err)
}

machineKey := binder.Peer().String()
machineKey := h.machineKey.String()
nodeKey := req.NodeKey.String()

var m *domain.Machine
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
m, err := h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
if err != nil {
return logError(err)
}
@@ -59,13 +60,13 @@ func (h *PollNetMapHandler) PollNetMap(c echo.Context) error {
}

if req.ReadOnly {
return h.handleReadOnly(c, binder, m, req)
return h.handleReadOnly(c, m, req)
} else {
return h.handleUpdate(c, binder, m, req)
return h.handleUpdate(c, m, req)
}
}

func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *domain.Machine, mapRequest *tailcfg.MapRequest) error {
func (h *PollNetMapHandler) handleUpdate(c echo.Context, m *domain.Machine, mapRequest *tailcfg.MapRequest) error {
ctx := c.Request().Context()

now := time.Now().UTC()
@@ -90,7 +91,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *

mapper := mapping.NewPollNetMapper(mapRequest, m.ID, h.repository, h.sessionManager)

response, err := createMapResponse(mapper, binder, false, mapRequest.Compress)
response, err := h.createMapResponse(mapper, false, mapRequest.Compress)
if err != nil {
return logError(err)
}
@@ -101,7 +102,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
// Listen to connection close
notify := c.Request().Context().Done()

keepAliveResponse, err := createKeepAliveResponse(binder, mapRequest)
keepAliveResponse, err := h.createKeepAliveResponse(mapRequest)
if err != nil {
return logError(err)
}
@@ -154,7 +155,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
var payload []byte
var payloadErr error

payload, payloadErr = createMapResponse(mapper, binder, true, mapRequest.Compress)
payload, payloadErr = h.createMapResponse(mapper, true, mapRequest.Compress)

if payloadErr != nil {
return payloadErr
@@ -173,7 +174,7 @@ func (h *PollNetMapHandler) handleUpdate(c echo.Context, binder bind.Binder, m *
}
}

func (h *PollNetMapHandler) handleReadOnly(c echo.Context, binder bind.Binder, m *domain.Machine, request *tailcfg.MapRequest) error {
func (h *PollNetMapHandler) handleReadOnly(c echo.Context, m *domain.Machine, request *tailcfg.MapRequest) error {
ctx := c.Request().Context()

m.HostInfo = domain.HostInfo(*request.Hostinfo)
@@ -184,7 +185,7 @@ func (h *PollNetMapHandler) handleReadOnly(c echo.Context, binder bind.Binder, m
}

mapper := mapping.NewPollNetMapper(request, m.ID, h.repository, h.sessionManager)
payload, err := createMapResponse(mapper, binder, false, request.Compress)
payload, err := h.createMapResponse(mapper, false, request.Compress)
if err != nil {
return logError(err)
}
@@ -193,18 +194,57 @@ func (h *PollNetMapHandler) handleReadOnly(c echo.Context, binder bind.Binder, m
return logError(err)
}

func createKeepAliveResponse(binder bind.Binder, request *tailcfg.MapRequest) ([]byte, error) {
func (h *PollNetMapHandler) createKeepAliveResponse(request *tailcfg.MapRequest) ([]byte, error) {
mapResponse := &tailcfg.MapResponse{
KeepAlive: true,
}

return binder.Marshal(request.Compress, mapResponse)
return h.marshalResponse(request.Compress, mapResponse)
}

func createMapResponse(m *mapping.PollNetMapper, binder bind.Binder, delta bool, compress string) ([]byte, error) {
func (h *PollNetMapHandler) createMapResponse(m *mapping.PollNetMapper, delta bool, compress string) ([]byte, error) {
response, err := m.CreateMapResponse(context.Background(), delta)
if err != nil {
return nil, err
}
return binder.Marshal(compress, response)
return h.marshalResponse(compress, response)
}

func (h *PollNetMapHandler) marshalResponse(compress string, v interface{}) ([]byte, error) {
var payload []byte

marshalled, err := json.Marshal(v)
if err != nil {
return nil, err
}

if compress == "zstd" {
payload = zstdEncode(marshalled)
} else {
payload = marshalled
}

data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(payload)))
data = append(data, payload...)

return data, nil
}

func zstdEncode(in []byte) []byte {
encoder := zstdEncoderPool.Get().(*zstd.Encoder)
out := encoder.EncodeAll(in, nil)
_ = encoder.Close()
zstdEncoderPool.Put(encoder)
return out
}

var zstdEncoderPool = &sync.Pool{
New: func() any {
encoder, err := smallzstd.NewEncoder(nil, zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return encoder
},
}
26 changes: 11 additions & 15 deletions internal/handlers/query_feature.go
Original file line number Diff line number Diff line change
@@ -2,41 +2,37 @@ package handlers

import (
"fmt"
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/dns"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/labstack/echo/v4"
"net/http"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)

func NewQueryFeatureHandlers(createBinder bind.Factory, dnsProvider dns.Provider, repository domain.Repository) *QueryFeatureHandlers {
func NewQueryFeatureHandlers(machineKey key.MachinePublic, dnsProvider dns.Provider, repository domain.Repository) *QueryFeatureHandlers {
return &QueryFeatureHandlers{
createBinder: createBinder,
repository: repository,
machineKey: machineKey,
dnsProvider: dnsProvider,
repository: repository,
}
}

type QueryFeatureHandlers struct {
createBinder bind.Factory
dnsProvider dns.Provider
repository domain.Repository
machineKey key.MachinePublic
dnsProvider dns.Provider
repository domain.Repository
}

func (h *QueryFeatureHandlers) QueryFeature(c echo.Context) error {
ctx := c.Request().Context()

binder, err := h.createBinder(c)
if err != nil {
return logError(err)
}

req := new(tailcfg.QueryFeatureRequest)
if err := binder.BindRequest(c, req); err != nil {
if err := c.Bind(req); err != nil {
return logError(err)
}

machineKey := binder.Peer().String()
machineKey := h.machineKey.String()
nodeKey := req.NodeKey.String()

resp := tailcfg.QueryFeatureResponse{}
@@ -61,7 +57,7 @@ func (h *QueryFeatureHandlers) QueryFeature(c echo.Context) error {
resp.Text = fmt.Sprintf("Unknown feature request '%s'\n", req.Feature)
}

return binder.WriteResponse(c, http.StatusOK, resp)
return c.JSON(http.StatusOK, resp)
}

const serverMessage = `Enabling HTTPS is required to use Serve:
55 changes: 25 additions & 30 deletions internal/handlers/registration.go
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@ package handlers
import (
"context"
"github.com/jsiebens/ionscale/internal/addr"
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/core"
"github.com/jsiebens/ionscale/internal/domain"
@@ -13,25 +12,26 @@ import (
"net/http"
"net/netip"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/util/dnsname"
"time"
)

func NewRegistrationHandlers(
createBinder bind.Factory,
machineKey key.MachinePublic,
config *config.Config,
sessionManager core.PollMapSessionManager,
repository domain.Repository) *RegistrationHandlers {
return &RegistrationHandlers{
createBinder: createBinder,
machineKey: machineKey,
sessionManager: sessionManager,
repository: repository,
config: config,
}
}

type RegistrationHandlers struct {
createBinder bind.Factory
machineKey key.MachinePublic
repository domain.Repository
sessionManager core.PollMapSessionManager
config *config.Config
@@ -40,21 +40,16 @@ type RegistrationHandlers struct {
func (h *RegistrationHandlers) Register(c echo.Context) error {
ctx := c.Request().Context()

binder, err := h.createBinder(c)
if err != nil {
return logError(err)
}

req := &tailcfg.RegisterRequest{}
if err := binder.BindRequest(c, req); err != nil {
if err := c.Bind(req); err != nil {
return logError(err)
}

machineKey := binder.Peer().String()
machineKey := h.machineKey.String()
nodeKey := req.NodeKey.String()

var m *domain.Machine
m, err = h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
m, err := h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)

if err != nil {
return logError(err)
@@ -63,7 +58,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
if m != nil {
if m.IsExpired() {
response := tailcfg.RegisterResponse{NodeKeyExpired: true}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

if !req.Expiry.IsZero() && req.Expiry.Before(time.Now()) {
@@ -82,7 +77,7 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
}

response := tailcfg.RegisterResponse{NodeKeyExpired: true}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

sanitizeHostname := dnsname.SanitizeHostname(req.Hostinfo.Hostname)
@@ -111,17 +106,17 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
Login: tLogin,
}

return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

return h.authenticateMachine(c, binder, machineKey, req)
return h.authenticateMachine(c, machineKey, req)
}

func (h *RegistrationHandlers) authenticateMachine(c echo.Context, binder bind.Binder, machineKey string, req *tailcfg.RegisterRequest) error {
func (h *RegistrationHandlers) authenticateMachine(c echo.Context, machineKey string, req *tailcfg.RegisterRequest) error {
ctx := c.Request().Context()

if req.Followup != "" {
return h.followup(c, binder, req)
return h.followup(c, req)
}

if req.Auth.AuthKey == "" {
@@ -138,17 +133,17 @@ func (h *RegistrationHandlers) authenticateMachine(c echo.Context, binder bind.B
err := h.repository.SaveRegistrationRequest(ctx, &request)
if err != nil {
response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "something went wrong"}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

response := tailcfg.RegisterResponse{AuthURL: authUrl}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
} else {
return h.authenticateMachineWithAuthKey(c, binder, machineKey, req)
return h.authenticateMachineWithAuthKey(c, machineKey, req)
}
}

func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, binder bind.Binder, machineKey string, req *tailcfg.RegisterRequest) error {
func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, machineKey string, req *tailcfg.RegisterRequest) error {
ctx := c.Request().Context()
nodeKey := req.NodeKey.String()

@@ -159,15 +154,15 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi

if authKey == nil {
response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "invalid auth key"}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

tailnet := authKey.Tailnet
user := authKey.User

if err := tailnet.ACLPolicy.CheckTagOwners(req.Hostinfo.RequestTags, &user); err != nil {
response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: err.Error()}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

registeredTags := authKey.Tags
@@ -254,18 +249,18 @@ func (h *RegistrationHandlers) authenticateMachineWithAuthKey(c echo.Context, bi
Login: tLogin,
}

return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req *tailcfg.RegisterRequest) error {
func (h *RegistrationHandlers) followup(c echo.Context, req *tailcfg.RegisterRequest) error {
// Listen to connection close
ctx := c.Request().Context()
notify := ctx.Done()
tick := time.NewTicker(2 * time.Second)

defer func() { tick.Stop() }()

machineKey := binder.Peer().String()
machineKey := h.machineKey.String()

for {
select {
@@ -274,7 +269,7 @@ func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req

if err != nil || m == nil {
response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "something went wrong"}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

if m != nil && m.Authenticated {
@@ -291,15 +286,15 @@ func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req
User: u,
Login: l,
}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}

if m != nil && len(m.Error) != 0 {
response := tailcfg.RegisterResponse{
MachineAuthorized: len(m.Error) != 0,
Error: m.Error,
}
return binder.WriteResponse(c, http.StatusOK, response)
return c.JSON(http.StatusOK, response)
}
case <-notify:
return nil
38 changes: 14 additions & 24 deletions internal/handlers/ssh_action.go
Original file line number Diff line number Diff line change
@@ -2,28 +2,28 @@ package handlers

import (
"fmt"
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo/v4"
"net/http"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"time"
)

func NewSSHActionHandlers(createBinder bind.Factory, config *config.Config, repository domain.Repository) *SSHActionHandlers {
func NewSSHActionHandlers(machineKey key.MachinePublic, config *config.Config, repository domain.Repository) *SSHActionHandlers {
return &SSHActionHandlers{
createBinder: createBinder,
repository: repository,
config: config,
machineKey: machineKey,
repository: repository,
config: config,
}
}

type SSHActionHandlers struct {
createBinder bind.Factory
repository domain.Repository
config *config.Config
machineKey key.MachinePublic
repository domain.Repository
config *config.Config
}

type sshActionRequestData struct {
@@ -35,13 +35,8 @@ type sshActionRequestData struct {
func (h *SSHActionHandlers) StartAuth(c echo.Context) error {
ctx := c.Request().Context()

binder, err := h.createBinder(c)
if err != nil {
return logError(err)
}

data := new(sshActionRequestData)
if err = c.Bind(data); err != nil {
if err := c.Bind(data); err != nil {
return logError(err)
}

@@ -67,7 +62,7 @@ func (h *SSHActionHandlers) StartAuth(c echo.Context) error {
AllowLocalPortForwarding: true,
}

return binder.WriteResponse(c, http.StatusOK, resp)
return c.JSON(http.StatusOK, resp)
}
}
}
@@ -92,19 +87,14 @@ check:
HoldAndDelegate: fmt.Sprintf("https://unused/machine/ssh/action/check/%s", key),
}

return binder.WriteResponse(c, http.StatusOK, resp)
return c.JSON(http.StatusOK, resp)
}

func (h *SSHActionHandlers) CheckAuth(c echo.Context) error {
// Listen to connection close
ctx := c.Request().Context()
notify := ctx.Done()

binder, err := h.createBinder(c)
if err != nil {
return logError(err)
}

tick := time.NewTicker(2 * time.Second)

defer func() { tick.Stop() }()
@@ -117,7 +107,7 @@ func (h *SSHActionHandlers) CheckAuth(c echo.Context) error {
m, err := h.repository.GetSSHActionRequest(ctx, key)

if err != nil || m == nil {
return binder.WriteResponse(c, http.StatusOK, &tailcfg.SSHAction{Reject: true})
return c.JSON(http.StatusOK, &tailcfg.SSHAction{Reject: true})
}

if m.Action == "accept" {
@@ -127,13 +117,13 @@ func (h *SSHActionHandlers) CheckAuth(c echo.Context) error {
AllowLocalPortForwarding: true,
}
_ = h.repository.DeleteSSHActionRequest(ctx, key)
return binder.WriteResponse(c, http.StatusOK, action)
return c.JSON(http.StatusOK, action)
}

if m.Action == "reject" {
action := &tailcfg.SSHAction{Reject: true}
_ = h.repository.DeleteSSHActionRequest(ctx, key)
return binder.WriteResponse(c, http.StatusOK, action)
return c.JSON(http.StatusOK, action)
}
case <-notify:
return nil
48 changes: 19 additions & 29 deletions internal/server/server.go
Original file line number Diff line number Diff line change
@@ -6,7 +6,6 @@ import (
"fmt"
"github.com/caddyserver/certmagic"
"github.com/jsiebens/ionscale/internal/auth"
"github.com/jsiebens/ionscale/internal/bind"
"github.com/jsiebens/ionscale/internal/config"
"github.com/jsiebens/ionscale/internal/core"
"github.com/jsiebens/ionscale/internal/database"
@@ -107,16 +106,15 @@ func Start(c *config.Config) error {
p.SetMetricsPath(metricsHandler)

createPeerHandler := func(machinePublicKey key.MachinePublic) http.Handler {
binder := bind.DefaultBinder(machinePublicKey)

registrationHandlers := handlers.NewRegistrationHandlers(binder, c, sessionManager, repository)
pollNetMapHandler := handlers.NewPollNetMapHandler(binder, sessionManager, repository)
dnsHandlers := handlers.NewDNSHandlers(binder, dnsProvider)
idTokenHandlers := handlers.NewIDTokenHandlers(binder, c, repository)
sshActionHandlers := handlers.NewSSHActionHandlers(binder, c, repository)
queryFeatureHandlers := handlers.NewQueryFeatureHandlers(binder, dnsProvider, repository)
registrationHandlers := handlers.NewRegistrationHandlers(machinePublicKey, c, sessionManager, repository)
pollNetMapHandler := handlers.NewPollNetMapHandler(machinePublicKey, sessionManager, repository)
dnsHandlers := handlers.NewDNSHandlers(machinePublicKey, dnsProvider)
idTokenHandlers := handlers.NewIDTokenHandlers(machinePublicKey, c, repository)
sshActionHandlers := handlers.NewSSHActionHandlers(machinePublicKey, c, repository)
queryFeatureHandlers := handlers.NewQueryFeatureHandlers(machinePublicKey, dnsProvider, repository)

e := echo.New()
e.Binder = handlers.JsonBinder{}
e.Use(EchoMetrics(p), EchoLogger(httpLogger), EchoErrorHandler(), EchoRecover())
e.POST("/machine/register", registrationHandlers.Register)
e.POST("/machine/map", pollNetMapHandler.PollNetMap)
@@ -131,10 +129,8 @@ func Start(c *config.Config) error {
}

noiseHandlers := handlers.NewNoiseHandlers(serverKey.ControlKey, createPeerHandler)
registrationHandlers := handlers.NewRegistrationHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, sessionManager, repository)
pollNetMapHandler := handlers.NewPollNetMapHandler(bind.BoxBinder(serverKey.LegacyControlKey), sessionManager, repository)
dnsHandlers := handlers.NewDNSHandlers(bind.BoxBinder(serverKey.LegacyControlKey), dnsProvider)
idTokenHandlers := handlers.NewIDTokenHandlers(bind.BoxBinder(serverKey.LegacyControlKey), c, repository)
oidcConfigHandlers := handlers.NewOIDCConfigHandlers(c, repository)

authenticationHandlers := handlers.NewAuthenticationHandlers(
c,
authProvider,
@@ -161,22 +157,16 @@ func Start(c *config.Config) error {
tlsAppHandler.GET("/version", handlers.Version)
tlsAppHandler.GET("/key", handlers.KeyHandler(serverKey))
tlsAppHandler.POST("/ts2021", noiseHandlers.Upgrade)
tlsAppHandler.POST("/machine/:id", registrationHandlers.Register)
tlsAppHandler.POST("/machine/:id/map", pollNetMapHandler.PollNetMap)
tlsAppHandler.POST("/machine/:id/set-dns", dnsHandlers.SetDNS)
tlsAppHandler.GET("/.well-known/jwks", idTokenHandlers.Jwks)
tlsAppHandler.GET("/.well-known/openid-configuration", idTokenHandlers.OpenIDConfig)

auth := tlsAppHandler.Group("/a")
auth.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
TokenLookup: "form:_csrf",
}))
auth.GET("/:flow/:key", authenticationHandlers.StartAuth)
auth.POST("/:flow/:key", authenticationHandlers.ProcessAuth)
auth.GET("/callback", authenticationHandlers.Callback)
auth.POST("/callback", authenticationHandlers.EndAuth)
auth.GET("/success", authenticationHandlers.Success)
auth.GET("/error", authenticationHandlers.Error)
tlsAppHandler.GET("/.well-known/jwks", oidcConfigHandlers.Jwks)
tlsAppHandler.GET("/.well-known/openid-configuration", oidcConfigHandlers.OpenIDConfig)

csrf := middleware.CSRFWithConfig(middleware.CSRFConfig{TokenLookup: "form:_csrf"})
tlsAppHandler.GET("/a/:flow/:key", authenticationHandlers.StartAuth, csrf)
tlsAppHandler.POST("/a/:flow/:key", authenticationHandlers.ProcessAuth, csrf)
tlsAppHandler.GET("/a/callback", authenticationHandlers.Callback, csrf)
tlsAppHandler.POST("/a/callback", authenticationHandlers.EndAuth, csrf)
tlsAppHandler.GET("/a/success", authenticationHandlers.Success, csrf)
tlsAppHandler.GET("/a/error", authenticationHandlers.Error, csrf)

tlsL, err := tlsListener(c)
if err != nil {

0 comments on commit cbcbd61

Please sign in to comment.