Skip to content
18 changes: 15 additions & 3 deletions backend/internal/handler/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"log/slog"
"time"

"github.com/generate/selfserve/internal/errs"
"github.com/generate/selfserve/internal/httpx"
Expand All @@ -12,7 +13,7 @@ import (
)

type NotificationsRepository interface {
FindByUserID(ctx context.Context, userID string) ([]*models.Notification, error)
FindByUserID(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error)
MarkRead(ctx context.Context, id, userID string) error
MarkAllRead(ctx context.Context, userID string) error
UpsertDeviceToken(ctx context.Context, userID, token, platform string) error
Expand All @@ -28,17 +29,28 @@ func NewNotificationsHandler(repo NotificationsRepository) *NotificationsHandler

// ListNotifications godoc
// @Summary List notifications
// @Description Returns the most recent notifications for the authenticated user
// @Description Returns the most recent notifications for the authenticated user, paginated by cursor
// @Tags notifications
// @Produce json
// @Param before query string false "Cursor: return notifications created before this RFC3339 timestamp"
// @Success 200 {array} models.Notification
// @Failure 400 {object} errs.HTTPError
// @Failure 500 {object} errs.HTTPError
// @Security BearerAuth
// @Router /notifications [get]
func (h *NotificationsHandler) ListNotifications(c *fiber.Ctx) error {
userID := c.Locals("userId").(string)

notifications, err := h.repo.FindByUserID(c.Context(), userID)
var before *time.Time
if raw := c.Query("before"); raw != "" {
t, err := time.Parse(time.RFC3339Nano, raw)
if err != nil {
return errs.BadRequest("before must be a valid RFC3339 timestamp")
}
before = &t
}

notifications, err := h.repo.FindByUserID(c.Context(), userID, before)
if err != nil {
slog.Error("failed to list notifications", "err", err)
return errs.InternalServerError()
Expand Down
14 changes: 7 additions & 7 deletions backend/internal/handler/notifications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ import (
const testUserID = "user_test_123"

type mockNotificationsRepository struct {
findByUserIDFunc func(ctx context.Context, userID string) ([]*models.Notification, error)
findByUserIDFunc func(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error)
markReadFunc func(ctx context.Context, id, userID string) error
markAllReadFunc func(ctx context.Context, userID string) error
upsertDeviceTokenFunc func(ctx context.Context, userID, token, platform string) error
}

func (m *mockNotificationsRepository) FindByUserID(ctx context.Context, userID string) ([]*models.Notification, error) {
func (m *mockNotificationsRepository) FindByUserID(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error) {
if m.findByUserIDFunc != nil {
return m.findByUserIDFunc(ctx, userID)
return m.findByUserIDFunc(ctx, userID, before)
}
return nil, nil
}
Expand Down Expand Up @@ -76,7 +76,7 @@ func TestNotificationsHandler_ListNotifications(t *testing.T) {

readAt := time.Now()
mock := &mockNotificationsRepository{
findByUserIDFunc: func(ctx context.Context, userID string) ([]*models.Notification, error) {
findByUserIDFunc: func(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error) {
return []*models.Notification{
{
ID: "notif-1",
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestNotificationsHandler_ListNotifications(t *testing.T) {
t.Parallel()

mock := &mockNotificationsRepository{
findByUserIDFunc: func(ctx context.Context, userID string) ([]*models.Notification, error) {
findByUserIDFunc: func(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error) {
return nil, nil
},
}
Expand All @@ -124,7 +124,7 @@ func TestNotificationsHandler_ListNotifications(t *testing.T) {

var capturedUserID string
mock := &mockNotificationsRepository{
findByUserIDFunc: func(ctx context.Context, userID string) ([]*models.Notification, error) {
findByUserIDFunc: func(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error) {
capturedUserID = userID
return nil, nil
},
Expand All @@ -142,7 +142,7 @@ func TestNotificationsHandler_ListNotifications(t *testing.T) {
t.Parallel()

mock := &mockNotificationsRepository{
findByUserIDFunc: func(ctx context.Context, userID string) ([]*models.Notification, error) {
findByUserIDFunc: func(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error) {
return nil, errors.New("db error")
},
}
Expand Down
30 changes: 22 additions & 8 deletions backend/internal/repository/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package repository
import (
"context"
"encoding/json"
"time"

"github.com/generate/selfserve/internal/errs"
"github.com/generate/selfserve/internal/models"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)

Expand Down Expand Up @@ -37,14 +39,26 @@ func (r *NotificationsRepository) InsertNotification(ctx context.Context, userID
return n, nil
}

func (r *NotificationsRepository) FindByUserID(ctx context.Context, userID string) ([]*models.Notification, error) {
rows, err := r.db.Query(ctx, `
SELECT id, user_id, type, title, body, data, read_at, created_at
FROM public.notifications
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT 50
`, userID)
func (r *NotificationsRepository) FindByUserID(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error) {
var rows pgx.Rows
var err error
if before != nil {
rows, err = r.db.Query(ctx, `
SELECT id, user_id, type, title, body, data, read_at, created_at
FROM public.notifications
WHERE user_id = $1 AND created_at < $2
ORDER BY created_at DESC
LIMIT 50
`, userID, before)
} else {
rows, err = r.db.Query(ctx, `
SELECT id, user_id, type, title, body, data, read_at, created_at
FROM public.notifications
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT 50
`, userID)
}
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion backend/internal/service/storage/postgres/repo_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

type NotificationsRepository interface {
InsertNotification(ctx context.Context, userID string, notifType models.NotificationType, title, body string) (*models.Notification, error)
FindByUserID(ctx context.Context, userID string) ([]*models.Notification, error)
FindByUserID(ctx context.Context, userID string, before *time.Time) ([]*models.Notification, error)
MarkRead(ctx context.Context, id, userID string) error
MarkAllRead(ctx context.Context, userID string) error
UpsertDeviceToken(ctx context.Context, userID, token, platform string) error
Expand Down
1 change: 1 addition & 0 deletions clients/mobile/app/_layout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ function AppLayout() {
<ThemeProvider value={colorScheme === "dark" ? DarkTheme : DefaultTheme}>
<Stack>
<Stack.Screen name="(tabs)" options={{ headerShown: false }} />
<Stack.Screen name="notifications" options={{ headerShown: false }} />
<Stack.Screen name="create-task-ai" options={{ headerShown: false }} />
<Stack.Screen
name="create-task-manual"
Expand Down
76 changes: 76 additions & 0 deletions clients/mobile/app/notifications.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { useEffect, useMemo, useRef } from "react";
import { ActivityIndicator, FlatList, Text, View } from "react-native";
import { SafeAreaView } from "react-native-safe-area-context";

import {
useGetNotifications,
useMarkAllNotificationsRead,
} from "@shared/api/notifications";
import type { Notification } from "@shared/types/notifications";
import { NotificationItem } from "@/components/notifications/notification-item";
import { ScreenHeader } from "@/components/ui/screen-header";

export default function NotificationsScreen() {
const { data, isLoading, isFetchingNextPage, hasNextPage, fetchNextPage } =
useGetNotifications();
const { mutate: markAllRead } = useMarkAllNotificationsRead();

const notifications = useMemo(() => data?.pages.flat() ?? [], [data]);

// Snapshot which IDs were unread when the screen first loaded so dots remain
// visible while user is reading — markAllRead fires immediately in the bg.
const initialUnreadIds = useRef<Set<string> | null>(null);

useEffect(() => {
if (notifications.length > 0 && initialUnreadIds.current === null) {
initialUnreadIds.current = new Set(
notifications.filter((n) => !n.read_at).map((n) => n.id),
);
markAllRead();
}
}, [notifications, markAllRead]);

return (
<SafeAreaView className="flex-1 bg-bg-surface" edges={["top"]}>
<ScreenHeader title="Notifications" />

{isLoading ? (
<View className="flex-1 items-center justify-center">
<ActivityIndicator />
</View>
) : (
<FlatList<Notification>
data={notifications}
keyExtractor={(item) => item.id}
renderItem={({ item }) => (
<NotificationItem
notification={item}
showUnreadDot={
initialUnreadIds.current?.has(item.id) ?? !item.read_at
}
/>
)}
onEndReached={() => {
if (hasNextPage && !isFetchingNextPage) fetchNextPage();
}}
onEndReachedThreshold={0.3}
ListFooterComponent={
isFetchingNextPage ? (
<View className="py-4 items-center">
<ActivityIndicator />
</View>
) : null
}
showsVerticalScrollIndicator={false}
ListEmptyComponent={
<View className="pt-20 items-center">
<Text className="text-[15px] text-text-subtle">
No notifications
</Text>
</View>
}
/>
)}
</SafeAreaView>
);
}
Loading
Loading