Skip to content
4 changes: 2 additions & 2 deletions cmd/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func BuildApplication(cfg appconf.Config, gtfsCfg gtfs.Config) (*app.Application
return nil, fmt.Errorf("failed to initialize GTFS manager: %w", err)
}

var directionCalculator *gtfs.DirectionCalculator
var directionCalculator *gtfs.AdvancedDirectionCalculator
if gtfsManager != nil {
directionCalculator = gtfs.NewDirectionCalculator(gtfsManager.GtfsDB.Queries)
directionCalculator = gtfs.NewAdvancedDirectionCalculator(gtfsManager.GtfsDB.Queries)
}

// Select clock implementation based on environment
Expand Down
20 changes: 20 additions & 0 deletions gtfsdb/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 17 additions & 1 deletion gtfsdb/query.sql
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ WHERE service_id NOT IN (SELECT service_id FROM removed_services)
UNION
SELECT DISTINCT service_id FROM added_services;


-- name: GetTripsForRouteInActiveServiceIDs :many
SELECT DISTINCT *
FROM trips t
Expand Down Expand Up @@ -636,6 +635,18 @@ WHERE
ORDER BY
s.shape_pt_sequence ASC;

-- name: GetStopsWithShapeContextByIDs :many
SELECT
st.stop_id,
t.shape_id,
s.lat,
s.lon,
st.shape_dist_traveled
FROM stop_times st
JOIN trips t ON st.trip_id = t.id
JOIN stops s ON st.stop_id = s.id
WHERE st.stop_id IN (sqlc.slice('stop_ids'));

-- name: GetTripsByBlockIDOrdered :many
SELECT
t.id,
Expand Down Expand Up @@ -969,6 +980,11 @@ WHERE bte.block_trip_index_id IN (sqlc.slice('index_ids'))
AND bte.service_id IN (sqlc.slice('service_ids'));


-- name: GetShapePointsByIDs :many
SELECT * FROM shapes
WHERE shape_id IN (sqlc.slice('shape_ids'))
ORDER BY shape_id, shape_pt_sequence;

-- name: SearchStopsByName :many
SELECT
s.id,
Expand Down
106 changes: 106 additions & 0 deletions gtfsdb/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions gtfsdb/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,6 @@ CREATE INDEX IF NOT EXISTS idx_block_trip_entry_service_id ON block_trip_entry (

-- migrate
CREATE INDEX IF NOT EXISTS idx_trips_block_id ON trips (block_id);

-- migrate
CREATE INDEX IF NOT EXISTS idx_shapes_shape_id ON shapes (shape_id);
2 changes: 1 addition & 1 deletion internal/app/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ type Application struct {
GtfsConfig gtfs.Config
Logger *slog.Logger
GtfsManager *gtfs.Manager
DirectionCalculator *gtfs.DirectionCalculator
DirectionCalculator *gtfs.AdvancedDirectionCalculator
Clock clock.Clock
}
46 changes: 34 additions & 12 deletions internal/gtfs/advanced_direction_calculator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gtfs
import (
"context"
"database/sql"
"log/slog"
"math"
"sort"
"strconv"
Expand All @@ -22,6 +23,7 @@ const (
type AdvancedDirectionCalculator struct {
queries *gtfsdb.Queries
varianceThreshold float64
contextCache map[string][]gtfsdb.GetStopsWithShapeContextRow // Cache of stop shape context data
shapeCache map[string][]gtfsdb.GetShapePointsWithDistanceRow // Cache of all shape data for bulk operations
initialized atomic.Bool // Tracks whether concurrent operations have started
}
Expand Down Expand Up @@ -55,19 +57,28 @@ func (adc *AdvancedDirectionCalculator) SetShapeCache(cache map[string][]gtfsdb.
adc.shapeCache = cache
}

// CalculateStopDirection computes the direction for a stop using the Java algorithm
func (adc *AdvancedDirectionCalculator) CalculateStopDirection(ctx context.Context, stopID string, gtfsDirection sql.NullString) string {
// Mark as initialized on first use to prevent configuration changes during concurrent operations
adc.initialized.Store(true)
// SetContextCache injects the bulk-loaded context data.
// IMPORTANT: This must be called before any concurrent calculation operations begin.
// Panics if called after internal state has been initialized (i.e., after the first
// fallback to shape-based calculation).
func (adc *AdvancedDirectionCalculator) SetContextCache(cache map[string][]gtfsdb.GetStopsWithShapeContextRow) {
if adc.initialized.Load() {
panic("SetContextCache called after concurrent operations have started")
}
adc.contextCache = cache
}

// Step 1: Try to use GTFS direction field if provided
if gtfsDirection.Valid && gtfsDirection.String != "" {
if direction := adc.translateGtfsDirection(gtfsDirection.String); direction != "" {
// CalculateStopDirection computes the direction for a stop using the Java algorithm
func (adc *AdvancedDirectionCalculator) CalculateStopDirection(ctx context.Context, stopID string, gtfsDirection ...sql.NullString) string {
if len(gtfsDirection) > 0 && gtfsDirection[0].Valid && gtfsDirection[0].String != "" {
if direction := adc.translateGtfsDirection(gtfsDirection[0].String); direction != "" {
return direction
}
}

// Step 2: Calculate from shape data
// Mark as initialized for concurrency safety
adc.initialized.Store(true)

return adc.computeFromShapes(ctx, stopID)
}

Expand Down Expand Up @@ -118,10 +129,21 @@ func (adc *AdvancedDirectionCalculator) translateGtfsDirection(direction string)

// computeFromShapes calculates direction from shape data using the Java algorithm
func (adc *AdvancedDirectionCalculator) computeFromShapes(ctx context.Context, stopID string) string {
// Get trips with shape context for this stop
stopTrips, err := adc.queries.GetStopsWithShapeContext(ctx, stopID)
if err != nil || len(stopTrips) == 0 {
return ""

var stopTrips []gtfsdb.GetStopsWithShapeContextRow

// Use cache if available, otherwise hit DB
if adc.contextCache != nil {
stopTrips = adc.contextCache[stopID]
} else {
var err error
stopTrips, err = adc.queries.GetStopsWithShapeContext(ctx, stopID)
if err != nil {
slog.Warn("failed to get stop shape context",
slog.String("stopID", stopID),
slog.String("error", err.Error()))
return ""
}
}

// Collect orientations from all trips, using cache to avoid duplicates
Expand Down
Loading
Loading