diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 9e24082..36d1f4b 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -78,6 +78,12 @@ var ( ) ) +// channelDrainTimeout is the maximum time to wait for in-flight +// measurements after the test duration expires. The sender goroutine +// typically produces the final measurement within microseconds of +// ctx.Done(), so this is a generous upper bound. +const channelDrainTimeout = 500 * time.Millisecond + type Handler struct { archivalDataDir string } @@ -237,7 +243,9 @@ func (h *Handler) upgradeAndRunMeasurement(kind model.TestDirection, rw http.Res ClientOptions: clientOptions, } defer func() { - archivalData.EndTime = time.Now() + if archivalData.EndTime.IsZero() { + archivalData.EndTime = time.Now() + } h.writeResult(uuid, kind, &archivalData) }() @@ -258,7 +266,37 @@ func (h *Handler) upgradeAndRunMeasurement(kind model.TestDirection, rw http.Res for { select { case <-timeout.Done(): - // If the test has timed out count it as a success and return. + // Record EndTime before draining so it reflects the + // actual measurement period, not the drain wait. + archivalData.EndTime = time.Now() + // The test has timed out. Before returning, drain any + // remaining measurements from the sender/receiver + // goroutines. The sender sends a final Measure() on + // ctx.Done() which may still be in-flight. + drainTimer := time.NewTimer(channelDrainTimeout) + defer drainTimer.Stop() + for draining := true; draining; { + select { + case m := <-senderCh: + if kind == model.DirectionDownload && m.CC != "" { + archivalData.CCAlgorithm = m.CC + } + archivalData.ServerMeasurements = append( + archivalData.ServerMeasurements, m.Measurement) + // Reset to a short timeout: the final measurement + // has arrived, just check for any stragglers. + drainTimer.Reset(5 * time.Millisecond) + case m := <-receiverCh: + if kind == model.DirectionUpload && m.CC != "" { + archivalData.CCAlgorithm = m.CC + } + archivalData.ClientMeasurements = append( + archivalData.ClientMeasurements, m.Measurement) + drainTimer.Reset(5 * time.Millisecond) + case <-drainTimer.C: + draining = false + } + } testsTotal.WithLabelValues(string(kind), "ok-timeout").Inc() return case m := <-senderCh: diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index f0ae4c0..150f71a 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -2,11 +2,14 @@ package handler_test import ( "context" + "encoding/json" + "math" "net" "net/http" "net/http/httptest" "net/url" "os" + "path/filepath" "strings" "testing" "time" @@ -83,13 +86,7 @@ func TestHandler_Upload(t *testing.T) { drain(t, timeout, senderCh, receiverCh, errCh) // Check that the output JSON file has been created. - files, err := os.ReadDir(tempDir) - if err != nil { - t.Fatalf("reading output folder failed: %v", err) - } - if len(files) != 1 { - t.Fatalf("invalid number of files in output folder") - } + waitForArchivalFile(t, tempDir, 2*time.Second) } func TestHandler_Download(t *testing.T) { @@ -124,19 +121,13 @@ func TestHandler_Download(t *testing.T) { } proto := throughput1.New(conn) - timeout, cancel := context.WithTimeout(context.Background(), 1*time.Second) + timeout, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() senderCh, receiverCh, errCh := proto.ReceiverLoop(timeout) drain(t, timeout, senderCh, receiverCh, errCh) // Check that the output JSON file has been created. - files, err := os.ReadDir(tempDir) - if err != nil { - t.Fatalf("reading output folder failed: %v", err) - } - if len(files) != 1 { - t.Fatalf("invalid number of files in output folder") - } + waitForArchivalFile(t, tempDir, 2*time.Second) } func TestHandler_DownloadInvalidCC(t *testing.T) { @@ -169,6 +160,29 @@ func TestHandler_DownloadInvalidCC(t *testing.T) { } } +// waitForArchivalFile polls until at least one JSON file appears in the +// directory tree, or the timeout is exceeded. The drain loop in the handler +// delays the deferred writeResult, so we need to poll. +func waitForArchivalFile(t *testing.T, dir string, timeout time.Duration) string { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + var found string + filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err == nil && !info.IsDir() && filepath.Ext(path) == ".json" { + found = path + } + return nil + }) + if found != "" { + return found + } + time.Sleep(50 * time.Millisecond) + } + t.Fatalf("no archival JSON file found in %s within %v", dir, timeout) + return "" +} + // Utility function to drain sender/receiver channels in tests. func drain(t *testing.T, timeout context.Context, senderCh, receiverCh <-chan model.WireMeasurement, errCh <-chan error) { @@ -189,6 +203,68 @@ func drain(t *testing.T, timeout context.Context, senderCh, } } +func TestHandler_DownloadFinalMeasurement(t *testing.T) { + tempDir := t.TempDir() + h := handler.New(tempDir) + + server := setupTestServer(tempDir, http.HandlerFunc(h.Download)) + server.Start() + defer server.Close() + + u, err := url.Parse(server.URL) + rtx.Must(err, "cannot get server URL") + u.Scheme = "ws" + q := u.Query() + q.Add("mid", "test-mid") + q.Add("streams", "1") + q.Add("duration", "500") + u.RawQuery = q.Encode() + + dialer := setupTestWSDialer(u) + + headers := http.Header{} + headers.Add("Sec-WebSocket-Protocol", spec.SecWebSocketProtocol) + + conn, _, err := dialer.Dial(u.String(), headers) + if err != nil { + t.Fatalf("websocket dial failed: %v", err) + } + proto := throughput1.New(conn) + timeout, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + senderCh, receiverCh, errCh := proto.ReceiverLoop(timeout) + drain(t, timeout, senderCh, receiverCh, errCh) + + // Wait for the archival JSON file to be written. + jsonFile := waitForArchivalFile(t, tempDir, 2*time.Second) + + data, err := os.ReadFile(jsonFile) + if err != nil { + t.Fatalf("failed to read archival file: %v", err) + } + + var result model.Throughput1Result + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("failed to unmarshal archival data: %v", err) + } + + if len(result.ServerMeasurements) == 0 { + t.Fatalf("expected at least one server measurement") + } + + // The last server measurement's ElapsedTime should be close to the + // requested duration (500ms = 500_000 microseconds). Allow 100ms + // tolerance. + last := result.ServerMeasurements[len(result.ServerMeasurements)-1] + requestedDurationUs := int64(500_000) // 500ms in microseconds + toleranceUs := int64(100_000) // 100ms + diff := int64(math.Abs(float64(last.ElapsedTime - requestedDurationUs))) + if diff > toleranceUs { + t.Errorf("last ServerMeasurement.ElapsedTime = %d us, want within %d us of %d us (diff = %d us)", + last.ElapsedTime, toleranceUs, requestedDurationUs, diff) + } +} + func TestHandler_Validation(t *testing.T) { // This string exceeds the maximum metadata key length. longKey := strings.Repeat("longkey", 10)