diff --git a/.mockery.yml b/.mockery.yml index 95acc7ce3e..dde86d0dce 100644 --- a/.mockery.yml +++ b/.mockery.yml @@ -1,16 +1,19 @@ quiet: False disable-version-string: True with-expecter: True -filename: 'mock_{{.InterfaceName}}.go' -dir: '{{.InterfaceDir}}' -mockname: 'Mock{{.InterfaceName}}' -outpkg: '{{.PackageName}}' +filename: "mock_{{.InterfaceName}}.go" +dir: "{{.InterfaceDir}}" +mockname: "Mock{{.InterfaceName}}" +outpkg: "{{.PackageName}}" inpackage: True packages: - github.com/nucleuscloud/neosync/backend/internal/jwt: + github.com/nucleuscloud/neosync/backend/internal/auth/jwt: # config: interfaces: JwtValidator: + github.com/nucleuscloud/neosync/backend/internal/auth/authmw: + interfaces: + AuthClient: github.com/nucleuscloud/neosync/backend/internal/nucleusdb: # config: interfaces: @@ -19,6 +22,7 @@ packages: # config: interfaces: Querier: + DBTX: github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect: # config: interfaces: diff --git a/backend/gen/go/db/mock_DBTX.go b/backend/gen/go/db/mock_DBTX.go new file mode 100644 index 0000000000..ccdbc5dbe1 --- /dev/null +++ b/backend/gen/go/db/mock_DBTX.go @@ -0,0 +1,280 @@ +// Code generated by mockery. DO NOT EDIT. + +package db_queries + +import ( + context "context" + + pgconn "github.com/jackc/pgx/v5/pgconn" + mock "github.com/stretchr/testify/mock" + + pgx "github.com/jackc/pgx/v5" +) + +// MockDBTX is an autogenerated mock type for the DBTX type +type MockDBTX struct { + mock.Mock +} + +type MockDBTX_Expecter struct { + mock *mock.Mock +} + +func (_m *MockDBTX) EXPECT() *MockDBTX_Expecter { + return &MockDBTX_Expecter{mock: &_m.Mock} +} + +// CopyFrom provides a mock function with given fields: ctx, tableName, columnNames, rowSrc +func (_m *MockDBTX) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + ret := _m.Called(ctx, tableName, columnNames, rowSrc) + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)); ok { + return rf(ctx, tableName, columnNames, rowSrc) + } + if rf, ok := ret.Get(0).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) int64); ok { + r0 = rf(ctx, tableName, columnNames, rowSrc) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) error); ok { + r1 = rf(ctx, tableName, columnNames, rowSrc) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDBTX_CopyFrom_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CopyFrom' +type MockDBTX_CopyFrom_Call struct { + *mock.Call +} + +// CopyFrom is a helper method to define mock.On call +// - ctx context.Context +// - tableName pgx.Identifier +// - columnNames []string +// - rowSrc pgx.CopyFromSource +func (_e *MockDBTX_Expecter) CopyFrom(ctx interface{}, tableName interface{}, columnNames interface{}, rowSrc interface{}) *MockDBTX_CopyFrom_Call { + return &MockDBTX_CopyFrom_Call{Call: _e.mock.On("CopyFrom", ctx, tableName, columnNames, rowSrc)} +} + +func (_c *MockDBTX_CopyFrom_Call) Run(run func(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource)) *MockDBTX_CopyFrom_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(pgx.Identifier), args[2].([]string), args[3].(pgx.CopyFromSource)) + }) + return _c +} + +func (_c *MockDBTX_CopyFrom_Call) Return(_a0 int64, _a1 error) *MockDBTX_CopyFrom_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDBTX_CopyFrom_Call) RunAndReturn(run func(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)) *MockDBTX_CopyFrom_Call { + _c.Call.Return(run) + return _c +} + +// Exec provides a mock function with given fields: _a0, _a1, _a2 +func (_m *MockDBTX) Exec(_a0 context.Context, _a1 string, _a2 ...interface{}) (pgconn.CommandTag, error) { + var _ca []interface{} + _ca = append(_ca, _a0, _a1) + _ca = append(_ca, _a2...) + ret := _m.Called(_ca...) + + var r0 pgconn.CommandTag + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)); ok { + return rf(_a0, _a1, _a2...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgconn.CommandTag); ok { + r0 = rf(_a0, _a1, _a2...) + } else { + r0 = ret.Get(0).(pgconn.CommandTag) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { + r1 = rf(_a0, _a1, _a2...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDBTX_Exec_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Exec' +type MockDBTX_Exec_Call struct { + *mock.Call +} + +// Exec is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 string +// - _a2 ...interface{} +func (_e *MockDBTX_Expecter) Exec(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_Exec_Call { + return &MockDBTX_Exec_Call{Call: _e.mock.On("Exec", + append([]interface{}{_a0, _a1}, _a2...)...)} +} + +func (_c *MockDBTX_Exec_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_Exec_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockDBTX_Exec_Call) Return(_a0 pgconn.CommandTag, _a1 error) *MockDBTX_Exec_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDBTX_Exec_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)) *MockDBTX_Exec_Call { + _c.Call.Return(run) + return _c +} + +// Query provides a mock function with given fields: _a0, _a1, _a2 +func (_m *MockDBTX) Query(_a0 context.Context, _a1 string, _a2 ...interface{}) (pgx.Rows, error) { + var _ca []interface{} + _ca = append(_ca, _a0, _a1) + _ca = append(_ca, _a2...) + ret := _m.Called(_ca...) + + var r0 pgx.Rows + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgx.Rows, error)); ok { + return rf(_a0, _a1, _a2...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Rows); ok { + r0 = rf(_a0, _a1, _a2...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(pgx.Rows) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { + r1 = rf(_a0, _a1, _a2...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDBTX_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query' +type MockDBTX_Query_Call struct { + *mock.Call +} + +// Query is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 string +// - _a2 ...interface{} +func (_e *MockDBTX_Expecter) Query(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_Query_Call { + return &MockDBTX_Query_Call{Call: _e.mock.On("Query", + append([]interface{}{_a0, _a1}, _a2...)...)} +} + +func (_c *MockDBTX_Query_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_Query_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockDBTX_Query_Call) Return(_a0 pgx.Rows, _a1 error) *MockDBTX_Query_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDBTX_Query_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (pgx.Rows, error)) *MockDBTX_Query_Call { + _c.Call.Return(run) + return _c +} + +// QueryRow provides a mock function with given fields: _a0, _a1, _a2 +func (_m *MockDBTX) QueryRow(_a0 context.Context, _a1 string, _a2 ...interface{}) pgx.Row { + var _ca []interface{} + _ca = append(_ca, _a0, _a1) + _ca = append(_ca, _a2...) + ret := _m.Called(_ca...) + + var r0 pgx.Row + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Row); ok { + r0 = rf(_a0, _a1, _a2...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(pgx.Row) + } + } + + return r0 +} + +// MockDBTX_QueryRow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryRow' +type MockDBTX_QueryRow_Call struct { + *mock.Call +} + +// QueryRow is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 string +// - _a2 ...interface{} +func (_e *MockDBTX_Expecter) QueryRow(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_QueryRow_Call { + return &MockDBTX_QueryRow_Call{Call: _e.mock.On("QueryRow", + append([]interface{}{_a0, _a1}, _a2...)...)} +} + +func (_c *MockDBTX_QueryRow_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_QueryRow_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockDBTX_QueryRow_Call) Return(_a0 pgx.Row) *MockDBTX_QueryRow_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockDBTX_QueryRow_Call) RunAndReturn(run func(context.Context, string, ...interface{}) pgx.Row) *MockDBTX_QueryRow_Call { + _c.Call.Return(run) + return _c +} + +// NewMockDBTX creates a new instance of MockDBTX. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockDBTX(t interface { + mock.TestingT + Cleanup(func()) +}) *MockDBTX { + mock := &MockDBTX{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/backend/internal/apikey/apikey.go b/backend/internal/apikey/apikey.go new file mode 100644 index 0000000000..7c8cd7e7ff --- /dev/null +++ b/backend/internal/apikey/apikey.go @@ -0,0 +1,40 @@ +package apikey + +import ( + "fmt" + "regexp" + "strings" + + "github.com/google/uuid" +) + +const ( + prefix = "neo" + accountTokenId = "at" + v1 = "v1" + separator = "_" + + uuidPattern = `[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-4[0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}` +) + +var ( + v1Prefix = strings.Join([]string{prefix, accountTokenId, v1}, separator) + + v1AccountTokenPattern = fmt.Sprintf( + `^(%s)%s(%s)%sv([\d+])%s%s$`, + prefix, separator, accountTokenId, separator, separator, uuidPattern, + ) + v1AccountTokenRegex = regexp.MustCompile(v1AccountTokenPattern) +) + +func NewV1AccountKey() string { + return v1AccountKey(uuid.NewString()) +} + +func v1AccountKey(suffix string) string { + return v1Prefix + separator + suffix +} + +func IsValidV1AccountKey(apikey string) bool { + return v1AccountTokenRegex.MatchString(apikey) +} diff --git a/backend/internal/apikey/apikey_test.go b/backend/internal/apikey/apikey_test.go new file mode 100644 index 0000000000..efd1ddad85 --- /dev/null +++ b/backend/internal/apikey/apikey_test.go @@ -0,0 +1,26 @@ +package apikey + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewV1AccountKey(t *testing.T) { + assert.NotEmpty(t, NewV1AccountKey()) +} + +func Test_v1AccountKey(t *testing.T) { + assert.Equal( + t, + v1AccountKey("foo-bar"), + "neo_at_v1_foo-bar", + ) +} + +func Test_IsValidV1AccountKey(t *testing.T) { + assert.True( + t, + IsValidV1AccountKey(NewV1AccountKey()), + ) +} diff --git a/backend/internal/auth/apikey/client.go b/backend/internal/auth/apikey/client.go new file mode 100644 index 0000000000..3d6e91010d --- /dev/null +++ b/backend/internal/auth/apikey/client.go @@ -0,0 +1,68 @@ +package auth_apikey + +import ( + "context" + "errors" + "net/http" + + db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" + "github.com/nucleuscloud/neosync/backend/internal/apikey" + nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" + "github.com/nucleuscloud/neosync/backend/internal/utils" +) + +type TokenContextKey struct{} +type TokenContextData struct { + RawToken string + ApiKey *db_queries.NeosyncApiAccountApiKey +} + +var ( + InvalidApiKeyErr = errors.New("token is not a valid neosync api key") +) + +type Queries interface { + GetAccountApiKeyByKeyValue(ctx context.Context, db db_queries.DBTX, apiKey string) (db_queries.NeosyncApiAccountApiKey, error) +} + +type Client struct { + q Queries + db db_queries.DBTX +} + +func New( + queries Queries, + db db_queries.DBTX, +) *Client { + return &Client{q: queries, db: db} +} + +func (c *Client) InjectTokenCtx(ctx context.Context, header http.Header) (context.Context, error) { + token, err := utils.GetBearerTokenFromHeader(header, "Authorization") + if err != nil { + return nil, err + } + if !apikey.IsValidV1AccountKey(token) { + return nil, InvalidApiKeyErr + } + + apiKey, err := c.q.GetAccountApiKeyByKeyValue(ctx, c.db, token) + if err != nil { + return nil, err + } + + newctx := context.WithValue(ctx, TokenContextKey{}, &TokenContextData{ + RawToken: token, + ApiKey: &apiKey, + }) + return newctx, err +} + +func GetTokenDataFromCtx(ctx context.Context) (*TokenContextData, error) { + data, ok := ctx.Value(TokenContextKey{}).(*TokenContextData) + if !ok { + return nil, nucleuserrors.NewUnauthenticated("ctx does not contain TokenContextData or unable to cast struct") + } + + return data, nil +} diff --git a/backend/internal/auth/apikey/client_test.go b/backend/internal/auth/apikey/client_test.go new file mode 100644 index 0000000000..987661236e --- /dev/null +++ b/backend/internal/auth/apikey/client_test.go @@ -0,0 +1,102 @@ +package auth_apikey + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" + "github.com/nucleuscloud/neosync/backend/internal/apikey" + "github.com/stretchr/testify/mock" + "github.com/zeebo/assert" +) + +func Test_Client_New(t *testing.T) { + mockQuerier := db_queries.NewMockQuerier(t) + mockDbTx := db_queries.NewMockDBTX(t) + + assert.NotNil(t, New(mockQuerier, mockDbTx)) +} + +func Test_Client_InjectTokenCtx(t *testing.T) { + mockQuerier := db_queries.NewMockQuerier(t) + mockDbTx := db_queries.NewMockDBTX(t) + + client := New(mockQuerier, mockDbTx) + + fakeToken := apikey.NewV1AccountKey() + apiKeyRecord := db_queries.NeosyncApiAccountApiKey{ + ID: pgtype.UUID{Valid: true}, + } + mockQuerier.On("GetAccountApiKeyByKeyValue", mock.Anything, mock.Anything, fakeToken). + Return(apiKeyRecord, nil) + + newctx, err := client.InjectTokenCtx(context.Background(), http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", fakeToken)}, + }) + assert.NoError(t, err) + assert.NotNil(t, newctx) + + data, err := GetTokenDataFromCtx(newctx) + assert.NoError(t, err) + assert.NotNil(t, data) + assert.Equal( + t, + data, + &TokenContextData{ + RawToken: fakeToken, + ApiKey: &apiKeyRecord, + }, + ) +} + +func Test_Client_InjectTokenCtx_InvalidHeader(t *testing.T) { + client := &Client{} + _, err := client.InjectTokenCtx(context.Background(), http.Header{"Authorization": []string{}}) + assert.Error(t, err) +} + +func Test_Client_InjectTokenCtx_InvalidToken(t *testing.T) { + mockQuerier := db_queries.NewMockQuerier(t) + mockDbTx := db_queries.NewMockDBTX(t) + + client := New(mockQuerier, mockDbTx) + + newctx, err := client.InjectTokenCtx(context.Background(), http.Header{ + "Authorization": []string{"Bearer 123"}, + }) + assert.Error(t, err) + assert.Nil(t, newctx) +} + +func Test_Client_InjectTokenCtx_NotFoundKeyValue(t *testing.T) { + mockQuerier := db_queries.NewMockQuerier(t) + mockDbTx := db_queries.NewMockDBTX(t) + + client := New(mockQuerier, mockDbTx) + + fakeToken := apikey.NewV1AccountKey() + + mockQuerier.On("GetAccountApiKeyByKeyValue", mock.Anything, mock.Anything, fakeToken). + Return(db_queries.NeosyncApiAccountApiKey{}, pgx.ErrNoRows) + + newctx, err := client.InjectTokenCtx(context.Background(), http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", fakeToken)}, + }) + assert.Error(t, err) + assert.Nil(t, newctx) +} + +func Test_GetTokenDataFromCtx(t *testing.T) { + ctx := context.WithValue(context.Background(), TokenContextKey{}, &TokenContextData{}) + _, err := GetTokenDataFromCtx(ctx) + assert.NoError(t, err) +} + +func Test_GetTokenDataFromCtx_UnAuthenticated(t *testing.T) { + _, err := GetTokenDataFromCtx(context.Background()) + assert.Error(t, err) +} diff --git a/backend/internal/auth/authmw/auth.go b/backend/internal/auth/authmw/auth.go new file mode 100644 index 0000000000..e3ba29e8b2 --- /dev/null +++ b/backend/internal/auth/authmw/auth.go @@ -0,0 +1,35 @@ +package authmw + +import ( + "context" + "errors" + "net/http" + + auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey" +) + +type AuthClient interface { + InjectTokenCtx(ctx context.Context, header http.Header) (context.Context, error) +} + +type AuthMiddleware struct { + jwtClient AuthClient + apiKeyClient AuthClient +} + +func New( + jwtClient AuthClient, + apiKeyClient AuthClient, +) *AuthMiddleware { + return &AuthMiddleware{jwtClient: jwtClient, apiKeyClient: apiKeyClient} +} + +func (n *AuthMiddleware) InjectTokenCtx(ctx context.Context, header http.Header) (context.Context, error) { + ctx, err := n.apiKeyClient.InjectTokenCtx(ctx, header) + if err != nil && !errors.Is(err, auth_apikey.InvalidApiKeyErr) { + return nil, err + } else if err != nil && errors.Is(err, auth_apikey.InvalidApiKeyErr) { + return n.jwtClient.InjectTokenCtx(ctx, header) + } + return ctx, nil +} diff --git a/backend/internal/auth/authmw/auth_test.go b/backend/internal/auth/authmw/auth_test.go new file mode 100644 index 0000000000..4022894f4b --- /dev/null +++ b/backend/internal/auth/authmw/auth_test.go @@ -0,0 +1,65 @@ +package authmw + +import ( + "context" + "errors" + "net/http" + "testing" + + auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey" + "github.com/stretchr/testify/mock" + "github.com/zeebo/assert" +) + +func Test_New(t *testing.T) { + mockAuthClient := NewMockAuthClient(t) + mw := New(mockAuthClient, mockAuthClient) + assert.NotNil(t, mw) +} + +func Test_AuthMiddleware_InjectTokenCtx_ApiKey(t *testing.T) { + mockJwt := NewMockAuthClient(t) + mockApiKey := NewMockAuthClient(t) + + mw := New(mockJwt, mockApiKey) + + mockApiKey.On("InjectTokenCtx", mock.Anything, mock.Anything). + Return(context.Background(), nil) + + _, err := mw.InjectTokenCtx(context.Background(), http.Header{ + "Authorization": []string{"Bearer foo"}, + }) + assert.NoError(t, err) +} + +func Test_AuthMiddleware_InjectTokenCtx_ApiKey_InternalError(t *testing.T) { + mockJwt := NewMockAuthClient(t) + mockApiKey := NewMockAuthClient(t) + + mw := New(mockJwt, mockApiKey) + + mockApiKey.On("InjectTokenCtx", mock.Anything, mock.Anything). + Return(nil, errors.New("internal")) + + _, err := mw.InjectTokenCtx(context.Background(), http.Header{ + "Authorization": []string{"Bearer foo"}, + }) + assert.Error(t, err) +} + +func Test_AuthMiddleware_InjectTokenCtx_ApiKey_JwtFallback(t *testing.T) { + mockJwt := NewMockAuthClient(t) + mockApiKey := NewMockAuthClient(t) + + mw := New(mockJwt, mockApiKey) + + mockApiKey.On("InjectTokenCtx", mock.Anything, mock.Anything). + Return(nil, auth_apikey.InvalidApiKeyErr) + mockJwt.On("InjectTokenCtx", mock.Anything, mock.Anything). + Return(context.Background(), nil) + + _, err := mw.InjectTokenCtx(context.Background(), http.Header{ + "Authorization": []string{"Bearer foo"}, + }) + assert.NoError(t, err) +} diff --git a/backend/internal/auth/authmw/mock_AuthClient.go b/backend/internal/auth/authmw/mock_AuthClient.go new file mode 100644 index 0000000000..95bfedcc8f --- /dev/null +++ b/backend/internal/auth/authmw/mock_AuthClient.go @@ -0,0 +1,92 @@ +// Code generated by mockery. DO NOT EDIT. + +package authmw + +import ( + context "context" + http "net/http" + + mock "github.com/stretchr/testify/mock" +) + +// MockAuthClient is an autogenerated mock type for the AuthClient type +type MockAuthClient struct { + mock.Mock +} + +type MockAuthClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockAuthClient) EXPECT() *MockAuthClient_Expecter { + return &MockAuthClient_Expecter{mock: &_m.Mock} +} + +// InjectTokenCtx provides a mock function with given fields: ctx, header +func (_m *MockAuthClient) InjectTokenCtx(ctx context.Context, header http.Header) (context.Context, error) { + ret := _m.Called(ctx, header) + + var r0 context.Context + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, http.Header) (context.Context, error)); ok { + return rf(ctx, header) + } + if rf, ok := ret.Get(0).(func(context.Context, http.Header) context.Context); ok { + r0 = rf(ctx, header) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, http.Header) error); ok { + r1 = rf(ctx, header) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockAuthClient_InjectTokenCtx_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InjectTokenCtx' +type MockAuthClient_InjectTokenCtx_Call struct { + *mock.Call +} + +// InjectTokenCtx is a helper method to define mock.On call +// - ctx context.Context +// - header http.Header +func (_e *MockAuthClient_Expecter) InjectTokenCtx(ctx interface{}, header interface{}) *MockAuthClient_InjectTokenCtx_Call { + return &MockAuthClient_InjectTokenCtx_Call{Call: _e.mock.On("InjectTokenCtx", ctx, header)} +} + +func (_c *MockAuthClient_InjectTokenCtx_Call) Run(run func(ctx context.Context, header http.Header)) *MockAuthClient_InjectTokenCtx_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(http.Header)) + }) + return _c +} + +func (_c *MockAuthClient_InjectTokenCtx_Call) Return(_a0 context.Context, _a1 error) *MockAuthClient_InjectTokenCtx_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockAuthClient_InjectTokenCtx_Call) RunAndReturn(run func(context.Context, http.Header) (context.Context, error)) *MockAuthClient_InjectTokenCtx_Call { + _c.Call.Return(run) + return _c +} + +// NewMockAuthClient creates a new instance of MockAuthClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockAuthClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockAuthClient { + mock := &MockAuthClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/backend/internal/jwt/claims.go b/backend/internal/auth/jwt/claims.go similarity index 100% rename from backend/internal/jwt/claims.go rename to backend/internal/auth/jwt/claims.go diff --git a/backend/internal/jwt/client.go b/backend/internal/auth/jwt/client.go similarity index 86% rename from backend/internal/jwt/client.go rename to backend/internal/auth/jwt/client.go index dc3e0bcca4..01e195be23 100644 --- a/backend/internal/jwt/client.go +++ b/backend/internal/auth/jwt/client.go @@ -11,6 +11,7 @@ import ( "github.com/auth0/go-jwt-middleware/v2/jwks" "github.com/auth0/go-jwt-middleware/v2/validator" nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" + "github.com/nucleuscloud/neosync/backend/internal/utils" ) type ClientConfig struct { @@ -99,28 +100,9 @@ func hasScope(scopes []string, expectedScope string) bool { return false } -func getBearerTokenFromHeader( - header http.Header, - key string, -) (string, error) { - unparsedToken := header.Get(key) - if unparsedToken == "" { - return "", nucleuserrors.NewUnauthenticated("must provide valid bearer token") - } - pieces := strings.Split(unparsedToken, " ") - if len(pieces) != 2 { - return "", nucleuserrors.NewUnauthenticated("token not in proper format") - } - if pieces[0] != "Bearer" { - return "", nucleuserrors.NewUnauthenticated("must provided bearer token") - } - token := pieces[1] - return token, nil -} - // Validates the ctx is authenticated. Stuffs the parsed token onto the context func (j *Client) InjectTokenCtx(ctx context.Context, header http.Header) (context.Context, error) { - token, err := getBearerTokenFromHeader(header, "Authorization") + token, err := utils.GetBearerTokenFromHeader(header, "Authorization") if err != nil { return nil, err } diff --git a/backend/internal/jwt/client_test.go b/backend/internal/auth/jwt/client_test.go similarity index 78% rename from backend/internal/jwt/client_test.go rename to backend/internal/auth/jwt/client_test.go index 7edc290271..a1f6248929 100644 --- a/backend/internal/jwt/client_test.go +++ b/backend/internal/auth/jwt/client_test.go @@ -71,27 +71,6 @@ func Test_New(t *testing.T) { assert.Error(t, err, "fails if api audiences is nil") } -func Test_getBearerTokenFromHeader(t *testing.T) { - _, err := getBearerTokenFromHeader(http.Header{}, "Authorization") - assert.Error(t, err) - _, err = getBearerTokenFromHeader(http.Header{"Authorization": []string{}}, "Authorization") - assert.Error(t, err) - _, err = getBearerTokenFromHeader(http.Header{"Authorization": []string{"Foo"}}, "Authorization") - assert.Error(t, err) - _, err = getBearerTokenFromHeader(http.Header{"Authorization": []string{"Foo Foo Foo"}}, "Authorization") - assert.Error(t, err) - _, err = getBearerTokenFromHeader(http.Header{"Authorization": []string{"Foo Foo"}}, "Authorization") - assert.Error(t, err) - _, err = getBearerTokenFromHeader(http.Header{"Authorization": []string{"Bearer"}}, "Authorization") - assert.Error(t, err) - _, err = getBearerTokenFromHeader(http.Header{"Authorization": []string{"Bearer 123"}}, "Authorizationn") - assert.Error(t, err) - - token, err := getBearerTokenFromHeader(http.Header{"Authorization": []string{"Bearer 123"}}, "Authorization") - assert.Nil(t, err) - assert.Equal(t, token, "123") -} - func Test_Client_InjectTokenCtx(t *testing.T) { customclaims := &CustomClaims{ Scope: "foo bar", diff --git a/backend/internal/jwt/mock_JwtValidator.go b/backend/internal/auth/jwt/mock_JwtValidator.go similarity index 100% rename from backend/internal/jwt/mock_JwtValidator.go rename to backend/internal/auth/jwt/mock_JwtValidator.go diff --git a/backend/internal/authmw/auth.go b/backend/internal/authmw/auth.go deleted file mode 100644 index 7b70083851..0000000000 --- a/backend/internal/authmw/auth.go +++ /dev/null @@ -1,39 +0,0 @@ -package authmw - -import ( - "context" - "net/http" -) - -type JwtClient interface { - InjectTokenCtx(ctx context.Context, header http.Header) (context.Context, error) -} - -type AuthMiddleware struct { - jwtClient JwtClient - // db *nucleusdb.NucleusDb -} - -func New( - jwtClient JwtClient, - // db *nucleusdb.NucleusDb, -) *AuthMiddleware { - - return &AuthMiddleware{jwtClient: jwtClient} -} - -func (n *AuthMiddleware) ValidateAndInjectAll(ctx context.Context, header http.Header) (context.Context, error) { - ctx, err := n.ValidateAndInjectJwtToken(ctx, header) - if err != nil { - return nil, err - } - return ctx, nil -} - -func (n *AuthMiddleware) ValidateAndInjectJwtToken(ctx context.Context, header http.Header) (context.Context, error) { - ctx, err := n.jwtClient.InjectTokenCtx(ctx, header) - if err != nil { - return nil, err - } - return ctx, nil -} diff --git a/backend/internal/cmds/mgmt/serve/connect/cmd.go b/backend/internal/cmds/mgmt/serve/connect/cmd.go index 544e509e26..92299be993 100644 --- a/backend/internal/cmds/mgmt/serve/connect/cmd.go +++ b/backend/internal/cmds/mgmt/serve/connect/cmd.go @@ -17,12 +17,13 @@ import ( "connectrpc.com/validate" "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" + auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey" + "github.com/nucleuscloud/neosync/backend/internal/auth/authmw" auth_client "github.com/nucleuscloud/neosync/backend/internal/auth/client" - "github.com/nucleuscloud/neosync/backend/internal/authmw" + auth_jwt "github.com/nucleuscloud/neosync/backend/internal/auth/jwt" up_cmd "github.com/nucleuscloud/neosync/backend/internal/cmds/mgmt/migrate/up" auth_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/auth" logger_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/logger" - auth_jwt "github.com/nucleuscloud/neosync/backend/internal/jwt" neosynclogger "github.com/nucleuscloud/neosync/backend/internal/logger" "github.com/nucleuscloud/neosync/backend/internal/nucleusdb" clientmanager "github.com/nucleuscloud/neosync/backend/internal/temporal/client-manager" @@ -122,25 +123,44 @@ func serve(ctx context.Context) error { return err } + otelInterceptor := otelconnect.NewInterceptor() + loggerInterceptor := logger_interceptor.NewInterceptor(logger) + stdInterceptors := []connect.Interceptor{ - otelconnect.NewInterceptor(), - logger_interceptor.NewInterceptor(logger), + otelInterceptor, + loggerInterceptor, validateInterceptor, } + // standard auth interceptors that should be applied to most services + stdAuthInterceptors := []connect.Interceptor{} + // this will only authenticate jwts, not api keys. Mostly used by just the api key service + jwtOnlyAuthInterceptors := []connect.Interceptor{} + isAuthEnabled := viper.GetBool("AUTH_ENABLED") if isAuthEnabled { jwtclient, err := auth_jwt.New(getJwtClientConfig()) if err != nil { return err } - stdInterceptors = append(stdInterceptors, auth_interceptor.NewInterceptor(authmw.New(jwtclient).ValidateAndInjectAll)) + apikeyClient := auth_apikey.New(db.Q, db.Db) + stdAuthInterceptors = append( + stdAuthInterceptors, + auth_interceptor.NewInterceptor( + authmw.New( + jwtclient, + apikeyClient, + ).InjectTokenCtx, + ), + ) + jwtOnlyAuthInterceptors = append( + jwtOnlyAuthInterceptors, + auth_interceptor.NewInterceptor( + jwtclient.InjectTokenCtx, + ), + ) } - stdInterceptorConnectOpt := connect.WithInterceptors( - stdInterceptors..., - ) - api := http.NewServeMux() useraccountService := v1alpha1_useraccountservice.New(&v1alpha1_useraccountservice.Config{ @@ -150,7 +170,8 @@ func serve(ctx context.Context) error { api.Handle( mgmtv1alpha1connect.NewUserAccountServiceHandler( useraccountService, - stdInterceptorConnectOpt, + connect.WithInterceptors(stdInterceptors...), + connect.WithInterceptors(stdAuthInterceptors...), ), ) @@ -158,7 +179,8 @@ func serve(ctx context.Context) error { api.Handle( mgmtv1alpha1connect.NewConnectionServiceHandler( connectionService, - stdInterceptorConnectOpt, + connect.WithInterceptors(stdInterceptors...), + connect.WithInterceptors(stdAuthInterceptors...), ), ) authcerts, err := getTemporalAuthCertificate() @@ -182,7 +204,8 @@ func serve(ctx context.Context) error { api.Handle( mgmtv1alpha1connect.NewJobServiceHandler( jobService, - stdInterceptorConnectOpt, + connect.WithInterceptors(stdInterceptors...), + connect.WithInterceptors(stdAuthInterceptors...), ), ) @@ -190,7 +213,8 @@ func serve(ctx context.Context) error { api.Handle( mgmtv1alpha1connect.NewTransformersServiceHandler( transformerService, - stdInterceptorConnectOpt, + connect.WithInterceptors(stdInterceptors...), + connect.WithInterceptors(stdAuthInterceptors...), ), ) @@ -220,12 +244,7 @@ func serve(ctx context.Context) error { api.Handle( mgmtv1alpha1connect.NewAuthServiceHandler( authService, - // auth service uses non-standard interceptors as we don't want to include the auth interceptor in this service - connect.WithInterceptors( - otelconnect.NewInterceptor(), - logger_interceptor.NewInterceptor(logger), - validateInterceptor, - ), + connect.WithInterceptors(stdInterceptors...), ), ) @@ -235,7 +254,8 @@ func serve(ctx context.Context) error { api.Handle( mgmtv1alpha1connect.NewApiKeyServiceHandler( apiKeyService, - stdInterceptorConnectOpt, + connect.WithInterceptors(stdInterceptors...), + connect.WithInterceptors(jwtOnlyAuthInterceptors...), ), ) diff --git a/backend/internal/utils/utils.go b/backend/internal/utils/utils.go index b4b517d82f..4e423e2938 100644 --- a/backend/internal/utils/utils.go +++ b/backend/internal/utils/utils.go @@ -3,6 +3,10 @@ package utils import ( "crypto/sha256" "fmt" + "net/http" + "strings" + + nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" ) func FilterSlice[T any](slice []T, filterFn func(T) bool) []T { @@ -29,3 +33,22 @@ func ToSha256(input string) string { bs := h.Sum(nil) return fmt.Sprintf("%x", bs) } + +func GetBearerTokenFromHeader( + header http.Header, + key string, +) (string, error) { + unparsedToken := header.Get(key) + if unparsedToken == "" { + return "", nucleuserrors.NewUnauthenticated("must provide valid bearer token") + } + pieces := strings.Split(unparsedToken, " ") + if len(pieces) != 2 { + return "", nucleuserrors.NewUnauthenticated("token not in proper format") + } + if pieces[0] != "Bearer" { + return "", nucleuserrors.NewUnauthenticated("must provided bearer token") + } + token := pieces[1] + return token, nil +} diff --git a/backend/internal/utils/utils_test.go b/backend/internal/utils/utils_test.go index b5832a19d2..4339076a48 100644 --- a/backend/internal/utils/utils_test.go +++ b/backend/internal/utils/utils_test.go @@ -2,6 +2,7 @@ package utils import ( "fmt" + "net/http" "testing" "github.com/stretchr/testify/assert" @@ -41,3 +42,24 @@ func Test_ToSha256(t *testing.T) { "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2", ) } + +func Test_GetBearerTokenFromHeader(t *testing.T) { + _, err := GetBearerTokenFromHeader(http.Header{}, "Authorization") + assert.Error(t, err) + _, err = GetBearerTokenFromHeader(http.Header{"Authorization": []string{}}, "Authorization") + assert.Error(t, err) + _, err = GetBearerTokenFromHeader(http.Header{"Authorization": []string{"Foo"}}, "Authorization") + assert.Error(t, err) + _, err = GetBearerTokenFromHeader(http.Header{"Authorization": []string{"Foo Foo Foo"}}, "Authorization") + assert.Error(t, err) + _, err = GetBearerTokenFromHeader(http.Header{"Authorization": []string{"Foo Foo"}}, "Authorization") + assert.Error(t, err) + _, err = GetBearerTokenFromHeader(http.Header{"Authorization": []string{"Bearer"}}, "Authorization") + assert.Error(t, err) + _, err = GetBearerTokenFromHeader(http.Header{"Authorization": []string{"Bearer 123"}}, "Authorizationn") + assert.Error(t, err) + + token, err := GetBearerTokenFromHeader(http.Header{"Authorization": []string{"Bearer 123"}}, "Authorization") + assert.Nil(t, err) + assert.Equal(t, token, "123") +} diff --git a/backend/services/mgmt/v1alpha1/api-key-service/api-keys.go b/backend/services/mgmt/v1alpha1/api-key-service/api-keys.go index 5d208e3d95..639c1cf3e7 100644 --- a/backend/services/mgmt/v1alpha1/api-key-service/api-keys.go +++ b/backend/services/mgmt/v1alpha1/api-key-service/api-keys.go @@ -2,12 +2,11 @@ package v1alpha1_apikeyservice import ( "context" - "fmt" "connectrpc.com/connect" - "github.com/google/uuid" db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" + "github.com/nucleuscloud/neosync/backend/internal/apikey" "github.com/nucleuscloud/neosync/backend/internal/dtomaps" nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" "github.com/nucleuscloud/neosync/backend/internal/nucleusdb" @@ -83,7 +82,7 @@ func (s *Service) CreateAccountApiKey( return nil, err } - clearKeyValue := getNewKeyValue() + clearKeyValue := apikey.NewV1AccountKey() hashedKeyValue := utils.ToSha256( clearKeyValue, ) @@ -103,10 +102,6 @@ func (s *Service) CreateAccountApiKey( }), nil } -func getNewKeyValue() string { - return fmt.Sprintf("neo_at_v1_%s", uuid.New().String()) -} - func (s *Service) RegenerateAccountApiKey( ctx context.Context, req *connect.Request[mgmtv1alpha1.RegenerateAccountApiKeyRequest], @@ -131,7 +126,7 @@ func (s *Service) RegenerateAccountApiKey( if err != nil { return nil, err } - clearKeyValue := getNewKeyValue() + clearKeyValue := apikey.NewV1AccountKey() hashedKeyValue := utils.ToSha256( clearKeyValue, ) diff --git a/backend/services/mgmt/v1alpha1/user-account-service/users.go b/backend/services/mgmt/v1alpha1/user-account-service/users.go index 083e9956e9..03dadc0239 100644 --- a/backend/services/mgmt/v1alpha1/user-account-service/users.go +++ b/backend/services/mgmt/v1alpha1/user-account-service/users.go @@ -8,9 +8,9 @@ import ( "github.com/jackc/pgx/v5/pgtype" db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" + authjwt "github.com/nucleuscloud/neosync/backend/internal/auth/jwt" "github.com/nucleuscloud/neosync/backend/internal/dtomaps" nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" - authjwt "github.com/nucleuscloud/neosync/backend/internal/jwt" "github.com/nucleuscloud/neosync/backend/internal/nucleusdb" ) diff --git a/backend/services/mgmt/v1alpha1/user-account-service/users_test.go b/backend/services/mgmt/v1alpha1/user-account-service/users_test.go index bf99526a5e..dc26c22d1c 100644 --- a/backend/services/mgmt/v1alpha1/user-account-service/users_test.go +++ b/backend/services/mgmt/v1alpha1/user-account-service/users_test.go @@ -11,7 +11,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - authjwt "github.com/nucleuscloud/neosync/backend/internal/jwt" + authjwt "github.com/nucleuscloud/neosync/backend/internal/auth/jwt" "github.com/nucleuscloud/neosync/backend/internal/nucleusdb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock"