diff --git a/cmd/api/integration_test.go b/cmd/api/integration_test.go index 9b7824d..c66e34c 100644 --- a/cmd/api/integration_test.go +++ b/cmd/api/integration_test.go @@ -21,6 +21,7 @@ import ( "github.com/InWheelOrg/inwheel-api/internal/a11y" "github.com/InWheelOrg/inwheel-api/internal/middleware" + "github.com/InWheelOrg/inwheel-api/internal/place" "github.com/InWheelOrg/inwheel-api/internal/testhelpers" "github.com/InWheelOrg/inwheel-api/pkg/models" "golang.org/x/time/rate" @@ -60,6 +61,7 @@ func newTestServer(t *testing.T) *Server { ctx := t.Context() return &Server{ db: testDB, + places: place.NewRepository(testDB), engine: &a11y.Engine{}, regLimiter: middleware.NewRateLimiter(ctx, rate.Every(time.Millisecond), 1000), keyLimiter: middleware.NewRateLimiter(ctx, rate.Every(time.Millisecond), 1000), diff --git a/cmd/api/main.go b/cmd/api/main.go index 2226548..e9c6df7 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -25,13 +25,13 @@ import ( "github.com/InWheelOrg/inwheel-api/internal/geo" "github.com/InWheelOrg/inwheel-api/internal/middleware" "github.com/InWheelOrg/inwheel-api/internal/pagination" + "github.com/InWheelOrg/inwheel-api/internal/place" "github.com/InWheelOrg/inwheel-api/internal/validation" "github.com/InWheelOrg/inwheel-api/pkg/models" "github.com/getkin/kin-openapi/openapi3filter" nethttp_middleware "github.com/oapi-codegen/nethttp-middleware" "golang.org/x/time/rate" "gorm.io/gorm" - "gorm.io/gorm/clause" ) type ctxKeyRequest struct{} @@ -52,6 +52,7 @@ func injectRequest() apiv1.StrictMiddlewareFunc { // Server handles HTTP requests for the InWheel API and implements StrictServerInterface. type Server struct { db *gorm.DB + places *place.Repository engine *a11y.Engine regLimiter *middleware.RateLimiter keyLimiter *middleware.RateLimiter @@ -95,6 +96,7 @@ func main() { srv := &Server{ db: gormDB, + places: place.NewRepository(gormDB), engine: &a11y.Engine{}, regLimiter: middleware.NewRateLimiter(ctx, rate.Every(20*time.Minute), 3), keyLimiter: middleware.NewRateLimiter(ctx, rate.Every(time.Second), 60), @@ -183,8 +185,6 @@ func bodySizeLimiter(maxBytes int64) apiv1.MiddlewareFunc { } } -// ── StrictServerInterface ───────────────────────────────────────────────────── - func (s *Server) ListPlaces(ctx context.Context, request apiv1.ListPlacesRequestObject) (apiv1.ListPlacesResponseObject, error) { q := request.Params @@ -313,58 +313,23 @@ func (s *Server) PatchPlaceAccessibility(ctx context.Context, request apiv1.Patc } input.SubmittedAt = &now - var result models.AccessibilityProfile - var auditAction string - err := s.db.Transaction(func(tx *gorm.DB) error { - var profile models.AccessibilityProfile - err := tx.Where("place_id = ?", id).First(&profile).Error - - if err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - if err := tx.First(&models.Place{}, "id = ?", id).Error; err != nil { - return err - } - input.PlaceID = id - input.UpdatedAt = now - if err := tx.Create(&input).Error; err != nil { - return err - } - result = input - auditAction = "create" - return nil - } - - updates := map[string]any{ - "overall_status": input.OverallStatus, - "components": input.Components, - "updated_at": now, - "submitted_by": input.SubmittedBy, - "submitted_at": input.SubmittedAt, - } - if err := tx.Model(&profile).Clauses(clause.Returning{}).Updates(updates).Error; err != nil { - return err - } - result = profile - auditAction = "update" - return nil - }) - + created, err := s.places.UpsertProfile(ctx, id, &input) if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { + if errors.Is(err, place.ErrPlaceNotFound) { return apiv1.PatchPlaceAccessibility404JSONResponse{Error: "place not found"}, nil } return nil, err } - audit.Log(s.db, "accessibility_profiles", result.ID, keyID, auditAction) + auditAction := "update" + if created { + auditAction = "create" + } + audit.Log(s.db, "accessibility_profiles", input.ID, keyID, auditAction) - return apiv1.PatchPlaceAccessibility200JSONResponse(result), nil + return apiv1.PatchPlaceAccessibility200JSONResponse(input), nil } -// ── Infrastructure handlers ─────────────────────────────────────────────────── - func (s *Server) handleHealthz(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]string{"status": "ok"}, http.StatusOK) } @@ -388,8 +353,6 @@ func (s *Server) handleOpenAPISpec(w http.ResponseWriter, r *http.Request) { } } -// ── Response type converters ────────────────────────────────────────────────── - func validationError(errs []validation.FieldError) apiv1.ValidationError { fields := make([]apiv1.FieldError, len(errs)) for i, e := range errs { diff --git a/internal/place/repository.go b/internal/place/repository.go index 633089b..85605c6 100644 --- a/internal/place/repository.go +++ b/internal/place/repository.go @@ -3,13 +3,14 @@ * SPDX-License-Identifier: AGPL-3.0-only */ -// Package place owns reads and writes against the places table. package place import ( "context" "encoding/json" + "errors" "fmt" + "time" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -18,30 +19,23 @@ import ( "github.com/InWheelOrg/inwheel-api/pkg/models" ) -// Repository is the data-access layer for places. +var ErrPlaceNotFound = errors.New("place not found") + type Repository struct { db *gorm.DB } -// NewRepository constructs a Repository backed by the given GORM connection. func NewRepository(db *gorm.DB) *Repository { return &Repository{db: db} } -// UpsertBatch inserts or updates the given places in a single SQL statement using -// (osm_id, osm_type) as the conflict key. Existing rows have their name, location, -// category, rank, tags, external_ids, and status replaced. Returns an error if the -// underlying SQL fails. An empty or nil batch is a no-op. func (r *Repository) UpsertBatch(ctx context.Context, places []models.Place) error { if len(places) == 0 { return nil } - // TargetWhere matches the partial index predicate. The index is defined - // WHERE osm_id <> 0 so test fixtures that create places without setting - // osm_id (zero value) don't collide on the unique constraint. In production, - // every place is OSM-sourced and has a non-zero osm_id, so the predicate - // covers every real row. + // TargetWhere matches the partial index (WHERE osm_id <> 0) so zero-OSMID + // test fixtures don't collide on the unique constraint. tx := r.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{ {Name: "osm_id"}, @@ -61,9 +55,6 @@ func (r *Repository) UpsertBatch(ctx context.Context, places []models.Place) err return nil } -// FindCandidates returns active places within radiusM metres of (lat, lng) -// whose category is in categories, ordered by ascending distance, capped at 32. -// Satisfies identity.CandidateRepo. func (r *Repository) FindCandidates( ctx context.Context, lat, lng, radiusM float64, @@ -91,9 +82,6 @@ func (r *Repository) FindCandidates( return out, nil } -// AttachExternalRef upserts ref into the place's external_ids map under the -// given source key, atomically via jsonb_set. Returns an error if no row has -// the given id. func (r *Repository) AttachExternalRef( ctx context.Context, placeID, source string, @@ -124,16 +112,93 @@ func (r *Repository) AttachExternalRef( return nil } -// Compile-time assertion that *Repository satisfies identity.CandidateRepo. -// The assertion lives here so a signature drift in either side fails the build -// at the boundary, not at the first caller. +// UpsertProfile creates or replaces the accessibility profile. Always overwrites — API write path. +// Returns created=true when a new row was inserted, false when an existing row was updated. +func (r *Repository) UpsertProfile(ctx context.Context, placeID string, profile *models.AccessibilityProfile) (created bool, err error) { + if profile == nil { + return false, fmt.Errorf("upsert profile: nil profile") + } + now := time.Now() + err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.First(&models.Place{}, "id = ?", placeID).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrPlaceNotFound + } + return fmt.Errorf("upsert profile: check place: %w", err) + } + var existing models.AccessibilityProfile + loadErr := tx.Where("place_id = ?", placeID).First(&existing).Error + if loadErr != nil && !errors.Is(loadErr, gorm.ErrRecordNotFound) { + return fmt.Errorf("upsert profile: load existing: %w", loadErr) + } + if errors.Is(loadErr, gorm.ErrRecordNotFound) { + profile.PlaceID = placeID + profile.UpdatedAt = now + created = true + return tx.Create(profile).Error + } + updates := map[string]any{ + "overall_status": profile.OverallStatus, + "components": profile.Components, + "updated_at": now, + "submitted_by": profile.SubmittedBy, + "submitted_at": profile.SubmittedAt, + "user_verified": profile.UserVerified, + } + if err := tx.Model(&existing).Clauses(clause.Returning{}).Updates(updates).Error; err != nil { + return err + } + *profile = existing + return nil + }) + return created, err +} + +// UpsertProfileIngestion creates or updates the accessibility profile but skips +// rows where user_verified=true, preserving human-submitted corrections. +// Returns written=true when a row was actually written. +func (r *Repository) UpsertProfileIngestion(ctx context.Context, placeID string, profile *models.AccessibilityProfile) (written bool, err error) { + if profile == nil { + return false, fmt.Errorf("upsert profile ingestion: nil profile") + } + now := time.Now() + err = r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var existing models.AccessibilityProfile + loadErr := tx.Where("place_id = ?", placeID).First(&existing).Error + if loadErr != nil && !errors.Is(loadErr, gorm.ErrRecordNotFound) { + return fmt.Errorf("upsert profile ingestion: load existing: %w", loadErr) + } + if errors.Is(loadErr, gorm.ErrRecordNotFound) { + profile.PlaceID = placeID + profile.UpdatedAt = now + profile.UserVerified = false + if err := tx.Create(profile).Error; err != nil { + return err + } + written = true + return nil + } + if existing.UserVerified { + return nil + } + updates := map[string]any{ + "overall_status": profile.OverallStatus, + "components": profile.Components, + "updated_at": now, + "submitted_by": nil, + "submitted_at": nil, + } + if err := tx.Model(&existing).Updates(updates).Error; err != nil { + return err + } + written = true + return nil + }) + return written, err +} + var _ interface { - FindCandidates( - ctx context.Context, - lat, lng, radiusM float64, - categories []models.Category, - ) ([]models.Place, error) + FindCandidates(ctx context.Context, lat, lng, radiusM float64, categories []models.Category) ([]models.Place, error) } = (*Repository)(nil) -// Compile-time assertion that *Repository satisfies identity.AttachRepo. var _ identity.AttachRepo = (*Repository)(nil) diff --git a/internal/place/repository_integration_test.go b/internal/place/repository_integration_test.go index 8303683..9da2aa5 100644 --- a/internal/place/repository_integration_test.go +++ b/internal/place/repository_integration_test.go @@ -9,10 +9,13 @@ package place_test import ( "context" + "errors" "strings" "testing" "time" + "gorm.io/gorm" + "github.com/InWheelOrg/inwheel-api/internal/place" "github.com/InWheelOrg/inwheel-api/internal/testhelpers" "github.com/InWheelOrg/inwheel-api/pkg/models" @@ -208,6 +211,227 @@ func names(ps []models.Place) []string { return out } +func mustCreatePlace(ctx context.Context, t *testing.T, db *gorm.DB, osmID int64, name string) string { + t.Helper() + p := models.Place{ + OSMID: osmID, + OSMType: models.OSMNode, + Name: name, + Lat: 46.4628, + Lng: 6.8417, + Category: models.CategoryCafe, + Rank: models.RankEstablishment, + Source: "osm", + Status: models.PlaceStatusActive, + } + if err := db.WithContext(ctx).Create(&p).Error; err != nil { + t.Fatalf("mustCreatePlace %q: %v", name, err) + } + return p.ID +} + +func TestRepository_UpsertProfile_CreatesWhenAbsent(t *testing.T) { + ctx := context.Background() + gormDB, cleanup, err := testhelpers.StartPostgres(ctx) + if err != nil { + t.Fatalf("start postgres: %v", err) + } + defer cleanup() + + repo := place.NewRepository(gormDB) + placeID := mustCreatePlace(ctx, t, gormDB, 1001, "Profile Test Place") + + profile := &models.AccessibilityProfile{ + OverallStatus: models.StatusAccessible, + } + created, err := repo.UpsertProfile(ctx, placeID, profile) + if err != nil { + t.Fatalf("UpsertProfile: %v", err) + } + if !created { + t.Errorf("created = false, want true on first insert") + } + + var got models.AccessibilityProfile + if err := gormDB.Where("place_id = ?", placeID).First(&got).Error; err != nil { + t.Fatalf("load profile: %v", err) + } + if got.OverallStatus != models.StatusAccessible { + t.Errorf("OverallStatus = %q, want %q", got.OverallStatus, models.StatusAccessible) + } + if got.PlaceID != placeID { + t.Errorf("PlaceID = %q, want %q", got.PlaceID, placeID) + } +} + +func TestRepository_UpsertProfile_UpdatesWhenPresent(t *testing.T) { + ctx := context.Background() + gormDB, cleanup, err := testhelpers.StartPostgres(ctx) + if err != nil { + t.Fatalf("start postgres: %v", err) + } + defer cleanup() + + repo := place.NewRepository(gormDB) + placeID := mustCreatePlace(ctx, t, gormDB, 1002, "Profile Update Place") + + first := &models.AccessibilityProfile{OverallStatus: models.StatusUnknown} + created, err := repo.UpsertProfile(ctx, placeID, first) + if err != nil { + t.Fatalf("first UpsertProfile: %v", err) + } + if !created { + t.Errorf("created = false, want true on first insert") + } + + second := &models.AccessibilityProfile{OverallStatus: models.StatusLimited} + created, err = repo.UpsertProfile(ctx, placeID, second) + if err != nil { + t.Fatalf("second UpsertProfile: %v", err) + } + if created { + t.Errorf("created = true, want false on update") + } + + var got models.AccessibilityProfile + if err := gormDB.Where("place_id = ?", placeID).First(&got).Error; err != nil { + t.Fatalf("load profile: %v", err) + } + if got.OverallStatus != models.StatusLimited { + t.Errorf("OverallStatus = %q, want %q", got.OverallStatus, models.StatusLimited) + } + var count int64 + gormDB.Model(&models.AccessibilityProfile{}).Where("place_id = ?", placeID).Count(&count) + if count != 1 { + t.Errorf("profile row count = %d, want 1", count) + } +} + +func TestRepository_UpsertProfile_OverwritesUserVerified(t *testing.T) { + ctx := context.Background() + gormDB, cleanup, err := testhelpers.StartPostgres(ctx) + if err != nil { + t.Fatalf("start postgres: %v", err) + } + defer cleanup() + + repo := place.NewRepository(gormDB) + placeID := mustCreatePlace(ctx, t, gormDB, 1003, "User Verified Overwrite Place") + + if _, err := repo.UpsertProfile(ctx, placeID, &models.AccessibilityProfile{OverallStatus: models.StatusAccessible, UserVerified: true}); err != nil { + t.Fatalf("seed: %v", err) + } + + override := &models.AccessibilityProfile{ + OverallStatus: models.StatusInaccessible, + UserVerified: false, + } + if _, err := repo.UpsertProfile(ctx, placeID, override); err != nil { + t.Fatalf("override UpsertProfile: %v", err) + } + + var got models.AccessibilityProfile + if err := gormDB.Where("place_id = ?", placeID).First(&got).Error; err != nil { + t.Fatalf("load profile: %v", err) + } + if got.OverallStatus != models.StatusInaccessible { + t.Errorf("OverallStatus = %q, want %q", got.OverallStatus, models.StatusInaccessible) + } +} + +func TestRepository_UpsertProfile_PlaceNotFound(t *testing.T) { + ctx := context.Background() + db, cleanup, err := testhelpers.StartPostgres(ctx) + if err != nil { + t.Fatalf("start postgres: %v", err) + } + defer cleanup() + + repo := place.NewRepository(db) + _, err = repo.UpsertProfile(ctx, "00000000-0000-0000-0000-000000000000", &models.AccessibilityProfile{OverallStatus: models.StatusAccessible}) + if !errors.Is(err, place.ErrPlaceNotFound) { + t.Errorf("err = %v, want ErrPlaceNotFound", err) + } +} + +func TestRepository_UpsertProfileIngestion_InsertsWhenAbsent(t *testing.T) { + ctx := context.Background() + db, cleanup, err := testhelpers.StartPostgres(ctx) + if err != nil { + t.Fatalf("start postgres: %v", err) + } + defer cleanup() + repo := place.NewRepository(db) + placeID := mustCreatePlace(ctx, t, db, 9004, "Café Pascal Ingestion") + + written, err := repo.UpsertProfileIngestion(ctx, placeID, &models.AccessibilityProfile{OverallStatus: models.StatusAccessible}) + if err != nil { + t.Fatalf("UpsertProfileIngestion: %v", err) + } + if !written { + t.Errorf("written = false, want true on first insert") + } + var stored models.AccessibilityProfile + db.Where("place_id = ?", placeID).First(&stored) + if stored.OverallStatus != models.StatusAccessible { + t.Errorf("OverallStatus = %q, want accessible", stored.OverallStatus) + } +} + +func TestRepository_UpsertProfileIngestion_OverwritesNonVerified(t *testing.T) { + ctx := context.Background() + db, cleanup, err := testhelpers.StartPostgres(ctx) + if err != nil { + t.Fatalf("start postgres: %v", err) + } + defer cleanup() + repo := place.NewRepository(db) + placeID := mustCreatePlace(ctx, t, db, 9005, "Café Pascal Ingestion2") + + if _, err := repo.UpsertProfileIngestion(ctx, placeID, &models.AccessibilityProfile{OverallStatus: models.StatusLimited}); err != nil { + t.Fatalf("seed: %v", err) + } + written, err := repo.UpsertProfileIngestion(ctx, placeID, &models.AccessibilityProfile{OverallStatus: models.StatusAccessible}) + if err != nil { + t.Fatalf("second: %v", err) + } + if !written { + t.Errorf("written = false, want true (existing row is not user-verified)") + } + var stored models.AccessibilityProfile + db.Where("place_id = ?", placeID).First(&stored) + if stored.OverallStatus != models.StatusAccessible { + t.Errorf("OverallStatus = %q, want accessible", stored.OverallStatus) + } +} + +func TestRepository_UpsertProfileIngestion_SkipsUserVerified(t *testing.T) { + ctx := context.Background() + db, cleanup, err := testhelpers.StartPostgres(ctx) + if err != nil { + t.Fatalf("start postgres: %v", err) + } + defer cleanup() + repo := place.NewRepository(db) + placeID := mustCreatePlace(ctx, t, db, 9006, "Café Pascal Ingestion3") + + if _, err := repo.UpsertProfile(ctx, placeID, &models.AccessibilityProfile{OverallStatus: models.StatusAccessible, UserVerified: true}); err != nil { + t.Fatalf("seed: %v", err) + } + written, err := repo.UpsertProfileIngestion(ctx, placeID, &models.AccessibilityProfile{OverallStatus: models.StatusInaccessible}) + if err != nil { + t.Fatalf("machine: %v", err) + } + if written { + t.Errorf("written = true, want false (user-verified row must not be overwritten)") + } + var stored models.AccessibilityProfile + db.Where("place_id = ?", placeID).First(&stored) + if stored.OverallStatus != models.StatusAccessible { + t.Errorf("user-verified row must survive; got %q", stored.OverallStatus) + } +} + func TestUnmatchedExternal_TableRoundTrip(t *testing.T) { ctx := context.Background() db, cleanup, err := testhelpers.StartPostgres(ctx)