Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions internal/buffer/episode.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ func (b *EpisodeBuffer) Columns(globalIndex int64, taskIndices []int64) map[stri
return out
}

func (b *EpisodeBuffer) ColumnsWithFrameStart(globalIndex, frameStart int64, taskIndices []int64) map[string]any {
out := b.Columns(globalIndex, taskIndices)
frameIdx := make([]int64, b.size)
timestamps := make([]float32, b.size)
for i := range frameIdx {
frameIdx[i] = frameStart + int64(i)
timestamps[i] = float32(frameStart+int64(i)) / float32(b.FPS)
}
out["frame_index"] = frameIdx
out["timestamp"] = timestamps
return out
}

func (b *EpisodeBuffer) Reset() {
b.columns = make(map[string]any)
b.tasks = nil
b.size = 0
}

func isScalarShape(shape []int) bool {
return len(shape) == 0 || (len(shape) == 1 && shape[0] == 1)
}
Expand Down Expand Up @@ -143,21 +162,21 @@ func appendFloat32Row(b *EpisodeBuffer, key string, row []float32) {
if _, ok := b.columns[key]; !ok {
b.columns[key] = [][]float32{}
}
b.columns[key] = append(b.columns[key].([][]float32), row)
b.columns[key] = append(b.columns[key].([][]float32), append([]float32(nil), row...))
}

func appendFloat64Row(b *EpisodeBuffer, key string, row []float64) {
if _, ok := b.columns[key]; !ok {
b.columns[key] = [][]float64{}
}
b.columns[key] = append(b.columns[key].([][]float64), row)
b.columns[key] = append(b.columns[key].([][]float64), append([]float64(nil), row...))
}

func appendInt64Row(b *EpisodeBuffer, key string, row []int64) {
if _, ok := b.columns[key]; !ok {
b.columns[key] = [][]int64{}
}
b.columns[key] = append(b.columns[key].([][]int64), row)
b.columns[key] = append(b.columns[key].([][]int64), append([]int64(nil), row...))
}

func appendFloat64(b *EpisodeBuffer, key string, val float64) {
Expand All @@ -170,7 +189,7 @@ func appendFloat64(b *EpisodeBuffer, key string, val float64) {
func toFloat32Row(val any) ([]float32, bool) {
switch v := val.(type) {
case []float32:
return v, true
return append([]float32(nil), v...), true
case []float64:
out := make([]float32, len(v))
for i, x := range v {
Expand Down
76 changes: 76 additions & 0 deletions internal/buffer/episode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package buffer

import (
"math"
"testing"

"github.com/ioai-tech/lerobot-go/internal/meta"
)

func TestAppendFloat64RowDeepCopy(t *testing.T) {
features := map[string]meta.FeatureSpec{
"observation.state": {DType: "float64", Shape: []int{3}},
}
b := New(0, 30, features)
shared := []float64{1, 2, 3}
for i := 0; i < 3; i++ {
if err := b.AddFrame(map[string]any{
"task": "pick",
"observation.state": shared,
}); err != nil {
t.Fatal(err)
}
for j := range shared {
shared[j] = float64(i*10 + j + 100)
}
}
rows := b.columns["observation.state"].([][]float64)
if len(rows) != 3 {
t.Fatalf("rows=%d want 3", len(rows))
}
want := [][]float64{
{1, 2, 3},
{100, 101, 102},
{110, 111, 112},
}
for i, row := range rows {
for j, v := range row {
if v != want[i][j] {
t.Fatalf("row[%d][%d]=%v want %v (slice alias corrupted stored rows)", i, j, v, want[i][j])
}
}
}
}

func TestObservationStateValuesStayFiniteThroughParquetPath(t *testing.T) {
features := map[string]meta.FeatureSpec{
"observation.state": {DType: "float64", Shape: []int{4}},
"action": {DType: "float64", Shape: []int{4}},
}
b := New(0, 30, features)
reuse := make([]float64, 4)
for i := 0; i < 20; i++ {
for j := range reuse {
reuse[j] = float64(i)*0.1 + float64(j)*0.01
}
if err := b.AddFrame(map[string]any{
"task": "move",
"observation.state": reuse,
"action": reuse,
}); err != nil {
t.Fatal(err)
}
}
stateRows := b.columns["observation.state"].([][]float64)
for i, row := range stateRows {
for j, v := range row {
if math.IsNaN(v) || math.IsInf(v, 0) || math.Abs(v) > 100 {
t.Fatalf("state[%d][%d]=%v out of sane joint range", i, j, v)
}
want := float64(i)*0.1 + float64(j)*0.01
if v != want {
t.Fatalf("state[%d][%d]=%v want %v", i, j, v, want)
}
}
}
}
15 changes: 15 additions & 0 deletions internal/manifest/episode.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Episode struct {
const FileName = "episode_meta.json"

func Write(dir string, ep Episode) error {
ep.Stats = stats.SanitizeEpisodeStats(ep.Stats)
data, err := json.MarshalIndent(ep, "", " ")
if err != nil {
return err
Expand Down Expand Up @@ -51,6 +52,20 @@ func stagingName(episodeIndex int) string {
return fmt.Sprintf("ep_%06d", episodeIndex)
}

// StagingMediaPath resolves a path stored in episode metadata relative to a staging dir.
// Legacy manifests may store absolute paths; those are returned unchanged.
func StagingMediaPath(stagingDir, rel string) string {
if filepath.IsAbs(rel) {
return rel
}
return filepath.Join(stagingDir, rel)
}

// StagingVideoRel returns the episode-relative path for a staged per-episode MP4.
func StagingVideoRel(videoKey string) string {
return filepath.Join("videos", filepath.Base(videoKey)+".mp4")
}

// ListStagingEpisodes returns completed staging episode dirs sorted by episode_index.
func ListStagingEpisodes(root string) ([]string, error) {
entries, err := os.ReadDir(root)
Expand Down
23 changes: 23 additions & 0 deletions internal/manifest/episode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package manifest

import (
"path/filepath"
"testing"
)

func TestStagingMediaPath(t *testing.T) {
dir := "/tmp/ep_000000"
rel := StagingVideoRel("observation.images.cam")
if rel != filepath.Join("videos", "observation.images.cam.mp4") {
t.Fatalf("rel=%q", rel)
}
got := StagingMediaPath(dir, rel)
want := filepath.Join(dir, rel)
if got != want {
t.Fatalf("got %q want %q", got, want)
}
abs := "/tmp/ep_000000/videos/foo.mp4"
if StagingMediaPath(dir, abs) != abs {
t.Fatalf("absolute path should pass through")
}
}
36 changes: 22 additions & 14 deletions internal/parquetx/append_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,33 @@ func WriteEpisodeBatch(ctx context.Context, dst string, entries []EpisodeBatchEn
return fmt.Errorf("no episode parquet entries")
}
alloc := memory.NewGoAllocator()
tables := make([]arrow.Table, 0, len(entries))
defer func() {
for _, tbl := range tables {
tbl.Release()
}
}()
for _, entry := range entries {
first, err := RewriteEpisodeParquet(ctx, entries[0].SourcePath, entries[0].Options, alloc)
if err != nil {
return err
}
defer first.Release()
writer, err := NewAppendWriter(dst, first.Schema())
if err != nil {
return err
}
if err := writer.WriteTable(first, 1024); err != nil {
_ = writer.Close()
return err
}
for _, entry := range entries[1:] {
tbl, err := RewriteEpisodeParquet(ctx, entry.SourcePath, entry.Options, alloc)
if err != nil {
_ = writer.Close()
return err
}
tables = append(tables, tbl)
}
merged, err := ConcatTables(alloc, tables)
if err != nil {
return err
if err := writer.WriteTable(tbl, 1024); err != nil {
tbl.Release()
_ = writer.Close()
return err
}
tbl.Release()
}
defer merged.Release()
return WriteTable(dst, merged, alloc)
return writer.Close()
}

func RewriteEpisodeParquet(ctx context.Context, src string, opts AppendEpisodeOptions, alloc memory.Allocator) (arrow.Table, error) {
Expand Down
7 changes: 6 additions & 1 deletion internal/parquetx/meta_v30.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

// WriteTasksParquet writes meta/tasks.parquet compatible with lerobot load_tasks().
// Task strings are stored as the pandas index column named "task".
func WriteTasksParquet(root string, taskMap map[string]int) error {
if len(taskMap) == 0 {
return nil
Expand Down Expand Up @@ -41,10 +42,14 @@ func WriteTasksParquet(root string, taskMap map[string]int) error {
defer taskArr.Release()
defer idxArr.Release()

pandasMD := arrow.NewMetadata(
[]string{"pandas"},
[]string{`{"index_columns": ["task"], "column_indexes": [{"name": null, "field_name": null, "pandas_type": "unicode", "numpy_type": "object", "metadata": {"encoding": "UTF-8"}}], "columns": [{"name": "task_index", "field_name": "task_index", "pandas_type": "int64", "numpy_type": "int64", "metadata": null}, {"name": "task", "field_name": "task", "pandas_type": "unicode", "numpy_type": "object", "metadata": null}], "attributes": {}, "pandas_version": "2.3.3"}`},
)
schema := arrow.NewSchema([]arrow.Field{
{Name: "task_index", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
{Name: "task", Type: arrow.BinaryTypes.String, Nullable: true},
}, nil)
}, &pandasMD)
cols := []arrow.Column{
*arrow.NewColumn(schema.Field(0), arrow.NewChunked(schema.Field(0).Type, []arrow.Array{idxArr})),
*arrow.NewColumn(schema.Field(1), arrow.NewChunked(schema.Field(1).Type, []arrow.Array{taskArr})),
Expand Down
24 changes: 21 additions & 3 deletions internal/parquetx/tasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package parquetx

import (
"context"
"os/exec"
"path/filepath"
"strings"
"testing"

"github.com/apache/arrow-go/v18/arrow"
Expand Down Expand Up @@ -31,17 +33,33 @@ func TestWriteTasksParquetUsesTaskColumn(t *testing.T) {
if len(tbl.Schema().FieldIndices("task")) == 0 {
t.Fatal("task column missing")
}
if len(tbl.Schema().FieldIndices("__index_level_0__")) != 0 {
t.Fatal("legacy __index_level_0__ column should not be written for new datasets")
if len(tbl.Schema().FieldIndices("task_index")) == 0 {
t.Fatal("task_index column missing")
}

loaded, err := ReadTasksParquet(context.Background(), path)
if err != nil {
t.Fatal(err)
}
if loaded["pick"] != 0 || loaded["place"] != 1 {
t.Fatalf("loaded task map=%v", loaded)
}

if err := exec.Command("python3", "-c", "import pandas").Run(); err != nil {
t.Skip("pandas not installed")
}
out, err := exec.Command("python3", "-c", `
import pandas as pd, sys
tasks = pd.read_parquet(sys.argv[1])
tasks.index.name = "task"
assert tasks.iloc[0].name == "pick", tasks.iloc[0].name
assert tasks.iloc[1].name == "place", tasks.iloc[1].name
`, path).CombinedOutput()
if err != nil {
t.Fatalf("pandas round-trip failed: %v\n%s", err, out)
}
if strings.TrimSpace(string(out)) != "" {
t.Fatalf("unexpected python output: %s", out)
}
}

func TestConcatTablesIgnoresSchemaMetadataDifferences(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions internal/parquetx/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ func (w *AppendWriter) WriteEpisodeColumns(columns map[string]any, length int, f
return w.writer.Write(rec)
}

func (w *AppendWriter) WriteRecordColumns(columns map[string]any, length int, features map[string]meta.FeatureSpec) error {
return w.WriteEpisodeColumns(columns, length, features)
}

func (w *AppendWriter) WriteTable(tbl arrow.Table, chunkSize int64) error {
return w.writer.WriteTable(tbl, chunkSize)
}

func (w *AppendWriter) Close() error {
if w.writer != nil {
err := w.writer.Close()
Expand Down
Loading
Loading