Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions examples/4_rollout_neuracore_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,26 +661,27 @@ def update_visualization(
default=None,
help="IP address of Meta Quest device (optional, defaults to None for auto-discovery)",
)
parser.add_argument(
policy_group = parser.add_mutually_exclusive_group(required=True)
policy_group.add_argument(
"--train-run-name",
type=str,
default=None,
help="Name of the training run to load policy from (for cloud training). Mutually exclusive with --model-path.",
help="Name of the training run to load policy from (for cloud training).",
)
parser.add_argument(
policy_group.add_argument(
"--model-path",
type=str,
default=None,
help="Path to local model file to load policy from. Mutually exclusive with --train-run-name.",
help="Path to local model file to load policy from.",
)
policy_group.add_argument(
"--remote-endpoint-name",
type=str,
default=None,
help="Name of remote Neuracore policy endpoint.",
)
args = parser.parse_args()

# Validate that exactly one of train-run-name or model-path is provided
if (args.train_run_name is None) == (args.model_path is None):
parser.error(
"Exactly one of --train-run-name or --model-path must be provided (not both, not neither)"
)

print("=" * 60)
print("PIPER ROBOT TEST WITH NEURACORE POLICY")
print("=" * 60)
Expand Down Expand Up @@ -722,7 +723,19 @@ def update_visualization(
for data_type, names in model_output_order.items():
print(f" {data_type.name}: {names}")

if args.train_run_name is not None:
if args.remote_endpoint_name is not None:
print(
f"\n🤖 Connecting to remote policy endpoint: {args.remote_endpoint_name}..."
)
try:
policy = nc.policy_remote_server(args.remote_endpoint_name)
except nc.EndpointError:
print(
f"❌ Endpoint '{args.remote_endpoint_name}' not available. "
"Please start it from the Neuracore dashboard."
)
sys.exit(1)
elif args.train_run_name is not None:
print(f"\n🤖 Loading policy from training run: {args.train_run_name}...")
policy = nc.policy(
train_run_name=args.train_run_name,
Expand Down
35 changes: 24 additions & 11 deletions examples/5_rollout_neuracore_policy_minimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,24 @@ def execute_horizon(

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Minimal Piper Policy Test")
parser.add_argument(
policy_group = parser.add_mutually_exclusive_group(required=True)
policy_group.add_argument(
"--train-run-name",
type=str,
default=None,
help="Name of the training run to load policy from (for cloud training). Mutually exclusive with --model-path.",
help="Name of the training run to load policy from (for cloud training).",
)
parser.add_argument(
policy_group.add_argument(
"--model-path",
type=str,
default=None,
help="Path to local model file to load policy from. Mutually exclusive with --train-run-name.",
help="Path to local model file to load policy from.",
)
policy_group.add_argument(
"--remote-endpoint-name",
type=str,
default=None,
help="Name of remote Neuracore policy endpoint.",
)
parser.add_argument(
"--frequency",
Expand All @@ -204,12 +211,6 @@ def execute_horizon(
)
args = parser.parse_args()

# Validate that exactly one of train-run-name or model-path is provided
if (args.train_run_name is None) == (args.model_path is None):
parser.error(
"Exactly one of --train-run-name or --model-path must be provided (not both, not neither)"
)

print("=" * 60)
print("PIPER POLICY ROLLOUT")
print("=" * 60)
Expand Down Expand Up @@ -244,7 +245,19 @@ def execute_horizon(
for data_type, names in model_output_order.items():
print(f" {data_type.name}: {names}")

if args.train_run_name is not None:
if args.remote_endpoint_name is not None:
print(
f"\n🤖 Connecting to remote policy endpoint: {args.remote_endpoint_name}..."
)
try:
policy = nc.policy_remote_server(args.remote_endpoint_name)
except nc.EndpointError:
print(
f"❌ Endpoint '{args.remote_endpoint_name}' not available. "
"Please start it from the Neuracore dashboard."
)
sys.exit(1)
elif args.train_run_name is not None:
print(f"\n🤖 Loading policy from training run: {args.train_run_name}...")
policy = nc.policy(
train_run_name=args.train_run_name,
Expand Down
28 changes: 22 additions & 6 deletions examples/6_visualize_policy_from_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,24 @@
description="Visualize policy predictions from dataset"
)
parser.add_argument("--dataset-name", type=str, required=True, help="Dataset name")
parser.add_argument(
policy_group = parser.add_mutually_exclusive_group(required=True)
policy_group.add_argument(
"--train-run-name", type=str, default=None, help="Training run name"
)
parser.add_argument("--model-path", type=str, default=None, help="Model file path")
policy_group.add_argument(
"--model-path", type=str, default=None, help="Model file path"
)
policy_group.add_argument(
"--remote-endpoint-name",
type=str,
default=None,
help="Name of remote Neuracore policy endpoint to use instead of a local policy.",
)
parser.add_argument(
"--frequency", type=int, default=100, help="Frequency of visualization"
)
args = parser.parse_args()

if (args.train_run_name is None) == (args.model_path is None):
parser.error("Exactly one of --train-run-name or --model-path must be provided")

# Connect to Neuracore
print("🔧 Initializing Neuracore...")
nc.login()
Expand All @@ -68,7 +74,17 @@
DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS: [GRIPPER_LOGGING_NAME],
}

if args.train_run_name:
if args.remote_endpoint_name:
print(f"🤖 Connecting to remote policy endpoint: {args.remote_endpoint_name}...")
try:
policy = nc.policy_remote_server(args.remote_endpoint_name)
except nc.EndpointError:
print(
f"❌ Endpoint '{args.remote_endpoint_name}' not available. "
"Please start it from the Neuracore dashboard."
)
sys.exit(1)
elif args.train_run_name:
print(f"🤖 Loading policy from training run: {args.train_run_name}...")
policy = nc.policy(
train_run_name=args.train_run_name,
Expand Down