Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 39 additions & 32 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/json"
"errors"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/samber/mo"
"gorm.io/gorm"
Expand Down Expand Up @@ -169,40 +168,37 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
playerUUID = mo.Some(p.UUID)
}

tx := app.DB.Begin()
defer tx.Rollback()

var client Client
if req.ClientToken == nil {
clientToken, err := RandomHex(16)
if err != nil {
return err
}
client = Client{
UUID: uuid.New().String(),
ClientToken: clientToken,
Version: 0,
PlayerUUID: OptionToNullString(playerUUID),
client = NewClient(user, clientToken, playerUUID)
if err := tx.Create(&client).Error; err != nil {
return err
}
user.Clients = append(user.Clients, client)
} else {
clientToken := *req.ClientToken
clientExists := false

for i := range user.Clients {
if user.Clients[i].ClientToken == clientToken {
clientExists = true
user.Clients[i].Version += 1
client = user.Clients[i]
break
}
}

if !clientExists {
client = Client{
UUID: uuid.New().String(),
ClientToken: clientToken,
Version: 0,
PlayerUUID: OptionToNullString(playerUUID),
if err := tx.First(&client, "user_uuid = ? AND client_token = ?", user.UUID, clientToken).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
// Client does not exist
client = NewClient(user, clientToken, playerUUID)
if err := tx.Create(&client).Error; err != nil {
return err
}
} else {
// Client exists
client.Version += 1
if err := tx.Save(&client).Error; err != nil {
return err
}
user.Clients = append(user.Clients, client)
}
}

Expand Down Expand Up @@ -246,8 +242,7 @@ func AuthAuthenticate(app *App) func(c echo.Context) error {
return err
}

// Save changes to user.Clients
if err := app.DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil {
if err := tx.Commit().Error; err != nil {
return err
}

Expand Down Expand Up @@ -285,8 +280,12 @@ func AuthRefresh(app *App) func(c echo.Context) error {
return err
}

client := app.GetClient(req.AccessToken, StalePolicyAllow)
if client == nil || client.ClientToken != req.ClientToken {
client, err := app.GetClient(req.AccessToken, StalePolicyAllow)
var userError *UserError
if err != nil && !errors.As(err, &userError) {
return err
}
if err != nil || client.ClientToken != req.ClientToken {
return invalidAccessTokenError
}
user := client.User
Expand Down Expand Up @@ -376,8 +375,12 @@ func AuthValidate(app *App) func(c echo.Context) error {
return err
}

client := app.GetClient(req.AccessToken, StalePolicyDeny)
if client == nil || client.ClientToken != req.ClientToken {
client, err := app.GetClient(req.AccessToken, StalePolicyDeny)
var userError *UserError
if err != nil && !errors.As(err, &userError) {
return err
}
if err != nil || client.ClientToken != req.ClientToken {
return c.NoContent(http.StatusForbidden)
}

Expand Down Expand Up @@ -427,8 +430,12 @@ func AuthInvalidate(app *App) func(c echo.Context) error {
return err
}

client := app.GetClient(req.AccessToken, StalePolicyAllow)
if client == nil {
client, err := app.GetClient(req.AccessToken, StalePolicyAllow)
var userError *UserError
if err != nil && !errors.As(err, &userError) {
return err
}
if err != nil {
return invalidAccessTokenError
}

Expand Down
65 changes: 57 additions & 8 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"net/http"
"testing"
"time"
)

func TestAuth(t *testing.T) {
Expand All @@ -29,6 +30,7 @@ func TestAuth(t *testing.T) {
t.Run("Test /validate", ts.testValidate)

t.Run("Test authenticate with duplicate client token", ts.testDuplicateClientToken)
t.Run("Test authenticate too many client tokens", ts.testTooManyClientTokens)
}
}

Expand Down Expand Up @@ -57,8 +59,9 @@ func (ts *TestSuite) authenticate(t *testing.T, username string, password string
accessToken := authenticateRes.AccessToken

// Check that the access token is valid
client := ts.App.GetClient(accessToken, StalePolicyDeny)
client, err := ts.App.GetClient(accessToken, StalePolicyDeny)
assert.NotNil(t, client)
assert.Nil(t, err)
assert.Equal(t, client.ClientToken, clientToken)

return &authenticateRes
Expand Down Expand Up @@ -118,10 +121,12 @@ func (ts *TestSuite) testAuthenticate(t *testing.T) {
assert.NotNil(t, client.Player)
assert.Equal(t, TEST_PLAYER_NAME, client.Player.Name)

accessTokenClient := ts.App.GetClient(response0.AccessToken, StalePolicyDeny)
accessTokenClient, err := ts.App.GetClient(response0.AccessToken, StalePolicyDeny)
assert.Nil(t, err)
assert.NotNil(t, accessTokenClient)
accessTokenClient.Player = client.Player
accessTokenClient.User = client.User
accessTokenClient.LastUsedAt = client.LastUsedAt

assert.Equal(t, client, *accessTokenClient)

Expand Down Expand Up @@ -372,7 +377,8 @@ func (ts *TestSuite) testInvalidate(t *testing.T) {

// Successful invalidate
// We should start with valid clients in the database
client := ts.App.GetClient(accessToken, StalePolicyDeny)
client, err := ts.App.GetClient(accessToken, StalePolicyDeny)
assert.Nil(t, err)
assert.NotNil(t, client)
var clients []Client
result := ts.App.DB.Model(Client{}).Where("player_uuid = ?", &client.Player.UUID).Find(&clients)
Expand All @@ -394,7 +400,8 @@ func (ts *TestSuite) testInvalidate(t *testing.T) {

// The token version of each client should have been incremented,
// invalidating all previously-issued JWTs
assert.Nil(t, ts.App.GetClient(accessToken, StalePolicyDeny))
_, err = ts.App.GetClient(accessToken, StalePolicyDeny)
assert.NotNil(t, err)
result = ts.App.DB.Model(Client{}).Where("player_uuid = ?", &client.Player.UUID).Find(&clients)
assert.Nil(t, result.Error)
for _, client := range clients {
Expand Down Expand Up @@ -445,11 +452,13 @@ func (ts *TestSuite) testRefresh(t *testing.T) {
assert.NotEqual(t, accessToken, refreshRes.AccessToken)

// The old accessToken should be invalid
client := ts.App.GetClient(accessToken, StalePolicyDeny)
client, err := ts.App.GetClient(accessToken, StalePolicyDeny)
assert.NotNil(t, err)
assert.Nil(t, client)

// The new token should be valid
client = ts.App.GetClient(refreshRes.AccessToken, StalePolicyDeny)
client, err = ts.App.GetClient(refreshRes.AccessToken, StalePolicyDeny)
assert.Nil(t, err)
assert.NotNil(t, client)

// The response should include a profile
Expand Down Expand Up @@ -537,7 +546,8 @@ func (ts *TestSuite) testSignout(t *testing.T) {
assert.Nil(t, result.Error)

// We should start with valid clients in the database
client := ts.App.GetClient(accessToken, StalePolicyDeny)
client, err := ts.App.GetClient(accessToken, StalePolicyDeny)
assert.Nil(t, err)
assert.NotNil(t, client)
var clients []Client
result = ts.App.DB.Model(Client{}).Where("user_uuid = ?", client.UserUUID).Find(&clients)
Expand All @@ -559,7 +569,8 @@ func (ts *TestSuite) testSignout(t *testing.T) {

// The token version of each client should have been incremented,
// invalidating all previously-issued JWTs
assert.Nil(t, ts.App.GetClient(accessToken, StalePolicyDeny))
_, err = ts.App.GetClient(accessToken, StalePolicyDeny)
assert.NotNil(t, err)
result = ts.App.DB.Model(Client{}).Where("user_uuid = ?", client.UserUUID).Find(&clients)
assert.Nil(t, result.Error)
assert.True(t, len(clients) > 0)
Expand Down Expand Up @@ -671,3 +682,41 @@ func (ts *TestSuite) testDuplicateClientToken(t *testing.T) {
assert.Nil(t, result.Error)
assert.Equal(t, TEST_OTHER_USERNAME, otherClient.Player.Name)
}

func (ts *TestSuite) testTooManyClientTokens(t *testing.T) {
var PAST time.Time = Unwrap(time.Parse(time.RFC3339Nano, "2018-01-01T00:00:00.000000000Z"))

var user User
assert.Nil(t, ts.App.DB.First(&user, "username = ?", TEST_USERNAME).Error)
assert.Nil(t, ts.App.DB.Where("user_uuid = ?", user.UUID).Delete(&Client{}).Error)

clients := make([]Client, 0, Constants.MaxClientCount)

// Create MaxCountClient clients
for range ts.App.Constants.MaxClientCount {
clientToken, err := RandomHex(16)
assert.Nil(t, err)

client := NewClient(&user, clientToken, mo.None[string]())
client.LastUsedAt = PAST

clients = append(clients, client)
}

assert.Nil(t, ts.App.DB.Create(&clients).Error)

var count int64
assert.Nil(t, ts.App.DB.Model(&Client{}).Where("user_uuid = ?", user.UUID).Count(&count).Error)
assert.Equal(t, int64(ts.App.Constants.MaxClientCount), count)

// Add one more client
response := ts.authenticate(t, TEST_PLAYER_NAME, TEST_PASSWORD)

// There should still only be MaxClientCount clients in the database
assert.Nil(t, ts.App.DB.Model(&Client{}).Where("user_uuid = ?", user.UUID).Count(&count).Error)
assert.Equal(t, int64(ts.App.Constants.MaxClientCount), count)

// The new client should have not been evicted
var client Client
assert.Nil(t, ts.App.DB.Find(&client, "user_uuid = ? AND client_token = ?", user.UUID, response.ClientToken).Error)
}
2 changes: 2 additions & 0 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ type ConstantsType struct {
ConfigDirectory string
MaxPlayerNameLength int
MaxUsernameLength int
MaxClientCount int
Version string
License string
LicenseURL string
Expand All @@ -156,6 +157,7 @@ var Constants = &ConstantsType{
MaxPlayerCountUnlimited: -1,
MaxUsernameLength: 16,
MaxPlayerNameLength: 16,
MaxClientCount: 256,
ConfigDirectory: GetDefaultConfigDirectory(),
Version: VERSION,
License: LICENSE,
Expand Down
Loading