diff --git a/go.mod b/go.mod index 59112e71..1a8be523 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.24.1 require ( firebase.google.com/go/v4 v4.12.1 github.com/android-sms-gateway/client-go v1.9.5 - github.com/android-sms-gateway/core v1.0.1 github.com/ansrivas/fiberprometheus/v2 v2.6.1 github.com/capcom6/go-helpers v0.3.0 github.com/capcom6/go-infra-fx v0.4.0 diff --git a/go.sum b/go.sum index 2738cf6f..9c4626b1 100644 --- a/go.sum +++ b/go.sum @@ -34,8 +34,6 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/android-sms-gateway/client-go v1.9.5 h1:fHrE1Pi3rKUdPVMmI9evKW0iyjB5bMIhFRxyq1wVQ+o= github.com/android-sms-gateway/client-go v1.9.5/go.mod h1:DQsReciU1xcaVW3T5Z2bqslNdsAwCFCtghawmA6g6L4= -github.com/android-sms-gateway/core v1.0.1 h1:7QyqyW3UQSQmEXQuUgXjZwHSnOd65DTxHUyhXQi6gpc= -github.com/android-sms-gateway/core v1.0.1/go.mod h1:HXczGDCKxTeuiwadPElczCx/y3Y6Wamc5kl5nFp5rVM= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/ansrivas/fiberprometheus/v2 v2.6.1 h1:wac3pXaE6BYYTF04AC6K0ktk6vCD+MnDOJZ3SK66kXM= diff --git a/internal/config/config.go b/internal/config/config.go index 49b5df5a..f470f859 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -16,6 +16,7 @@ type Config struct { SSE SSE `yaml:"sse"` // server-sent events config Messages Messages `yaml:"messages"` // messages config Cache Cache `yaml:"cache"` // cache (memory or redis) config + PubSub PubSub `yaml:"pubsub"` // pubsub (memory or redis) config } type Gateway struct { @@ -81,6 +82,10 @@ type Cache struct { URL string `yaml:"url" envconfig:"CACHE__URL"` } +type PubSub struct { + URL string `yaml:"url" envconfig:"PUBSUB__URL"` +} + var defaultConfig = Config{ Gateway: Gateway{Mode: GatewayModePublic}, HTTP: HTTP{ @@ -113,4 +118,7 @@ var defaultConfig = Config{ Cache: Cache{ URL: "memory://", }, + PubSub: PubSub{ + URL: "memory://", + }, } diff --git a/internal/config/module.go b/internal/config/module.go index e4ca2136..5b5f7ad1 100644 --- a/internal/config/module.go +++ b/internal/config/module.go @@ -11,6 +11,7 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/modules/messages" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/sse" + "github.com/android-sms-gateway/server/internal/sms-gateway/pubsub" "github.com/capcom6/go-infra-fx/config" "github.com/capcom6/go-infra-fx/db" "github.com/capcom6/go-infra-fx/http" @@ -122,4 +123,10 @@ var Module = fx.Module( URL: cfg.Cache.URL, } }), + fx.Provide(func(cfg Config) pubsub.Config { + return pubsub.Config{ + URL: cfg.PubSub.URL, + BufferSize: 128, + } + }), ) diff --git a/internal/sms-gateway/app.go b/internal/sms-gateway/app.go index 37287cbf..cf38fbfd 100644 --- a/internal/sms-gateway/app.go +++ b/internal/sms-gateway/app.go @@ -21,6 +21,7 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/modules/webhooks" "github.com/android-sms-gateway/server/internal/sms-gateway/online" "github.com/android-sms-gateway/server/internal/sms-gateway/openapi" + "github.com/android-sms-gateway/server/internal/sms-gateway/pubsub" "github.com/capcom6/go-infra-fx/cli" "github.com/capcom6/go-infra-fx/db" "github.com/capcom6/go-infra-fx/http" @@ -45,6 +46,7 @@ var Module = fx.Module( push.Module, db.Module, cache.Module(), + pubsub.Module(), events.Module, messages.Module(), health.Module, diff --git a/internal/sms-gateway/cache/factory.go b/internal/sms-gateway/cache/factory.go index 76895199..16659d89 100644 --- a/internal/sms-gateway/cache/factory.go +++ b/internal/sms-gateway/cache/factory.go @@ -4,7 +4,6 @@ import ( "fmt" "net/url" - "github.com/android-sms-gateway/core/redis" "github.com/android-sms-gateway/server/pkg/cache" ) @@ -40,13 +39,14 @@ func NewFactory(config Config) (Factory, error) { }, }, nil case "redis": - client, err := redis.New(redis.Config{URL: config.URL}) - if err != nil { - return nil, fmt.Errorf("can't create redis client: %w", err) - } return &factory{ new: func(name string) (Cache, error) { - return cache.NewRedis(client, name, 0), nil + return cache.NewRedis(cache.RedisConfig{ + Client: nil, + URL: config.URL, + Prefix: keyPrefix + name, + TTL: 0, + }) }, }, nil default: @@ -56,5 +56,5 @@ func NewFactory(config Config) (Factory, error) { // New implements Factory. func (f *factory) New(name string) (Cache, error) { - return f.new(keyPrefix + name) + return f.new(name) } diff --git a/internal/sms-gateway/modules/events/events.go b/internal/sms-gateway/modules/events/events.go index 89c90941..e3d2b043 100644 --- a/internal/sms-gateway/modules/events/events.go +++ b/internal/sms-gateway/modules/events/events.go @@ -6,15 +6,15 @@ import ( "github.com/android-sms-gateway/client-go/smsgateway" ) -func NewMessageEnqueuedEvent() *Event { +func NewMessageEnqueuedEvent() Event { return NewEvent(smsgateway.PushMessageEnqueued, nil) } -func NewWebhooksUpdatedEvent() *Event { +func NewWebhooksUpdatedEvent() Event { return NewEvent(smsgateway.PushWebhooksUpdated, nil) } -func NewMessagesExportRequestedEvent(since, until time.Time) *Event { +func NewMessagesExportRequestedEvent(since, until time.Time) Event { return NewEvent( smsgateway.PushMessagesExportRequested, map[string]string{ @@ -24,6 +24,6 @@ func NewMessagesExportRequestedEvent(since, until time.Time) *Event { ) } -func NewSettingsUpdatedEvent() *Event { +func NewSettingsUpdatedEvent() Event { return NewEvent(smsgateway.PushSettingsUpdated, nil) } diff --git a/internal/sms-gateway/modules/events/metrics.go b/internal/sms-gateway/modules/events/metrics.go index 13afdc84..e2734679 100644 --- a/internal/sms-gateway/modules/events/metrics.go +++ b/internal/sms-gateway/modules/events/metrics.go @@ -19,8 +19,11 @@ const ( DeliveryTypeSSE = "sse" DeliveryTypeUnknown = "unknown" - FailureReasonQueueFull = "queue_full" - FailureReasonProviderFailed = "provider_failed" + FailureReasonSerializationError = "serialization_error" + FailureReasonPublishError = "publish_error" + FailureReasonProviderFailed = "provider_failed" + + EventTypeUnknown = "unknown" ) // metrics contains all Prometheus metrics for the events module diff --git a/internal/sms-gateway/modules/events/module.go b/internal/sms-gateway/modules/events/module.go index 3b6ba7e8..8f6cf71a 100644 --- a/internal/sms-gateway/modules/events/module.go +++ b/internal/sms-gateway/modules/events/module.go @@ -14,11 +14,18 @@ var Module = fx.Module( }), fx.Provide(newMetrics, fx.Private), fx.Provide(NewService), - fx.Invoke(func(lc fx.Lifecycle, svc *Service) { + fx.Invoke(func(lc fx.Lifecycle, svc *Service, logger *zap.Logger, sh fx.Shutdowner) { ctx, cancel := context.WithCancel(context.Background()) lc.Append(fx.Hook{ OnStart: func(_ context.Context) error { - go svc.Run(ctx) + go func() { + if err := svc.Run(ctx); err != nil { + logger.Error("Error running events service", zap.Error(err)) + if err := sh.Shutdown(fx.ExitCode(1)); err != nil { + logger.Error("Failed to shutdown", zap.Error(err)) + } + } + }() return nil }, OnStop: func(_ context.Context) error { diff --git a/internal/sms-gateway/modules/events/service.go b/internal/sms-gateway/modules/events/service.go index 202b6bb6..384dbfa1 100644 --- a/internal/sms-gateway/modules/events/service.go +++ b/internal/sms-gateway/modules/events/service.go @@ -3,27 +3,33 @@ package events import ( "context" "fmt" + "time" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/sse" + "github.com/android-sms-gateway/server/internal/sms-gateway/pubsub" "go.uber.org/zap" ) +const ( + pubsubTopic = "events" +) + type Service struct { deviceSvc *devices.Service sseSvc *sse.Service pushSvc *push.Service - queue chan eventWrapper + pubsub pubsub.PubSub metrics *metrics logger *zap.Logger } -func NewService(devicesSvc *devices.Service, sseSvc *sse.Service, pushSvc *push.Service, metrics *metrics, logger *zap.Logger) *Service { +func NewService(devicesSvc *devices.Service, sseSvc *sse.Service, pushSvc *push.Service, pubsub pubsub.PubSub, metrics *metrics, logger *zap.Logger) *Service { return &Service{ deviceSvc: devicesSvc, sseSvc: sseSvc, @@ -31,44 +37,72 @@ func NewService(devicesSvc *devices.Service, sseSvc *sse.Service, pushSvc *push. metrics: metrics, - queue: make(chan eventWrapper, 128), + pubsub: pubsub, logger: logger, } } -func (s *Service) Notify(userID string, deviceID *string, event *Event) error { +func (s *Service) Notify(userID string, deviceID *string, event Event) error { + if event.EventType == "" { + return fmt.Errorf("event type is empty") + } + + subCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wrapper := eventWrapper{ UserID: userID, DeviceID: deviceID, Event: event, } - select { - case s.queue <- wrapper: - // Successfully enqueued - s.metrics.IncrementEnqueued(string(event.eventType)) - default: - s.metrics.IncrementFailed(string(event.eventType), DeliveryTypeUnknown, FailureReasonQueueFull) - return fmt.Errorf("event queue is full") + wrapperBytes, err := wrapper.serialize() + if err != nil { + s.metrics.IncrementFailed(string(event.EventType), DeliveryTypeUnknown, FailureReasonSerializationError) + return fmt.Errorf("can't serialize event wrapper: %w", err) + } + + if err := s.pubsub.Publish(subCtx, pubsubTopic, wrapperBytes); err != nil { + s.metrics.IncrementFailed(string(event.EventType), DeliveryTypeUnknown, FailureReasonPublishError) + return fmt.Errorf("can't publish event: %w", err) } + s.metrics.IncrementEnqueued(string(event.EventType)) + return nil } -func (s *Service) Run(ctx context.Context) { +func (s *Service) Run(ctx context.Context) error { + sub, err := s.pubsub.Subscribe(ctx, pubsubTopic) + if err != nil { + return fmt.Errorf("can't subscribe to pubsub: %w", err) + } + defer sub.Close() + + ch := sub.Receive() for { select { - case wrapper := <-s.queue: - s.processEvent(wrapper) case <-ctx.Done(): s.logger.Info("Event service stopped") - return + return nil + case msg, ok := <-ch: + if !ok { + s.logger.Info("Subscription closed") + return nil + } + wrapper := new(eventWrapper) + if err := wrapper.deserialize(msg.Data); err != nil { + s.metrics.IncrementFailed(EventTypeUnknown, DeliveryTypeUnknown, FailureReasonSerializationError) + s.logger.Error("Failed to deserialize event wrapper", zap.Error(err)) + continue + } + s.processEvent(wrapper) } } } -func (s *Service) processEvent(wrapper eventWrapper) { +func (s *Service) processEvent(wrapper *eventWrapper) { // Load devices from database filters := []devices.SelectFilter{} if wrapper.DeviceID != nil { @@ -91,26 +125,26 @@ func (s *Service) processEvent(wrapper eventWrapper) { if device.PushToken != nil && *device.PushToken != "" { // Device has push token, use push service if err := s.pushSvc.Enqueue(*device.PushToken, push.Event{ - Type: wrapper.Event.eventType, - Data: wrapper.Event.data, + Type: wrapper.Event.EventType, + Data: wrapper.Event.Data, }); err != nil { s.logger.Error("Failed to enqueue push notification", zap.String("user_id", wrapper.UserID), zap.String("device_id", device.ID), zap.Error(err)) - s.metrics.IncrementFailed(string(wrapper.Event.eventType), DeliveryTypePush, FailureReasonProviderFailed) + s.metrics.IncrementFailed(string(wrapper.Event.EventType), DeliveryTypePush, FailureReasonProviderFailed) } else { - s.metrics.IncrementSent(string(wrapper.Event.eventType), DeliveryTypePush) + s.metrics.IncrementSent(string(wrapper.Event.EventType), DeliveryTypePush) } continue } // No push token, use SSE service if err := s.sseSvc.Send(device.ID, sse.Event{ - Type: wrapper.Event.eventType, - Data: wrapper.Event.data, + Type: wrapper.Event.EventType, + Data: wrapper.Event.Data, }); err != nil { s.logger.Error("Failed to send SSE notification", zap.String("user_id", wrapper.UserID), zap.String("device_id", device.ID), zap.Error(err)) - s.metrics.IncrementFailed(string(wrapper.Event.eventType), DeliveryTypeSSE, FailureReasonProviderFailed) + s.metrics.IncrementFailed(string(wrapper.Event.EventType), DeliveryTypeSSE, FailureReasonProviderFailed) } else { - s.metrics.IncrementSent(string(wrapper.Event.eventType), DeliveryTypeSSE) + s.metrics.IncrementSent(string(wrapper.Event.EventType), DeliveryTypeSSE) } } } diff --git a/internal/sms-gateway/modules/events/types.go b/internal/sms-gateway/modules/events/types.go index 76755e17..76e4d89e 100644 --- a/internal/sms-gateway/modules/events/types.go +++ b/internal/sms-gateway/modules/events/types.go @@ -1,23 +1,33 @@ package events import ( + "encoding/json" + "github.com/android-sms-gateway/client-go/smsgateway" ) type Event struct { - eventType smsgateway.PushEventType - data map[string]string + EventType smsgateway.PushEventType `json:"event_type"` + Data map[string]string `json:"data"` } -func NewEvent(eventType smsgateway.PushEventType, data map[string]string) *Event { - return &Event{ - eventType: eventType, - data: data, +func NewEvent(eventType smsgateway.PushEventType, data map[string]string) Event { + return Event{ + EventType: eventType, + Data: data, } } type eventWrapper struct { - UserID string - DeviceID *string - Event *Event + UserID string `json:"user_id"` + DeviceID *string `json:"device_id,omitempty"` + Event Event `json:"event"` +} + +func (w *eventWrapper) serialize() ([]byte, error) { + return json.Marshal(w) +} + +func (w *eventWrapper) deserialize(data []byte) error { + return json.Unmarshal(data, w) } diff --git a/internal/sms-gateway/pubsub/config.go b/internal/sms-gateway/pubsub/config.go new file mode 100644 index 00000000..932ad789 --- /dev/null +++ b/internal/sms-gateway/pubsub/config.go @@ -0,0 +1,7 @@ +package pubsub + +// Config controls the PubSub backend via a URL (e.g., "memory://", "redis://..."). +type Config struct { + URL string + BufferSize uint +} diff --git a/internal/sms-gateway/pubsub/module.go b/internal/sms-gateway/pubsub/module.go new file mode 100644 index 00000000..4f3d0bca --- /dev/null +++ b/internal/sms-gateway/pubsub/module.go @@ -0,0 +1,29 @@ +package pubsub + +import ( + "context" + + "go.uber.org/fx" + "go.uber.org/zap" +) + +func Module() fx.Option { + return fx.Module( + "pubsub", + fx.Decorate(func(log *zap.Logger) *zap.Logger { + return log.Named("pubsub") + }), + fx.Provide(New), + fx.Invoke(func(ps PubSub, logger *zap.Logger, lc fx.Lifecycle) { + lc.Append(fx.Hook{ + OnStop: func(_ context.Context) error { + if err := ps.Close(); err != nil { + logger.Error("pubsub close failed", zap.Error(err)) + return err + } + return nil + }, + }) + }), + ) +} diff --git a/internal/sms-gateway/pubsub/pubsub.go b/internal/sms-gateway/pubsub/pubsub.go new file mode 100644 index 00000000..84ca2e9f --- /dev/null +++ b/internal/sms-gateway/pubsub/pubsub.go @@ -0,0 +1,41 @@ +package pubsub + +import ( + "fmt" + "net/url" + + "github.com/android-sms-gateway/server/pkg/pubsub" +) + +const ( + topicPrefix = "sms-gateway:" +) + +type PubSub = pubsub.PubSub + +func New(config Config) (PubSub, error) { + if config.URL == "" { + config.URL = "memory://" + } + + u, err := url.Parse(config.URL) + if err != nil { + return nil, fmt.Errorf("can't parse url: %w", err) + } + + opts := []pubsub.Option{} + opts = append(opts, pubsub.WithBufferSize(config.BufferSize)) + + switch u.Scheme { + case "memory": + return pubsub.NewMemory(opts...), nil + case "redis": + return pubsub.NewRedis(pubsub.RedisConfig{ + Client: nil, + URL: config.URL, + Prefix: topicPrefix, + }, opts...) + default: + return nil, fmt.Errorf("invalid scheme: %s", u.Scheme) + } +} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index da5eb703..e50e2898 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -33,4 +33,8 @@ type Cache interface { // The cache is cleared after the call. // The operation is safe for concurrent use. Drain(ctx context.Context) (map[string][]byte, error) + + // Close closes the cache. + // The operation is safe for concurrent use. + Close() error } diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go index 8088c5fb..b4d8d7cc 100644 --- a/pkg/cache/memory.go +++ b/pkg/cache/memory.go @@ -192,4 +192,8 @@ func (m *memoryCache) cleanup(cb func()) { m.mux.Unlock() } +func (m *memoryCache) Close() error { + return nil +} + var _ Cache = (*memoryCache)(nil) diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index 3cdc47a8..32e122d9 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -48,26 +48,59 @@ return value ` ) +// RedisConfig configures the Redis cache backend. +type RedisConfig struct { + // Client is the Redis client to use. + // If nil, a client is created from the URL. + Client *redis.Client + + // URL is the Redis URL to use. + // If empty, the Redis client is not created. + URL string + + // Prefix is the prefix to use for all keys in the Redis cache. + Prefix string + + // TTL is the time-to-live for all cache entries. + TTL time.Duration +} + type redisCache struct { - client *redis.Client + client *redis.Client + ownedClient bool key string ttl time.Duration } -func NewRedis(client *redis.Client, prefix string, ttl time.Duration) *redisCache { - if prefix != "" && !strings.HasSuffix(prefix, ":") { - prefix += ":" +func NewRedis(config RedisConfig) (*redisCache, error) { + if config.Prefix != "" && !strings.HasSuffix(config.Prefix, ":") { + config.Prefix += ":" } - return &redisCache{ - client: client, + if config.Client == nil && config.URL == "" { + return nil, fmt.Errorf("no redis client or url provided") + } - key: prefix + redisCacheKey, + client := config.Client + if client == nil { + opt, err := redis.ParseURL(config.URL) + if err != nil { + return nil, fmt.Errorf("failed to parse redis url: %w", err) + } - ttl: ttl, + client = redis.NewClient(opt) } + + return &redisCache{ + client: client, + ownedClient: config.Client == nil, + + key: config.Prefix + redisCacheKey, + + ttl: config.TTL, + }, nil } // Cleanup implements Cache. @@ -218,4 +251,12 @@ func (r *redisCache) SetOrFail(ctx context.Context, key string, value []byte, op return nil } +func (r *redisCache) Close() error { + if r.ownedClient { + return r.client.Close() + } + + return nil +} + var _ Cache = (*redisCache)(nil) diff --git a/pkg/pubsub/memory.go b/pkg/pubsub/memory.go new file mode 100644 index 00000000..04ce2ef3 --- /dev/null +++ b/pkg/pubsub/memory.go @@ -0,0 +1,155 @@ +package pubsub + +import ( + "context" + "sync" + + "github.com/google/uuid" +) + +type memoryPubSub struct { + bufferSize uint + + wg sync.WaitGroup + mu sync.RWMutex + topics map[string]map[string]subscriber + closeCh chan struct{} +} + +type subscriber struct { + ch chan Message + ctx context.Context +} + +func NewMemory(opts ...Option) *memoryPubSub { + o := options{ + bufferSize: 0, + } + o.apply(opts...) + + return &memoryPubSub{ + bufferSize: o.bufferSize, + + topics: make(map[string]map[string]subscriber), + closeCh: make(chan struct{}), + } +} + +// Publish sends a message to all subscribers of the given topic. +// This method blocks until all subscribers have received the message +// or until ctx is cancelled or the pubsub instance is closed. +func (m *memoryPubSub) Publish(ctx context.Context, topic string, data []byte) error { + select { + case <-m.closeCh: + return ErrPubSubClosed + default: + } + + if topic == "" { + return ErrInvalidTopic + } + + m.mu.RLock() + defer m.mu.RUnlock() + + subscribers, exists := m.topics[topic] + if !exists { + return nil + } + + wg := &sync.WaitGroup{} + msg := Message{Topic: topic, Data: data} + + for _, sub := range subscribers { + wg.Add(1) + go func(sub subscriber) { + defer wg.Done() + + select { + case sub.ch <- msg: + case <-ctx.Done(): + return + case <-m.closeCh: + return + case <-sub.ctx.Done(): + return + } + }(sub) + } + + wg.Wait() + + return nil +} + +func (m *memoryPubSub) Subscribe(ctx context.Context, topic string) (*Subscription, error) { + select { + case <-m.closeCh: + return nil, ErrPubSubClosed + default: + } + + if topic == "" { + return nil, ErrInvalidTopic + } + + id := uuid.NewString() + subCtx, cancel := context.WithCancel(ctx) + ch := make(chan Message, m.bufferSize) + + m.subscribe(id, topic, subscriber{ch: ch, ctx: subCtx}) + + m.wg.Add(1) + go func() { + select { + case <-subCtx.Done(): + case <-m.closeCh: + } + + cancel() + m.unsubscribe(id, topic) + close(ch) + + m.wg.Done() + }() + + return &Subscription{id: id, ctx: subCtx, cancel: cancel, ch: ch}, nil +} + +func (m *memoryPubSub) subscribe(id, topic string, sub subscriber) { + m.mu.Lock() + defer m.mu.Unlock() + + subscriptions, ok := m.topics[topic] + if !ok { + subscriptions = make(map[string]subscriber) + m.topics[topic] = subscriptions + } + subscriptions[id] = sub +} + +func (m *memoryPubSub) unsubscribe(id, topic string) { + m.mu.Lock() + defer m.mu.Unlock() + + subscriptions, ok := m.topics[topic] + if !ok { + return + } + delete(subscriptions, id) +} + +func (m *memoryPubSub) Close() error { + select { + case <-m.closeCh: + return nil + default: + } + close(m.closeCh) + + m.wg.Wait() + + return nil +} + +var _ PubSub = (*memoryPubSub)(nil) diff --git a/pkg/pubsub/options.go b/pkg/pubsub/options.go new file mode 100644 index 00000000..e62d1d70 --- /dev/null +++ b/pkg/pubsub/options.go @@ -0,0 +1,21 @@ +package pubsub + +type Option func(*options) + +type options struct { + bufferSize uint +} + +func (o *options) apply(opts ...Option) *options { + for _, opt := range opts { + opt(o) + } + + return o +} + +func WithBufferSize(bufferSize uint) Option { + return func(o *options) { + o.bufferSize = bufferSize + } +} diff --git a/pkg/pubsub/pubsub.go b/pkg/pubsub/pubsub.go new file mode 100644 index 00000000..c8954518 --- /dev/null +++ b/pkg/pubsub/pubsub.go @@ -0,0 +1,50 @@ +package pubsub + +import ( + "context" + "errors" +) + +var ( + ErrPubSubClosed = errors.New("pubsub is closed") + ErrInvalidTopic = errors.New("invalid topic name") +) + +type Message struct { + Topic string + Data []byte +} + +type Subscription struct { + id string + ch <-chan Message + ctx context.Context + cancel context.CancelFunc +} + +func (s *Subscription) Receive() <-chan Message { + return s.ch +} + +func (s *Subscription) Close() { + s.cancel() +} + +type Subscriber interface { + // Subscribe subscribes to a topic and returns a channel for receiving messages. + // The channel will be closed when the context is cancelled. + Subscribe(ctx context.Context, topic string) (*Subscription, error) +} + +type Publisher interface { + // Publish publishes a message to a topic. + // All subscribers to the topic will receive the message (fan-out). + Publish(ctx context.Context, topic string, data []byte) error +} + +type PubSub interface { + Publisher + Subscriber + // Close closes the pubsub instance and releases all resources. + Close() error +} diff --git a/pkg/pubsub/redis.go b/pkg/pubsub/redis.go new file mode 100644 index 00000000..4cd99d0b --- /dev/null +++ b/pkg/pubsub/redis.go @@ -0,0 +1,180 @@ +package pubsub + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + + "github.com/google/uuid" + "github.com/redis/go-redis/v9" +) + +// RedisConfig configures the Redis pubsub backend. +type RedisConfig struct { + // Client is the Redis client to use. + // If nil, a client is created from the URL. + // If both Client and URL are provided, Client takes precedence. + Client *redis.Client + + // URL is the Redis URL to use. + // If empty, the Redis client is not created. + URL string + + // Prefix is the prefix to use for all topics. + Prefix string +} + +type redisPubSub struct { + prefix string + bufferSize uint + + client *redis.Client + ownedClient bool + + wg sync.WaitGroup + mu sync.Mutex + subscribers map[string]context.CancelFunc + closeCh chan struct{} +} + +func NewRedis(config RedisConfig, opts ...Option) (*redisPubSub, error) { + if config.Prefix != "" && !strings.HasSuffix(config.Prefix, ":") { + config.Prefix += ":" + } + + if config.Client == nil && config.URL == "" { + return nil, fmt.Errorf("no redis client or url provided") + } + + client := config.Client + if client == nil { + opt, err := redis.ParseURL(config.URL) + if err != nil { + return nil, fmt.Errorf("failed to parse redis url: %w", err) + } + + client = redis.NewClient(opt) + } + + o := options{ + bufferSize: 0, + } + o.apply(opts...) + + return &redisPubSub{ + prefix: config.Prefix, + bufferSize: o.bufferSize, + + client: client, + ownedClient: config.Client == nil, + + subscribers: make(map[string]context.CancelFunc), + closeCh: make(chan struct{}), + }, nil +} + +func (r *redisPubSub) Publish(ctx context.Context, topic string, data []byte) error { + select { + case <-r.closeCh: + return ErrPubSubClosed + default: + } + + if topic == "" { + return ErrInvalidTopic + } + + return r.client.Publish(ctx, r.prefix+topic, data).Err() +} + +func (r *redisPubSub) Subscribe(ctx context.Context, topic string) (*Subscription, error) { + select { + case <-r.closeCh: + return nil, ErrPubSubClosed + default: + } + + if topic == "" { + return nil, ErrInvalidTopic + } + + ps := r.client.Subscribe(ctx, r.prefix+topic) + _, err := ps.Receive(ctx) + if err != nil { + closeErr := ps.Close() + return nil, errors.Join(fmt.Errorf("can't subscribe: %w", err), closeErr) + } + + id := uuid.NewString() + subCtx, cancel := context.WithCancel(ctx) + ch := make(chan Message, r.bufferSize) + + // Track this subscriber + r.mu.Lock() + r.subscribers[id] = cancel + r.mu.Unlock() + + r.wg.Add(1) + go func() { + defer func() { + _ = ps.Close() + close(ch) + + r.mu.Lock() + delete(r.subscribers, id) + r.mu.Unlock() + + r.wg.Done() + }() + + for { + select { + case <-r.closeCh: + return + case <-subCtx.Done(): + return + case msg, ok := <-ps.Channel(): + if !ok { + return + } + if msg == nil { + continue + } + + select { + case ch <- Message{ + Topic: topic, + Data: []byte(msg.Payload), + }: + case <-r.closeCh: + return + case <-subCtx.Done(): + return + } + } + } + }() + + return &Subscription{id: id, ctx: subCtx, cancel: cancel, ch: ch}, nil +} + +func (r *redisPubSub) Close() error { + select { + case <-r.closeCh: + return nil + default: + close(r.closeCh) + } + + r.wg.Wait() + + if r.ownedClient { + return r.client.Close() + } + + return nil +} + +var _ PubSub = (*redisPubSub)(nil)