diff --git a/internal/database/notification.go b/internal/database/notification.go index af738d0..9f5a425 100644 --- a/internal/database/notification.go +++ b/internal/database/notification.go @@ -25,10 +25,9 @@ type Notification struct { NotifyAfter time.Time `gorm:"type:timestamptz;not null;index"` NotifyBefore time.Time `gorm:"type:timestamptz;not null;index"` - IsNotified bool `gorm:"type:boolean;not null;default:false;index"` } -func (n *Notification) ToDomain(targetUserIDs []string) domain.Notification { +func (n *Notification) ToDomain(targetUsers []domain.NotificationTargetUser) domain.Notification { return domain.Notification{ ID: n.ID, Title: n.Title, @@ -45,8 +44,7 @@ func (n *Notification) ToDomain(targetUserIDs []string) domain.Notification { URL: n.URL, NotifyAfter: n.NotifyAfter, NotifyBefore: n.NotifyBefore, - IsNotified: n.IsNotified, - TargetUserIDs: targetUserIDs, + TargetUsers: targetUsers, } } @@ -67,6 +65,5 @@ func NotificationFromDomain(n domain.Notification) Notification { URL: n.URL, NotifyAfter: n.NotifyAfter, NotifyBefore: n.NotifyBefore, - IsNotified: n.IsNotified, } } diff --git a/internal/database/notification_target_user.go b/internal/database/notification_target_user.go index 1d4018c..1bdcdb2 100644 --- a/internal/database/notification_target_user.go +++ b/internal/database/notification_target_user.go @@ -1,8 +1,11 @@ package database +import "time" + type NotificationTargetUser struct { NotificationID string `gorm:"type:text;primaryKey"` UserID string `gorm:"type:text;primaryKey"` + NotifiedAt *time.Time `gorm:"type:timestamptz;index"` Notification Notification `gorm:"constraint:OnDelete:CASCADE"` User User `gorm:"constraint:OnDelete:CASCADE"` } diff --git a/internal/domain/notification.go b/internal/domain/notification.go index 9d2373c..d559deb 100644 --- a/internal/domain/notification.go +++ b/internal/domain/notification.go @@ -21,7 +21,11 @@ type Notification struct { NotifyAfter time.Time NotifyBefore time.Time - IsNotified bool - TargetUserIDs []string + TargetUsers []NotificationTargetUser +} + +type NotificationTargetUser struct { + UserID string + NotifiedAt *time.Time } diff --git a/internal/handler/converter.go b/internal/handler/converter.go index 82f2298..2e23ef8 100644 --- a/internal/handler/converter.go +++ b/internal/handler/converter.go @@ -82,9 +82,12 @@ func toDomainFCMToken(req api.FCMTokenRequest) domain.FCMToken { } func toAPINotification(n domain.Notification) api.Notification { - targetUsers := make([]api.NotificationTargetUser, 0, len(n.TargetUserIDs)) - for _, uid := range n.TargetUserIDs { - targetUsers = append(targetUsers, api.NotificationTargetUser{UserId: uid}) + targetUsers := make([]api.NotificationTargetUser, 0, len(n.TargetUsers)) + for _, t := range n.TargetUsers { + targetUsers = append(targetUsers, api.NotificationTargetUser{ + UserId: t.UserID, + NotifiedAt: t.NotifiedAt, + }) } return api.Notification{ Id: n.ID, @@ -115,6 +118,10 @@ func toAPINotifications(notifications []domain.Notification) []api.Notification } func toDomainNotification(id string, req api.NotificationRequest) domain.Notification { + targetUsers := make([]domain.NotificationTargetUser, 0, len(req.TargetUserIds)) + for _, uid := range req.TargetUserIds { + targetUsers = append(targetUsers, domain.NotificationTargetUser{UserID: uid}) + } return domain.Notification{ ID: id, Title: req.Title, @@ -131,7 +138,7 @@ func toDomainNotification(id string, req api.NotificationRequest) domain.Notific URL: req.Url, NotifyAfter: req.NotifyAfter, NotifyBefore: req.NotifyBefore, - TargetUserIDs: req.TargetUserIds, + TargetUsers: targetUsers, } } diff --git a/internal/repository/notification_create.go b/internal/repository/notification_create.go index 3d3c474..7d662ce 100644 --- a/internal/repository/notification_create.go +++ b/internal/repository/notification_create.go @@ -14,18 +14,20 @@ func (r *NotificationRepository) CreateNotification(ctx context.Context, notific dbNotification := database.NotificationFromDomain(notification) + uniqueTargets := uniqueTargetUsers(notification.TargetUsers) + err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Create(&dbNotification).Error; err != nil { return err } - uniqueIDs := uniqueStrings(notification.TargetUserIDs) - if len(uniqueIDs) > 0 { - targets := make([]database.NotificationTargetUser, 0, len(uniqueIDs)) - for _, userID := range uniqueIDs { + if len(uniqueTargets) > 0 { + targets := make([]database.NotificationTargetUser, 0, len(uniqueTargets)) + for _, t := range uniqueTargets { targets = append(targets, database.NotificationTargetUser{ NotificationID: notification.ID, - UserID: userID, + UserID: t.UserID, + NotifiedAt: t.NotifiedAt, }) } if err := tx.Create(&targets).Error; err != nil { @@ -39,5 +41,5 @@ func (r *NotificationRepository) CreateNotification(ctx context.Context, notific return domain.Notification{}, err } - return dbNotification.ToDomain(uniqueStrings(notification.TargetUserIDs)), nil + return dbNotification.ToDomain(uniqueTargets), nil } diff --git a/internal/repository/notification_dispatch.go b/internal/repository/notification_dispatch.go index ab77593..3fd0fba 100644 --- a/internal/repository/notification_dispatch.go +++ b/internal/repository/notification_dispatch.go @@ -2,6 +2,7 @@ package repository import ( "context" + "time" "github.com/fun-dotto/user-api/internal/database" "github.com/fun-dotto/user-api/internal/domain" @@ -31,9 +32,12 @@ func (r *NotificationRepository) GetNotificationsByIDs(ctx context.Context, ids return nil, err } - targetMap := make(map[string][]string) + targetMap := make(map[string][]domain.NotificationTargetUser) for _, t := range allTargets { - targetMap[t.NotificationID] = append(targetMap[t.NotificationID], t.UserID) + targetMap[t.NotificationID] = append(targetMap[t.NotificationID], domain.NotificationTargetUser{ + UserID: t.UserID, + NotifiedAt: t.NotifiedAt, + }) } notifications := make([]domain.Notification, 0, len(dbNotifications)) @@ -44,26 +48,32 @@ func (r *NotificationRepository) GetNotificationsByIDs(ctx context.Context, ids return notifications, nil } -func (r *NotificationRepository) DispatchNotifications(ctx context.Context, ids []string) ([]domain.Notification, error) { - uniqueIDs := uniqueStrings(ids) - if len(uniqueIDs) == 0 { +func (r *NotificationRepository) DispatchNotifications(ctx context.Context, deliveries map[string][]string) ([]domain.Notification, error) { + if len(deliveries) == 0 { return []domain.Notification{}, nil } - if err := r.db.WithContext(ctx).Model(&database.Notification{}). - Where("id IN ?", uniqueIDs). - Update("is_notified", true).Error; err != nil { - return nil, err - } - - notifications, err := r.GetNotificationsByIDs(ctx, uniqueIDs) - if err != nil { - return nil, err + now := time.Now() + notificationIDs := make([]string, 0, len(deliveries)) + for nid, userIDs := range deliveries { + uniqueUsers := uniqueStrings(userIDs) + if len(uniqueUsers) == 0 { + continue + } + db := r.db.WithContext(ctx).Model(&database.NotificationTargetUser{}). + Where("notification_id = ? AND user_id IN ?", nid, uniqueUsers). + Update("notified_at", now) + if db.Error != nil { + return nil, db.Error + } + if db.RowsAffected > 0 { + notificationIDs = append(notificationIDs, nid) + } } - for i := range notifications { - notifications[i].IsNotified = true + if len(notificationIDs) == 0 { + return []domain.Notification{}, nil } - return notifications, nil + return r.GetNotificationsByIDs(ctx, notificationIDs) } diff --git a/internal/repository/notification_list.go b/internal/repository/notification_list.go index 737c83c..495767f 100644 --- a/internal/repository/notification_list.go +++ b/internal/repository/notification_list.go @@ -17,7 +17,13 @@ func (r *NotificationRepository) ListNotifications(ctx context.Context, filter d query = query.Where("notify_after <= ?", *filter.NotifyAtTo) } if filter.IsNotified != nil { - query = query.Where("is_notified = ?", *filter.IsNotified) + if *filter.IsNotified { + query = query.Where(`NOT EXISTS (SELECT 1 FROM notification_target_users tu WHERE tu.notification_id = notifications.id AND tu.notified_at IS NULL) + AND EXISTS (SELECT 1 FROM notification_target_users tu WHERE tu.notification_id = notifications.id)`) + } else { + query = query.Where(`(EXISTS (SELECT 1 FROM notification_target_users tu WHERE tu.notification_id = notifications.id AND tu.notified_at IS NULL) + OR NOT EXISTS (SELECT 1 FROM notification_target_users tu WHERE tu.notification_id = notifications.id))`) + } } var dbNotifications []database.Notification @@ -39,9 +45,12 @@ func (r *NotificationRepository) ListNotifications(ctx context.Context, filter d return nil, err } - targetMap := make(map[string][]string) + targetMap := make(map[string][]domain.NotificationTargetUser) for _, t := range allTargets { - targetMap[t.NotificationID] = append(targetMap[t.NotificationID], t.UserID) + targetMap[t.NotificationID] = append(targetMap[t.NotificationID], domain.NotificationTargetUser{ + UserID: t.UserID, + NotifiedAt: t.NotifiedAt, + }) } notifications := make([]domain.Notification, 0, len(dbNotifications)) diff --git a/internal/repository/notification_update.go b/internal/repository/notification_update.go index 2f0435c..6273f17 100644 --- a/internal/repository/notification_update.go +++ b/internal/repository/notification_update.go @@ -3,14 +3,17 @@ package repository import ( "context" "errors" + "time" "github.com/fun-dotto/user-api/internal/database" "github.com/fun-dotto/user-api/internal/domain" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func (r *NotificationRepository) UpdateNotification(ctx context.Context, notification domain.Notification) (domain.Notification, error) { var dbNotification database.Notification + uniqueTargets := uniqueTargetUsers(notification.TargetUsers) err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var existing database.Notification @@ -23,21 +26,39 @@ func (r *NotificationRepository) UpdateNotification(ctx context.Context, notific dbNotification = database.NotificationFromDomain(notification) - if err := tx.Omit("IsNotified").Save(&dbNotification).Error; err != nil { + if err := tx.Save(&dbNotification).Error; err != nil { return err } + var existingTargets []database.NotificationTargetUser + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where("notification_id = ?", notification.ID). + Find(&existingTargets).Error; err != nil { + return err + } + existingNotifiedAt := make(map[string]*time.Time, len(existingTargets)) + for _, t := range existingTargets { + existingNotifiedAt[t.UserID] = t.NotifiedAt + } + if err := tx.Where("notification_id = ?", notification.ID).Delete(&database.NotificationTargetUser{}).Error; err != nil { return err } - uniqueIDs := uniqueStrings(notification.TargetUserIDs) - if len(uniqueIDs) > 0 { - targets := make([]database.NotificationTargetUser, 0, len(uniqueIDs)) - for _, userID := range uniqueIDs { + if len(uniqueTargets) > 0 { + targets := make([]database.NotificationTargetUser, 0, len(uniqueTargets)) + for i, t := range uniqueTargets { + notifiedAt := t.NotifiedAt + if notifiedAt == nil { + if prev, ok := existingNotifiedAt[t.UserID]; ok { + notifiedAt = prev + uniqueTargets[i].NotifiedAt = prev + } + } targets = append(targets, database.NotificationTargetUser{ NotificationID: notification.ID, - UserID: userID, + UserID: t.UserID, + NotifiedAt: notifiedAt, }) } if err := tx.Create(&targets).Error; err != nil { @@ -54,5 +75,5 @@ func (r *NotificationRepository) UpdateNotification(ctx context.Context, notific return domain.Notification{}, err } - return dbNotification.ToDomain(uniqueStrings(notification.TargetUserIDs)), nil + return dbNotification.ToDomain(uniqueTargets), nil } diff --git a/internal/repository/util.go b/internal/repository/util.go index 1bf7a8d..13cd60d 100644 --- a/internal/repository/util.go +++ b/internal/repository/util.go @@ -1,5 +1,7 @@ package repository +import "github.com/fun-dotto/user-api/internal/domain" + func uniqueStrings(s []string) []string { seen := make(map[string]struct{}, len(s)) result := make([]string, 0, len(s)) @@ -12,3 +14,16 @@ func uniqueStrings(s []string) []string { } return result } + +func uniqueTargetUsers(targets []domain.NotificationTargetUser) []domain.NotificationTargetUser { + seen := make(map[string]struct{}, len(targets)) + result := make([]domain.NotificationTargetUser, 0, len(targets)) + for _, t := range targets { + if _, ok := seen[t.UserID]; ok { + continue + } + seen[t.UserID] = struct{}{} + result = append(result, t) + } + return result +} diff --git a/internal/service/notification.go b/internal/service/notification.go index 73b6b40..4a63210 100644 --- a/internal/service/notification.go +++ b/internal/service/notification.go @@ -13,23 +13,27 @@ type NotificationRepository interface { UpdateNotification(ctx context.Context, notification domain.Notification) (domain.Notification, error) DeleteNotification(ctx context.Context, id string) error GetNotificationsByIDs(ctx context.Context, ids []string) ([]domain.Notification, error) - DispatchNotifications(ctx context.Context, ids []string) ([]domain.Notification, error) + DispatchNotifications(ctx context.Context, deliveries map[string][]string) ([]domain.Notification, error) } type FCMTokenRepositoryForNotification interface { ListFCMTokens(ctx context.Context, filter domain.FCMTokenListFilter) ([]domain.FCMToken, error) } +type MessagingClient interface { + SendEachForMulticast(ctx context.Context, message *messaging.MulticastMessage) (*messaging.BatchResponse, error) +} + type NotificationService struct { - repo NotificationRepository - fcmTokenRepo FCMTokenRepositoryForNotification - messagingClient *messaging.Client + repo NotificationRepository + fcmTokenRepo FCMTokenRepositoryForNotification + messagingClient MessagingClient } func NewNotificationService( repo NotificationRepository, fcmTokenRepo FCMTokenRepositoryForNotification, - messagingClient *messaging.Client, + messagingClient MessagingClient, ) *NotificationService { return &NotificationService{ repo: repo, diff --git a/internal/service/notification_dispatch.go b/internal/service/notification_dispatch.go index 78f8a68..cc6536a 100644 --- a/internal/service/notification_dispatch.go +++ b/internal/service/notification_dispatch.go @@ -20,8 +20,8 @@ func (s *NotificationService) DispatchNotifications(ctx context.Context, ids []s userIDSet := make(map[string]struct{}) for _, n := range notifications { - for _, uid := range n.TargetUserIDs { - userIDSet[uid] = struct{}{} + for _, t := range n.TargetUsers { + userIDSet[t.UserID] = struct{}{} } } @@ -40,34 +40,57 @@ func (s *NotificationService) DispatchNotifications(ctx context.Context, ids []s } } - successIDs := make([]string, 0, len(notifications)) + deliveries := make(map[string][]string, len(notifications)) for _, n := range notifications { - tokens := collectTokens(n.TargetUserIDs, tokensByUser) + pendingUserIDs := make([]string, 0, len(n.TargetUsers)) + for _, t := range n.TargetUsers { + pendingUserIDs = append(pendingUserIDs, t.UserID) + } + if len(pendingUserIDs) == 0 { + continue + } + + tokens, tokenUserIDs := collectTokens(pendingUserIDs, tokensByUser) if len(tokens) == 0 { - successIDs = append(successIDs, n.ID) + deliveries[n.ID] = pendingUserIDs continue } - sent, err := s.sendToTokens(ctx, n, tokens) + successUserIDs, err := s.sendToTokens(ctx, n, tokens, tokenUserIDs) if err != nil { - log.Printf("FCM send failed for notification %s: %v", n.ID, err) - continue + log.Printf("FCM send partially failed for notification %s (success=%d/%d users): %v", n.ID, len(successUserIDs), len(pendingUserIDs), err) + } + + successSet := make(map[string]struct{}, len(successUserIDs)) + for _, uid := range successUserIDs { + successSet[uid] = struct{}{} + } + delivered := make([]string, 0, len(pendingUserIDs)) + for _, uid := range pendingUserIDs { + if _, ok := successSet[uid]; ok { + delivered = append(delivered, uid) + continue + } + if _, hasToken := tokensByUser[uid]; !hasToken { + delivered = append(delivered, uid) + } } - if sent > 0 { - successIDs = append(successIDs, n.ID) + if len(delivered) > 0 { + deliveries[n.ID] = delivered } } - if len(successIDs) == 0 { + if len(deliveries) == 0 { return []domain.Notification{}, nil } - return s.repo.DispatchNotifications(ctx, successIDs) + return s.repo.DispatchNotifications(ctx, deliveries) } -func collectTokens(userIDs []string, tokensByUser map[string][]string) []string { +func collectTokens(userIDs []string, tokensByUser map[string][]string) ([]string, []string) { seen := make(map[string]struct{}) tokens := make([]string, 0) + tokenUserIDs := make([]string, 0) for _, uid := range userIDs { for _, tk := range tokensByUser[uid] { if _, ok := seen[tk]; ok { @@ -75,14 +98,15 @@ func collectTokens(userIDs []string, tokensByUser map[string][]string) []string } seen[tk] = struct{}{} tokens = append(tokens, tk) + tokenUserIDs = append(tokenUserIDs, uid) } } - return tokens + return tokens, tokenUserIDs } const fcmMulticastBatchSize = 500 -func (s *NotificationService) sendToTokens(ctx context.Context, n domain.Notification, tokens []string) (int, error) { +func (s *NotificationService) sendToTokens(ctx context.Context, n domain.Notification, tokens []string, tokenUserIDs []string) ([]string, error) { data := map[string]string{"notification_id": n.ID} if n.URL != nil { data["url"] = *n.URL @@ -105,7 +129,7 @@ func (s *NotificationService) sendToTokens(ctx context.Context, n domain.Notific apnsConfig := buildAPNSConfig(n) webpushConfig := buildWebpushConfig(n) - totalSuccess := 0 + successUserSet := make(map[string]struct{}) for start := 0; start < len(tokens); start += fcmMulticastBatchSize { end := min(start+fcmMulticastBatchSize, len(tokens)) msg := &messaging.MulticastMessage{ @@ -119,18 +143,37 @@ func (s *NotificationService) sendToTokens(ctx context.Context, n domain.Notific } resp, err := s.messagingClient.SendEachForMulticast(ctx, msg) if err != nil { - return totalSuccess, err - } - totalSuccess += resp.SuccessCount - if resp.FailureCount > 0 { - for i, r := range resp.Responses { - if r.Error != nil { - log.Printf("FCM delivery failed for notification %s token=%s: %v", n.ID, tokens[start+i], r.Error) - } + return collectSuccessUserIDs(tokenUserIDs, successUserSet), err + } + for i, r := range resp.Responses { + uid := tokenUserIDs[start+i] + if r.Error != nil { + log.Printf("FCM delivery failed for notification %s token=%s: %v", n.ID, tokens[start+i], r.Error) + continue } + successUserSet[uid] = struct{}{} + } + } + return collectSuccessUserIDs(tokenUserIDs, successUserSet), nil +} + +func collectSuccessUserIDs(tokenUserIDs []string, successUserSet map[string]struct{}) []string { + if len(successUserSet) == 0 { + return nil + } + seen := make(map[string]struct{}, len(successUserSet)) + result := make([]string, 0, len(successUserSet)) + for _, uid := range tokenUserIDs { + if _, ok := successUserSet[uid]; !ok { + continue + } + if _, dup := seen[uid]; dup { + continue } + seen[uid] = struct{}{} + result = append(result, uid) } - return totalSuccess, nil + return result } func buildAndroidConfig(n domain.Notification) *messaging.AndroidConfig { diff --git a/internal/service/notification_dispatch_test.go b/internal/service/notification_dispatch_test.go new file mode 100644 index 0000000..00bfda7 --- /dev/null +++ b/internal/service/notification_dispatch_test.go @@ -0,0 +1,340 @@ +package service + +import ( + "context" + "errors" + "sort" + "testing" + "time" + + "firebase.google.com/go/v4/messaging" + "github.com/fun-dotto/user-api/internal/domain" +) + +type stubNotificationRepo struct { + getByIDs func(ctx context.Context, ids []string) ([]domain.Notification, error) + dispatch func(ctx context.Context, deliveries map[string][]string) ([]domain.Notification, error) + dispatchCalled bool + lastDeliveriesArg map[string][]string + getByIDsCalledWith []string +} + +func (s *stubNotificationRepo) ListNotifications(context.Context, domain.NotificationListFilter) ([]domain.Notification, error) { + return nil, errors.New("not implemented") +} +func (s *stubNotificationRepo) CreateNotification(context.Context, domain.Notification) (domain.Notification, error) { + return domain.Notification{}, errors.New("not implemented") +} +func (s *stubNotificationRepo) UpdateNotification(context.Context, domain.Notification) (domain.Notification, error) { + return domain.Notification{}, errors.New("not implemented") +} +func (s *stubNotificationRepo) DeleteNotification(context.Context, string) error { + return errors.New("not implemented") +} +func (s *stubNotificationRepo) GetNotificationsByIDs(ctx context.Context, ids []string) ([]domain.Notification, error) { + s.getByIDsCalledWith = ids + if s.getByIDs == nil { + return nil, nil + } + return s.getByIDs(ctx, ids) +} +func (s *stubNotificationRepo) DispatchNotifications(ctx context.Context, deliveries map[string][]string) ([]domain.Notification, error) { + s.dispatchCalled = true + s.lastDeliveriesArg = deliveries + if s.dispatch == nil { + return nil, nil + } + return s.dispatch(ctx, deliveries) +} + +type stubFCMTokenRepo struct { + list func(ctx context.Context, filter domain.FCMTokenListFilter) ([]domain.FCMToken, error) + called bool + lastFilter domain.FCMTokenListFilter +} + +func (s *stubFCMTokenRepo) ListFCMTokens(ctx context.Context, filter domain.FCMTokenListFilter) ([]domain.FCMToken, error) { + s.called = true + s.lastFilter = filter + if s.list == nil { + return nil, nil + } + return s.list(ctx, filter) +} + +type stubMessagingClient struct { + send func(ctx context.Context, msg *messaging.MulticastMessage) (*messaging.BatchResponse, error) + calls int + tokenCalls [][]string +} + +func (s *stubMessagingClient) SendEachForMulticast(ctx context.Context, msg *messaging.MulticastMessage) (*messaging.BatchResponse, error) { + s.calls++ + tokens := append([]string{}, msg.Tokens...) + s.tokenCalls = append(s.tokenCalls, tokens) + if s.send == nil { + responses := make([]*messaging.SendResponse, len(msg.Tokens)) + for i := range responses { + responses[i] = &messaging.SendResponse{Success: true, MessageID: "msg"} + } + return &messaging.BatchResponse{SuccessCount: len(responses), Responses: responses}, nil + } + return s.send(ctx, msg) +} + +func newServiceWithStubs(repo *stubNotificationRepo, tokenRepo *stubFCMTokenRepo, msg *stubMessagingClient) *NotificationService { + return &NotificationService{repo: repo, fcmTokenRepo: tokenRepo, messagingClient: msg} +} + +func notifiedAt(t time.Time) *time.Time { return &t } + +func TestDispatchNotifications_NoNotifications(t *testing.T) { + repo := &stubNotificationRepo{getByIDs: func(context.Context, []string) ([]domain.Notification, error) { + return []domain.Notification{}, nil + }} + tokenRepo := &stubFCMTokenRepo{} + msg := &stubMessagingClient{} + svc := newServiceWithStubs(repo, tokenRepo, msg) + + got, err := svc.DispatchNotifications(context.Background(), []string{"n1"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 0 { + t.Errorf("expected empty result, got %d", len(got)) + } + if tokenRepo.called { + t.Errorf("expected ListFCMTokens not called") + } + if msg.calls != 0 { + t.Errorf("expected no FCM calls, got %d", msg.calls) + } + if repo.dispatchCalled { + t.Errorf("expected DispatchNotifications not called") + } +} + +func TestDispatchNotifications_GetByIDsError(t *testing.T) { + wantErr := errors.New("db down") + repo := &stubNotificationRepo{getByIDs: func(context.Context, []string) ([]domain.Notification, error) { + return nil, wantErr + }} + svc := newServiceWithStubs(repo, &stubFCMTokenRepo{}, &stubMessagingClient{}) + + if _, err := svc.DispatchNotifications(context.Background(), []string{"n1"}); !errors.Is(err, wantErr) { + t.Errorf("expected error %v, got %v", wantErr, err) + } +} + +func TestDispatchNotifications_NoTargetUsers(t *testing.T) { + repo := &stubNotificationRepo{getByIDs: func(context.Context, []string) ([]domain.Notification, error) { + return []domain.Notification{{ID: "n1", Title: "t", Body: "b"}}, nil + }} + tokenRepo := &stubFCMTokenRepo{} + msg := &stubMessagingClient{} + svc := newServiceWithStubs(repo, tokenRepo, msg) + + got, err := svc.DispatchNotifications(context.Background(), []string{"n1"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 0 { + t.Errorf("expected empty result, got %d", len(got)) + } + if tokenRepo.called { + t.Errorf("expected ListFCMTokens not called when no target users") + } + if repo.dispatchCalled { + t.Errorf("expected DispatchNotifications not called") + } +} + +func TestDispatchNotifications_NoTokens_StillRecordsDelivery(t *testing.T) { + notification := domain.Notification{ + ID: "n1", + Title: "t", + Body: "b", + TargetUsers: []domain.NotificationTargetUser{ + {UserID: "u1"}, + {UserID: "u2", NotifiedAt: notifiedAt(time.Now())}, + }, + } + repo := &stubNotificationRepo{ + getByIDs: func(context.Context, []string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }, + dispatch: func(_ context.Context, _ map[string][]string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }, + } + tokenRepo := &stubFCMTokenRepo{list: func(context.Context, domain.FCMTokenListFilter) ([]domain.FCMToken, error) { + return []domain.FCMToken{}, nil + }} + msg := &stubMessagingClient{} + svc := newServiceWithStubs(repo, tokenRepo, msg) + + got, err := svc.DispatchNotifications(context.Background(), []string{"n1"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 1 { + t.Fatalf("expected 1 notification, got %d", len(got)) + } + if msg.calls != 0 { + t.Errorf("expected no FCM calls when no tokens, got %d", msg.calls) + } + gotUsers := repo.lastDeliveriesArg["n1"] + sort.Strings(gotUsers) + if len(gotUsers) != 2 || gotUsers[0] != "u1" || gotUsers[1] != "u2" { + t.Errorf("expected both u1 and u2 in deliveries (force re-send), got %v", gotUsers) + } +} + +func TestDispatchNotifications_AllTokensSucceed(t *testing.T) { + notification := domain.Notification{ + ID: "n1", + Title: "t", + Body: "b", + TargetUsers: []domain.NotificationTargetUser{ + {UserID: "u1"}, + {UserID: "u2", NotifiedAt: notifiedAt(time.Now())}, + }, + } + repo := &stubNotificationRepo{ + getByIDs: func(context.Context, []string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }, + dispatch: func(_ context.Context, _ map[string][]string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }, + } + tokenRepo := &stubFCMTokenRepo{list: func(context.Context, domain.FCMTokenListFilter) ([]domain.FCMToken, error) { + return []domain.FCMToken{ + {UserID: "u1", Token: "tok-u1"}, + {UserID: "u2", Token: "tok-u2"}, + }, nil + }} + msg := &stubMessagingClient{} + svc := newServiceWithStubs(repo, tokenRepo, msg) + + if _, err := svc.DispatchNotifications(context.Background(), []string{"n1"}); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if msg.calls != 1 { + t.Errorf("expected 1 FCM batch call, got %d", msg.calls) + } + gotUsers := repo.lastDeliveriesArg["n1"] + sort.Strings(gotUsers) + if len(gotUsers) != 2 || gotUsers[0] != "u1" || gotUsers[1] != "u2" { + t.Errorf("expected both users delivered, got %v", gotUsers) + } +} + +func TestDispatchNotifications_PartialFCMFailure(t *testing.T) { + notification := domain.Notification{ + ID: "n1", + Title: "t", + Body: "b", + TargetUsers: []domain.NotificationTargetUser{ + {UserID: "u1"}, + {UserID: "u2"}, + {UserID: "u3"}, + }, + } + repo := &stubNotificationRepo{ + getByIDs: func(context.Context, []string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }, + dispatch: func(_ context.Context, _ map[string][]string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }, + } + tokenRepo := &stubFCMTokenRepo{list: func(context.Context, domain.FCMTokenListFilter) ([]domain.FCMToken, error) { + return []domain.FCMToken{ + {UserID: "u1", Token: "tok-u1"}, + {UserID: "u2", Token: "tok-u2"}, + }, nil + }} + msg := &stubMessagingClient{send: func(_ context.Context, m *messaging.MulticastMessage) (*messaging.BatchResponse, error) { + responses := make([]*messaging.SendResponse, len(m.Tokens)) + for i, tk := range m.Tokens { + if tk == "tok-u2" { + responses[i] = &messaging.SendResponse{Error: errors.New("invalid token")} + } else { + responses[i] = &messaging.SendResponse{Success: true, MessageID: "msg"} + } + } + return &messaging.BatchResponse{Responses: responses}, nil + }} + svc := newServiceWithStubs(repo, tokenRepo, msg) + + if _, err := svc.DispatchNotifications(context.Background(), []string{"n1"}); err != nil { + t.Fatalf("unexpected error: %v", err) + } + gotUsers := repo.lastDeliveriesArg["n1"] + sort.Strings(gotUsers) + want := []string{"u1", "u3"} + if len(gotUsers) != len(want) || gotUsers[0] != want[0] || gotUsers[1] != want[1] { + t.Errorf("expected delivered=%v (u1 succeeded; u3 had no token), got %v", want, gotUsers) + } +} + +func TestDispatchNotifications_FCMBatchError_KeepsPartialSuccess(t *testing.T) { + notification := domain.Notification{ + ID: "n1", + Title: "t", + Body: "b", + TargetUsers: []domain.NotificationTargetUser{ + {UserID: "u1"}, + }, + } + repo := &stubNotificationRepo{ + getByIDs: func(context.Context, []string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }, + dispatch: func(_ context.Context, _ map[string][]string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }, + } + tokenRepo := &stubFCMTokenRepo{list: func(context.Context, domain.FCMTokenListFilter) ([]domain.FCMToken, error) { + return []domain.FCMToken{{UserID: "u1", Token: "tok-u1"}}, nil + }} + msg := &stubMessagingClient{send: func(context.Context, *messaging.MulticastMessage) (*messaging.BatchResponse, error) { + return nil, errors.New("fcm down") + }} + svc := newServiceWithStubs(repo, tokenRepo, msg) + + got, err := svc.DispatchNotifications(context.Background(), []string{"n1"}) + if err != nil { + t.Fatalf("expected nil error (FCM error is logged, not returned), got %v", err) + } + if len(got) != 0 { + t.Errorf("expected empty result when no users delivered, got %d", len(got)) + } + if repo.dispatchCalled { + t.Errorf("expected DispatchNotifications not called when no successful deliveries") + } +} + +func TestDispatchNotifications_FCMTokenListError(t *testing.T) { + notification := domain.Notification{ + ID: "n1", + Title: "t", + Body: "b", + TargetUsers: []domain.NotificationTargetUser{ + {UserID: "u1"}, + }, + } + repo := &stubNotificationRepo{getByIDs: func(context.Context, []string) ([]domain.Notification, error) { + return []domain.Notification{notification}, nil + }} + wantErr := errors.New("token db down") + tokenRepo := &stubFCMTokenRepo{list: func(context.Context, domain.FCMTokenListFilter) ([]domain.FCMToken, error) { + return nil, wantErr + }} + svc := newServiceWithStubs(repo, tokenRepo, &stubMessagingClient{}) + + if _, err := svc.DispatchNotifications(context.Background(), []string{"n1"}); !errors.Is(err, wantErr) { + t.Errorf("expected error %v, got %v", wantErr, err) + } +}