Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test case where Model() is not replacing the old value within a transaction #739

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion db.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {

func RunMigrations() {
var err error
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}}
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Employee{}, &Department{}, &Audit{}}
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })

Expand Down
17 changes: 10 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
module gorm.io/playground

go 1.20
go 1.22.3

require (
github.com/google/uuid v1.3.1
gorm.io/driver/mysql v1.5.2
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.5.3
gorm.io/driver/sqlserver v1.5.1
gorm.io/gen v0.3.25
gorm.io/gorm v1.25.4
gorm.io/gorm v1.25.10
)

require (
Expand All @@ -17,15 +18,17 @@ require (
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/microsoft/go-mssqldb v1.5.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/sys v0.14.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/tools v0.15.0 // indirect
gorm.io/datatypes v1.1.1-0.20230130040222-c43177d3cf8c // indirect
gorm.io/hints v1.1.0 // indirect
Expand Down
193 changes: 193 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
package main

import (
"errors"
"fmt"
"github.com/google/uuid"
"gorm.io/driver/postgres"
"gorm.io/gorm"
gormLogger "gorm.io/gorm/logger"
"os"
"sync"
"testing"
)

Expand All @@ -18,3 +26,188 @@ func TestGORM(t *testing.T) {
t.Errorf("Failed, got error: %v", err)
}
}

type dbConnections struct {
aDB *gorm.DB
tDB map[string]*gorm.DB
}

var db dbConnections

// TestSelectResultCaching
// Before running this test execute this query
// DELETE FROM departments
// WHERE tenant_id = ':1'
// AND department_id = ':2'
func TestSelectResultCaching(t *testing.T) {
initDB()
if err := initConstraints(); err != nil {
t.Errorf(err.Error())
}
newD := Department{
TenantID: "T1",
DepartmentID: uuid.New().String(),
DepartmentName: "FANCY DEPARTMENT NAME",
Employees: []Employee{
{EmployeeID: "001", EmployeeName: "Jinzhu"},
{EmployeeID: "002", EmployeeName: "Muchen"},
{EmployeeID: "003", EmployeeName: "Mingze"},
{EmployeeID: "004", EmployeeName: "Yichen"},
{EmployeeID: "005", EmployeeName: "Muyang"},
},
}
tDB, err := getDB(newD.TenantID)
if err != nil {
t.Errorf(err.Error())
return
}
dbDept, err := getDept(tDB, newD.TenantID, newD.DepartmentID)
if err != nil {
t.Errorf("Get department before create error: %v", err)
return
}
if len(dbDept.DepartmentID) > 0 {
t.Errorf("Get department before create but it already exists: " + dbDept.DepartmentID)
return
}
dbDept, err = createDept(tDB, newD)
if err != nil {
t.Errorf("Create department error: %v", err)
return
}
dbDept, err = getDept(tDB, dbDept.TenantID, dbDept.DepartmentID)
if err != nil {
t.Errorf("Get department after create error: %v", err)
return
}
}

func getDept(tDB *gorm.DB, tenantID string, departmentID string) (*Department, error) {
var departments []Department
output := tDB.Model(Department{}).
Preload("Employees").
Where("tenant_id = ?", tenantID).
Where("department_id = ?", departmentID).
Find(&departments)
if output.Error != nil || output.RowsAffected == 0 {
return &Department{}, output.Error
}
if output.RowsAffected == 1 {
return &departments[0], nil
}
return &Department{}, errors.New(fmt.Sprintf("too many rows affected %v, check indexes and keys", output.RowsAffected))
}

func createDept(tDB *gorm.DB, newD Department) (*Department, error) {
tx := tDB.Begin()
defer tx.Rollback()
output := tx.
Model(&Department{}).
Create(newD)
if output.Error != nil || output.RowsAffected == 0 {
return &Department{}, output.Error
}
audit := Audit{
AuditID: uuid.New().String(),
AuditDesc: "created department " + newD.DepartmentID,
}
tx.
Model(Audit{}).
Create(&audit)
if output.Error != nil || output.RowsAffected == 0 {
return &Department{}, output.Error
}
tx.Commit()
return &newD, output.Error
}

func initDB() {
db.aDB = DB
db.tDB = make(map[string]*gorm.DB)
}

const grantSQL = "grant delete, insert, references, select, trigger, truncate, update on %s to %s;"
const policySQL = "CREATE POLICY %s ON %s using (tenant_id = current_setting('myapp.current_tenant_id')) WITH CHECK (tenant_id = current_setting('myapp.current_tenant_id'));"
const rlsSQL = "ALTER TABLE %s ENABLE ROW LEVEL SECURITY;"
const setTenantIDSQL = "SET myapp.current_tenant_id = '%s'"

func initConstraints() error {
aDB, err := getDB("")
if err != nil {
return err
}

nonAdminUser := os.Getenv("NON_ADMIN_USER")

err = constrain(aDB, nonAdminUser, Employee{}.TableName())
if err != nil {
return err
}
err = constrain(aDB, nonAdminUser, Department{}.TableName())
if err != nil {
return err
}
return nil
}

func constrain(aDB *gorm.DB, nonAdminUser string, tName string) error {
//1. Grant
grantQuery := fmt.Sprintf(grantSQL, tName, nonAdminUser)
err := aDB.Exec(grantQuery).Error
if err != nil {
return err
}
//2. policy
policyQuery := fmt.Sprintf(policySQL, tName+"_rls_policy", tName)
err = aDB.Exec(policyQuery).Error
if err != nil {
return err
}
//3. enable RLS
rlsQuery := fmt.Sprintf(rlsSQL, tName)
err = aDB.Exec(rlsQuery).Error
if err != nil {
return err
}
return nil
}

var lock sync.Mutex

func getDB(tenantID string) (*gorm.DB, error) {
if len(tenantID) == 0 {
return db.aDB, nil
}
tDB, ok := db.tDB[tenantID]
if ok {
return getDBSession(tDB, tenantID)
}
lock.Lock()
defer lock.Unlock()
tDB, ok = db.tDB[tenantID]
if ok {
return getDBSession(tDB, tenantID)
}
var err error
tDB, err = gorm.Open(postgres.Open(os.Getenv("TDSN")), &gorm.Config{})
if err != nil {
return nil, err
}
db.tDB[tenantID] = tDB

tDB = db.tDB[tenantID]
return getDBSession(tDB, tenantID)
}

func getDBSession(conn *gorm.DB, tenantID string) (*gorm.DB, error) {
session := conn.Session(&gorm.Session{
// TODO add more attributes https://gorm.io/docs/session.html if necessary
Logger: gormLogger.Default.LogMode(gormLogger.Info),
}).
Exec(fmt.Sprintf(setTenantIDSQL, tenantID))
if session.Error != nil {
return nil, session.Error
}
return session, nil

}
36 changes: 35 additions & 1 deletion models.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,40 @@ type Company struct {
}

type Language struct {
Code string `gorm:"primarykey"`
Code string `gorm:"primaryKey"`
Name string
}

type Department struct {
TenantID string `gorm:"size:50;primaryKey"`
DepartmentID string `gorm:"size:50;primaryKey"`
DepartmentName string `gorm:"size:255"`
Employees []Employee `gorm:"foreignKey:TenantID,DepartmentID;references:TenantID,DepartmentID"`
}

type Employee struct {
TenantID string `gorm:"size:50;primaryKey"`
EmployeeID string `gorm:"size:50;primaryKey"`
EmployeeName string `gorm:"size:255"`
DepartmentID string `gorm:"size:50"`
}

type Audit struct {
AuditID string `gorm:"size:50;primaryKey"`
AuditDesc string `gorm:"size:255"`
}

// TableName setting
func (Employee) TableName() string {
return "employee"
}

// TableName setting
func (Department) TableName() string {
return "department"
}

// TableName setting
func (Audit) TableName() string {
return "audit"
}
Loading