From 234f5f7ae091a3beef8e1974c282244d1b3332b1 Mon Sep 17 00:00:00 2001 From: joaner Date: Tue, 9 Jun 2026 20:31:53 +0800 Subject: [PATCH 1/2] Support LeRobot-compatible tasks.parquet and v3 streaming video staging. Write task descriptions as pandas index metadata and extend staging/merge paths needed by the updated mcap2lerobot converter. --- internal/buffer/episode.go | 27 +- internal/buffer/episode_test.go | 76 +++ internal/manifest/episode.go | 15 + internal/manifest/episode_test.go | 23 + internal/parquetx/append_data.go | 36 +- internal/parquetx/meta_v30.go | 7 +- internal/parquetx/tasks_test.go | 21 +- internal/parquetx/writer.go | 8 + internal/stats/stats.go | 79 ++- internal/stats/stats_test.go | 17 +- internal/v21/merge.go | 2 +- internal/v21/staging.go | 344 ++++++++++-- internal/v30/integrity.go | 62 +++ internal/v30/merge.go | 149 ++++- internal/v30/merge_test.go | 53 ++ internal/v30/staging.go | 643 ++++++++++++++++++++-- internal/v30/staging_h264_remux_test.go | 61 ++ internal/v30/staging_integrity_test.go | 59 ++ internal/v30/staging_streaming_test.go | 119 ++++ internal/video/encode.go | 16 + internal/video/frame.go | 59 ++ internal/video/frame_test.go | 25 + internal/video/h264_remux_encoder.go | 137 +++++ internal/video/h264_remux_encoder_test.go | 34 ++ internal/video/info.go | 53 +- internal/video/info_test.go | 48 +- internal/video/raw_encoder.go | 214 +++++++ lerobot/config.go | 1 + lerobot/video_frame.go | 12 + lerobot/writer.go | 21 +- mcap2lerobot/convert.go | 30 + mcap2lerobot/doc.go | 5 + 32 files changed, 2284 insertions(+), 172 deletions(-) create mode 100644 internal/buffer/episode_test.go create mode 100644 internal/manifest/episode_test.go create mode 100644 internal/v30/integrity.go create mode 100644 internal/v30/staging_h264_remux_test.go create mode 100644 internal/v30/staging_integrity_test.go create mode 100644 internal/v30/staging_streaming_test.go create mode 100644 internal/video/frame.go create mode 100644 internal/video/frame_test.go create mode 100644 internal/video/h264_remux_encoder.go create mode 100644 internal/video/h264_remux_encoder_test.go create mode 100644 internal/video/raw_encoder.go create mode 100644 lerobot/video_frame.go create mode 100644 mcap2lerobot/convert.go create mode 100644 mcap2lerobot/doc.go diff --git a/internal/buffer/episode.go b/internal/buffer/episode.go index beb6595..bbd3e84 100644 --- a/internal/buffer/episode.go +++ b/internal/buffer/episode.go @@ -72,6 +72,25 @@ func (b *EpisodeBuffer) Columns(globalIndex int64, taskIndices []int64) map[stri return out } +func (b *EpisodeBuffer) ColumnsWithFrameStart(globalIndex, frameStart int64, taskIndices []int64) map[string]any { + out := b.Columns(globalIndex, taskIndices) + frameIdx := make([]int64, b.size) + timestamps := make([]float32, b.size) + for i := range frameIdx { + frameIdx[i] = frameStart + int64(i) + timestamps[i] = float32(frameStart+int64(i)) / float32(b.FPS) + } + out["frame_index"] = frameIdx + out["timestamp"] = timestamps + return out +} + +func (b *EpisodeBuffer) Reset() { + b.columns = make(map[string]any) + b.tasks = nil + b.size = 0 +} + func isScalarShape(shape []int) bool { return len(shape) == 0 || (len(shape) == 1 && shape[0] == 1) } @@ -143,21 +162,21 @@ func appendFloat32Row(b *EpisodeBuffer, key string, row []float32) { if _, ok := b.columns[key]; !ok { b.columns[key] = [][]float32{} } - b.columns[key] = append(b.columns[key].([][]float32), row) + b.columns[key] = append(b.columns[key].([][]float32), append([]float32(nil), row...)) } func appendFloat64Row(b *EpisodeBuffer, key string, row []float64) { if _, ok := b.columns[key]; !ok { b.columns[key] = [][]float64{} } - b.columns[key] = append(b.columns[key].([][]float64), row) + b.columns[key] = append(b.columns[key].([][]float64), append([]float64(nil), row...)) } func appendInt64Row(b *EpisodeBuffer, key string, row []int64) { if _, ok := b.columns[key]; !ok { b.columns[key] = [][]int64{} } - b.columns[key] = append(b.columns[key].([][]int64), row) + b.columns[key] = append(b.columns[key].([][]int64), append([]int64(nil), row...)) } func appendFloat64(b *EpisodeBuffer, key string, val float64) { @@ -170,7 +189,7 @@ func appendFloat64(b *EpisodeBuffer, key string, val float64) { func toFloat32Row(val any) ([]float32, bool) { switch v := val.(type) { case []float32: - return v, true + return append([]float32(nil), v...), true case []float64: out := make([]float32, len(v)) for i, x := range v { diff --git a/internal/buffer/episode_test.go b/internal/buffer/episode_test.go new file mode 100644 index 0000000..00b2d3a --- /dev/null +++ b/internal/buffer/episode_test.go @@ -0,0 +1,76 @@ +package buffer + +import ( + "math" + "testing" + + "github.com/ioai-tech/lerobot-go/internal/meta" +) + +func TestAppendFloat64RowDeepCopy(t *testing.T) { + features := map[string]meta.FeatureSpec{ + "observation.state": {DType: "float64", Shape: []int{3}}, + } + b := New(0, 30, features) + shared := []float64{1, 2, 3} + for i := 0; i < 3; i++ { + if err := b.AddFrame(map[string]any{ + "task": "pick", + "observation.state": shared, + }); err != nil { + t.Fatal(err) + } + for j := range shared { + shared[j] = float64(i*10 + j + 100) + } + } + rows := b.columns["observation.state"].([][]float64) + if len(rows) != 3 { + t.Fatalf("rows=%d want 3", len(rows)) + } + want := [][]float64{ + {1, 2, 3}, + {100, 101, 102}, + {110, 111, 112}, + } + for i, row := range rows { + for j, v := range row { + if v != want[i][j] { + t.Fatalf("row[%d][%d]=%v want %v (slice alias corrupted stored rows)", i, j, v, want[i][j]) + } + } + } +} + +func TestObservationStateValuesStayFiniteThroughParquetPath(t *testing.T) { + features := map[string]meta.FeatureSpec{ + "observation.state": {DType: "float64", Shape: []int{4}}, + "action": {DType: "float64", Shape: []int{4}}, + } + b := New(0, 30, features) + reuse := make([]float64, 4) + for i := 0; i < 20; i++ { + for j := range reuse { + reuse[j] = float64(i)*0.1 + float64(j)*0.01 + } + if err := b.AddFrame(map[string]any{ + "task": "move", + "observation.state": reuse, + "action": reuse, + }); err != nil { + t.Fatal(err) + } + } + stateRows := b.columns["observation.state"].([][]float64) + for i, row := range stateRows { + for j, v := range row { + if math.IsNaN(v) || math.IsInf(v, 0) || math.Abs(v) > 100 { + t.Fatalf("state[%d][%d]=%v out of sane joint range", i, j, v) + } + want := float64(i)*0.1 + float64(j)*0.01 + if v != want { + t.Fatalf("state[%d][%d]=%v want %v", i, j, v, want) + } + } + } +} diff --git a/internal/manifest/episode.go b/internal/manifest/episode.go index c0291fa..e86fd86 100644 --- a/internal/manifest/episode.go +++ b/internal/manifest/episode.go @@ -24,6 +24,7 @@ type Episode struct { const FileName = "episode_meta.json" func Write(dir string, ep Episode) error { + ep.Stats = stats.SanitizeEpisodeStats(ep.Stats) data, err := json.MarshalIndent(ep, "", " ") if err != nil { return err @@ -51,6 +52,20 @@ func stagingName(episodeIndex int) string { return fmt.Sprintf("ep_%06d", episodeIndex) } +// StagingMediaPath resolves a path stored in episode metadata relative to a staging dir. +// Legacy manifests may store absolute paths; those are returned unchanged. +func StagingMediaPath(stagingDir, rel string) string { + if filepath.IsAbs(rel) { + return rel + } + return filepath.Join(stagingDir, rel) +} + +// StagingVideoRel returns the episode-relative path for a staged per-episode MP4. +func StagingVideoRel(videoKey string) string { + return filepath.Join("videos", filepath.Base(videoKey)+".mp4") +} + // ListStagingEpisodes returns completed staging episode dirs sorted by episode_index. func ListStagingEpisodes(root string) ([]string, error) { entries, err := os.ReadDir(root) diff --git a/internal/manifest/episode_test.go b/internal/manifest/episode_test.go new file mode 100644 index 0000000..5cac256 --- /dev/null +++ b/internal/manifest/episode_test.go @@ -0,0 +1,23 @@ +package manifest + +import ( + "path/filepath" + "testing" +) + +func TestStagingMediaPath(t *testing.T) { + dir := "/tmp/ep_000000" + rel := StagingVideoRel("observation.images.cam") + if rel != filepath.Join("videos", "observation.images.cam.mp4") { + t.Fatalf("rel=%q", rel) + } + got := StagingMediaPath(dir, rel) + want := filepath.Join(dir, rel) + if got != want { + t.Fatalf("got %q want %q", got, want) + } + abs := "/tmp/ep_000000/videos/foo.mp4" + if StagingMediaPath(dir, abs) != abs { + t.Fatalf("absolute path should pass through") + } +} diff --git a/internal/parquetx/append_data.go b/internal/parquetx/append_data.go index 54d9834..54f7ed4 100644 --- a/internal/parquetx/append_data.go +++ b/internal/parquetx/append_data.go @@ -51,25 +51,33 @@ func WriteEpisodeBatch(ctx context.Context, dst string, entries []EpisodeBatchEn return fmt.Errorf("no episode parquet entries") } alloc := memory.NewGoAllocator() - tables := make([]arrow.Table, 0, len(entries)) - defer func() { - for _, tbl := range tables { - tbl.Release() - } - }() - for _, entry := range entries { + first, err := RewriteEpisodeParquet(ctx, entries[0].SourcePath, entries[0].Options, alloc) + if err != nil { + return err + } + defer first.Release() + writer, err := NewAppendWriter(dst, first.Schema()) + if err != nil { + return err + } + if err := writer.WriteTable(first, 1024); err != nil { + _ = writer.Close() + return err + } + for _, entry := range entries[1:] { tbl, err := RewriteEpisodeParquet(ctx, entry.SourcePath, entry.Options, alloc) if err != nil { + _ = writer.Close() return err } - tables = append(tables, tbl) - } - merged, err := ConcatTables(alloc, tables) - if err != nil { - return err + if err := writer.WriteTable(tbl, 1024); err != nil { + tbl.Release() + _ = writer.Close() + return err + } + tbl.Release() } - defer merged.Release() - return WriteTable(dst, merged, alloc) + return writer.Close() } func RewriteEpisodeParquet(ctx context.Context, src string, opts AppendEpisodeOptions, alloc memory.Allocator) (arrow.Table, error) { diff --git a/internal/parquetx/meta_v30.go b/internal/parquetx/meta_v30.go index bd643f1..7acdcd5 100644 --- a/internal/parquetx/meta_v30.go +++ b/internal/parquetx/meta_v30.go @@ -13,6 +13,7 @@ import ( ) // WriteTasksParquet writes meta/tasks.parquet compatible with lerobot load_tasks(). +// Task strings are stored as the pandas index column named "task". func WriteTasksParquet(root string, taskMap map[string]int) error { if len(taskMap) == 0 { return nil @@ -41,10 +42,14 @@ func WriteTasksParquet(root string, taskMap map[string]int) error { defer taskArr.Release() defer idxArr.Release() + pandasMD := arrow.NewMetadata( + []string{"pandas"}, + []string{`{"index_columns": ["task"], "column_indexes": [{"name": null, "field_name": null, "pandas_type": "unicode", "numpy_type": "object", "metadata": {"encoding": "UTF-8"}}], "columns": [{"name": "task_index", "field_name": "task_index", "pandas_type": "int64", "numpy_type": "int64", "metadata": null}, {"name": "task", "field_name": "task", "pandas_type": "unicode", "numpy_type": "object", "metadata": null}], "attributes": {}, "pandas_version": "2.3.3"}`}, + ) schema := arrow.NewSchema([]arrow.Field{ {Name: "task_index", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, {Name: "task", Type: arrow.BinaryTypes.String, Nullable: true}, - }, nil) + }, &pandasMD) cols := []arrow.Column{ *arrow.NewColumn(schema.Field(0), arrow.NewChunked(schema.Field(0).Type, []arrow.Array{idxArr})), *arrow.NewColumn(schema.Field(1), arrow.NewChunked(schema.Field(1).Type, []arrow.Array{taskArr})), diff --git a/internal/parquetx/tasks_test.go b/internal/parquetx/tasks_test.go index b96de0c..101f755 100644 --- a/internal/parquetx/tasks_test.go +++ b/internal/parquetx/tasks_test.go @@ -2,7 +2,9 @@ package parquetx import ( "context" + "os/exec" "path/filepath" + "strings" "testing" "github.com/apache/arrow-go/v18/arrow" @@ -31,10 +33,9 @@ func TestWriteTasksParquetUsesTaskColumn(t *testing.T) { if len(tbl.Schema().FieldIndices("task")) == 0 { t.Fatal("task column missing") } - if len(tbl.Schema().FieldIndices("__index_level_0__")) != 0 { - t.Fatal("legacy __index_level_0__ column should not be written for new datasets") + if len(tbl.Schema().FieldIndices("task_index")) == 0 { + t.Fatal("task_index column missing") } - loaded, err := ReadTasksParquet(context.Background(), path) if err != nil { t.Fatal(err) @@ -42,6 +43,20 @@ func TestWriteTasksParquetUsesTaskColumn(t *testing.T) { if loaded["pick"] != 0 || loaded["place"] != 1 { t.Fatalf("loaded task map=%v", loaded) } + + out, err := exec.Command("python3", "-c", ` +import pandas as pd, sys +tasks = pd.read_parquet(sys.argv[1]) +tasks.index.name = "task" +assert tasks.iloc[0].name == "pick", tasks.iloc[0].name +assert tasks.iloc[1].name == "place", tasks.iloc[1].name +`, path).CombinedOutput() + if err != nil { + t.Fatalf("pandas round-trip failed: %v\n%s", err, out) + } + if strings.TrimSpace(string(out)) != "" { + t.Fatalf("unexpected python output: %s", out) + } } func TestConcatTablesIgnoresSchemaMetadataDifferences(t *testing.T) { diff --git a/internal/parquetx/writer.go b/internal/parquetx/writer.go index ebfe927..5a1db93 100644 --- a/internal/parquetx/writer.go +++ b/internal/parquetx/writer.go @@ -77,6 +77,14 @@ func (w *AppendWriter) WriteEpisodeColumns(columns map[string]any, length int, f return w.writer.Write(rec) } +func (w *AppendWriter) WriteRecordColumns(columns map[string]any, length int, features map[string]meta.FeatureSpec) error { + return w.WriteEpisodeColumns(columns, length, features) +} + +func (w *AppendWriter) WriteTable(tbl arrow.Table, chunkSize int64) error { + return w.writer.WriteTable(tbl, chunkSize) +} + func (w *AppendWriter) Close() error { if w.writer != nil { err := w.writer.Close() diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 78b77df..e8fe3ee 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -93,13 +93,61 @@ func computeImageFeatureStats(imgs [][][][]uint8, sampleCount int) FeatureStats } func computeVectorStats(vals any, shape []int) FeatureStats { - flat := flattenRows(vals) + flat := sanitizeFlatRows(flattenRows(vals)) if len(flat) == 0 { return FeatureStats{"count": []int64{0}} } return computeFlatStats(flat, len(flat)) } +func finiteFloat(v float64) float64 { + if math.IsNaN(v) || math.IsInf(v, 0) { + return 0 + } + return v +} + +func finiteFloatSlice(in []float64) []float64 { + out := make([]float64, len(in)) + for i, v := range in { + out[i] = finiteFloat(v) + } + return out +} + +func sanitizeFlatRows(flat [][]float64) [][]float64 { + for i, row := range flat { + for j, v := range row { + flat[i][j] = finiteFloat(v) + } + } + return flat +} + +func sanitizeFeatureStats(fs FeatureStats) FeatureStats { + out := make(FeatureStats, len(fs)) + for k, v := range fs { + switch x := v.(type) { + case []float64: + out[k] = finiteFloatSlice(x) + case ImageStat311: + out[k] = ImageStat311FromChannels(finiteFloatSlice(x.Channels())) + default: + out[k] = v + } + } + return out +} + +// SanitizeEpisodeStats replaces non-finite floats so JSON encoding succeeds. +func SanitizeEpisodeStats(ep EpisodeStats) EpisodeStats { + out := make(EpisodeStats, len(ep)) + for key, fs := range ep { + out[key] = sanitizeFeatureStats(fs) + } + return out +} + func computeFlatStats(flat [][]float64, sampleCount int) FeatureStats { dim := len(flat[0]) count := float64(len(flat)) @@ -126,18 +174,20 @@ func computeFlatStats(flat [][]float64, sampleCount int) FeatureStats { mean := make([]float64, dim) std := make([]float64, dim) for i := range mean { - mean[i] = sum[i] / count + mean[i] = finiteFloat(sum[i] / count) variance := sumSq[i]/count - mean[i]*mean[i] - if variance < 0 { + if variance < 0 || math.IsNaN(variance) || math.IsInf(variance, 0) { variance = 0 } - std[i] = math.Sqrt(variance) + std[i] = finiteFloat(math.Sqrt(variance)) + min[i] = finiteFloat(min[i]) + max[i] = finiteFloat(max[i]) } result := FeatureStats{ - "min": append([]float64(nil), min...), - "max": append([]float64(nil), max...), - "mean": append([]float64(nil), mean...), - "std": append([]float64(nil), std...), + "min": finiteFloatSlice(min), + "max": finiteFloatSlice(max), + "mean": finiteFloatSlice(mean), + "std": finiteFloatSlice(std), "count": []int64{int64(sampleCount)}, } if len(flat) < 2 { @@ -378,13 +428,18 @@ func aggregateFeature(parts []FeatureStats) map[string]any { } mean := make([]float64, dim) std := make([]float64, dim) + if totalCount <= 0 { + return sanitizeFeatureStats(parts[0]) + } for j := range mean { - mean[j] = weightedMean[j] / totalCount + mean[j] = finiteFloat(weightedMean[j] / totalCount) variance := weightedVar[j]/totalCount - mean[j]*mean[j] - if variance < 0 { + if variance < 0 || math.IsNaN(variance) || math.IsInf(variance, 0) { variance = 0 } - std[j] = math.Sqrt(variance) + std[j] = finiteFloat(math.Sqrt(variance)) + min[j] = finiteFloat(min[j]) + max[j] = finiteFloat(max[j]) } isImage := isImageStat(parts[0]["mean"]) result := map[string]any{ @@ -417,7 +472,7 @@ func aggregateFeature(parts []FeatureStats) map[string]any { } } for j := range qvals { - qvals[j] /= totalCount + qvals[j] = finiteFloat(qvals[j] / totalCount) } if isImage { result[key] = ImageStat311FromChannels(qvals) diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index 4c49cb7..ed91ae7 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -1,6 +1,10 @@ package stats -import "testing" +import ( + "encoding/json" + "math" + "testing" +) func TestAggregateStats(t *testing.T) { a := EpisodeStats{ @@ -24,3 +28,14 @@ func TestAggregateStats(t *testing.T) { t.Fatalf("count=%v", st["count"]) } } + +func TestComputeVectorStatsSanitizesInf(t *testing.T) { + fs := computeVectorStats([]float64{1, math.Inf(1), math.NaN()}, nil) + data, err := json.Marshal(fs) + if err != nil { + t.Fatalf("marshal stats: %v", err) + } + if len(data) == 0 { + t.Fatal("empty json") + } +} diff --git a/internal/v21/merge.go b/internal/v21/merge.go index 3a025fe..9cea32d 100644 --- a/internal/v21/merge.go +++ b/internal/v21/merge.go @@ -59,7 +59,7 @@ func Merge(ctx context.Context, cfg MergeConfig) error { return err } for videoKey, rel := range ep.Videos { - src := filepath.Join(dir, rel) + src := manifest.StagingMediaPath(dir, rel) dst := filepath.Join(cfg.OutputRoot, meta.V21VideoPathFromInfo(info, videoKey, ep.EpisodeIndex)) if err := copyFile(src, dst); err != nil { return err diff --git a/internal/v21/staging.go b/internal/v21/staging.go index eebbefb..ecc3d96 100644 --- a/internal/v21/staging.go +++ b/internal/v21/staging.go @@ -3,8 +3,10 @@ package v21 import ( "context" "fmt" + "log/slog" "os" "path/filepath" + "sync" "github.com/ioai-tech/lerobot-go/internal/buffer" "github.com/ioai-tech/lerobot-go/internal/features" @@ -27,6 +29,7 @@ type StagingConfig struct { UseVideos bool Stats stats.Options TempRoot string + H264Remux bool } type StagingWriter struct { @@ -35,6 +38,8 @@ type StagingWriter struct { frameStore *tempfs.Store imageBytes map[string][][]byte videoFrameCounts map[string]int + videoEncoders map[string]*video.RawRGBEncoder + pendingH264Remux map[string][][]byte } func NewStagingWriter(cfg StagingConfig) (*StagingWriter, error) { @@ -55,50 +60,159 @@ func NewStagingWriter(cfg StagingConfig) (*StagingWriter, error) { return nil, err } var frameStore *tempfs.Store - if hasVideoFeatures(cfg.Features) { + if hasVideoFeatures(cfg.Features) && !cfg.UseVideos { store, err := tempfs.New(tempfs.Config{EpisodeDir: cfg.Dir, TempRoot: cfg.TempRoot}) if err != nil { return nil, err } frameStore = store } + videoEncoders := make(map[string]*video.RawRGBEncoder) + if cfg.UseVideos && !cfg.H264Remux { + for key, spec := range feats { + if spec.DType == "video" && len(spec.Shape) >= 2 { + h, w := spec.Shape[0], spec.Shape[1] + rel := manifest.StagingVideoRel(key) + out := filepath.Join(cfg.Dir, rel) + if err := os.MkdirAll(filepath.Dir(out), 0o755); err != nil { + for _, e := range videoEncoders { + _ = e.Close() + } + return nil, err + } + enc, err := video.NewRawRGBEncoder(context.Background(), cfg.Locator, cfg.VCodec, cfg.CRF, cfg.FPS, w, h, out) + if err != nil { + for _, e := range videoEncoders { + _ = e.Close() + } + return nil, err + } + videoEncoders[key] = enc + } + } + } return &StagingWriter{ cfg: cfg, buf: buffer.New(cfg.Episode, cfg.FPS, feats), frameStore: frameStore, imageBytes: make(map[string][][]byte), videoFrameCounts: make(map[string]int), + videoEncoders: videoEncoders, }, nil } +func (w *StagingWriter) SetH264Remux(ctx context.Context, tracks map[string][][]byte) error { + _ = ctx + if !w.cfg.H264Remux { + return fmt.Errorf("staging writer not configured for h264 remux") + } + w.pendingH264Remux = tracks + return nil +} + +func (w *StagingWriter) AppendRGBVideoFrame(ctx context.Context, key string, frame video.VideoFrameRGB24) error { + spec, ok := w.buf.Features[key] + if !ok || spec.DType != "video" { + return fmt.Errorf("not a video feature: %q", key) + } + if err := frame.Validate(); err != nil { + return err + } + enc, err := w.ensureRawEncoder(key, spec) + if err != nil { + return err + } + if err := enc.WriteFrame(frame); err != nil { + return err + } + w.videoFrameCounts[key]++ + return nil +} + func (w *StagingWriter) AddFrame(ctx context.Context, frame map[string]any) error { _ = ctx + usedDecodeEncoder := false for key, spec := range w.buf.Features { if spec.DType != "video" && spec.DType != "image" { continue } - if raw, ok := frame[key]; ok { - if png, ok := raw.([]byte); ok { - switch spec.DType { - case "image": - w.imageBytes[key] = append(w.imageBytes[key], append([]byte(nil), png...)) - case "video": - if w.frameStore == nil { - return fmt.Errorf("frame store not initialized for video feature %q", key) - } - frameIndex := w.videoFrameCounts[key] - rel := filepath.Join("images", key, fmt.Sprintf("frame-%06d.png", frameIndex)) - if err := w.frameStore.WritePNG(rel, png); err != nil { - return err - } - w.videoFrameCounts[key] = frameIndex + 1 + val, ok := frame[key] + if !ok { + continue + } + switch spec.DType { + case "image": + if png, ok := val.([]byte); ok { + w.imageBytes[key] = append(w.imageBytes[key], append([]byte(nil), png...)) + } + case "video": + if len(spec.Shape) < 2 { + continue + } + vf, ok, err := video.ParseOptionalVideoFrameRGB24(val, spec.Shape[1], spec.Shape[0]) + if err != nil { + return err + } + if !ok { + continue + } + enc, err := w.ensureRawEncoder(key, spec) + if err != nil { + return err + } + if err := enc.WriteFrame(vf); err != nil { + return err + } + w.videoFrameCounts[key]++ + usedDecodeEncoder = true + continue + if png, ok := val.([]byte); ok { + if w.frameStore == nil { + return fmt.Errorf("frame store not initialized for video feature %q", key) } + frameIndex := w.videoFrameCounts[key] + rel := filepath.Join("images", key, fmt.Sprintf("frame-%06d.png", frameIndex)) + if err := w.frameStore.WritePNG(rel, png); err != nil { + return err + } + w.videoFrameCounts[key] = frameIndex + 1 + } + } + } + if !usedDecodeEncoder && w.cfg.H264Remux && w.cfg.UseVideos { + for key, spec := range w.buf.Features { + if spec.DType == "video" { + w.videoFrameCounts[key]++ } } } return w.buf.AddFrame(frame) } +func (w *StagingWriter) ensureRawEncoder(key string, spec meta.FeatureSpec) (*video.RawRGBEncoder, error) { + if enc := w.videoEncoders[key]; enc != nil { + return enc, nil + } + if len(spec.Shape) < 2 { + return nil, fmt.Errorf("video feature %q missing shape", key) + } + h, width := spec.Shape[0], spec.Shape[1] + rel := manifest.StagingVideoRel(key) + out := filepath.Join(w.cfg.Dir, rel) + if err := os.MkdirAll(filepath.Dir(out), 0o755); err != nil { + return nil, err + } + enc, err := video.NewRawRGBEncoder(context.Background(), w.cfg.Locator, w.cfg.VCodec, w.cfg.CRF, w.cfg.FPS, width, h, out) + if err != nil { + return nil, err + } + if w.videoEncoders == nil { + w.videoEncoders = make(map[string]*video.RawRGBEncoder) + } + w.videoEncoders[key] = enc + return enc, nil +} + func (w *StagingWriter) SaveEpisode(ctx context.Context) (manifest.Episode, error) { if w.buf.Size() == 0 { return manifest.Episode{}, fmt.Errorf("empty episode") @@ -140,32 +254,9 @@ func (w *StagingWriter) SaveEpisode(ctx context.Context) (manifest.Episode, erro FrameBytes: w.imageBytes, }, featureStats, w.cfg.Stats) - videos := map[string]string{} - durations := map[string]float64{} - if w.cfg.UseVideos { - for key, spec := range w.buf.Features { - if spec.DType != "video" { - continue - } - if w.videoFrameCounts[key] == 0 { - continue - } - out := filepath.Join(w.cfg.Dir, "videos", sanitizeKey(key)+".mp4") - pattern := w.frameStore.Pattern(key) - if err := video.EncodeFromPNGDir(ctx, video.EncodeConfig{ - Locator: w.cfg.Locator, - VCodec: w.cfg.VCodec, - CRF: w.cfg.CRF, - FPS: w.cfg.FPS, - PNGPattern: pattern, - OutputPath: out, - }); err != nil { - return manifest.Episode{}, err - } - videos[key] = out - d, _ := video.DurationSeconds(ctx, w.cfg.Locator, out) - durations[key] = d - } + videos, durations, err := w.finalizeVideos(ctx) + if err != nil { + return manifest.Episode{}, err } ep := manifest.Episode{ @@ -184,6 +275,164 @@ func (w *StagingWriter) SaveEpisode(ctx context.Context) (manifest.Episode, erro return ep, nil } +func (w *StagingWriter) videoDurationFromFrames(key string) float64 { + if w.cfg.FPS <= 0 { + return 0 + } + return float64(w.videoFrameCounts[key]) / float64(w.cfg.FPS) +} + +func (w *StagingWriter) finalizeVideos(ctx context.Context) (map[string]string, map[string]float64, error) { + videos := map[string]string{} + durations := map[string]float64{} + if !w.cfg.UseVideos { + return videos, durations, nil + } + if w.cfg.H264Remux { + return w.finalizeH264RemuxVideos(ctx) + } + type encClose struct { + key string + enc *video.RawRGBEncoder + } + var toClose []encClose + for key, spec := range w.buf.Features { + if spec.DType != "video" || w.videoFrameCounts[key] == 0 { + continue + } + if enc := w.videoEncoders[key]; enc != nil { + toClose = append(toClose, encClose{key: key, enc: enc}) + } + } + for _, job := range toClose { + if err := job.enc.Close(); err != nil { + return nil, nil, err + } + } + for key, spec := range w.buf.Features { + if spec.DType != "video" || w.videoFrameCounts[key] == 0 { + continue + } + rel := manifest.StagingVideoRel(key) + out := filepath.Join(w.cfg.Dir, rel) + if enc := w.videoEncoders[key]; enc != nil { + videos[key] = rel + durations[key] = w.videoDurationFromFrames(key) + continue + } + if w.frameStore == nil { + continue + } + if err := video.EncodeFromPNGDir(ctx, video.EncodeConfig{ + Locator: w.cfg.Locator, + VCodec: w.cfg.VCodec, + CRF: w.cfg.CRF, + FPS: w.cfg.FPS, + Threads: video.ResolveEncoderThreads(), + PNGPattern: w.frameStore.Pattern(key), + OutputPath: out, + }); err != nil { + return nil, nil, err + } + videos[key] = rel + if os.Getenv("LEROBOT_FFPROBE_DURATION") == "1" { + d, _ := video.DurationSeconds(ctx, w.cfg.Locator, out) + durations[key] = d + } else { + durations[key] = w.videoDurationFromFrames(key) + } + } + return videos, durations, nil +} + +func (w *StagingWriter) finalizeH264RemuxVideos(ctx context.Context) (map[string]string, map[string]float64, error) { + videos := map[string]string{} + durations := map[string]float64{} + if w.pendingH264Remux == nil { + return videos, durations, nil + } + type remuxJob struct { + key string + aus [][]byte + rel string + out string + } + var jobs []remuxJob + for key, spec := range w.buf.Features { + if spec.DType != "video" { + continue + } + aus, ok := w.pendingH264Remux[key] + if !ok || len(aus) == 0 { + continue + } + rel := manifest.StagingVideoRel(key) + out := filepath.Join(w.cfg.Dir, rel) + if err := os.MkdirAll(filepath.Dir(out), 0o755); err != nil { + return nil, nil, err + } + jobs = append(jobs, remuxJob{key: key, aus: aus, rel: rel, out: out}) + } + if len(jobs) == 0 { + return videos, durations, nil + } + var countMu sync.Mutex + runJob := func(job remuxJob) error { + enc, err := video.NewH264RemuxEncoder(ctx, w.cfg.Locator, w.cfg.FPS, job.out) + if err != nil { + return err + } + if err := enc.WriteAccessUnits(job.aus); err != nil { + _ = enc.Close() + return err + } + if err := enc.Close(); err != nil { + return err + } + expected := len(job.aus) + if nb, err := enc.FrameCount(ctx); err == nil && expected > 0 { + if nb != expected && nb+1 != expected && nb != expected-1 { + slog.Warn("h264 remux frame count drift", "feature", job.key, "ffprobe", nb, "expected", expected) + } + } + countMu.Lock() + w.videoFrameCounts[job.key] = expected + countMu.Unlock() + return nil + } + if len(jobs) == 1 { + if err := runJob(jobs[0]); err != nil { + return nil, nil, err + } + videos[jobs[0].key] = jobs[0].rel + durations[jobs[0].key] = w.videoDurationFromFrames(jobs[0].key) + return videos, durations, nil + } + var wg sync.WaitGroup + errCh := make(chan error, len(jobs)) + for _, job := range jobs { + wg.Add(1) + go func(job remuxJob) { + defer wg.Done() + if err := runJob(job); err != nil { + errCh <- err + } + }(job) + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + return nil, nil, err + } + } + for _, job := range jobs { + videos[job.key] = job.rel + durations[job.key] = w.videoDurationFromFrames(job.key) + } + return videos, durations, nil +} + func buildImageCells(frames [][]byte) []parquetx.ImageCell { cells := make([]parquetx.ImageCell, len(frames)) for i, b := range frames { @@ -196,11 +445,17 @@ func buildImageCells(frames [][]byte) []parquetx.ImageCell { } func (w *StagingWriter) Close() error { + for _, enc := range w.videoEncoders { + _ = enc.Close() + } w.cleanupFrames() return nil } func (w *StagingWriter) videoFramePaths() map[string][]string { + if w.frameStore == nil { + return nil + } out := make(map[string][]string, len(w.videoFrameCounts)) for key, count := range w.videoFrameCounts { paths := make([]string, count) @@ -241,10 +496,6 @@ func indexOf(tasks []string, task string) int { return 0 } -func sanitizeKey(key string) string { - return filepath.Base(key) -} - func statsFeatureMap(features map[string]meta.FeatureSpec) map[string]stats.FeatureDesc { out := make(map[string]stats.FeatureDesc, len(features)) for k, v := range features { @@ -252,3 +503,4 @@ func statsFeatureMap(features map[string]meta.FeatureSpec) map[string]stats.Feat } return out } + diff --git a/internal/v30/integrity.go b/internal/v30/integrity.go new file mode 100644 index 0000000..9cbf7fc --- /dev/null +++ b/internal/v30/integrity.go @@ -0,0 +1,62 @@ +package v30 + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/ioai-tech/lerobot-go/internal/meta" +) + +func requiredVideoKeys(features map[string]meta.FeatureSpec) []string { + var keys []string + for k, f := range features { + if f.DType == "video" { + keys = append(keys, k) + } + } + return keys +} + +// ValidateOutputIntegrity checks merged dataset has data parquet and video files. +func ValidateOutputIntegrity(outputRoot string, features map[string]meta.FeatureSpec) error { + matches, err := filepath.Glob(filepath.Join(outputRoot, "data", "chunk-*", "file-*.parquet")) + if err != nil { + return err + } + if len(matches) == 0 { + return fmt.Errorf("no data parquet files under %s", outputRoot) + } + for _, p := range matches { + fi, err := os.Stat(p) + if err != nil { + return fmt.Errorf("data parquet missing: %w", err) + } + if fi.Size() == 0 { + return fmt.Errorf("data parquet empty: %s", p) + } + } + if !hasVideoFeatures(features) { + return nil + } + for _, vk := range requiredVideoKeys(features) { + pattern := filepath.Join(outputRoot, "videos", vk, "chunk-*", "file-*.mp4") + vm, err := filepath.Glob(pattern) + if err != nil { + return err + } + if len(vm) == 0 { + return fmt.Errorf("no merged video files for feature %q", vk) + } + for _, f := range vm { + fi, err := os.Stat(f) + if err != nil { + return fmt.Errorf("video file missing %s: %w", f, err) + } + if fi.Size() == 0 { + return fmt.Errorf("video file empty: %s", f) + } + } + } + return nil +} diff --git a/internal/v30/merge.go b/internal/v30/merge.go index 17a4c0e..e599765 100644 --- a/internal/v30/merge.go +++ b/internal/v30/merge.go @@ -3,6 +3,10 @@ package v30 import ( "context" "fmt" + "os" + "runtime" + "strconv" + "sync" "path/filepath" @@ -22,6 +26,7 @@ type MergeConfig struct { Features map[string]meta.FeatureSpec Locator video.Locator Stats stats.Options + MaxWorkers int DataFileSizeMB int VideoFileSizeMB int } @@ -111,12 +116,21 @@ func Merge(ctx context.Context, cfg MergeConfig) error { } st.flushDataBatch() st.flushVideoBatches() - if err := st.writeDataBatches(ctx, cfg.OutputRoot); err != nil { + workers := mergeWorkers(cfg) + if err := st.writeDataBatches(ctx, cfg.OutputRoot, workers); err != nil { return err } if err := st.writeVideoBatches(ctx, cfg); err != nil { return err } + if hasVideoFeatures(cfg.Features) { + for _, vk := range requiredVideoKeys(cfg.Features) { + vs := st.videoState[vk] + if vs == nil || len(vs.batches) == 0 { + return fmt.Errorf("merge produced no video batches for feature %q", vk) + } + } + } if err := parquetx.WriteTasksParquet(cfg.OutputRoot, st.taskMap); err != nil { return err } @@ -131,10 +145,29 @@ func Merge(ctx context.Context, cfg MergeConfig) error { if err := meta.UpdateVideoFeaturesInfo(ctx, &st.info, cfg.OutputRoot, cfg.Locator); err != nil { return err } + if err := ValidateOutputIntegrity(cfg.OutputRoot, cfg.Features); err != nil { + return fmt.Errorf("merged output integrity: %w", err) + } return meta.WriteInfo(cfg.OutputRoot, st.info) } func (st *mergeState) ingestEpisode(ctx context.Context, cfg MergeConfig, dir string, ep manifest.Episode) error { + if hasVideoFeatures(cfg.Features) { + for _, vk := range requiredVideoKeys(cfg.Features) { + rel, ok := ep.Videos[vk] + if !ok || rel == "" { + return fmt.Errorf("episode %d missing staged video for feature %q", ep.EpisodeIndex, vk) + } + src := manifest.StagingMediaPath(dir, rel) + fi, err := os.Stat(src) + if err != nil { + return fmt.Errorf("episode %d video %q missing file: %w", ep.EpisodeIndex, vk, err) + } + if fi.Size() == 0 { + return fmt.Errorf("episode %d video %q empty file", ep.EpisodeIndex, vk) + } + } + } srcPQ := filepath.Join(dir, ep.FramesParquet) srcSizeMB, err := meta.ParquetUncompressedSizeMB(srcPQ) if err != nil { @@ -164,7 +197,7 @@ func (st *mergeState) ingestEpisode(ctx context.Context, cfg MergeConfig, dir st } for videoKey, rel := range ep.Videos { vs := st.ensureVideoState(videoKey) - src := filepath.Join(dir, rel) + src := manifest.StagingMediaPath(dir, rel) segSizeMB, err := video.FileSizeMB(src) if err != nil { return err @@ -215,12 +248,29 @@ func (st *mergeState) maybeRotateDataBatch(srcSizeMB float64) { if len(st.dataBatch.entries) == 0 { return } + if streamingMergeEpisodeLimit() > 0 && len(st.dataBatch.entries) >= streamingMergeEpisodeLimit() { + st.flushDataBatch() + st.advanceDataBatch() + return + } if st.dataBatch.sizeMB+srcSizeMB >= float64(st.info.DataFilesSizeInMB) { st.flushDataBatch() st.advanceDataBatch() } } +func streamingMergeEpisodeLimit() int { + v := os.Getenv("LEROBOT_MERGE_EPISODES_PER_FILE") + if v == "" { + return 0 + } + n, err := strconv.Atoi(v) + if err != nil || n < 0 { + return 0 + } + return n +} + func (st *mergeState) flushDataBatch() { if len(st.dataBatch.entries) == 0 { return @@ -278,28 +328,113 @@ func (st *mergeState) flushVideoBatches() { } } -func (st *mergeState) writeDataBatches(ctx context.Context, outputRoot string) error { +func mergeWorkers(cfg MergeConfig) int { + if cfg.MaxWorkers > 0 { + return cfg.MaxWorkers + } + return max(1, runtime.NumCPU()-2) +} + +func (st *mergeState) writeDataBatches(ctx context.Context, outputRoot string, workers int) error { + if len(st.dataBatches) == 0 { + return nil + } + if workers <= 1 || len(st.dataBatches) == 1 { + for _, batch := range st.dataBatches { + dst := filepath.Join(outputRoot, meta.DataPath(batch.chunk, batch.file)) + if err := parquetx.WriteEpisodeBatch(ctx, dst, batch.entries); err != nil { + return err + } + } + return nil + } + sem := make(chan struct{}, workers) + var wg sync.WaitGroup + errCh := make(chan error, len(st.dataBatches)) for _, batch := range st.dataBatches { - dst := filepath.Join(outputRoot, meta.DataPath(batch.chunk, batch.file)) - if err := parquetx.WriteEpisodeBatch(ctx, dst, batch.entries); err != nil { + batch := batch + wg.Add(1) + go func() { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + dst := filepath.Join(outputRoot, meta.DataPath(batch.chunk, batch.file)) + if err := parquetx.WriteEpisodeBatch(ctx, dst, batch.entries); err != nil { + errCh <- err + } + }() + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { return err } } return nil } +type videoBatchJob struct { + locator video.Locator + dst string + segments []string +} + func (st *mergeState) writeVideoBatches(ctx context.Context, cfg MergeConfig) error { + workers := mergeWorkers(cfg) + var jobs []videoBatchJob for _, state := range st.videoState { for _, batch := range state.batches { - dst := filepath.Join(cfg.OutputRoot, meta.VideoPath(batch.key, batch.chunk, batch.file)) - if err := video.SafeConcat(ctx, cfg.Locator, batch.segmentPath, dst, true); err != nil { + jobs = append(jobs, videoBatchJob{ + locator: cfg.Locator, + dst: filepath.Join(cfg.OutputRoot, meta.VideoPath(batch.key, batch.chunk, batch.file)), + segments: batch.segmentPath, + }) + } + } + if len(jobs) == 0 { + return nil + } + if workers <= 1 || len(jobs) == 1 { + for _, job := range jobs { + if err := video.SafeConcat(ctx, job.locator, job.segments, job.dst, true); err != nil { return err } } + return nil + } + sem := make(chan struct{}, workers) + var wg sync.WaitGroup + errCh := make(chan error, len(jobs)) + for _, job := range jobs { + job := job + wg.Add(1) + go func() { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + if err := video.SafeConcat(ctx, job.locator, job.segments, job.dst, true); err != nil { + errCh <- err + } + }() + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + return err + } } return nil } +func max(a, b int) int { + if a > b { + return a + } + return b +} + func hasVideoFeatures(features map[string]meta.FeatureSpec) bool { for _, f := range features { if f.DType == "video" { diff --git a/internal/v30/merge_test.go b/internal/v30/merge_test.go index fabaa5d..bdc58b7 100644 --- a/internal/v30/merge_test.go +++ b/internal/v30/merge_test.go @@ -126,6 +126,59 @@ func TestMergeMultiEpisodeParquetAppend(t *testing.T) { } } +func TestMergeEpisodePerFileLimit(t *testing.T) { + t.Setenv("LEROBOT_MERGE_EPISODES_PER_FILE", "1") + dir := t.TempDir() + features := map[string]meta.FeatureSpec{ + "observation.state": {DType: "float32", Shape: []int{2}}, + "action": {DType: "float32", Shape: []int{2}}, + } + ctx := context.Background() + for epIdx, task := range []string{"pick", "place"} { + staging := filepath.Join(dir, fmt.Sprintf("ep_%06d", epIdx)) + w, err := v30.NewStagingWriter(v30.StagingConfig{ + Dir: staging, Episode: epIdx, FPS: 10, Features: features, UseVideos: false, + }) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 2; i++ { + if err := w.AddFrame(ctx, map[string]any{ + "task": task, + "observation.state": []float32{float32(i), float32(epIdx)}, + "action": []float32{0.1, 0.2}, + }); err != nil { + t.Fatal(err) + } + } + if _, err := w.SaveEpisode(ctx); err != nil { + t.Fatal(err) + } + } + out := filepath.Join(dir, "dataset") + if err := v30.Merge(ctx, v30.MergeConfig{ + StagingRoot: dir, + OutputRoot: out, + FPS: 10, + Features: features, + }); err != nil { + t.Fatal(err) + } + if rows, err := parquetx.TableNumRows(filepath.Join(out, meta.DataPath(0, 0))); err != nil || rows != 2 { + t.Fatalf("file 0 rows=%d err=%v", rows, err) + } + if rows, err := parquetx.TableNumRows(filepath.Join(out, meta.DataPath(0, 1))); err != nil || rows != 2 { + t.Fatalf("file 1 rows=%d err=%v", rows, err) + } + episodes, err := parquetx.ReadEpisodesMeta(out) + if err != nil { + t.Fatal(err) + } + if episodes[1].DataFileIndex != 1 { + t.Fatalf("episode1 data file=%d want 1", episodes[1].DataFileIndex) + } +} + func TestMergeRotatesDataFilesBeforeAssigningEpisodeMetadata(t *testing.T) { dir := t.TempDir() features := map[string]meta.FeatureSpec{ diff --git a/internal/v30/staging.go b/internal/v30/staging.go index 99481ed..dbfc380 100644 --- a/internal/v30/staging.go +++ b/internal/v30/staging.go @@ -3,8 +3,11 @@ package v30 import ( "context" "fmt" + "log/slog" "os" "path/filepath" + "strconv" + "sync" "github.com/ioai-tech/lerobot-go/internal/buffer" "github.com/ioai-tech/lerobot-go/internal/features" @@ -28,6 +31,7 @@ type StagingConfig struct { Streaming bool Stats stats.Options TempRoot string + H264Remux bool } type StagingWriter struct { @@ -37,6 +41,20 @@ type StagingWriter struct { imageBytes map[string][][]byte videoFrameCounts map[string]int streamFiles map[string]string + pqWriter *parquetx.AppendWriter + tasks []string + totalFrames int + flushSize int + chunkStats []stats.EpisodeStats + + // videoEncoders holds per-feature raw RGB streaming encoders. + // When present (for UseVideos + video dtype), frames are fed directly + // to ffmpeg via pipe instead of writing per-frame PNGs to tempfs. + // This is the key change to eliminate the O(episode length * cameras) PNG + // accumulation that caused 20GB+ tmpfs usage. + videoEncoders map[string]*video.RawRGBEncoder + + pendingH264Remux map[string][][]byte } func NewStagingWriter(cfg StagingConfig) (*StagingWriter, error) { @@ -64,6 +82,28 @@ func NewStagingWriter(cfg StagingConfig) (*StagingWriter, error) { } frameStore = store } + videoEncoders := make(map[string]*video.RawRGBEncoder) + if cfg.UseVideos && !cfg.H264Remux { + for key, spec := range feats { + if spec.DType == "video" && len(spec.Shape) >= 2 { + h, w := spec.Shape[0], spec.Shape[1] + out := filepath.Join(cfg.Dir, "videos", filepath.Base(key)+".mp4") + if err := os.MkdirAll(filepath.Dir(out), 0o755); err != nil { + return nil, err + } + enc, err := video.NewRawRGBEncoder(context.Background(), cfg.Locator, cfg.VCodec, cfg.CRF, cfg.FPS, w, h, out) + if err != nil { + // cleanup any encoders started so far + for _, e := range videoEncoders { + _ = e.Close() + } + return nil, err + } + videoEncoders[key] = enc + } + } + } + return &StagingWriter{ cfg: cfg, buf: buffer.New(cfg.Episode, cfg.FPS, feats), @@ -71,41 +111,105 @@ func NewStagingWriter(cfg StagingConfig) (*StagingWriter, error) { imageBytes: make(map[string][][]byte), videoFrameCounts: make(map[string]int), streamFiles: make(map[string]string), + flushSize: resolveFlushSize(), + videoEncoders: videoEncoders, }, nil } +type videoJob struct { + key string + frame video.VideoFrameRGB24 +} + +func (w *StagingWriter) SetH264Remux(ctx context.Context, tracks map[string][][]byte) error { + _ = ctx + if !w.cfg.H264Remux { + return fmt.Errorf("staging writer not configured for h264 remux") + } + w.pendingH264Remux = tracks + return nil +} + func (w *StagingWriter) AddFrame(ctx context.Context, frame map[string]any) error { _ = ctx + var videoJobs []videoJob for key, spec := range w.buf.Features { if spec.DType != "video" && spec.DType != "image" { continue } - if raw, ok := frame[key]; ok { - if png, ok := raw.([]byte); ok { - switch spec.DType { - case "image": - w.imageBytes[key] = append(w.imageBytes[key], append([]byte(nil), png...)) - case "video": - if w.frameStore == nil { - return fmt.Errorf("frame store not initialized for video feature %q", key) - } - frameIndex := w.videoFrameCounts[key] - rel := filepath.Join("images", key, fmt.Sprintf("frame-%06d.png", frameIndex)) - if err := w.frameStore.WritePNG(rel, png); err != nil { - return err - } - w.videoFrameCounts[key] = frameIndex + 1 + val, ok := frame[key] + if !ok { + continue + } + switch spec.DType { + case "image": + if png, ok := val.([]byte); ok { + w.imageBytes[key] = append(w.imageBytes[key], append([]byte(nil), png...)) + } + case "video": + if spec.Shape == nil || len(spec.Shape) < 2 { + continue + } + vf, ok, err := video.ParseOptionalVideoFrameRGB24(val, spec.Shape[1], spec.Shape[0]) + if err != nil { + return err + } + if !ok { + continue + } + if _, err := w.ensureRawEncoder(key, spec); err != nil { + return err + } + videoJobs = append(videoJobs, videoJob{key: key, frame: vf}) + continue + if png, ok := val.([]byte); ok { + if w.frameStore == nil { + return fmt.Errorf("frame store not initialized for video feature %q", key) } + frameIndex := w.videoFrameCounts[key] + rel := filepath.Join("images", key, fmt.Sprintf("frame-%06d.png", frameIndex)) + if err := w.frameStore.WritePNG(rel, png); err != nil { + return err + } + w.videoFrameCounts[key] = frameIndex + 1 } } } - return w.buf.AddFrame(frame) + if len(videoJobs) > 0 { + if err := w.writeVideoJobsParallel(videoJobs); err != nil { + return err + } + } + if err := w.buf.AddFrame(frame); err != nil { + return err + } + task, _ := frame["task"].(string) + if task == "" { + task, _ = frame["__task__"].(string) + } + w.tasks = append(w.tasks, task) + if w.streamingEnabled() && w.buf.Size() >= w.flushSize { + return w.flushBuffered(ctx) + } + return nil } func (w *StagingWriter) SaveEpisode(ctx context.Context) (manifest.Episode, error) { - if w.buf.Size() == 0 { + if w.totalFrames+w.buf.Size() == 0 { return manifest.Episode{}, fmt.Errorf("empty episode") } + if w.streamingEnabled() { + if err := w.flushBuffered(ctx); err != nil { + return manifest.Episode{}, err + } + if w.pqWriter != nil { + if err := w.pqWriter.Close(); err != nil { + return manifest.Episode{}, err + } + w.pqWriter = nil + } + return w.saveStreamingEpisode(ctx) + } schema, err := parquetx.BuildArrowSchema(w.buf.Features) if err != nil { return manifest.Episode{}, err @@ -143,38 +247,438 @@ func (w *StagingWriter) SaveEpisode(ctx context.Context) (manifest.Episode, erro FrameBytes: w.imageBytes, }, featureStats, w.cfg.Stats) + if err := w.validateVideoCoverage(w.buf.Size()); err != nil { + return manifest.Episode{}, err + } + videos, durations, err := w.finalizeVideos(ctx) + if err != nil { + return manifest.Episode{}, err + } + + ep := manifest.Episode{ + EpisodeIndex: w.cfg.Episode, + Length: w.buf.Size(), + Tasks: taskSet, + FramesParquet: "frames.parquet", + Videos: videos, + Stats: epStats, + VideoDurations: durations, + } + if err := manifest.Write(w.cfg.Dir, ep); err != nil { + return manifest.Episode{}, err + } + w.cleanupFrames() + return ep, nil +} + +func (w *StagingWriter) videoDurationFromFrames(key string) float64 { + if w.cfg.FPS <= 0 { + return 0 + } + return float64(w.videoFrameCounts[key]) / float64(w.cfg.FPS) +} + +func (w *StagingWriter) finalizeVideos(ctx context.Context) (map[string]string, map[string]float64, error) { videos := map[string]string{} durations := map[string]float64{} - if w.cfg.UseVideos { - for key, spec := range w.buf.Features { - if spec.DType != "video" { + if !w.cfg.UseVideos { + return videos, durations, nil + } + if w.cfg.H264Remux { + rv, rd, err := w.finalizeH264RemuxVideos(ctx) + if err != nil { + return nil, nil, err + } + for k, v := range rv { + videos[k] = v + } + for k, v := range rd { + durations[k] = v + } + } + ev, ed, err := w.finalizeEncodedVideos(ctx) + if err != nil { + return nil, nil, err + } + for k, v := range ev { + videos[k] = v + } + for k, v := range ed { + durations[k] = v + } + if err := w.requireVideoOutputs(videos); err != nil { + return nil, nil, err + } + return videos, durations, nil +} + +func (w *StagingWriter) finalizeEncodedVideos(ctx context.Context) (map[string]string, map[string]float64, error) { + videos := map[string]string{} + durations := map[string]float64{} + type encClose struct { + key string + enc *video.RawRGBEncoder + } + var toClose []encClose + for key, spec := range w.buf.Features { + if spec.DType != "video" || w.videoFrameCounts[key] == 0 { + continue + } + if w.pendingH264Remux != nil { + if _, remux := w.pendingH264Remux[key]; remux { continue } - if w.videoFrameCounts[key] == 0 { - continue + } + if enc := w.videoEncoders[key]; enc != nil { + toClose = append(toClose, encClose{key: key, enc: enc}) + } + } + if len(toClose) > 1 { + var wg sync.WaitGroup + errCh := make(chan error, len(toClose)) + for _, job := range toClose { + wg.Add(1) + go func(job encClose) { + defer wg.Done() + if err := job.enc.Close(); err != nil { + errCh <- err + } + }(job) + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + return nil, nil, err } - out := filepath.Join(w.cfg.Dir, "videos", filepath.Base(key)+".mp4") - pattern := w.frameStore.Pattern(key) - if err := video.EncodeFromPNGDir(ctx, video.EncodeConfig{ - Locator: w.cfg.Locator, - VCodec: w.cfg.VCodec, - CRF: w.cfg.CRF, - FPS: w.cfg.FPS, - PNGPattern: pattern, - OutputPath: out, - }); err != nil { - return manifest.Episode{}, err + } + } else if len(toClose) == 1 { + if err := toClose[0].enc.Close(); err != nil { + return nil, nil, err + } + } + for key, spec := range w.buf.Features { + if spec.DType != "video" || w.videoFrameCounts[key] == 0 { + continue + } + if w.pendingH264Remux != nil { + if _, remux := w.pendingH264Remux[key]; remux { + continue } - videos[key] = out + } + rel := manifest.StagingVideoRel(key) + out := filepath.Join(w.cfg.Dir, rel) + if enc := w.videoEncoders[key]; enc != nil { + out = enc.OutputPath() + videos[key] = rel + durations[key] = w.videoDurationFromFrames(key) + continue + } + if w.frameStore == nil { + return nil, nil, fmt.Errorf("video feature %q has frames but no encoder or frame store", key) + } + if err := os.MkdirAll(filepath.Dir(out), 0o755); err != nil { + return nil, nil, err + } + pattern := w.frameStore.Pattern(key) + if err := video.EncodeFromPNGDir(ctx, video.EncodeConfig{ + Locator: w.cfg.Locator, + VCodec: w.cfg.VCodec, + CRF: w.cfg.CRF, + FPS: w.cfg.FPS, + Threads: resolveEncoderThreads(), + PNGPattern: pattern, + OutputPath: out, + }); err != nil { + return nil, nil, err + } + videos[key] = rel + if os.Getenv("LEROBOT_FFPROBE_DURATION") == "1" { d, _ := video.DurationSeconds(ctx, w.cfg.Locator, out) durations[key] = d + } else { + durations[key] = w.videoDurationFromFrames(key) } } + return videos, durations, nil +} + +func (w *StagingWriter) validateVideoCoverage(episodeLen int) error { + if episodeLen == 0 || !w.cfg.UseVideos { + return nil + } + for key, spec := range w.buf.Features { + if spec.DType != "video" { + continue + } + if w.pendingH264Remux != nil { + if aus, remux := w.pendingH264Remux[key]; remux { + if len(aus) == 0 { + return fmt.Errorf("remux video feature %q has no access units", key) + } + continue + } + } + if w.videoFrameCounts[key] == 0 { + return fmt.Errorf("video feature %q has no frames (episode length %d)", key, episodeLen) + } + } + return nil +} + +func (w *StagingWriter) requireVideoOutputs(videos map[string]string) error { + if !w.cfg.UseVideos { + return nil + } + for key, spec := range w.buf.Features { + if spec.DType != "video" { + continue + } + rel, ok := videos[key] + if !ok || rel == "" { + return fmt.Errorf("video feature %q missing output file", key) + } + out := filepath.Join(w.cfg.Dir, rel) + fi, err := os.Stat(out) + if err != nil { + return fmt.Errorf("video feature %q output not found: %w", key, err) + } + if fi.Size() == 0 { + return fmt.Errorf("video feature %q output is empty: %s", key, out) + } + } + return nil +} + +// AppendRGBVideoFrame writes one RGB frame into a lazily-created encoder (decode fallback path). +func (w *StagingWriter) AppendRGBVideoFrame(ctx context.Context, key string, frame video.VideoFrameRGB24) error { + spec, ok := w.buf.Features[key] + if !ok || spec.DType != "video" { + return fmt.Errorf("not a video feature: %q", key) + } + if err := frame.Validate(); err != nil { + return err + } + enc, err := w.ensureRawEncoder(key, spec) + if err != nil { + return err + } + if err := enc.WriteFrame(frame); err != nil { + return err + } + w.videoFrameCounts[key]++ + return nil +} + +func (w *StagingWriter) finalizeH264RemuxVideos(ctx context.Context) (map[string]string, map[string]float64, error) { + videos := map[string]string{} + durations := map[string]float64{} + if w.pendingH264Remux == nil { + return videos, durations, nil + } + type remuxJob struct { + key string + aus [][]byte + rel string + out string + } + var jobs []remuxJob + for key, aus := range w.pendingH264Remux { + spec, ok := w.buf.Features[key] + if !ok || spec.DType != "video" { + continue + } + if len(aus) == 0 { + return nil, nil, fmt.Errorf("remux track empty for video feature %q", key) + } + rel := manifest.StagingVideoRel(key) + out := filepath.Join(w.cfg.Dir, rel) + if err := os.MkdirAll(filepath.Dir(out), 0o755); err != nil { + return nil, nil, err + } + jobs = append(jobs, remuxJob{key: key, aus: aus, rel: rel, out: out}) + } + if len(jobs) == 0 { + return videos, durations, nil + } + var countMu sync.Mutex + runJob := func(job remuxJob) error { + enc, err := video.NewH264RemuxEncoder(ctx, w.cfg.Locator, w.cfg.FPS, job.out) + if err != nil { + return err + } + if err := enc.WriteAccessUnits(job.aus); err != nil { + _ = enc.Close() + return err + } + if err := enc.Close(); err != nil { + return err + } + expected := len(job.aus) + if nb, err := enc.FrameCount(ctx); err == nil && expected > 0 { + if nb != expected && nb+1 != expected && nb != expected-1 { + slog.Warn("h264 remux frame count drift", "feature", job.key, "ffprobe", nb, "expected", expected) + } + } + fi, statErr := os.Stat(job.out) + if statErr != nil { + return fmt.Errorf("remux output missing for %q: %w", job.key, statErr) + } + if fi.Size() == 0 { + return fmt.Errorf("remux output empty for %q: %s", job.key, job.out) + } + countMu.Lock() + w.videoFrameCounts[job.key] = expected + countMu.Unlock() + return nil + } + if len(jobs) == 1 { + if err := runJob(jobs[0]); err != nil { + return nil, nil, err + } + videos[jobs[0].key] = jobs[0].rel + durations[jobs[0].key] = w.videoDurationFromFrames(jobs[0].key) + return videos, durations, nil + } + var wg sync.WaitGroup + errCh := make(chan error, len(jobs)) + for _, job := range jobs { + wg.Add(1) + go func(job remuxJob) { + defer wg.Done() + if err := runJob(job); err != nil { + errCh <- err + } + }(job) + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + return nil, nil, err + } + } + for _, job := range jobs { + videos[job.key] = job.rel + durations[job.key] = w.videoDurationFromFrames(job.key) + } + return videos, durations, nil +} + +func (w *StagingWriter) ensureRawEncoder(key string, spec meta.FeatureSpec) (*video.RawRGBEncoder, error) { + if enc := w.videoEncoders[key]; enc != nil { + return enc, nil + } + if len(spec.Shape) < 2 { + return nil, fmt.Errorf("video feature %q missing shape", key) + } + h, width := spec.Shape[0], spec.Shape[1] + out := filepath.Join(w.cfg.Dir, "videos", filepath.Base(key)+".mp4") + if err := os.MkdirAll(filepath.Dir(out), 0o755); err != nil { + return nil, err + } + enc, err := video.NewRawRGBEncoder(context.Background(), w.cfg.Locator, w.cfg.VCodec, w.cfg.CRF, w.cfg.FPS, width, h, out) + if err != nil { + return nil, err + } + if w.videoEncoders == nil { + w.videoEncoders = make(map[string]*video.RawRGBEncoder) + } + w.videoEncoders[key] = enc + return enc, nil +} + +func (w *StagingWriter) writeVideoJobsParallel(jobs []videoJob) error { + if len(jobs) == 1 { + job := jobs[0] + if err := w.videoEncoders[job.key].WriteFrame(job.frame); err != nil { + return err + } + w.videoFrameCounts[job.key]++ + return nil + } + var wg sync.WaitGroup + errCh := make(chan error, len(jobs)) + for _, job := range jobs { + wg.Add(1) + go func(job videoJob) { + defer wg.Done() + if err := w.videoEncoders[job.key].WriteFrame(job.frame); err != nil { + errCh <- err + } + }(job) + } + wg.Wait() + close(errCh) + for err := range errCh { + if err != nil { + return err + } + } + for _, job := range jobs { + w.videoFrameCounts[job.key]++ + } + return nil +} + +func (w *StagingWriter) streamingEnabled() bool { + return w.cfg.Streaming && w.cfg.UseVideos && !w.cfg.H264Remux && !hasImageFeatures(w.buf.Features) +} + +func (w *StagingWriter) flushBuffered(ctx context.Context) error { + _ = ctx + if w.buf.Size() == 0 { + return nil + } + if w.pqWriter == nil { + schema, err := parquetx.BuildArrowSchema(w.buf.Features) + if err != nil { + return err + } + pqPath := filepath.Join(w.cfg.Dir, "frames.parquet") + writer, err := parquetx.NewAppendWriterWithFeatures(pqPath, schema, w.buf.Features) + if err != nil { + return err + } + w.pqWriter = writer + } + taskSet := uniqueTasks(w.tasks) + taskIndices := make([]int64, len(w.buf.Tasks())) + for i, t := range w.buf.Tasks() { + taskIndices[i] = int64(indexOf(taskSet, t)) + } + cols := w.buf.ColumnsWithFrameStart(0, int64(w.totalFrames), taskIndices) + w.chunkStats = append(w.chunkStats, stats.ComputeEpisodeStats(stats.EpisodeInput{ + Columns: cols, + }, statsFeatureMap(w.buf.Features), w.cfg.Stats)) + if err := w.pqWriter.WriteRecordColumns(cols, w.buf.Size(), w.buf.Features); err != nil { + return err + } + w.totalFrames += w.buf.Size() + w.buf.Reset() + return nil +} + +func (w *StagingWriter) saveStreamingEpisode(ctx context.Context) (manifest.Episode, error) { + featureStats := statsFeatureMap(w.buf.Features) + epStats := aggregateAsEpisodeStats(w.chunkStats) + mediaStats := stats.ComputeEpisodeStats(stats.EpisodeInput{ + FramePaths: w.videoFramePaths(), + }, featureStats, w.cfg.Stats) + epStats = mergeStats(epStats, mediaStats) + + if err := w.validateVideoCoverage(w.totalFrames); err != nil { + return manifest.Episode{}, err + } + videos, durations, err := w.finalizeVideos(ctx) + if err != nil { + return manifest.Episode{}, err + } ep := manifest.Episode{ EpisodeIndex: w.cfg.Episode, - Length: w.buf.Size(), - Tasks: taskSet, + Length: w.totalFrames, + Tasks: uniqueTasks(w.tasks), FramesParquet: "frames.parquet", Videos: videos, Stats: epStats, @@ -187,6 +691,47 @@ func (w *StagingWriter) SaveEpisode(ctx context.Context) (manifest.Episode, erro return ep, nil } +func mergeStats(base, overlay stats.EpisodeStats) stats.EpisodeStats { + if len(base) == 0 { + return overlay + } + for k, v := range overlay { + base[k] = v + } + return base +} + +func aggregateAsEpisodeStats(parts []stats.EpisodeStats) stats.EpisodeStats { + agg := stats.AggregateStats(parts) + out := make(stats.EpisodeStats, len(agg)) + for key, feature := range agg { + fs := make(stats.FeatureStats, len(feature)) + for stat, value := range feature { + fs[stat] = value + } + out[key] = fs + } + return out +} + +func resolveFlushSize() int { + if v := os.Getenv("LEROBOT_STREAMING_FLUSH_FRAMES"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return 256 +} + +func resolveEncoderThreads() int { + if v := os.Getenv("LEROBOT_ENCODER_THREADS"); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return 0 +} + func buildImageCells(frames [][]byte) []parquetx.ImageCell { cells := make([]parquetx.ImageCell, len(frames)) for i, b := range frames { @@ -199,13 +744,28 @@ func buildImageCells(frames [][]byte) []parquetx.ImageCell { } func (w *StagingWriter) Close() error { + if w.pqWriter != nil { + if err := w.pqWriter.Close(); err != nil { + return err + } + w.pqWriter = nil + } + for _, enc := range w.videoEncoders { + _ = enc.Close() + } w.cleanupFrames() return nil } func (w *StagingWriter) videoFramePaths() map[string][]string { - out := make(map[string][]string, len(w.videoFrameCounts)) + if w.frameStore == nil { + return nil + } + out := make(map[string][]string) for key, count := range w.videoFrameCounts { + if count == 0 { + continue + } paths := make([]string, count) for i := 0; i < count; i++ { paths[i] = w.frameStore.FramePath(key, i) @@ -251,3 +811,12 @@ func statsFeatureMap(features map[string]meta.FeatureSpec) map[string]stats.Feat } return out } + +func hasImageFeatures(features map[string]meta.FeatureSpec) bool { + for _, f := range features { + if f.DType == "image" { + return true + } + } + return false +} diff --git a/internal/v30/staging_h264_remux_test.go b/internal/v30/staging_h264_remux_test.go new file mode 100644 index 0000000..8de63fb --- /dev/null +++ b/internal/v30/staging_h264_remux_test.go @@ -0,0 +1,61 @@ +package v30 + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/ioai-tech/lerobot-go/internal/meta" +) + +func TestStagingH264RemuxCopyMux(t *testing.T) { + if _, err := exec.LookPath("ffmpeg"); err != nil { + t.Skip("ffmpeg not installed") + } + dir := t.TempDir() + key := "observation.images.cam" + w, err := NewStagingWriter(StagingConfig{ + Dir: dir, + Episode: 0, + FPS: 10, + UseVideos: true, + H264Remux: true, + Features: map[string]meta.FeatureSpec{ + key: {DType: "video", Shape: []int{4, 4, 3}}, + }, + }) + if err != nil { + t.Fatalf("NewStagingWriter: %v", err) + } + for i := 0; i < 3; i++ { + if err := w.AddFrame(context.Background(), map[string]any{ + "task": "t", + }); err != nil { + t.Fatalf("AddFrame: %v", err) + } + } + // Minimal synthetic AU; copy mux may fail on invalid bitstream — test wiring only. + if err := w.SetH264Remux(context.Background(), map[string][][]byte{ + key: { + []byte{0, 0, 0, 1, 0x65, 0x88}, + []byte{0, 0, 0, 1, 0x65, 0x88}, + []byte{0, 0, 0, 1, 0x65, 0x88}, + }, + }); err != nil { + t.Fatalf("SetH264Remux: %v", err) + } + _, err = w.SaveEpisode(context.Background()) + if err != nil { + t.Skipf("synthetic AU remux not accepted by ffmpeg: %v", err) + } + out := filepath.Join(dir, "videos", "cam.mp4") + fi, statErr := os.Stat(out) + if statErr != nil { + t.Fatalf("missing output mp4: %v", statErr) + } + if fi.Size() == 0 { + t.Fatal("empty output mp4") + } +} diff --git a/internal/v30/staging_integrity_test.go b/internal/v30/staging_integrity_test.go new file mode 100644 index 0000000..5bcaf37 --- /dev/null +++ b/internal/v30/staging_integrity_test.go @@ -0,0 +1,59 @@ +package v30 + +import ( + "context" + "testing" + + "github.com/ioai-tech/lerobot-go/internal/meta" +) + +func TestSaveEpisodeFailsWithoutVideoFrames(t *testing.T) { + dir := t.TempDir() + key := "observation.images.cam" + w, err := NewStagingWriter(StagingConfig{ + Dir: dir, + Episode: 0, + FPS: 10, + UseVideos: true, + H264Remux: true, + Features: map[string]meta.FeatureSpec{ + key: {DType: "video", Shape: []int{4, 4, 3}}, + }, + }) + if err != nil { + t.Fatalf("NewStagingWriter: %v", err) + } + if err := w.AddFrame(context.Background(), map[string]any{"task": "t"}); err != nil { + t.Fatalf("AddFrame: %v", err) + } + if _, err := w.SaveEpisode(context.Background()); err == nil { + t.Fatal("expected SaveEpisode to fail without video output") + } +} + +func TestSaveEpisodeFailsWithEmptyRemuxTrack(t *testing.T) { + dir := t.TempDir() + key := "observation.images.cam" + w, err := NewStagingWriter(StagingConfig{ + Dir: dir, + Episode: 0, + FPS: 10, + UseVideos: true, + H264Remux: true, + Features: map[string]meta.FeatureSpec{ + key: {DType: "video", Shape: []int{4, 4, 3}}, + }, + }) + if err != nil { + t.Fatalf("NewStagingWriter: %v", err) + } + if err := w.AddFrame(context.Background(), map[string]any{"task": "t"}); err != nil { + t.Fatalf("AddFrame: %v", err) + } + if err := w.SetH264Remux(context.Background(), map[string][][]byte{key: {}}); err != nil { + t.Fatalf("SetH264Remux: %v", err) + } + if _, err := w.SaveEpisode(context.Background()); err == nil { + t.Fatal("expected SaveEpisode to fail with empty remux track") + } +} diff --git a/internal/v30/staging_streaming_test.go b/internal/v30/staging_streaming_test.go new file mode 100644 index 0000000..a0a3a56 --- /dev/null +++ b/internal/v30/staging_streaming_test.go @@ -0,0 +1,119 @@ +package v30_test + +import ( + "bytes" + "context" + "image" + "image/color" + "image/png" + "os" + "path/filepath" + "testing" + + "github.com/ioai-tech/lerobot-go/internal/meta" + "github.com/ioai-tech/lerobot-go/internal/parquetx" + v30 "github.com/ioai-tech/lerobot-go/internal/v30" + "github.com/ioai-tech/lerobot-go/internal/video" +) + +func TestStreamingStagingFlushesParquetChunks(t *testing.T) { + t.Setenv("LEROBOT_STREAMING_FLUSH_FRAMES", "2") + dir := filepath.Join(t.TempDir(), "ep_000000") + features := map[string]meta.FeatureSpec{ + "observation.images.cam": {DType: "video", Shape: []int{4, 4, 3}}, + "observation.state": {DType: "float32", Shape: []int{2}}, + "action": {DType: "float32", Shape: []int{2}}, + } + w, err := v30.NewStagingWriter(v30.StagingConfig{ + Dir: dir, Episode: 0, FPS: 10, Features: features, UseVideos: true, Streaming: true, + }) + if err != nil { + t.Fatal(err) + } + frameRGB := tinyRGB24(t) + ctx := context.Background() + for i := 0; i < 5; i++ { + if err := w.AddFrame(ctx, map[string]any{ + "task": "pick", + "observation.images.cam": video.VideoFrameRGB24{ + Data: frameRGB, Width: 4, Height: 4, + }, + "observation.state": []float32{float32(i), 1}, + "action": []float32{0.1, 0.2}, + }); err != nil { + t.Fatal(err) + } + } + ep, err := w.SaveEpisode(ctx) + if err != nil { + t.Fatal(err) + } + if ep.Length != 5 { + t.Fatalf("length=%d want 5", ep.Length) + } + pq := filepath.Join(dir, "frames.parquet") + rows, err := parquetx.TableNumRows(pq) + if err != nil { + t.Fatal(err) + } + if rows != 5 { + t.Fatalf("rows=%d want 5", rows) + } + tbl, err := parquetx.ReadTable(ctx, pq, nil) + if err != nil { + t.Fatal(err) + } + defer tbl.Release() + frameIndex, err := parquetx.ExtractInt64Column(tbl, "frame_index") + if err != nil { + t.Fatal(err) + } + for i, got := range frameIndex { + if got != int64(i) { + t.Fatalf("frame_index[%d]=%d", i, got) + } + } + if _, err := os.Stat(filepath.Join(dir, "videos", "observation.images.cam.mp4")); err != nil { + t.Fatalf("video missing: %v", err) + } + if ep.Videos["observation.images.cam"] != filepath.Join("videos", "observation.images.cam.mp4") { + t.Fatalf("videos path=%q want relative staging path", ep.Videos["observation.images.cam"]) + } +} + +func tinyRGB24(t *testing.T) []byte { + t.Helper() + img := image.NewRGBA(image.Rect(0, 0, 4, 4)) + for y := 0; y < 4; y++ { + for x := 0; x < 4; x++ { + img.Set(x, y, color.RGBA{R: uint8(x * 40), G: uint8(y * 40), B: 128, A: 255}) + } + } + out := make([]byte, 4*4*3) + i := 0 + for y := 0; y < 4; y++ { + for x := 0; x < 4; x++ { + c := color.RGBAModel.Convert(img.At(x, y)).(color.RGBA) + out[i] = c.R + out[i+1] = c.G + out[i+2] = c.B + i += 3 + } + } + return out +} + +func tinyPNG(t *testing.T) []byte { + t.Helper() + img := image.NewRGBA(image.Rect(0, 0, 4, 4)) + for y := 0; y < 4; y++ { + for x := 0; x < 4; x++ { + img.Set(x, y, color.RGBA{R: uint8(x * 40), G: uint8(y * 40), B: 128, A: 255}) + } + } + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + t.Fatal(err) + } + return buf.Bytes() +} diff --git a/internal/video/encode.go b/internal/video/encode.go index acb533f..8bef8e4 100644 --- a/internal/video/encode.go +++ b/internal/video/encode.go @@ -11,6 +11,19 @@ import ( const DefaultCRF = 25 +// ResolveEncoderThreads reads LEROBOT_ENCODER_THREADS (0 when unset/invalid). +func ResolveEncoderThreads() int { + v := os.Getenv("LEROBOT_ENCODER_THREADS") + if v == "" { + return 0 + } + n, err := strconv.Atoi(v) + if err != nil || n <= 0 { + return 0 + } + return n +} + type EncodeConfig struct { Locator Locator VCodec string @@ -32,6 +45,9 @@ func EncodeFromPNGDir(ctx context.Context, cfg EncodeConfig) error { if cfg.VCodec == "" { cfg.VCodec = "libx264" } + if err := os.MkdirAll(filepath.Dir(cfg.OutputPath), 0o755); err != nil { + return err + } args := []string{ "-y", "-framerate", strconv.Itoa(cfg.FPS), diff --git a/internal/video/frame.go b/internal/video/frame.go new file mode 100644 index 0000000..52d9913 --- /dev/null +++ b/internal/video/frame.go @@ -0,0 +1,59 @@ +package video + +import "fmt" + +// VideoFrameRGB24 is the typed contract for streaming raw RGB24 into encoders. +type VideoFrameRGB24 struct { + Data []byte + Width int + Height int +} + +func (f VideoFrameRGB24) ExpectedSize() int { + if f.Width <= 0 || f.Height <= 0 { + return 0 + } + return f.Width * f.Height * 3 +} + +func (f VideoFrameRGB24) Validate() error { + want := f.ExpectedSize() + if want == 0 { + return fmt.Errorf("invalid video frame dimensions %dx%d", f.Width, f.Height) + } + if len(f.Data) != want { + return fmt.Errorf("rgb24 frame size mismatch: got %d want %d", len(f.Data), want) + } + return nil +} + +// ParseOptionalVideoFrameRGB24 parses a frame value; nil/absent input returns ok=false without error. +func ParseOptionalVideoFrameRGB24(v any, width, height int) (VideoFrameRGB24, bool, error) { + if v == nil { + return VideoFrameRGB24{}, false, nil + } + frame, err := AsVideoFrameRGB24(v, width, height) + if err != nil { + return VideoFrameRGB24{}, false, err + } + return frame, true, nil +} + +// AsVideoFrameRGB24 accepts typed frames or legacy []byte with explicit dimensions. +func AsVideoFrameRGB24(v any, width, height int) (VideoFrameRGB24, error) { + switch x := v.(type) { + case VideoFrameRGB24: + if err := x.Validate(); err != nil { + return VideoFrameRGB24{}, err + } + return x, nil + case []byte: + f := VideoFrameRGB24{Data: x, Width: width, Height: height} + if err := f.Validate(); err != nil { + return VideoFrameRGB24{}, err + } + return f, nil + default: + return VideoFrameRGB24{}, fmt.Errorf("unsupported video frame type %T", v) + } +} diff --git a/internal/video/frame_test.go b/internal/video/frame_test.go new file mode 100644 index 0000000..b81aaef --- /dev/null +++ b/internal/video/frame_test.go @@ -0,0 +1,25 @@ +package video + +import "testing" + +func TestVideoFrameRGB24Validate(t *testing.T) { + f := VideoFrameRGB24{Data: make([]byte, 12), Width: 2, Height: 2} + if err := f.Validate(); err != nil { + t.Fatalf("valid frame rejected: %v", err) + } + bad := VideoFrameRGB24{Data: make([]byte, 10), Width: 2, Height: 2} + if err := bad.Validate(); err == nil { + t.Fatal("expected size mismatch error") + } +} + +func TestAsVideoFrameRGB24LegacyBytes(t *testing.T) { + raw := make([]byte, 12) + got, err := AsVideoFrameRGB24(raw, 2, 2) + if err != nil { + t.Fatalf("AsVideoFrameRGB24 failed: %v", err) + } + if len(got.Data) != 12 { + t.Fatalf("unexpected data len %d", len(got.Data)) + } +} diff --git a/internal/video/h264_remux_encoder.go b/internal/video/h264_remux_encoder.go new file mode 100644 index 0000000..d93aff5 --- /dev/null +++ b/internal/video/h264_remux_encoder.go @@ -0,0 +1,137 @@ +package video + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "strconv" + "strings" +) + +// H264RemuxEncoder muxes Annex-B access units into MP4 via ffmpeg copy (no RGB round-trip). +type H264RemuxEncoder struct { + locator Locator + fps int + output string + + cmd *exec.Cmd + stdin io.WriteCloser +} + +// NewH264RemuxEncoder starts ffmpeg in h264 copy mode. +func NewH264RemuxEncoder(ctx context.Context, locator Locator, fps int, outputPath string) (*H264RemuxEncoder, error) { + if locator == nil { + locator = NewLocator(Config{}) + } + ffmpeg, err := locator.FFmpegPath() + if err != nil { + return nil, err + } + if fps <= 0 { + fps = 30 + } + if err := os.MkdirAll(dirOf(outputPath), 0o755); err != nil { + return nil, err + } + args := h264RemuxFFmpegArgs(fps, outputPath) + cmd := exec.CommandContext(ctx, ffmpeg, args...) + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + if err := cmd.Start(); err != nil { + _ = stdin.Close() + return nil, err + } + return &H264RemuxEncoder{ + locator: locator, + fps: fps, + output: outputPath, + cmd: cmd, + stdin: stdin, + }, nil +} + +// WriteAccessUnits writes ordered access units (nil entries are skipped). +func (e *H264RemuxEncoder) WriteAccessUnits(aus [][]byte) error { + if e.stdin == nil { + return fmt.Errorf("encoder closed or not started") + } + for _, au := range aus { + if len(au) == 0 { + continue + } + if _, err := e.stdin.Write(au); err != nil { + return err + } + } + return nil +} + +// Close flushes stdin and waits for ffmpeg. +func (e *H264RemuxEncoder) Close() error { + var firstErr error + if e.stdin != nil { + if err := e.stdin.Close(); err != nil && firstErr == nil { + firstErr = err + } + e.stdin = nil + } + if e.cmd != nil { + if err := e.cmd.Wait(); err != nil && firstErr == nil { + firstErr = err + } + e.cmd = nil + } + return firstErr +} + +// OutputPath returns the destination mp4 path. +func (e *H264RemuxEncoder) OutputPath() string { + return e.output +} + +// FrameCount probes packet count via ffprobe when available. +func (e *H264RemuxEncoder) FrameCount(ctx context.Context) (int, error) { + ffprobe, err := e.locator.FFprobePath() + if err != nil { + return 0, err + } + cmd := exec.CommandContext(ctx, ffprobe, + "-v", "error", "-select_streams", "v:0", + "-count_packets", "-show_entries", "stream=nb_read_packets", + "-of", "csv=p=0", e.output, + ) + out, err := cmd.Output() + if err != nil { + return 0, err + } + n, err := strconv.Atoi(strings.TrimSpace(string(out))) + if err != nil { + return 0, err + } + return n, nil +} + +func h264RemuxFFmpegArgs(fps int, outputPath string) []string { + return []string{ + "-y", "-hide_banner", "-loglevel", "error", + "-f", "h264", + "-framerate", strconv.Itoa(fps), + "-i", "pipe:0", + "-c:v", "copy", + "-movflags", "+faststart", + outputPath, + } +} + +func dirOf(path string) string { + for i := len(path) - 1; i >= 0; i-- { + if path[i] == '/' || path[i] == '\\' { + return path[:i] + } + } + return "." +} diff --git a/internal/video/h264_remux_encoder_test.go b/internal/video/h264_remux_encoder_test.go new file mode 100644 index 0000000..d940b21 --- /dev/null +++ b/internal/video/h264_remux_encoder_test.go @@ -0,0 +1,34 @@ +package video + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "slices" + "strings" + "testing" +) + +func TestH264RemuxEncoderRequiresFFmpeg(t *testing.T) { + if _, err := exec.LookPath("ffmpeg"); err != nil { + t.Skip("ffmpeg not installed") + } + dir := t.TempDir() + out := filepath.Join(dir, "out.mp4") + enc, err := NewH264RemuxEncoder(context.Background(), nil, 30, out) + if err != nil { + t.Fatalf("NewH264RemuxEncoder: %v", err) + } + _ = enc.Close() + if _, err := os.Stat(out); err == nil { + // ffmpeg may create empty/partial file on close without input + } +} + +func TestH264RemuxFFmpegArgsOmitsGenpts(t *testing.T) { + args := h264RemuxFFmpegArgs(30, "/tmp/out.mp4") + if slices.Contains(args, "+genpts") || strings.Contains(strings.Join(args, " "), "genpts") { + t.Fatalf("remux args should not use genpts: %v", args) + } +} diff --git a/internal/video/info.go b/internal/video/info.go index 67ce78b..ee28b47 100644 --- a/internal/video/info.go +++ b/internal/video/info.go @@ -15,17 +15,46 @@ type ffprobeOutput struct { } type ffprobeStream struct { - CodecType string `json:"codec_type"` - CodecName string `json:"codec_name"` - Width int `json:"width"` - Height int `json:"height"` - PixFmt string `json:"pix_fmt"` - RFrameRate string `json:"r_frame_rate"` - Channels int `json:"channels"` - BitRate string `json:"bit_rate"` - SampleRate string `json:"sample_rate"` - BitsPerSample int `json:"bits_per_raw_sample"` - ChannelLayout string `json:"channel_layout"` + CodecType string `json:"codec_type"` + CodecName string `json:"codec_name"` + Width int `json:"width"` + Height int `json:"height"` + PixFmt string `json:"pix_fmt"` + RFrameRate string `json:"r_frame_rate"` + Channels int `json:"channels"` + BitRate string `json:"bit_rate"` + SampleRate string `json:"sample_rate"` + BitsPerSample ffprobeInt `json:"bits_per_raw_sample"` + ChannelLayout string `json:"channel_layout"` +} + +// ffprobeInt accepts ffprobe numeric fields encoded as JSON numbers or strings. +type ffprobeInt int + +func (f *ffprobeInt) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + *f = 0 + return nil + } + var n int + if err := json.Unmarshal(data, &n); err == nil { + *f = ffprobeInt(n) + return nil + } + var s string + if err := json.Unmarshal(data, &s); err == nil { + if s == "" { + *f = 0 + return nil + } + v, err := strconv.Atoi(s) + if err != nil { + return err + } + *f = ffprobeInt(v) + return nil + } + return nil } // GetVideoInfo mirrors lerobot.datasets.video_utils.get_video_info (v0.5.1). @@ -101,7 +130,7 @@ func GetVideoInfo(ctx context.Context, locator Locator, videoPath string) (map[s } } if audioStream.BitsPerSample > 0 { - info["audio.bit_depth"] = audioStream.BitsPerSample + info["audio.bit_depth"] = int(audioStream.BitsPerSample) } if audioStream.ChannelLayout != "" { info["audio.channel_layout"] = audioStream.ChannelLayout diff --git a/internal/video/info_test.go b/internal/video/info_test.go index b7945cb..83a20bf 100644 --- a/internal/video/info_test.go +++ b/internal/video/info_test.go @@ -1,44 +1,26 @@ -package video_test +package video import ( - "context" - "os" - "path/filepath" + "encoding/json" "testing" - - "github.com/ioai-tech/lerobot-go/internal/video" ) -func TestGetVideoInfoFromRealDataset(t *testing.T) { - home, _ := os.UserHomeDir() - candidates := []string{ - filepath.Join(home, "Downloads/lerobot_dataset/data/chunk-000/episode_000000.mp4"), - } - var videoPath string - for _, c := range candidates { - for _, key := range []string{ - "videos/chunk-000/observation.images.x2w_camera_head_realsense_compressed/episode_000000.mp4", - } { - p := filepath.Join(home, "Downloads/lerobot_dataset", key) - if _, err := os.Stat(p); err == nil { - videoPath = p - break - } - } - if videoPath != "" { - break - } - _ = c +func TestFFprobeIntUnmarshalString(t *testing.T) { + var s ffprobeStream + if err := json.Unmarshal([]byte(`{"bits_per_raw_sample":"16"}`), &s); err != nil { + t.Fatal(err) } - if videoPath == "" { - t.Skip("no local video file for ffprobe test") + if int(s.BitsPerSample) != 16 { + t.Fatalf("got %d want 16", s.BitsPerSample) } - locator := video.NewLocator(video.Config{}) - info, err := video.GetVideoInfo(context.Background(), locator, videoPath) - if err != nil { +} + +func TestFFprobeIntUnmarshalNumber(t *testing.T) { + var s ffprobeStream + if err := json.Unmarshal([]byte(`{"bits_per_raw_sample":24}`), &s); err != nil { t.Fatal(err) } - if info["video.height"] == nil || info["video.width"] == nil { - t.Fatalf("missing video fields: %v", info) + if int(s.BitsPerSample) != 24 { + t.Fatalf("got %d want 24", s.BitsPerSample) } } diff --git a/internal/video/raw_encoder.go b/internal/video/raw_encoder.go new file mode 100644 index 0000000..4ff0fe2 --- /dev/null +++ b/internal/video/raw_encoder.go @@ -0,0 +1,214 @@ +package video + +import ( + "bytes" + "context" + "fmt" + "image" + "image/color" + "image/png" + "io" + "os" + "os/exec" + "strconv" + "sync" +) + +// RawRGBEncoder streams raw RGB24 frames directly into an ffmpeg process +// for video dtype features. This avoids writing per-frame PNG files to +// tempfs/disk for the bulk of the episode, keeping memory bounded to +// in-flight decoded frames + small stats samples. +type RawRGBEncoder struct { + locator Locator + vcodec string + crf int + fps int + width int + height int + output string + + cmd *exec.Cmd + stdin io.WriteCloser + frameCh chan []byte + stopCh chan struct{} + writerWG sync.WaitGroup + closed bool + closeMu sync.Mutex +} + +func encoderQueueDepth() int { + v := os.Getenv("LEROBOT_ENCODER_QUEUE_FRAMES") + if v == "" { + return 3 + } + n, err := strconv.Atoi(v) + if err != nil || n <= 0 { + return 3 + } + return n +} + +func NewRawRGBEncoder(ctx context.Context, locator Locator, vcodec string, crf, fps, width, height int, outputPath string) (*RawRGBEncoder, error) { + if locator == nil { + locator = NewLocator(Config{}) + } + ffmpeg, err := locator.FFmpegPath() + if err != nil { + return nil, err + } + if crf <= 0 { + crf = DefaultCRF + } + if vcodec == "" { + vcodec = "libx264" + } + + args := []string{ + "-y", + "-f", "rawvideo", + "-pix_fmt", "rgb24", + "-s", fmt.Sprintf("%dx%d", width, height), + "-framerate", strconv.Itoa(fps), + "-i", "pipe:0", + "-c:v", vcodec, + "-crf", strconv.Itoa(crf), + "-pix_fmt", "yuv420p", + "-movflags", "+faststart", + } + if threads := ResolveEncoderThreads(); threads > 0 { + args = append(args, "-threads", strconv.Itoa(threads)) + } + args = append(args, outputPath) + + cmd := exec.CommandContext(ctx, ffmpeg, args...) + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + if err := cmd.Start(); err != nil { + _ = stdin.Close() + return nil, err + } + + e := &RawRGBEncoder{ + locator: locator, + vcodec: vcodec, + crf: crf, + fps: fps, + width: width, + height: height, + output: outputPath, + cmd: cmd, + stdin: stdin, + frameCh: make(chan []byte, encoderQueueDepth()), + stopCh: make(chan struct{}), + } + e.writerWG.Add(1) + go e.writerLoop() + return e, nil +} + +func (e *RawRGBEncoder) writerLoop() { + defer e.writerWG.Done() + for { + select { + case frame, ok := <-e.frameCh: + if !ok { + return + } + if e.stdin != nil { + _, _ = e.stdin.Write(frame) + } + case <-e.stopCh: + return + } + } +} + +func (e *RawRGBEncoder) WriteFrame(frame VideoFrameRGB24) error { + if err := frame.Validate(); err != nil { + return err + } + return e.enqueue(frame.Data) +} + +// WriteFrameBytes keeps backward compatibility for callers still passing raw bytes. +func (e *RawRGBEncoder) WriteFrameBytes(rgb24 []byte) error { + frame := VideoFrameRGB24{Data: rgb24, Width: e.width, Height: e.height} + return e.WriteFrame(frame) +} + +func (e *RawRGBEncoder) enqueue(rgb24 []byte) error { + e.closeMu.Lock() + defer e.closeMu.Unlock() + if e.closed || e.stdin == nil { + return fmt.Errorf("encoder closed or not started") + } + buf := append([]byte(nil), rgb24...) + e.frameCh <- buf + return nil +} + +func (e *RawRGBEncoder) Close() error { + e.closeMu.Lock() + if e.closed { + e.closeMu.Unlock() + e.writerWG.Wait() + return nil + } + e.closed = true + close(e.frameCh) + e.closeMu.Unlock() + e.writerWG.Wait() + + var firstErr error + if e.stdin != nil { + if err := e.stdin.Close(); err != nil && firstErr == nil { + firstErr = err + } + e.stdin = nil + } + if e.cmd != nil { + if err := e.cmd.Wait(); err != nil && firstErr == nil { + firstErr = err + } + e.cmd = nil + } + return firstErr +} + +func (e *RawRGBEncoder) OutputPath() string { + return e.output +} + +// rgb24ToPNG is a small helper used only for a bounded number of stats +// sample frames. It encodes one RGB24 frame to PNG bytes so that +// existing image stats sampling code can consume it via FrameBytes. +func rgb24ToPNG(rgb24 []byte, width, height int) ([]byte, error) { + img := &rgbImage{ + data: rgb24, + width: width, + height: height, + } + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +type rgbImage struct { + data []byte + width, height int +} + +func (i *rgbImage) ColorModel() color.Model { return color.RGBAModel } +func (i *rgbImage) Bounds() image.Rectangle { return image.Rect(0, 0, i.width, i.height) } +func (i *rgbImage) At(x, y int) color.Color { + if x < 0 || y < 0 || x >= i.width || y >= i.height { + return color.RGBA{} + } + idx := (y*i.width + x) * 3 + r, g, b := i.data[idx], i.data[idx+1], i.data[idx+2] + return color.RGBA{R: r, G: g, B: b, A: 255} +} diff --git a/lerobot/config.go b/lerobot/config.go index d996874..69f1e4d 100644 --- a/lerobot/config.go +++ b/lerobot/config.go @@ -90,6 +90,7 @@ type StagingConfig struct { FFmpeg FFmpegConfig Streaming bool Stats StatsMode + H264Remux bool } // MergeConfig finalizes completed staging episodes into the on-disk dataset layout. diff --git a/lerobot/video_frame.go b/lerobot/video_frame.go new file mode 100644 index 0000000..0f0efb6 --- /dev/null +++ b/lerobot/video_frame.go @@ -0,0 +1,12 @@ +package lerobot + +import "github.com/ioai-tech/lerobot-go/internal/video" + +// VideoFrameRGB24 is the typed RGB24 payload for video dtype features. +type VideoFrameRGB24 = video.VideoFrameRGB24 + +// NewVideoFrameRGB24 validates and constructs a video frame payload. +func NewVideoFrameRGB24(data []byte, width, height int) (VideoFrameRGB24, error) { + frame := VideoFrameRGB24{Data: data, Width: width, Height: height} + return frame, frame.Validate() +} diff --git a/lerobot/writer.go b/lerobot/writer.go index b0f25b3..10a3ae5 100644 --- a/lerobot/writer.go +++ b/lerobot/writer.go @@ -22,6 +22,8 @@ type Dataset interface { // StagingWriter records one episode into a single ep_NNNNNN directory. type StagingWriter interface { AddFrame(ctx context.Context, frame Frame) error + SetH264Remux(ctx context.Context, tracks map[string][][]byte) error + AppendRGBVideoFrame(ctx context.Context, key string, frame VideoFrameRGB24) error SaveEpisode(ctx context.Context) (*EpisodeManifest, error) Close() error } @@ -36,6 +38,8 @@ type EpisodeManifest struct { type episodeBackend interface { AddFrame(ctx context.Context, frame map[string]any) error + SetH264Remux(ctx context.Context, tracks map[string][][]byte) error + AppendRGBVideoFrame(ctx context.Context, key string, frame video.VideoFrameRGB24) error SaveEpisode(ctx context.Context) (manifest.Episode, error) Close() error } @@ -49,6 +53,14 @@ func (s *stagingWrapper) AddFrame(ctx context.Context, frame Frame) error { return s.backend.AddFrame(ctx, frame.toMap()) } +func (s *stagingWrapper) SetH264Remux(ctx context.Context, tracks map[string][][]byte) error { + return s.backend.SetH264Remux(ctx, tracks) +} + +func (s *stagingWrapper) AppendRGBVideoFrame(ctx context.Context, key string, frame VideoFrameRGB24) error { + return s.backend.AppendRGBVideoFrame(ctx, key, frame) +} + func (s *stagingWrapper) SaveEpisode(ctx context.Context) (*EpisodeManifest, error) { ep, err := s.backend.SaveEpisode(ctx) if err != nil { @@ -80,6 +92,7 @@ func NewStagingWriter(ctx context.Context, cfg StagingConfig) (StagingWriter, er w, err := v21.NewStagingWriter(v21.StagingConfig{ Dir: cfg.Dir, TempRoot: cfg.TempRoot, Episode: cfg.Episode, FPS: cfg.FPS, Features: cfg.Features, Locator: locator, VCodec: cfg.VCodec, CRF: cfg.CRF, UseVideos: cfg.UseVideos, Stats: cfg.Stats.toOptions(), + H264Remux: cfg.H264Remux, }) if err != nil { return nil, err @@ -89,6 +102,7 @@ func NewStagingWriter(ctx context.Context, cfg StagingConfig) (StagingWriter, er w, err := v30.NewStagingWriter(v30.StagingConfig{ Dir: cfg.Dir, TempRoot: cfg.TempRoot, Episode: cfg.Episode, FPS: cfg.FPS, Features: cfg.Features, Locator: locator, VCodec: cfg.VCodec, CRF: cfg.CRF, UseVideos: cfg.UseVideos, Streaming: cfg.Streaming, Stats: cfg.Stats.toOptions(), + H264Remux: cfg.H264Remux, }) if err != nil { return nil, err @@ -185,6 +199,11 @@ func (d *serialDataset) Finalize(ctx context.Context) error { func (d *serialDataset) Root() string { return d.root } +// ValidateOutputIntegrity checks merged dataset has data parquet and video files. +func ValidateOutputIntegrity(root string, features map[string]FeatureSpec) error { + return v30.ValidateOutputIntegrity(root, features) +} + // Merge finalizes completed staging episodes into the official on-disk layout. func Merge(ctx context.Context, cfg MergeConfig) error { if cfg.Version == VersionUnset { @@ -202,7 +221,7 @@ func Merge(ctx context.Context, cfg MergeConfig) error { return v30.Merge(ctx, v30.MergeConfig{ StagingRoot: cfg.StagingRoot, OutputRoot: cfg.OutputRoot, RobotType: cfg.RobotType, FPS: cfg.FPS, Features: cfg.Features, - Locator: locator, Stats: cfg.Stats.toOptions(), + Locator: locator, Stats: cfg.Stats.toOptions(), MaxWorkers: cfg.MaxWorkers, }) } } diff --git a/mcap2lerobot/convert.go b/mcap2lerobot/convert.go new file mode 100644 index 0000000..a5583c4 --- /dev/null +++ b/mcap2lerobot/convert.go @@ -0,0 +1,30 @@ +package mcap2lerobot + +import ( + "context" + "fmt" + + "github.com/ioai-tech/lerobot-go/lerobot" +) + +// Config holds future mcap conversion options. +type Config struct { + InputPath string + OutputRoot string + StagingRoot string + FPS int + Version lerobot.Version + MaxWorkers int +} + +// Convert is a Phase 4 placeholder that validates config and documents the integration point. +func Convert(ctx context.Context, cfg Config) error { + if cfg.InputPath == "" { + return fmt.Errorf("input path required") + } + if cfg.OutputRoot == "" { + return fmt.Errorf("output root required") + } + _ = ctx + return fmt.Errorf("mcap2lerobot Go converter not yet implemented; use lerobot.NewStagingWriter + lerobot.Merge from ported pipeline") +} diff --git a/mcap2lerobot/doc.go b/mcap2lerobot/doc.go new file mode 100644 index 0000000..7987b7c --- /dev/null +++ b/mcap2lerobot/doc.go @@ -0,0 +1,5 @@ +// Package mcap2lerobot will convert MCAP recordings to LeRobot datasets using +// github.com/ioai-tech/lerobot-go/lerobot staging and merge APIs. +// +// Phase 4 scaffold: integrate foxglove/mcap reader and port new_mcap2lerobot pipeline. +package mcap2lerobot From 6d3cc382d4052a74942ac2f4be0b4faedb6a1406 Mon Sep 17 00:00:00 2001 From: joaner Date: Tue, 9 Jun 2026 20:37:00 +0800 Subject: [PATCH 2/2] Fix CI lint and test skips for optional pandas/ffmpeg dependencies. Remove unreachable staging branches, gofmt affected files, and skip environment-dependent tests when pandas or ffmpeg are unavailable. --- internal/parquetx/tasks_test.go | 3 ++ internal/v21/staging.go | 13 -------- internal/v30/merge.go | 4 +-- internal/v30/staging.go | 16 ++-------- internal/v30/staging_streaming_test.go | 25 ++++------------ internal/video/h264_remux_encoder_test.go | 4 --- internal/video/info.go | 22 +++++++------- internal/video/raw_encoder.go | 36 ----------------------- 8 files changed, 24 insertions(+), 99 deletions(-) diff --git a/internal/parquetx/tasks_test.go b/internal/parquetx/tasks_test.go index 101f755..26ed1ac 100644 --- a/internal/parquetx/tasks_test.go +++ b/internal/parquetx/tasks_test.go @@ -44,6 +44,9 @@ func TestWriteTasksParquetUsesTaskColumn(t *testing.T) { t.Fatalf("loaded task map=%v", loaded) } + if err := exec.Command("python3", "-c", "import pandas").Run(); err != nil { + t.Skip("pandas not installed") + } out, err := exec.Command("python3", "-c", ` import pandas as pd, sys tasks = pd.read_parquet(sys.argv[1]) diff --git a/internal/v21/staging.go b/internal/v21/staging.go index ecc3d96..2a695b3 100644 --- a/internal/v21/staging.go +++ b/internal/v21/staging.go @@ -165,18 +165,6 @@ func (w *StagingWriter) AddFrame(ctx context.Context, frame map[string]any) erro } w.videoFrameCounts[key]++ usedDecodeEncoder = true - continue - if png, ok := val.([]byte); ok { - if w.frameStore == nil { - return fmt.Errorf("frame store not initialized for video feature %q", key) - } - frameIndex := w.videoFrameCounts[key] - rel := filepath.Join("images", key, fmt.Sprintf("frame-%06d.png", frameIndex)) - if err := w.frameStore.WritePNG(rel, png); err != nil { - return err - } - w.videoFrameCounts[key] = frameIndex + 1 - } } } if !usedDecodeEncoder && w.cfg.H264Remux && w.cfg.UseVideos { @@ -503,4 +491,3 @@ func statsFeatureMap(features map[string]meta.FeatureSpec) map[string]stats.Feat } return out } - diff --git a/internal/v30/merge.go b/internal/v30/merge.go index e599765..cdbd1b4 100644 --- a/internal/v30/merge.go +++ b/internal/v30/merge.go @@ -375,8 +375,8 @@ func (st *mergeState) writeDataBatches(ctx context.Context, outputRoot string, w } type videoBatchJob struct { - locator video.Locator - dst string + locator video.Locator + dst string segments []string } diff --git a/internal/v30/staging.go b/internal/v30/staging.go index dbfc380..2c32316 100644 --- a/internal/v30/staging.go +++ b/internal/v30/staging.go @@ -147,7 +147,7 @@ func (w *StagingWriter) AddFrame(ctx context.Context, frame map[string]any) erro w.imageBytes[key] = append(w.imageBytes[key], append([]byte(nil), png...)) } case "video": - if spec.Shape == nil || len(spec.Shape) < 2 { + if len(spec.Shape) < 2 { continue } vf, ok, err := video.ParseOptionalVideoFrameRGB24(val, spec.Shape[1], spec.Shape[0]) @@ -161,18 +161,6 @@ func (w *StagingWriter) AddFrame(ctx context.Context, frame map[string]any) erro return err } videoJobs = append(videoJobs, videoJob{key: key, frame: vf}) - continue - if png, ok := val.([]byte); ok { - if w.frameStore == nil { - return fmt.Errorf("frame store not initialized for video feature %q", key) - } - frameIndex := w.videoFrameCounts[key] - rel := filepath.Join("images", key, fmt.Sprintf("frame-%06d.png", frameIndex)) - if err := w.frameStore.WritePNG(rel, png); err != nil { - return err - } - w.videoFrameCounts[key] = frameIndex + 1 - } } } if len(videoJobs) > 0 { @@ -369,7 +357,7 @@ func (w *StagingWriter) finalizeEncodedVideos(ctx context.Context) (map[string]s rel := manifest.StagingVideoRel(key) out := filepath.Join(w.cfg.Dir, rel) if enc := w.videoEncoders[key]; enc != nil { - out = enc.OutputPath() + _ = enc.OutputPath() videos[key] = rel durations[key] = w.videoDurationFromFrames(key) continue diff --git a/internal/v30/staging_streaming_test.go b/internal/v30/staging_streaming_test.go index a0a3a56..8586576 100644 --- a/internal/v30/staging_streaming_test.go +++ b/internal/v30/staging_streaming_test.go @@ -1,12 +1,11 @@ package v30_test import ( - "bytes" "context" "image" "image/color" - "image/png" "os" + "os/exec" "path/filepath" "testing" @@ -17,6 +16,9 @@ import ( ) func TestStreamingStagingFlushesParquetChunks(t *testing.T) { + if _, err := exec.LookPath("ffmpeg"); err != nil { + t.Skip("ffmpeg not installed") + } t.Setenv("LEROBOT_STREAMING_FLUSH_FRAMES", "2") dir := filepath.Join(t.TempDir(), "ep_000000") features := map[string]meta.FeatureSpec{ @@ -38,8 +40,8 @@ func TestStreamingStagingFlushesParquetChunks(t *testing.T) { "observation.images.cam": video.VideoFrameRGB24{ Data: frameRGB, Width: 4, Height: 4, }, - "observation.state": []float32{float32(i), 1}, - "action": []float32{0.1, 0.2}, + "observation.state": []float32{float32(i), 1}, + "action": []float32{0.1, 0.2}, }); err != nil { t.Fatal(err) } @@ -102,18 +104,3 @@ func tinyRGB24(t *testing.T) []byte { } return out } - -func tinyPNG(t *testing.T) []byte { - t.Helper() - img := image.NewRGBA(image.Rect(0, 0, 4, 4)) - for y := 0; y < 4; y++ { - for x := 0; x < 4; x++ { - img.Set(x, y, color.RGBA{R: uint8(x * 40), G: uint8(y * 40), B: 128, A: 255}) - } - } - var buf bytes.Buffer - if err := png.Encode(&buf, img); err != nil { - t.Fatal(err) - } - return buf.Bytes() -} diff --git a/internal/video/h264_remux_encoder_test.go b/internal/video/h264_remux_encoder_test.go index d940b21..1a2c321 100644 --- a/internal/video/h264_remux_encoder_test.go +++ b/internal/video/h264_remux_encoder_test.go @@ -2,7 +2,6 @@ package video import ( "context" - "os" "os/exec" "path/filepath" "slices" @@ -21,9 +20,6 @@ func TestH264RemuxEncoderRequiresFFmpeg(t *testing.T) { t.Fatalf("NewH264RemuxEncoder: %v", err) } _ = enc.Close() - if _, err := os.Stat(out); err == nil { - // ffmpeg may create empty/partial file on close without input - } } func TestH264RemuxFFmpegArgsOmitsGenpts(t *testing.T) { diff --git a/internal/video/info.go b/internal/video/info.go index ee28b47..95c2792 100644 --- a/internal/video/info.go +++ b/internal/video/info.go @@ -15,17 +15,17 @@ type ffprobeOutput struct { } type ffprobeStream struct { - CodecType string `json:"codec_type"` - CodecName string `json:"codec_name"` - Width int `json:"width"` - Height int `json:"height"` - PixFmt string `json:"pix_fmt"` - RFrameRate string `json:"r_frame_rate"` - Channels int `json:"channels"` - BitRate string `json:"bit_rate"` - SampleRate string `json:"sample_rate"` - BitsPerSample ffprobeInt `json:"bits_per_raw_sample"` - ChannelLayout string `json:"channel_layout"` + CodecType string `json:"codec_type"` + CodecName string `json:"codec_name"` + Width int `json:"width"` + Height int `json:"height"` + PixFmt string `json:"pix_fmt"` + RFrameRate string `json:"r_frame_rate"` + Channels int `json:"channels"` + BitRate string `json:"bit_rate"` + SampleRate string `json:"sample_rate"` + BitsPerSample ffprobeInt `json:"bits_per_raw_sample"` + ChannelLayout string `json:"channel_layout"` } // ffprobeInt accepts ffprobe numeric fields encoded as JSON numbers or strings. diff --git a/internal/video/raw_encoder.go b/internal/video/raw_encoder.go index 4ff0fe2..16d5600 100644 --- a/internal/video/raw_encoder.go +++ b/internal/video/raw_encoder.go @@ -1,12 +1,8 @@ package video import ( - "bytes" "context" "fmt" - "image" - "image/color" - "image/png" "io" "os" "os/exec" @@ -180,35 +176,3 @@ func (e *RawRGBEncoder) Close() error { func (e *RawRGBEncoder) OutputPath() string { return e.output } - -// rgb24ToPNG is a small helper used only for a bounded number of stats -// sample frames. It encodes one RGB24 frame to PNG bytes so that -// existing image stats sampling code can consume it via FrameBytes. -func rgb24ToPNG(rgb24 []byte, width, height int) ([]byte, error) { - img := &rgbImage{ - data: rgb24, - width: width, - height: height, - } - var buf bytes.Buffer - if err := png.Encode(&buf, img); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -type rgbImage struct { - data []byte - width, height int -} - -func (i *rgbImage) ColorModel() color.Model { return color.RGBAModel } -func (i *rgbImage) Bounds() image.Rectangle { return image.Rect(0, 0, i.width, i.height) } -func (i *rgbImage) At(x, y int) color.Color { - if x < 0 || y < 0 || x >= i.width || y >= i.height { - return color.RGBA{} - } - idx := (y*i.width + x) * 3 - r, g, b := i.data[idx], i.data[idx+1], i.data[idx+2] - return color.RGBA{R: r, G: g, B: b, A: 255} -}