Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions internal/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ var (
)
)

// channelDrainTimeout is the maximum time to wait for in-flight
// measurements after the test duration expires. The sender goroutine
// typically produces the final measurement within microseconds of
// ctx.Done(), so this is a generous upper bound.
const channelDrainTimeout = 500 * time.Millisecond

type Handler struct {
archivalDataDir string
}
Expand Down Expand Up @@ -237,7 +243,9 @@ func (h *Handler) upgradeAndRunMeasurement(kind model.TestDirection, rw http.Res
ClientOptions: clientOptions,
}
defer func() {
archivalData.EndTime = time.Now()
if archivalData.EndTime.IsZero() {
archivalData.EndTime = time.Now()
}
h.writeResult(uuid, kind, &archivalData)
}()

Expand All @@ -258,7 +266,37 @@ func (h *Handler) upgradeAndRunMeasurement(kind model.TestDirection, rw http.Res
for {
select {
case <-timeout.Done():
// If the test has timed out count it as a success and return.
// Record EndTime before draining so it reflects the
// actual measurement period, not the drain wait.
archivalData.EndTime = time.Now()
// The test has timed out. Before returning, drain any
// remaining measurements from the sender/receiver
// goroutines. The sender sends a final Measure() on
// ctx.Done() which may still be in-flight.
drainTimer := time.NewTimer(channelDrainTimeout)
defer drainTimer.Stop()
for draining := true; draining; {
select {
case m := <-senderCh:
if kind == model.DirectionDownload && m.CC != "" {
archivalData.CCAlgorithm = m.CC
}
archivalData.ServerMeasurements = append(
archivalData.ServerMeasurements, m.Measurement)
// Reset to a short timeout: the final measurement
// has arrived, just check for any stragglers.
drainTimer.Reset(5 * time.Millisecond)
case m := <-receiverCh:
if kind == model.DirectionUpload && m.CC != "" {
archivalData.CCAlgorithm = m.CC
}
archivalData.ClientMeasurements = append(
archivalData.ClientMeasurements, m.Measurement)
drainTimer.Reset(5 * time.Millisecond)
case <-drainTimer.C:
draining = false
}
}
testsTotal.WithLabelValues(string(kind), "ok-timeout").Inc()
return
case m := <-senderCh:
Expand Down
106 changes: 91 additions & 15 deletions internal/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package handler_test

import (
"context"
"encoding/json"
"math"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -83,13 +86,7 @@ func TestHandler_Upload(t *testing.T) {
drain(t, timeout, senderCh, receiverCh, errCh)

// Check that the output JSON file has been created.
files, err := os.ReadDir(tempDir)
if err != nil {
t.Fatalf("reading output folder failed: %v", err)
}
if len(files) != 1 {
t.Fatalf("invalid number of files in output folder")
}
waitForArchivalFile(t, tempDir, 2*time.Second)
}

func TestHandler_Download(t *testing.T) {
Expand Down Expand Up @@ -124,19 +121,13 @@ func TestHandler_Download(t *testing.T) {
}

proto := throughput1.New(conn)
timeout, cancel := context.WithTimeout(context.Background(), 1*time.Second)
timeout, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
senderCh, receiverCh, errCh := proto.ReceiverLoop(timeout)
drain(t, timeout, senderCh, receiverCh, errCh)

// Check that the output JSON file has been created.
files, err := os.ReadDir(tempDir)
if err != nil {
t.Fatalf("reading output folder failed: %v", err)
}
if len(files) != 1 {
t.Fatalf("invalid number of files in output folder")
}
waitForArchivalFile(t, tempDir, 2*time.Second)
}

func TestHandler_DownloadInvalidCC(t *testing.T) {
Expand Down Expand Up @@ -169,6 +160,29 @@ func TestHandler_DownloadInvalidCC(t *testing.T) {
}
}

// waitForArchivalFile polls until at least one JSON file appears in the
// directory tree, or the timeout is exceeded. The drain loop in the handler
// delays the deferred writeResult, so we need to poll.
func waitForArchivalFile(t *testing.T, dir string, timeout time.Duration) string {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
var found string
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err == nil && !info.IsDir() && filepath.Ext(path) == ".json" {
found = path
}
return nil
})
if found != "" {
return found
}
time.Sleep(50 * time.Millisecond)
}
t.Fatalf("no archival JSON file found in %s within %v", dir, timeout)
return ""
}

// Utility function to drain sender/receiver channels in tests.
func drain(t *testing.T, timeout context.Context, senderCh,
receiverCh <-chan model.WireMeasurement, errCh <-chan error) {
Expand All @@ -189,6 +203,68 @@ func drain(t *testing.T, timeout context.Context, senderCh,
}
}

func TestHandler_DownloadFinalMeasurement(t *testing.T) {
tempDir := t.TempDir()
h := handler.New(tempDir)

server := setupTestServer(tempDir, http.HandlerFunc(h.Download))
server.Start()
defer server.Close()

u, err := url.Parse(server.URL)
rtx.Must(err, "cannot get server URL")
u.Scheme = "ws"
q := u.Query()
q.Add("mid", "test-mid")
q.Add("streams", "1")
q.Add("duration", "500")
u.RawQuery = q.Encode()

dialer := setupTestWSDialer(u)

headers := http.Header{}
headers.Add("Sec-WebSocket-Protocol", spec.SecWebSocketProtocol)

conn, _, err := dialer.Dial(u.String(), headers)
if err != nil {
t.Fatalf("websocket dial failed: %v", err)
}
proto := throughput1.New(conn)
timeout, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
senderCh, receiverCh, errCh := proto.ReceiverLoop(timeout)
drain(t, timeout, senderCh, receiverCh, errCh)

// Wait for the archival JSON file to be written.
jsonFile := waitForArchivalFile(t, tempDir, 2*time.Second)

data, err := os.ReadFile(jsonFile)
if err != nil {
t.Fatalf("failed to read archival file: %v", err)
}

var result model.Throughput1Result
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal archival data: %v", err)
}

if len(result.ServerMeasurements) == 0 {
t.Fatalf("expected at least one server measurement")
}

// The last server measurement's ElapsedTime should be close to the
// requested duration (500ms = 500_000 microseconds). Allow 100ms
// tolerance.
last := result.ServerMeasurements[len(result.ServerMeasurements)-1]
requestedDurationUs := int64(500_000) // 500ms in microseconds
toleranceUs := int64(100_000) // 100ms
diff := int64(math.Abs(float64(last.ElapsedTime - requestedDurationUs)))
if diff > toleranceUs {
t.Errorf("last ServerMeasurement.ElapsedTime = %d us, want within %d us of %d us (diff = %d us)",
last.ElapsedTime, toleranceUs, requestedDurationUs, diff)
}
}

func TestHandler_Validation(t *testing.T) {
// This string exceeds the maximum metadata key length.
longKey := strings.Repeat("longkey", 10)
Expand Down
Loading