Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/api/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/InWheelOrg/inwheel-api/internal/a11y"
"github.com/InWheelOrg/inwheel-api/internal/middleware"
"github.com/InWheelOrg/inwheel-api/internal/place"
"github.com/InWheelOrg/inwheel-api/internal/testhelpers"
"github.com/InWheelOrg/inwheel-api/pkg/models"
"golang.org/x/time/rate"
Expand Down Expand Up @@ -60,6 +61,7 @@ func newTestServer(t *testing.T) *Server {
ctx := t.Context()
return &Server{
db: testDB,
places: place.NewRepository(testDB),
engine: &a11y.Engine{},
regLimiter: middleware.NewRateLimiter(ctx, rate.Every(time.Millisecond), 1000),
keyLimiter: middleware.NewRateLimiter(ctx, rate.Every(time.Millisecond), 1000),
Expand Down
59 changes: 11 additions & 48 deletions cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import (
"github.com/InWheelOrg/inwheel-api/internal/geo"
"github.com/InWheelOrg/inwheel-api/internal/middleware"
"github.com/InWheelOrg/inwheel-api/internal/pagination"
"github.com/InWheelOrg/inwheel-api/internal/place"
"github.com/InWheelOrg/inwheel-api/internal/validation"
"github.com/InWheelOrg/inwheel-api/pkg/models"
"github.com/getkin/kin-openapi/openapi3filter"
nethttp_middleware "github.com/oapi-codegen/nethttp-middleware"
"golang.org/x/time/rate"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)

type ctxKeyRequest struct{}
Expand All @@ -52,6 +52,7 @@ func injectRequest() apiv1.StrictMiddlewareFunc {
// Server handles HTTP requests for the InWheel API and implements StrictServerInterface.
type Server struct {
db *gorm.DB
places *place.Repository
engine *a11y.Engine
regLimiter *middleware.RateLimiter
keyLimiter *middleware.RateLimiter
Expand Down Expand Up @@ -95,6 +96,7 @@ func main() {

srv := &Server{
db: gormDB,
places: place.NewRepository(gormDB),
engine: &a11y.Engine{},
regLimiter: middleware.NewRateLimiter(ctx, rate.Every(20*time.Minute), 3),
keyLimiter: middleware.NewRateLimiter(ctx, rate.Every(time.Second), 60),
Expand Down Expand Up @@ -183,8 +185,6 @@ func bodySizeLimiter(maxBytes int64) apiv1.MiddlewareFunc {
}
}

// ── StrictServerInterface ─────────────────────────────────────────────────────

func (s *Server) ListPlaces(ctx context.Context, request apiv1.ListPlacesRequestObject) (apiv1.ListPlacesResponseObject, error) {
q := request.Params

Expand Down Expand Up @@ -313,58 +313,23 @@ func (s *Server) PatchPlaceAccessibility(ctx context.Context, request apiv1.Patc
}
input.SubmittedAt = &now

var result models.AccessibilityProfile
var auditAction string
err := s.db.Transaction(func(tx *gorm.DB) error {
var profile models.AccessibilityProfile
err := tx.Where("place_id = ?", id).First(&profile).Error

if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if err := tx.First(&models.Place{}, "id = ?", id).Error; err != nil {
return err
}
input.PlaceID = id
input.UpdatedAt = now
if err := tx.Create(&input).Error; err != nil {
return err
}
result = input
auditAction = "create"
return nil
}

updates := map[string]any{
"overall_status": input.OverallStatus,
"components": input.Components,
"updated_at": now,
"submitted_by": input.SubmittedBy,
"submitted_at": input.SubmittedAt,
}
if err := tx.Model(&profile).Clauses(clause.Returning{}).Updates(updates).Error; err != nil {
return err
}
result = profile
auditAction = "update"
return nil
})

created, err := s.places.UpsertProfile(ctx, id, &input)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
if errors.Is(err, place.ErrPlaceNotFound) {
return apiv1.PatchPlaceAccessibility404JSONResponse{Error: "place not found"}, nil
}
return nil, err
}

audit.Log(s.db, "accessibility_profiles", result.ID, keyID, auditAction)
auditAction := "update"
if created {
auditAction = "create"
}
audit.Log(s.db, "accessibility_profiles", input.ID, keyID, auditAction)

return apiv1.PatchPlaceAccessibility200JSONResponse(result), nil
return apiv1.PatchPlaceAccessibility200JSONResponse(input), nil
}

// ── Infrastructure handlers ───────────────────────────────────────────────────

func (s *Server) handleHealthz(w http.ResponseWriter, r *http.Request) {
writeJSON(w, map[string]string{"status": "ok"}, http.StatusOK)
}
Expand All @@ -388,8 +353,6 @@ func (s *Server) handleOpenAPISpec(w http.ResponseWriter, r *http.Request) {
}
}

// ── Response type converters ──────────────────────────────────────────────────

func validationError(errs []validation.FieldError) apiv1.ValidationError {
fields := make([]apiv1.FieldError, len(errs))
for i, e := range errs {
Expand Down
119 changes: 92 additions & 27 deletions internal/place/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
* SPDX-License-Identifier: AGPL-3.0-only
*/

// Package place owns reads and writes against the places table.
package place

import (
"context"
"encoding/json"
"errors"
"fmt"
"time"

"gorm.io/gorm"
"gorm.io/gorm/clause"
Expand All @@ -18,30 +19,23 @@ import (
"github.com/InWheelOrg/inwheel-api/pkg/models"
)

// Repository is the data-access layer for places.
var ErrPlaceNotFound = errors.New("place not found")

type Repository struct {
db *gorm.DB
}

// NewRepository constructs a Repository backed by the given GORM connection.
func NewRepository(db *gorm.DB) *Repository {
return &Repository{db: db}
}

// UpsertBatch inserts or updates the given places in a single SQL statement using
// (osm_id, osm_type) as the conflict key. Existing rows have their name, location,
// category, rank, tags, external_ids, and status replaced. Returns an error if the
// underlying SQL fails. An empty or nil batch is a no-op.
func (r *Repository) UpsertBatch(ctx context.Context, places []models.Place) error {
if len(places) == 0 {
return nil
}

// TargetWhere matches the partial index predicate. The index is defined
// WHERE osm_id <> 0 so test fixtures that create places without setting
// osm_id (zero value) don't collide on the unique constraint. In production,
// every place is OSM-sourced and has a non-zero osm_id, so the predicate
// covers every real row.
// TargetWhere matches the partial index (WHERE osm_id <> 0) so zero-OSMID
// test fixtures don't collide on the unique constraint.
tx := r.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "osm_id"},
Expand All @@ -61,9 +55,6 @@ func (r *Repository) UpsertBatch(ctx context.Context, places []models.Place) err
return nil
}

// FindCandidates returns active places within radiusM metres of (lat, lng)
// whose category is in categories, ordered by ascending distance, capped at 32.
// Satisfies identity.CandidateRepo.
func (r *Repository) FindCandidates(
ctx context.Context,
lat, lng, radiusM float64,
Expand Down Expand Up @@ -91,9 +82,6 @@ func (r *Repository) FindCandidates(
return out, nil
}

// AttachExternalRef upserts ref into the place's external_ids map under the
// given source key, atomically via jsonb_set. Returns an error if no row has
// the given id.
func (r *Repository) AttachExternalRef(
ctx context.Context,
placeID, source string,
Expand Down Expand Up @@ -124,16 +112,93 @@ func (r *Repository) AttachExternalRef(
return nil
}

// Compile-time assertion that *Repository satisfies identity.CandidateRepo.
// The assertion lives here so a signature drift in either side fails the build
// at the boundary, not at the first caller.
// UpsertProfile creates or replaces the accessibility profile. Always overwrites — API write path.
// Returns created=true when a new row was inserted, false when an existing row was updated.
func (r *Repository) UpsertProfile(ctx context.Context, placeID string, profile *models.AccessibilityProfile) (created bool, err error) {
if profile == nil {
return false, fmt.Errorf("upsert profile: nil profile")
}
now := time.Now()
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.First(&models.Place{}, "id = ?", placeID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrPlaceNotFound
}
return fmt.Errorf("upsert profile: check place: %w", err)
}
var existing models.AccessibilityProfile
loadErr := tx.Where("place_id = ?", placeID).First(&existing).Error
if loadErr != nil && !errors.Is(loadErr, gorm.ErrRecordNotFound) {
return fmt.Errorf("upsert profile: load existing: %w", loadErr)
}
if errors.Is(loadErr, gorm.ErrRecordNotFound) {
profile.PlaceID = placeID
profile.UpdatedAt = now
created = true
return tx.Create(profile).Error
}
updates := map[string]any{
"overall_status": profile.OverallStatus,
"components": profile.Components,
"updated_at": now,
"submitted_by": profile.SubmittedBy,
"submitted_at": profile.SubmittedAt,
"user_verified": profile.UserVerified,
}
if err := tx.Model(&existing).Clauses(clause.Returning{}).Updates(updates).Error; err != nil {
return err
}
*profile = existing
return nil
})
return created, err
}

// UpsertProfileIngestion creates or updates the accessibility profile but skips
// rows where user_verified=true, preserving human-submitted corrections.
// Returns written=true when a row was actually written.
func (r *Repository) UpsertProfileIngestion(ctx context.Context, placeID string, profile *models.AccessibilityProfile) (written bool, err error) {
if profile == nil {
return false, fmt.Errorf("upsert profile ingestion: nil profile")
}
now := time.Now()
err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var existing models.AccessibilityProfile
loadErr := tx.Where("place_id = ?", placeID).First(&existing).Error
if loadErr != nil && !errors.Is(loadErr, gorm.ErrRecordNotFound) {
return fmt.Errorf("upsert profile ingestion: load existing: %w", loadErr)
}
if errors.Is(loadErr, gorm.ErrRecordNotFound) {
profile.PlaceID = placeID
profile.UpdatedAt = now
profile.UserVerified = false
if err := tx.Create(profile).Error; err != nil {
return err
}
written = true
return nil
}
if existing.UserVerified {
return nil
}
updates := map[string]any{
"overall_status": profile.OverallStatus,
"components": profile.Components,
"updated_at": now,
"submitted_by": nil,
"submitted_at": nil,
}
if err := tx.Model(&existing).Updates(updates).Error; err != nil {
return err
}
written = true
return nil
})
return written, err
}

var _ interface {
FindCandidates(
ctx context.Context,
lat, lng, radiusM float64,
categories []models.Category,
) ([]models.Place, error)
FindCandidates(ctx context.Context, lat, lng, radiusM float64, categories []models.Category) ([]models.Place, error)
} = (*Repository)(nil)

// Compile-time assertion that *Repository satisfies identity.AttachRepo.
var _ identity.AttachRepo = (*Repository)(nil)
Loading
Loading