Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions api/v1_notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@ import (
"slices"
"strings"

"api.audius.co/api/dbv1"
"api.audius.co/trashid"
"github.com/gofiber/fiber/v2"
"github.com/jackc/pgx/v5"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)

// Per-group cap on how many actions we mine for actor user IDs. Notification
// groups can fan out (e.g. one row representing 100 followers); the client
// only renders one avatar per group, so a single actor profile is enough.
// Target entity IDs (the followee, the reposted track, etc.) are duplicated
// across every action in a group, so reading just the first action still
// surfaces every target — only the actor list is bounded by this cap.
const notificationRelatedActorsPerGroup = 1

type GetNotificationsQueryParams struct {
// Note that when limit is 0, we return 20 items to calculate unread count
Limit int `query:"limit" default:"20" validate:"min=0,max=100"`
Expand Down Expand Up @@ -239,6 +248,10 @@ limit @limit::int
return err
}

userIds := []int32{}
trackIds := []int32{}
playlistIds := []int32{}

unreadCount := 0
for _, notif := range notifs {

Expand All @@ -248,6 +261,16 @@ limit @limit::int
return strings.Compare(specA, specB)
})

// Mine related entity IDs from the first N actions of each group. This
// must happen BEFORE HashifyJson re-encodes ints as opaque strings.
mineLimit := len(notif.Actions)
if mineLimit > notificationRelatedActorsPerGroup {
mineLimit = notificationRelatedActorsPerGroup
}
for _, action := range notif.Actions[:mineLimit] {
collectNotificationRelatedIds(action, &userIds, &trackIds, &playlistIds)
}

// each row from notification table has `actions`
// which is a jsonb field that is an array of objects.
// we need to hash encode all id fields (HashifyJson)
Expand Down Expand Up @@ -306,11 +329,111 @@ limit @limit::int
}
}

related, err := app.queries.Parallel(c.Context(), dbv1.ParallelParams{
UserIds: userIds,
TrackIds: trackIds,
PlaylistIds: playlistIds,
MyID: app.getMyId(c),
AuthedWallet: app.tryGetAuthedWallet(c),
IncludeUnlisted: true,
})
if err != nil {
return err
}

return c.JSON(fiber.Map{
"data": fiber.Map{
"notifications": notifs,
"unread_count": unreadCount,
},
"related": fiber.Map{
"users": related.UserList(),
"tracks": related.TrackList(),
"playlists": related.PlaylistList(),
},
})

}

// collectNotificationRelatedIds extracts user/track/playlist IDs from a single
// raw (pre-hashify) notification action's data so the caller can batch-load
// the related entities in one shot. Field names mirror the Python
// extend_notification.py mapping; *_item_id and content_id fields are
// polymorphic and disambiguated by the sibling type field.
func collectNotificationRelatedIds(action json.RawMessage, userIds, trackIds, playlistIds *[]int32) {
appendInt := func(target *[]int32, val gjson.Result) {
if val.Exists() && val.Type == gjson.Number {
*target = append(*target, int32(val.Int()))
}
}

for _, path := range []string{
"data.user_id",
"data.follower_user_id",
"data.followee_user_id",
"data.comment_user_id",
"data.entity_user_id",
"data.reacter_user_id",
"data.sender_user_id",
"data.receiver_user_id",
"data.dethroned_user_id",
"data.grantee_user_id",
"data.tastemaker_user_id",
"data.tastemaker_item_owner_id",
"data.track_owner_id",
"data.parent_track_owner_id",
"data.playlist_owner_id",
"data.buyer_user_id",
"data.seller_user_id",
} {
appendInt(userIds, gjson.GetBytes(action, path))
}

appendInt(trackIds, gjson.GetBytes(action, "data.track_id"))
appendInt(trackIds, gjson.GetBytes(action, "data.parent_track_id"))
appendInt(playlistIds, gjson.GetBytes(action, "data.playlist_id"))

// Polymorphic fields: split by sibling type discriminator.
itemType := strings.ToLower(gjson.GetBytes(action, "data.type").String())
for _, path := range []string{
"data.repost_item_id",
"data.save_item_id",
"data.repost_of_repost_item_id",
"data.save_of_repost_item_id",
} {
val := gjson.GetBytes(action, path)
if !val.Exists() || val.Type != gjson.Number {
continue
}
if itemType == "track" {
*trackIds = append(*trackIds, int32(val.Int()))
} else if itemType == "playlist" || itemType == "album" {
*playlistIds = append(*playlistIds, int32(val.Int()))
}
}

if val := gjson.GetBytes(action, "data.tastemaker_item_id"); val.Exists() && val.Type == gjson.Number {
switch strings.ToLower(gjson.GetBytes(action, "data.tastemaker_item_type").String()) {
case "track":
*trackIds = append(*trackIds, int32(val.Int()))
case "playlist", "album":
*playlistIds = append(*playlistIds, int32(val.Int()))
}
}

if val := gjson.GetBytes(action, "data.content_id"); val.Exists() && val.Type == gjson.Number {
switch strings.ToLower(gjson.GetBytes(action, "data.content_type").String()) {
case "track":
*trackIds = append(*trackIds, int32(val.Int()))
case "playlist", "album":
*playlistIds = append(*playlistIds, int32(val.Int()))
}
}

// Comment notifications: entity_id is a track when entity_type is Track.
if val := gjson.GetBytes(action, "data.entity_id"); val.Exists() && val.Type == gjson.Number {
if strings.EqualFold(gjson.GetBytes(action, "data.entity_type").String(), "track") {
*trackIds = append(*trackIds, int32(val.Int()))
}
}
}
116 changes: 116 additions & 0 deletions api/v1_notifications_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"strconv"
"testing"

"api.audius.co/database"
Expand Down Expand Up @@ -468,3 +469,118 @@ func TestV1Notifications_AnnouncementRequiresUserIdInUserIds(t *testing.T) {
"data.notifications.0.actions.0.data.title": "For user 1",
})
}

// TestV1Notifications_RelatedEntities exercises the response's `related` block:
//
// - users/tracks/playlists referenced by notification action data are
// hydrated server-side so the client doesn't need follow-up round trips
// - actor IDs are capped at notificationRelatedActorsPerGroup per group so
// a fan-out notification (e.g. 100 followers) doesn't bloat the response;
// the target entity (the followee, in this case) is duplicated in every
// action's data so it's still picked up under the cap
// - polymorphic *_item_id fields (repost_item_id here) are routed to the
// right bucket based on the sibling `type` discriminator
func TestV1Notifications_RelatedEntities(t *testing.T) {
app := emptyTestApp(t)

const recipient = 1
// Five followers, but the per-group cap should drop us to
// notificationRelatedActorsPerGroup followers + the followee.
followers := []int{100, 101, 102, 103, 104}
const reposter = 300
const repostedTrackID = 50
const repostedTrackOwner = 200
const savedPlaylistID = 60
const saver = 400

users := []map[string]any{
{"user_id": recipient},
{"user_id": reposter},
{"user_id": repostedTrackOwner},
{"user_id": saver},
}
for _, fid := range followers {
users = append(users, map[string]any{"user_id": fid})
}

notifs := []map[string]any{
{
"id": 10,
"specifier": "300",
"group_id": "repost:track:50",
"type": "repost",
"user_ids": []int{recipient},
"data": []byte(`{"type": "track", "user_id": 300, "repost_item_id": 50}`),
"timestamp": "2025-01-01 00:00:00",
},
{
"id": 11,
"specifier": "400",
"group_id": "save:playlist:60",
"type": "save",
"user_ids": []int{recipient},
"data": []byte(`{"type": "playlist", "user_id": 400, "save_item_id": 60}`),
"timestamp": "2025-01-02 00:00:00",
},
}
// Five follow notifications, all in the same group (one logical
// "you got followed by 5 people" notification after json_agg).
for i, fid := range followers {
notifs = append(notifs, map[string]any{
"id": 20 + i,
"specifier": strconv.Itoa(fid),
"group_id": "follow:1",
"type": "follow",
"user_ids": []int{recipient},
"data": []byte(`{"follower_user_id": ` + strconv.Itoa(fid) +
`, "followee_user_id": ` + strconv.Itoa(recipient) + `}`),
"timestamp": "2025-01-03 00:00:00",
})
}

fixtures := database.FixtureMap{
"users": users,
"tracks": []map[string]any{{"track_id": repostedTrackID, "owner_id": repostedTrackOwner}},
"playlists": []map[string]any{
{"playlist_id": savedPlaylistID, "playlist_owner_id": recipient},
},
"notification": notifs,
}

database.Seed(app.pool.Replicas[0], fixtures)

status, body := testGet(t, app, "/v1/notifications/"+trashid.MustEncodeHashID(recipient))
assert.Equal(t, 200, status)

gotTrackIds := pluckStrings(body, "related.tracks.#.id")
assert.ElementsMatch(t,
[]string{trashid.MustEncodeHashID(repostedTrackID)},
gotTrackIds,
"reposted track must be hydrated under related.tracks",
)

gotPlaylistIds := pluckStrings(body, "related.playlists.#.id")
assert.ElementsMatch(t,
[]string{trashid.MustEncodeHashID(savedPlaylistID)},
gotPlaylistIds,
"saved playlist must be hydrated under related.playlists",
)

gotUserIds := pluckStrings(body, "related.users.#.id")

// Fan-out cap: at most notificationRelatedActorsPerGroup followers from the
// follow group, plus the reposter, the saver, and the followee (recipient).
maxFollowersHydrated := notificationRelatedActorsPerGroup
maxExpected := maxFollowersHydrated + 3 // reposter, saver, followee
assert.LessOrEqual(t, len(gotUserIds), maxExpected,
"actor cap must bound the related.users size for fan-out groups; got %v", gotUserIds)

// Always-included targets: the recipient (followee), the reposter, the saver.
assert.Contains(t, gotUserIds, trashid.MustEncodeHashID(recipient),
"followee (recipient) must appear in related.users")
assert.Contains(t, gotUserIds, trashid.MustEncodeHashID(reposter),
"reposter must appear in related.users")
assert.Contains(t, gotUserIds, trashid.MustEncodeHashID(saver),
"saver must appear in related.users")
}

Loading