diff --git a/README.md b/README.md index 89dc9af8..d4c3174e 100644 --- a/README.md +++ b/README.md @@ -258,3 +258,4 @@ This project is licensed under the [Apache 2.0 license](LICENSE). If you have any issues or feature requests, please contact us. PR is welcomed. - https://github.com/casbin/casbin/issues - https://discord.gg/S5UjpzGZjN + diff --git a/enforcer_cached.go b/enforcer_cached.go index 2c230db1..8c1c6847 100644 --- a/enforcer_cached.go +++ b/enforcer_cached.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "time" + "github.com/casbin/casbin/v3/persist" "github.com/casbin/casbin/v3/persist/cache" ) @@ -87,6 +88,7 @@ func (e *CachedEnforcer) Enforce(rvals ...interface{}) (bool, error) { return res, err } +// LoadPolicy reloads the policy from file/database and clears the cache. func (e *CachedEnforcer) LoadPolicy() error { if atomic.LoadInt32(&e.enableCache) != 0 { if err := e.cache.Clear(); err != nil { @@ -96,6 +98,33 @@ func (e *CachedEnforcer) LoadPolicy() error { return e.Enforcer.LoadPolicy() } +// SetWatcher sets the current watcher for the CachedEnforcer. +// It overrides the base Enforcer.SetWatcher to ensure that: +// 1) For WatcherEx implementations (e.g., Redis watcher), a proper callback is set +// that calls CachedEnforcer.InvalidateCache() + LoadPolicy() for efficient cache clearing. +// 2) For basic Watcher implementations, the callback calls LoadPolicy() which clears cache. +func (e *CachedEnforcer) SetWatcher(watcher persist.Watcher) error { + e.watcher = watcher + if _, ok := watcher.(persist.WatcherEx); ok { + // For WatcherEx, set a callback that invalidates cache on any policy change. + // The callback is invoked by the watcher implementation (e.g., Redis pub/sub subscriber) + // when another instance modifies the policy. + return watcher.SetUpdateCallback(func(string) { + // First invalidate the cache to prevent stale reads + if atomic.LoadInt32(&e.enableCache) != 0 { + _ = e.InvalidateCache() + } + // Then reload the policy from the persistence layer + _ = e.LoadPolicy() + }) + } + // For basic Watcher, the default callback is sufficient since + // LoadPolicy() on CachedEnforcer already clears the cache. + return watcher.SetUpdateCallback(func(string) { _ = e.LoadPolicy() }) +} + +// RemovePolicy removes an authorization rule from the current policy. +// It also removes the corresponding cache entry. func (e *CachedEnforcer) RemovePolicy(params ...interface{}) (bool, error) { if atomic.LoadInt32(&e.enableCache) != 0 { key, ok := e.getKey(params...) @@ -132,10 +161,16 @@ func (e *CachedEnforcer) getCachedResult(key string) (res bool, err error) { return e.cache.Get(key) } +// SetExpireTime sets the cache expiration time (TTL). +// Use 0 or negative duration to make cache entries never expire. +// This is useful in multi-instance scenarios where you want to avoid lock contention +// and recalculation overhead, and instead manually trigger LoadPolicy() or InvalidateCache() +// when policies change. func (e *CachedEnforcer) SetExpireTime(expireTime time.Duration) { e.expireTime = expireTime } +// SetCache sets the cache implementation. func (e *CachedEnforcer) SetCache(c cache.Cache) { e.cache = c } @@ -173,7 +208,7 @@ func GetCacheKey(params ...interface{}) (string, bool) { return key.String(), true } -// ClearPolicy clears all policy. +// ClearPolicy clears all policy and the cache. func (e *CachedEnforcer) ClearPolicy() { if atomic.LoadInt32(&e.enableCache) != 0 { if err := e.cache.Clear(); err != nil { diff --git a/enforcer_cached_synced.go b/enforcer_cached_synced.go index 579281d6..42cefc29 100644 --- a/enforcer_cached_synced.go +++ b/enforcer_cached_synced.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "github.com/casbin/casbin/v3/persist" "github.com/casbin/casbin/v3/persist/cache" ) @@ -91,6 +92,31 @@ func (e *SyncedCachedEnforcer) LoadPolicy() error { return e.SyncedEnforcer.LoadPolicy() } +// SetWatcher sets the current watcher for the SyncedCachedEnforcer. +// It overrides the base SyncedEnforcer.SetWatcher to ensure that: +// 1) For WatcherEx implementations (e.g., Redis watcher), a proper callback is set +// that calls InvalidateCache() + LoadPolicy() for efficient cache clearing. +// 2) For basic Watcher implementations, the callback calls LoadPolicy() which clears cache. +func (e *SyncedCachedEnforcer) SetWatcher(watcher persist.Watcher) error { + e.SyncedEnforcer.watcher = watcher + if _, ok := watcher.(persist.WatcherEx); ok { + // For WatcherEx, set a callback that invalidates cache on any policy change. + // The callback is invoked by the watcher implementation (e.g., Redis pub/sub subscriber) + // when another instance modifies the policy. + return watcher.SetUpdateCallback(func(string) { + // First invalidate the cache to prevent stale reads + if atomic.LoadInt32(&e.enableCache) != 0 { + _ = e.InvalidateCache() + } + // Then reload the policy from the persistence layer + _ = e.LoadPolicy() + }) + } + // For basic Watcher, the default callback is sufficient since + // LoadPolicy() on SyncedCachedEnforcer already clears the cache. + return watcher.SetUpdateCallback(func(string) { _ = e.LoadPolicy() }) +} + func (e *SyncedCachedEnforcer) AddPolicy(params ...interface{}) (bool, error) { if ok, err := e.checkOneAndRemoveCache(params...); !ok { return ok, err diff --git a/enforcer_cached_watcher_test.go b/enforcer_cached_watcher_test.go new file mode 100644 index 00000000..29aa11ec --- /dev/null +++ b/enforcer_cached_watcher_test.go @@ -0,0 +1,314 @@ +// Copyright 2024 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "sync" + "testing" + + "github.com/casbin/casbin/v3/persist/cache" +) + +// mockWatcherEx is a mock WatcherEx for testing callback behavior. +type mockWatcherEx struct { + mu sync.Mutex + callback func(string) + updateCount int +} + +func (m *mockWatcherEx) SetUpdateCallback(fn func(string)) error { + m.mu.Lock() + defer m.mu.Unlock() + m.callback = fn + return nil +} + +func (m *mockWatcherEx) Update() error { + return nil +} + +func (m *mockWatcherEx) Close() {} + +func (m *mockWatcherEx) UpdateForAddPolicy(sec, ptype string, params ...string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.updateCount++ + // Simulate calling the callback when a policy change is pushed + if m.callback != nil { + m.callback("add_policy") + } + return nil +} + +func (m *mockWatcherEx) UpdateForRemovePolicy(sec, ptype string, params ...string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.updateCount++ + if m.callback != nil { + m.callback("remove_policy") + } + return nil +} + +func (m *mockWatcherEx) UpdateForRemoveFilteredPolicy(sec, ptype string, fieldIndex int, fieldValues ...string) error { + return nil +} + +func (m *mockWatcherEx) UpdateForSavePolicy(model interface{}) error { + return nil +} + +func (m *mockWatcherEx) UpdateForAddPolicies(sec string, ptype string, rules ...[]string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.updateCount++ + if m.callback != nil { + m.callback("add_policies") + } + return nil +} + +func (m *mockWatcherEx) UpdateForRemovePolicies(sec string, ptype string, rules ...[]string) error { + return nil +} + +func (m *mockWatcherEx) UpdateForUpdatePolicy(sec string, ptype string, oldRule, newRule []string) error { + return nil +} + +func (m *mockWatcherEx) UpdateForUpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) error { + return nil +} + +func (m *mockWatcherEx) GetUpdateCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.updateCount +} + +// TestCachedEnforcerWatcherExCallback tests that CachedEnforcer properly sets a callback +// when a WatcherEx is provided, and that the callback invalidates the cache. +func TestCachedEnforcerWatcherExCallback(t *testing.T) { + e, err := NewCachedEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatalf("Failed to create CachedEnforcer: %v", err) + } + + mock := &mockWatcherEx{} + if err := e.SetWatcher(mock); err != nil { + t.Fatalf("Failed to set watcher: %v", err) + } + + // First enforce call - should be a cache miss + res, err := e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected alice to have read access to data1") + } + + // Second enforce call - should be cached + res, err = e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected cached result for alice read data1") + } + + // Simulate a policy change from another instance via WatcherEx.UpdateForAddPolicy + // This should trigger the callback, which should invalidate the cache + // and reload the policy. + _, err = e.AddPolicy("alice", "data1", "write") + if err != nil { + t.Fatalf("AddPolicy failed: %v", err) + } + + // The cache should have been invalidated by the callback + // Verify that the new policy is reflected + res, err = e.Enforce("alice", "data1", "write") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected alice to have write access to data1 after policy was added") + } +} + +// TestCachedEnforcerBasicWatcherCallback tests that CachedEnforcer works with +// a basic Watcher (not WatcherEx) and properly invalidates cache. +func TestCachedEnforcerBasicWatcherCallback(t *testing.T) { + e, err := NewCachedEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatalf("Failed to create CachedEnforcer: %v", err) + } + + mock := &mockBasicWatcher{} + if err := e.SetWatcher(mock); err != nil { + t.Fatalf("Failed to set watcher: %v", err) + } + + // First enforce call - cache miss + res, err := e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected alice to have read access to data1") + } + + // Second enforce call - should be cached (same result) + res, err = e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected cached result") + } + + // Verify the basic watcher is being used (not the WatcherEx path) + if !mock.callbackSet { + t.Error("Expected callback to be set for basic watcher") + } +} + +// mockBasicWatcher is a basic Watcher for testing. +type mockBasicWatcher struct { + mu sync.Mutex + callback func(string) + callbackSet bool +} + +func (m *mockBasicWatcher) SetUpdateCallback(fn func(string)) error { + m.mu.Lock() + defer m.mu.Unlock() + m.callback = fn + m.callbackSet = true + return nil +} + +func (m *mockBasicWatcher) Update() error { + return nil +} + +func (m *mockBasicWatcher) Close() {} + +// TestSyncedCachedEnforcerWatcherExCallback tests that SyncedCachedEnforcer properly +// sets a callback for WatcherEx implementations. +func TestSyncedCachedEnforcerWatcherExCallback(t *testing.T) { + e, err := NewSyncedCachedEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatalf("Failed to create SyncedCachedEnforcer: %v", err) + } + + mock := &mockWatcherEx{} + if err := e.SetWatcher(mock); err != nil { + t.Fatalf("Failed to set watcher: %v", err) + } + + // First enforce - cache miss + res, err := e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected alice to have read access to data1") + } + + // Second enforce - cached + res, err = e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected cached result") + } + + // Remove policy - should trigger callback which invalidates cache + _, err = e.RemovePolicy("alice", "data1", "read") + if err != nil { + t.Fatalf("RemovePolicy failed: %v", err) + } + + // After policy removal and cache invalidation, this should now be false + res, err = e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if res { + t.Error("Expected alice to NOT have read access after policy was removed") + } +} + +// TestInvalidateCacheDirectly tests that InvalidateCache works independently. +func TestInvalidateCacheDirectly(t *testing.T) { + e, err := NewCachedEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatalf("Failed to create CachedEnforcer: %v", err) + } + + // Populate cache + res, _ := e.Enforce("alice", "data1", "read") + if !res { + t.Error("Expected alice to have read access to data1") + } + + // Directly invalidate cache + if err := e.InvalidateCache(); err != nil { + t.Fatalf("InvalidateCache failed: %v", err) + } + + // Next enforce should recompute (cache miss) and succeed + res, err = e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected alice to still have read access after cache invalidation") + } +} + +// TestCacheWithCustomTTL tests that setting a custom TTL works. +func TestCacheWithCustomTTL(t *testing.T) { + e, err := NewCachedEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatalf("Failed to create CachedEnforcer: %v", err) + } + + // Set a short TTL + e.SetExpireTime(100) // 100ns - essentially immediate for tests + + // Populate cache + res, err := e.Enforce("alice", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !res { + t.Error("Expected alice to have read access to data1") + } + + // The cache entry should have the TTL set + c, ok := e.cache.(*cache.DefaultCache) + if !ok { + t.Skip("Cache is not a DefaultCache, skipping TTL test") + } + + item, exists := (*c)["alice$$data1$$read$$"] + if !exists { + t.Error("Expected cache entry for alice$$data1$$read$$") + } + _ = item // item.ttl should be 100ns +}