Skip to content
15 changes: 8 additions & 7 deletions cmd/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@ import (
"encoding/json"
"fmt"
"log/slog"
"maglev.onebusaway.org/internal/app"
"maglev.onebusaway.org/internal/appconf"
"maglev.onebusaway.org/internal/gtfs"
"maglev.onebusaway.org/internal/logging"
"maglev.onebusaway.org/internal/restapi"
"maglev.onebusaway.org/internal/webui"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"

"maglev.onebusaway.org/internal/app"
"maglev.onebusaway.org/internal/appconf"
"maglev.onebusaway.org/internal/gtfs"
"maglev.onebusaway.org/internal/logging"
"maglev.onebusaway.org/internal/restapi"
"maglev.onebusaway.org/internal/webui"
)

// ParseAPIKeys splits a comma-separated string of API keys and trims whitespace from each key.
Expand Down Expand Up @@ -46,7 +47,7 @@ func BuildApplication(cfg appconf.Config, gtfsCfg gtfs.Config) (*app.Application

var directionCalculator *gtfs.DirectionCalculator
if gtfsManager != nil {
directionCalculator = gtfs.NewDirectionCalculator(gtfsManager.GtfsDB.Queries)
directionCalculator = gtfs.NewDirectionCalculator(gtfsManager)
}

coreApp := &app.Application{
Expand Down
4 changes: 4 additions & 0 deletions gtfsdb/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ func (c *Client) Close() error {
return c.DB.Close()
}

func (c *Client) GetDBPath() string {
return c.config.DBPath
}

// DownloadAndStore downloads GTFS data from the given URL and stores it in the database
func (c *Client) DownloadAndStore(ctx context.Context, url, authHeaderKey, authHeaderValue string) error {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
Expand Down
16 changes: 8 additions & 8 deletions internal/gtfs/direction_calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ import (
const unknownDirection = models.UnknownValue

type DirectionCalculator struct {
queries *gtfsdb.Queries
gtfsManager *Manager
}

func NewDirectionCalculator(queries *gtfsdb.Queries) *DirectionCalculator {
func NewDirectionCalculator(gtfsManager *Manager) *DirectionCalculator {
return &DirectionCalculator{
queries: queries,
gtfsManager: gtfsManager,
}
}

// CalculateStopDirection determines the compass direction for a stop
// First checks the database for precomputed direction, falls back to calculation if needed
func (dc *DirectionCalculator) CalculateStopDirection(ctx context.Context, stopID string) string {
// Strategy 1: Check database for precomputed direction (O(1) lookup)
stop, err := dc.queries.GetStop(ctx, stopID)
stop, err := dc.gtfsManager.GtfsDB.Queries.GetStop(ctx, stopID)
if err == nil && stop.Direction.Valid && stop.Direction.String != "" {
return stop.Direction.String
}
Expand All @@ -45,7 +45,7 @@ func (dc *DirectionCalculator) CalculateStopDirection(ctx context.Context, stopI

func (dc *DirectionCalculator) calculateFromShape(ctx context.Context, stopID string) string {
// Get trips serving this stop
stopTrips, err := dc.queries.GetStopsWithTripContext(ctx, stopID)
stopTrips, err := dc.gtfsManager.GtfsDB.Queries.GetStopsWithTripContext(ctx, stopID)
if err != nil || len(stopTrips) == 0 {
return unknownDirection
}
Expand All @@ -58,7 +58,7 @@ func (dc *DirectionCalculator) calculateFromShape(ctx context.Context, stopID st
}

// Get shape points for this trip
shapePoints, err := dc.queries.GetShapePointsForTrip(ctx, stopTrip.TripID)
shapePoints, err := dc.gtfsManager.GtfsDB.Queries.GetShapePointsForTrip(ctx, stopTrip.TripID)
if err != nil || len(shapePoints) < 2 {
continue
}
Expand All @@ -82,15 +82,15 @@ func (dc *DirectionCalculator) calculateFromShape(ctx context.Context, stopID st
}

func (dc *DirectionCalculator) calculateFromNextStop(ctx context.Context, stopID string) string {
stopTrips, err := dc.queries.GetStopsWithTripContext(ctx, stopID)
stopTrips, err := dc.gtfsManager.GtfsDB.Queries.GetStopsWithTripContext(ctx, stopID)
if err != nil || len(stopTrips) == 0 {
return unknownDirection
}

directions := make(map[string]int)

for _, stopTrip := range stopTrips {
nextStop, err := dc.queries.GetNextStopInTrip(ctx, gtfsdb.GetNextStopInTripParams{
nextStop, err := dc.gtfsManager.GtfsDB.Queries.GetNextStopInTrip(ctx, gtfsdb.GetNextStopInTripParams{
TripID: stopTrip.TripID,
StopSequence: stopTrip.StopSequence,
})
Expand Down
62 changes: 41 additions & 21 deletions internal/gtfs/direction_calculator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ func TestNewDirectionCalculator(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

require.NotNil(t, dc)
assert.NotNil(t, dc.queries)
assert.NotNil(t, dc.gtfsManager)
}

func TestCalculateStopDirection_PrecomputedDirection(t *testing.T) {
Expand All @@ -33,7 +34,8 @@ func TestCalculateStopDirection_PrecomputedDirection(t *testing.T) {
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.CalculateStopDirection(ctx, "STOP1")

assert.Equal(t, "N", direction)
Expand Down Expand Up @@ -92,7 +94,8 @@ func TestCalculateStopDirection_NoPrecomputedDirection_FallsBackToShape(t *testi
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.CalculateStopDirection(ctx, "STOP2")

// Should calculate direction from shape (northbound based on coordinates)
Expand Down Expand Up @@ -147,7 +150,8 @@ func TestCalculateStopDirection_NoShapeData_FallsBackToNextStop(t *testing.T) {
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.CalculateStopDirection(ctx, "STOP3")

// Should calculate direction from next stop (northbound)
Expand All @@ -159,7 +163,8 @@ func TestCalculateStopDirection_NoData_ReturnsUnknown(t *testing.T) {
defer func() { _ = client.Close() }()
ctx := context.Background()

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.CalculateStopDirection(ctx, "NONEXISTENT")

assert.Equal(t, models.UnknownValue, direction)
Expand All @@ -177,7 +182,8 @@ func TestCalculateStopDirection_StopWithoutTrips_ReturnsUnknown(t *testing.T) {
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.CalculateStopDirection(ctx, "STOP_ALONE")

assert.Equal(t, models.UnknownValue, direction)
Expand All @@ -186,7 +192,8 @@ func TestCalculateStopDirection_StopWithoutTrips_ReturnsUnknown(t *testing.T) {
func TestFindClosestShapePoint_EmptyPoints(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

result := dc.findClosestShapePoint([]gtfsdb.GetShapePointsForTripRow{}, 40.7128, -74.0060)

Expand All @@ -196,7 +203,8 @@ func TestFindClosestShapePoint_EmptyPoints(t *testing.T) {
func TestFindClosestShapePoint_SinglePoint(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

points := []gtfsdb.GetShapePointsForTripRow{
{Lat: 40.7128, Lon: -74.0060},
Expand All @@ -210,7 +218,8 @@ func TestFindClosestShapePoint_SinglePoint(t *testing.T) {
func TestFindClosestShapePoint_MultiplePoints(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

points := []gtfsdb.GetShapePointsForTripRow{
{Lat: 40.7128, Lon: -74.0060}, // Point 0 - far
Expand All @@ -227,7 +236,8 @@ func TestFindClosestShapePoint_MultiplePoints(t *testing.T) {
func TestFindClosestShapePoint_ClosestIsFirst(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

points := []gtfsdb.GetShapePointsForTripRow{
{Lat: 40.7128, Lon: -74.0060}, // Point 0 - closest
Expand All @@ -243,7 +253,8 @@ func TestFindClosestShapePoint_ClosestIsFirst(t *testing.T) {
func TestFindClosestShapePoint_ClosestIsLast(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

points := []gtfsdb.GetShapePointsForTripRow{
{Lat: 40.7128, Lon: -74.0060}, // Point 0 - far
Expand All @@ -259,7 +270,8 @@ func TestFindClosestShapePoint_ClosestIsLast(t *testing.T) {
func TestGetMostCommonDirection_EmptyMap(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

directions := make(map[string]int)

Expand All @@ -271,7 +283,8 @@ func TestGetMostCommonDirection_EmptyMap(t *testing.T) {
func TestGetMostCommonDirection_SingleDirection(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

directions := map[string]int{
"N": 5,
Expand All @@ -285,7 +298,8 @@ func TestGetMostCommonDirection_SingleDirection(t *testing.T) {
func TestGetMostCommonDirection_MultipleDirections(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

directions := map[string]int{
"N": 3,
Expand All @@ -301,7 +315,8 @@ func TestGetMostCommonDirection_MultipleDirections(t *testing.T) {
func TestGetMostCommonDirection_TieGoesToFirst(t *testing.T) {
client := setupTestClient(t)
defer func() { _ = client.Close() }()
dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)

// Map iteration order is random, but one will be selected
directions := map[string]int{
Expand All @@ -327,7 +342,8 @@ func TestCalculateFromShape_NoTrips(t *testing.T) {
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.calculateFromShape(ctx, "STOP_NO_TRIPS")

assert.Equal(t, models.UnknownValue, direction)
Expand Down Expand Up @@ -375,7 +391,8 @@ func TestCalculateFromShape_TripWithoutShape(t *testing.T) {
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.calculateFromShape(ctx, "STOP_NO_SHAPE")

assert.Equal(t, models.UnknownValue, direction)
Expand Down Expand Up @@ -429,7 +446,8 @@ func TestCalculateFromShape_ShapeWithOnePoint(t *testing.T) {
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.calculateFromShape(ctx, "STOP_ONE_POINT")

assert.Equal(t, models.UnknownValue, direction)
Expand Down Expand Up @@ -486,7 +504,8 @@ func TestCalculateFromShape_ClosestPointIsLast(t *testing.T) {
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.calculateFromShape(ctx, "STOP_AT_END")

// When stop is at last shape point, can't calculate direction
Expand Down Expand Up @@ -535,7 +554,8 @@ func TestCalculateFromNextStop_NoNextStop(t *testing.T) {
`)
require.NoError(t, err)

dc := NewDirectionCalculator(client.Queries)
manager := &Manager{GtfsDB: client}
dc := NewDirectionCalculator(manager)
direction := dc.calculateFromNextStop(ctx, "LAST_STOP")

assert.Equal(t, models.UnknownValue, direction)
Expand Down
2 changes: 1 addition & 1 deletion internal/gtfs/gtfs_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func InitGTFSManager(config Config) (*Manager, error) {
}
manager.setStaticGTFS(staticData)

gtfsDB, err := buildGtfsDB(config, isLocalFile)
gtfsDB, err := buildGtfsDB(config, isLocalFile, "")
if err != nil {
return nil, fmt.Errorf("error building GTFS database: %w", err)
}
Expand Down
Loading
Loading