Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/sethvargo/go-password/password"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/storage"
Expand Down Expand Up @@ -107,7 +108,7 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)
aud := a.requestAud(ctx, r)

pageParams, err := paginate(r)
pageParams, err := shared.Paginate(r)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
}
Expand All @@ -123,7 +124,7 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {
if err != nil {
return apierrors.NewInternalServerError("Database error finding users").WithInternalError(err)
}
addPaginationHeaders(w, r, pageParams)
shared.AddPaginationHeaders(w, r, pageParams)

return sendJSON(w, http.StatusOK, AdminListUsersResponse{
Users: users,
Expand Down
2 changes: 1 addition & 1 deletion internal/api/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (ts *AdminTestSuite) TestAdminUsers() {
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)

assert.Equal(ts.T(), "</admin/users?page=0>; rel=\"last\"", w.Header().Get("Link"))
assert.Equal(ts.T(), "</admin/users?page=1>; rel=\"last\"", w.Header().Get("Link"))
assert.Equal(ts.T(), "0", w.Header().Get("X-Total-Count"))
}

Expand Down
5 changes: 3 additions & 2 deletions internal/api/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"strings"

"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
)

Expand All @@ -19,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
db := a.db.WithContext(ctx)

// aud := a.requestAud(ctx, r)
pageParams, err := paginate(r)
pageParams, err := shared.Paginate(r)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err)
}
Expand All @@ -42,7 +43,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
return apierrors.NewInternalServerError("Error searching for audit logs").WithInternalError(err)
}

addPaginationHeaders(w, r, pageParams)
shared.AddPaginationHeaders(w, r, pageParams)

return sendJSON(w, http.StatusOK, logs)
}
18 changes: 12 additions & 6 deletions internal/api/oauthserver/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,28 @@
ctx := r.Context()
db := s.db.WithContext(ctx)

// TODO(cemal) :: Add pagination, check the `/admin/users` endpoint for reference
pageParams, err := shared.Paginate(r)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
}

var clients []models.OAuthServerClient
if err := db.Q().Where("deleted_at is null").Order("created_at desc").All(&clients); err != nil {
q := db.Q().Where("deleted_at is null").Order("created_at desc")
if err := q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&clients); err != nil { // #nosec G115

Check failure

Code scanning / CodeQL

Incorrect conversion between integer types High

Incorrect conversion of an unsigned 64-bit integer from
strconv.ParseUint
to a lower bit size type int without an upper bound check.

Check failure

Code scanning / CodeQL

Incorrect conversion between integer types High

Incorrect conversion of an unsigned 64-bit integer from
strconv.ParseUint
to a lower bit size type int without an upper bound check.
return apierrors.NewInternalServerError("Error listing OAuth clients").WithInternalError(err)
}
pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115

shared.AddPaginationHeaders(w, r, pageParams)

responses := make([]OAuthServerClientResponse, len(clients))
for i, client := range clients {
responses[i] = *oauthServerClientToResponse(&client)
}

response := OAuthServerClientListResponse{
return shared.SendJSON(w, http.StatusOK, OAuthServerClientListResponse{
Clients: responses,
}

return shared.SendJSON(w, http.StatusOK, response)
})
}

// OAuthTokenParams represents the parameters for the OAuth token endpoint
Expand Down
76 changes: 76 additions & 0 deletions internal/api/oauthserver/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,82 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListHandler() {
}
}

func (ts *OAuthClientTestSuite) TestOAuthServerClientListPagination() {
client1, _ := ts.createTestOAuthClient()
client2, _ := ts.createTestOAuthClient()
client3, _ := ts.createTestOAuthClient()
allIDs := []string{client1.ID.String(), client2.ID.String(), client3.ID.String()}

// page=1, per_page=1: returns 1 item, has next + last links
req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=1&per_page=1", nil)
w := httptest.NewRecorder()
require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req))
assert.Equal(ts.T(), http.StatusOK, w.Code)
assert.Empty(ts.T(), w.Header().Get("X-Total-Count"))
assert.Contains(ts.T(), w.Header().Get("Link"), `rel="next"`)
assert.Contains(ts.T(), w.Header().Get("Link"), `rel="last"`)
var page1 OAuthServerClientListResponse
require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &page1))
assert.Len(ts.T(), page1.Clients, 1)

// page=2, per_page=1: returns 1 item, still has next link
req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=2&per_page=1", nil)
w = httptest.NewRecorder()
require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req))
assert.Empty(ts.T(), w.Header().Get("X-Total-Count"))
assert.Contains(ts.T(), w.Header().Get("Link"), `rel="next"`)
var page2 OAuthServerClientListResponse
require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &page2))
assert.Len(ts.T(), page2.Clients, 1)

// page=3, per_page=1: last page — no next link
req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=3&per_page=1", nil)
w = httptest.NewRecorder()
require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req))
assert.Empty(ts.T(), w.Header().Get("X-Total-Count"))
assert.NotContains(ts.T(), w.Header().Get("Link"), `rel="next"`)
var page3 OAuthServerClientListResponse
require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &page3))
assert.Len(ts.T(), page3.Clients, 1)

// all three pages together cover all clients with no duplicates
pagedIDs := []string{page1.Clients[0].ClientID, page2.Clients[0].ClientID, page3.Clients[0].ClientID}
for _, id := range allIDs {
assert.Contains(ts.T(), pagedIDs, id)
}

// per_page=2: page 1 returns 2, page 2 returns 1
req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=1&per_page=2", nil)
w = httptest.NewRecorder()
require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req))
var halfPage1 OAuthServerClientListResponse
require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &halfPage1))
assert.Len(ts.T(), halfPage1.Clients, 2)
assert.Contains(ts.T(), w.Header().Get("Link"), `rel="next"`)

req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=2&per_page=2", nil)
w = httptest.NewRecorder()
require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req))
var halfPage2 OAuthServerClientListResponse
require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &halfPage2))
assert.Len(ts.T(), halfPage2.Clients, 1)
assert.NotContains(ts.T(), w.Header().Get("Link"), `rel="next"`)

// no params: returns all 3 with default page size
req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients", nil)
w = httptest.NewRecorder()
require.NoError(ts.T(), ts.Server.OAuthServerClientList(w, req))
assert.Empty(ts.T(), w.Header().Get("X-Total-Count"))
var all OAuthServerClientListResponse
require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &all))
assert.Len(ts.T(), all.Clients, 3)

// invalid page param returns an error
req = httptest.NewRequest(http.MethodGet, "/admin/oauth/clients?page=abc", nil)
w = httptest.NewRecorder()
assert.Error(ts.T(), ts.Server.OAuthServerClientList(w, req))
}

func (ts *OAuthClientTestSuite) TestOAuthServerClientUpdateHandler() {
// Create a test client first
client, _ := ts.createTestOAuthClient()
Expand Down
64 changes: 0 additions & 64 deletions internal/api/pagination.go

This file was deleted.

80 changes: 80 additions & 0 deletions internal/api/shared/pagination.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package shared

import (
"fmt"
"math"
"net/http"
"net/url"
"strconv"

"github.com/supabase/auth/internal/models"
)

const DefaultPerPage = 50
Comment thread
cemalkilic marked this conversation as resolved.
const MaxPerPage = 1000

func calculateTotalPages(perPage, total uint64) uint64 {
pages := total / perPage
if total%perPage > 0 {
return pages + 1
}
return pages
}

func AddPaginationHeaders(w http.ResponseWriter, r *http.Request, p *models.Pagination) {
totalPages := max(calculateTotalPages(p.PerPage, p.Count), 1)
u, _ := url.ParseRequestURI(r.URL.String())
query := u.Query()
header := ""
if totalPages > p.Page {
query.Set("page", fmt.Sprintf("%v", p.Page+1))
u.RawQuery = query.Encode()
header += "<" + u.String() + ">; rel=\"next\", "
}
query.Set("page", fmt.Sprintf("%v", totalPages))
u.RawQuery = query.Encode()
header += "<" + u.String() + ">; rel=\"last\""

w.Header().Add("Link", header)
if p.ShowTotalCount {
w.Header().Add("X-Total-Count", fmt.Sprintf("%v", p.Count))
}
}

func Paginate(r *http.Request) (*models.Pagination, error) {
params := r.URL.Query()
queryPage := params.Get("page")
queryPerPage := params.Get("per_page")
Comment thread
cemalkilic marked this conversation as resolved.
var page uint64 = 1
var perPage uint64 = DefaultPerPage
var err error
if queryPage != "" {
page, err = strconv.ParseUint(queryPage, 10, 64)
if err != nil {
return nil, err
}
if page == 0 {
return nil, fmt.Errorf("page must be greater than 0")
}
if page > math.MaxInt32 {
return nil, fmt.Errorf("page exceeds maximum allowed value")
}
}
if queryPerPage != "" {
perPage, err = strconv.ParseUint(queryPerPage, 10, 64)
if err != nil {
return nil, err
}
if perPage == 0 {
return nil, fmt.Errorf("per_page must be greater than 0")
}
if perPage > MaxPerPage {
return nil, fmt.Errorf("per_page must not exceed %d", MaxPerPage)
}
}

return &models.Pagination{
Page: page,
PerPage: perPage,
}, nil
}
Loading
Loading