Skip to content

Commit 13a9143

Browse files
serbobaclaude
andcommitted
Add MAB-RRT planner with adaptive cylinder sampling
- MAB-RRT: Multi-Armed Bandit RRT with PCA-guided cylindrical sampling - Adaptive burn-in phase discovers valid motion directions - 3-arm MAB: UNIFORM, CYLINDER_UP, CYLINDER_DOWN - CylinderSampler with 2D/3D PCA support - Demo with 6D assembly and 2D bug trap scenarios - Requires adaptive_max_radius in YAML config (no hardcoded fallbacks) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent bdf975a commit 13a9143

11 files changed

Lines changed: 3322 additions & 127 deletions

File tree

demos/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ if (OMPL_BUILD_DEMOS)
6262
add_ompl_demo(demo_MAB_RRT disassembly/MAB_RRT_Demo.cpp)
6363
target_link_libraries(demo_MAB_RRT PRIVATE yaml-cpp::yaml-cpp)
6464
65+
# MAB-RRT Benchmark - Compare planners on occupancy grid
66+
add_ompl_demo(demo_benchmark_occupancy_grid disassembly/benchmark_occupancy_grid.cpp)
67+
target_link_libraries(demo_benchmark_occupancy_grid PRIVATE yaml-cpp::yaml-cpp)
68+
6569
# Copy config file to build directory
6670
configure_file(
6771
${CMAKE_CURRENT_SOURCE_DIR}/disassembly/benchmark_baseline.yaml

demos/disassembly/MAB_RRT_Demo.cpp

Lines changed: 537 additions & 14 deletions
Large diffs are not rendered by default.

demos/disassembly/benchmark_baseline.yaml

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
### MAB-SSRRT Configuration
2-
# Configuration file for MAB-SSRRT (Multi-Armed Bandit Sphere-Sampled RRT) planner
1+
### MAB-RRT Configuration
2+
# Configuration file for MAB-RRT (Multi-Armed Bandit RRT) planner
33

44
# Goal Bias
5-
uniform_goal_bias : 0.75
5+
uniform_goal_bias : 0.1
66
sphere_goal_bias : 0.0000001
77

88
# MAB Parameters
99
#
10-
# MAB-SSRRT uses a flat 3-arm MAB:
10+
# MAB-RRT uses a flat 3-arm MAB:
1111
# - Arm 0: UNIFORM sampling
1212
# - Arm 1: CYLINDER_UP sampling (along +Z axis)
1313
# - Arm 2: CYLINDER_DOWN sampling (along -Z axis)
@@ -16,53 +16,65 @@ mab_uniform_sphere_window_size : 256 # Sliding window size for MAB arm select
1616

1717
# If > 0, after this many consecutive valid CYLINDER samples, the planner forces one UNIFORM sample on the next iteration.
1818
# 0 = feature disabled (default, preserves current behavior)
19-
forced_uniform_after_cylinder_valid_streak: 6
19+
forced_uniform_after_cylinder_valid_streak: 5
2020

2121
# Adaptive Sphere Sampling Parameters
22-
adaptive_quasirandom_sample_size : 8
22+
# Adaptive Sphere Sampling Parameters
23+
adaptive_quasirandom_sample_size : 64
2324
adaptive_start_radius : 1.0
24-
adaptive_min_radius : 0.00001
25-
adaptive_shrink_step : 0.9
26-
adaptive_grow_step : 0.1
25+
adaptive_min_radius : 0.000001
26+
adaptive_max_radius : 25.0
27+
# Direct multipliers (no exp() applied)
28+
# To maintain same behavior as exp(-0.7) and exp(0.9):
29+
# exp(-0.7) ≈ 0.496585 (shrinks by this factor)
30+
# exp(0.9) ≈ 2.459603 (grows by this factor)
31+
adaptive_shrink_step : 0.496585
32+
adaptive_grow_step : 2.459603
2733
adaptive_min_expected_validity_rate : 0.10
2834
adaptive_max_expected_validity_rate: 0.50
29-
adaptive_burnin_max_steps: 10
35+
adaptive_burnin_max_steps: 50
3036

3137
pca_filter_top_percent: 1.0 # Percentage of top PCA components to keep, default 1.0 to use all valid samples for online PCA recalibration
3238

3339
# Initial Uniform Sampling Check
3440
# Before starting adaptive sphere burn-in, try k uniform samples.
3541
# If validity rate > probability, skip adaptive sphere and use uniform only.
3642
initialFreeSamplingProbability: 0.20
37-
initialNumberOfUniformSampleTrials: 10
43+
initialNumberOfUniformSampleTrials: 0
3844

3945
# Cylinder Sampling Configuration
4046
enableDynamicCylinderPCA: true # Enable dynamic recomputation of cylinder PCA axis
4147

4248
# Burn-in Early Exit
43-
# If true, MAB-SSRRT will abort planning after burn-in when not enough valid samples were found around the origin.
44-
burnin_early_exit_enabled: true
49+
# If true, MAB-RRT will abort planning after burn-in when not enough valid samples were found around the origin.
50+
burnin_early_exit_enabled: false
4551
# Minimum number of valid burn-in samples required to continue planning.
4652
# If total_valid_samples < this value AND burnin_early_exit_enabled == true, the planner will exit early from solve().
47-
burnin_min_valid_samples_for_continue: 3
53+
burnin_min_valid_samples_for_continue: 2
4854

4955
# If true, burn-in will exit early if validity rate is 1.0 for consecutiveFullValidityMaxIterations steps
50-
earlyExitOnConsecutiveFullValidity: true
56+
earlyExitOnConsecutiveFullValidity: false
5157
consecutiveFullValidityMaxIterations: 5
5258

5359
# Sphere sampler values
5460

55-
sphere_extension_eps : 0.3 # if sphere sample was walid std::exp(sphere_extension_eps)*currentBestRadius
56-
# to grow the cylinder and sphere as fas as possible
57-
# also used for kappa maybe add another param for kappa
61+
# Extension factor for cylinder height extension (EXPONENTIAL GROWTH)
62+
# extensionHeight = bestRadius * (exp(sphere_extension_eps - 1.0) - 1.0)
63+
# This allows fast exponential growth even when bestRadius is very small (e.g., 0.001)
64+
# Examples:
65+
# sphere_extension_eps = 1.0 → extensionHeight = 0 (no extension)
66+
# sphere_extension_eps = 2.0 → extensionHeight = bestRadius * (exp(1.0) - 1.0) ≈ bestRadius * 1.718
67+
# sphere_extension_eps = 3.0 → extensionHeight = bestRadius * (exp(2.0) - 1.0) ≈ bestRadius * 6.389
68+
# sphere_extension_eps = 4.0 → extensionHeight = bestRadius * (exp(3.0) - 1.0) ≈ bestRadius * 19.085
69+
sphere_extension_eps : 2.0 # 1.0 = no extension, >1.0 = exponential growth via exp(value - 1.0)
5870

5971
# MAB Reward System
6072
# Fixed reward parameters for MAB arm selection
6173
uniform_sampler_fixed_valid_reward: 9999999.0 # Fixed reward for valid uniform samples (high to encourage exploration when disassembled)
6274
uniform_sampler_invalid_reward: 0.0 # Fixed reward for invalid uniform samples
6375

6476
sphere_sampler_fixed_valid_reward: 5.0 # Fixed reward for valid cylinder samples
65-
sphere_sampler_invalid_reward: 1.00 # Fixed reward for invalid cylinder samples (small positive to discourage uniform when assembled)
77+
sphere_sampler_invalid_reward: 0.00 # Fixed reward for invalid cylinder samples (small positive to discourage uniform when assembled)
6678

6779
# Cylinder Sampling Parameters
6880
cylinder_radius_offset_multiplier: 0.1 # Offset multiplier for cylinder radius in PCA fitting

src/ompl/base/samplers/sphere/AdaptiveSphereSampler.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ namespace ompl
7474
void setCylinderRadiusOffsetMultiplier(double val) { cylinderSampler_.setRadiusOffsetMultiplier(val); }
7575
void setFibonacciJitterRadius(double val) { fibSampler_.setJitterRadius(val); }
7676
void setPCAFilterTopPercent(double val) { pcaFilterTopPercent_ = val; }
77+
void freezeCylinder() { cylinderFrozen_ = true; } // Prevent cylinder refitting during planning
78+
void setCylinderHeight(double height) { cylinderSampler_.setCylinderHeight(height); }
7779

7880
void updateComponentCenter(int componentIndex);
7981
void appendFibonacciSamples(double customRadius = -1.0);
@@ -108,6 +110,7 @@ namespace ompl
108110
bool validPointsDirty_ = true;
109111
Eigen::Vector3d cylinderAxis_ = Eigen::Vector3d::UnitZ();
110112
bool hasCylinderAxis_ = false;
113+
bool cylinderFrozen_ = false; // If true, don't refit cylinder during planning
111114

112115
/** @brief Project 3D sphere point to 2D circle (in-place modification) */
113116
void projectSphereToCircle(Point& p, double radius);

src/ompl/base/samplers/sphere/CylinderSampler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace ompl
5151
// Accessors
5252
const Cylinder& getCylinder() const { return cylinder_; }
5353
bool hasValidCylinder() const { return hasValidCylinder_; }
54+
void setCylinderHeight(double height) { cylinder_.height = height; }
5455

5556
private:
5657
Cylinder cylinder_;

src/ompl/base/samplers/sphere/src/AdaptiveSphereSampler.cpp

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,42 @@ double AdaptiveSphereSampler::getValidSampleRate() const
187187

188188
AdaptiveSphereSampler::Point AdaptiveSphereSampler::getRandomSampleFromCylinder(double extensionHeight, int chosenDirection)
189189
{
190-
// Use cached cylinder if available and valid
191-
if (hasCylinderAxis_)
192-
{
193-
cylinderSampler_.fitWithAxis(allValidPoints_, cylinderAxis_, cylinderSampler_.getRadiusOffsetMultiplier());
194-
}
195-
else if (validPointsDirty_)
190+
// DEBUG: Print cylinder state before sampling
191+
const auto& cyl = cylinderSampler_.getCylinder();
192+
OMPL_INFORM("DEBUG ADAPTIVE_SPHERE: getRandomSampleFromCylinder called with extensionHeight=%.6f, direction=%d",
193+
extensionHeight, chosenDirection);
194+
OMPL_INFORM("DEBUG ADAPTIVE_SPHERE: Current cylinder height=%.6f, radius=%.6f, hasValidCylinder=%d",
195+
cyl.height, cyl.radius, cylinderSampler_.hasValidCylinder());
196+
OMPL_INFORM("DEBUG ADAPTIVE_SPHERE: All valid points count=%zu", allValidPoints_.size());
197+
198+
// If cylinder is frozen (after burn-in), don't refit - use existing cylinder
199+
if (!cylinderFrozen_)
196200
{
197-
cylinderSampler_.fit(allValidPoints_, pcaFilterTopPercent_);
198-
validPointsDirty_ = false;
201+
// Preserve the height before refitting (it was set to bestRadius during burn-in)
202+
double preservedHeight = cylinderSampler_.getCylinder().height;
203+
204+
// Use cached cylinder if available and valid
205+
if (hasCylinderAxis_)
206+
{
207+
cylinderSampler_.fitWithAxis(allValidPoints_, cylinderAxis_, cylinderSampler_.getRadiusOffsetMultiplier());
208+
// Restore the preserved height (only update axis, not height)
209+
if (preservedHeight > 0.0) {
210+
cylinderSampler_.setCylinderHeight(preservedHeight);
211+
}
212+
OMPL_INFORM("DEBUG ADAPTIVE_SPHERE: Refitted cylinder with axis, preserved height=%.6f",
213+
preservedHeight);
214+
}
215+
else if (validPointsDirty_)
216+
{
217+
cylinderSampler_.fit(allValidPoints_, pcaFilterTopPercent_);
218+
validPointsDirty_ = false;
219+
// Restore the preserved height (only update axis, not height)
220+
if (preservedHeight > 0.0) {
221+
cylinderSampler_.setCylinderHeight(preservedHeight);
222+
}
223+
OMPL_INFORM("DEBUG ADAPTIVE_SPHERE: Refitted cylinder, preserved height=%.6f",
224+
preservedHeight);
225+
}
199226
}
200227

201228
Point sample = cylinderSampler_.sampleCylinder(extensionHeight, chosenDirection, cylinderSampler_.getSamplingRadiusMultiplier(), rng_);
@@ -210,7 +237,16 @@ AdaptiveSphereSampler::Point AdaptiveSphereSampler::getRandomSampleFromCylinder(
210237

211238
AdaptiveSphereSampler::Point AdaptiveSphereSampler::getRandomSampleFromCylinder(double extensionHeight, int chosenDirection, const Eigen::Vector3d &axis)
212239
{
240+
// Preserve the height before refitting (it was set to bestRadius during burn-in)
241+
double preservedHeight = cylinderSampler_.getCylinder().height;
242+
213243
cylinderSampler_.fitWithAxis(allValidPoints_, axis, cylinderSampler_.getRadiusOffsetMultiplier());
244+
245+
// Restore the preserved height (only update axis, not height)
246+
if (preservedHeight > 0.0) {
247+
cylinderSampler_.setCylinderHeight(preservedHeight);
248+
}
249+
214250
Point sample = cylinderSampler_.sampleCylinder(extensionHeight, chosenDirection, cylinderSampler_.getSamplingRadiusMultiplier(), rng_);
215251

216252
// Project to 2D circle if needed

src/ompl/base/samplers/sphere/src/CylinderSampler.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <ompl/base/samplers/sphere/CylinderSampler.h>
1111
#include <ompl/datastructures/geometry/GeometryUtils.h>
12+
#include <ompl/util/Console.h>
1213
#include <algorithm>
1314
#include <cmath>
1415
#include <limits>
@@ -154,6 +155,15 @@ void CylinderSampler::fitWithAxis(const std::vector<Point>& points, const Eigen:
154155
CylinderSampler::Point CylinderSampler::sampleCylinder(double extensionHeight, int direction, double radiusMultiplier, std::mt19937& rng)
155156
{
156157
if (!hasValidCylinder_) throw std::runtime_error("Cylinder not initialized");
158+
159+
// DEBUG: Print cylinder sampling parameters
160+
OMPL_INFORM("DEBUG CYLINDER_SAMPLER: sampleCylinder called with extensionHeight=%.6f, direction=%d, radiusMultiplier=%.6f",
161+
extensionHeight, direction, radiusMultiplier);
162+
OMPL_INFORM("DEBUG CYLINDER_SAMPLER: Current cylinder height=%.6f, radius=%.6f",
163+
cylinder_.height, cylinder_.radius);
164+
OMPL_INFORM("DEBUG CYLINDER_SAMPLER: Will sample from height %.6f to %.6f (extension adds %.6f)",
165+
cylinder_.height, cylinder_.height + extensionHeight, extensionHeight);
166+
157167
return samplePointInternal(cylinder_, extensionHeight, direction, cylinder_.radius * radiusMultiplier, rng);
158168
}
159169

@@ -177,34 +187,30 @@ CylinderSampler::Point CylinderSampler::samplePointInternal(const Cylinder& cyl,
177187
bool is2D = (std::fabs(cyl.az) < 1e-6);
178188

179189
if (is2D) {
180-
// 2D case: Cylinder becomes a line
181-
// Sample along the line: h * axis + perpendicular jitter
182-
// For 2D, perpendicular to axis (ax, ay) is: (-ay, ax) or (ay, -ax)
183-
// Normalize the axis first
190+
// 2D case: Cylinder becomes a line segment
191+
// Sample only along the axis, NO perpendicular jitter
192+
// In 2D, adding perpendicular jitter would make total distance = sqrt(h² + r²),
193+
// which is incorrect. We want samples to lie exactly on the line.
184194
double axis_norm = std::sqrt(cyl.ax * cyl.ax + cyl.ay * cyl.ay);
185195
if (axis_norm < 1e-10) {
186-
// Degenerate axis, use default
187-
double px = r * std::cos(theta);
188-
double py = r * std::sin(theta);
196+
// Degenerate axis, sample at origin
197+
double px = 0.0;
198+
double py = 0.0;
189199
double pz = 0.0;
190-
double newRadius_ = std::sqrt(px * px + py * py);
200+
double newRadius_ = 0.0;
191201
return Point{px, py, pz, newRadius_, true};
192202
}
193203

194204
double ax_norm = cyl.ax / axis_norm;
195205
double ay_norm = cyl.ay / axis_norm;
196206

197-
// Perpendicular vector: (-ay, ax) - this is 90 degree rotation
198-
double perp_x = -ay_norm;
199-
double perp_y = ax_norm;
200-
201-
// Sample: h * axis + r * (cos(θ) * perp_x + sin(θ) * perp_y)
202-
// For 2D line, we sample along axis with perpendicular jitter
203-
double px = h * ax_norm + r * (std::cos(theta) * perp_x + std::sin(theta) * perp_y);
204-
double py = h * ay_norm + r * (std::cos(theta) * perp_y - std::sin(theta) * perp_x);
207+
// Sample only along the axis: h * axis (no perpendicular jitter)
208+
double px = h * ax_norm;
209+
double py = h * ay_norm;
205210
double pz = 0.0; // Always 0 for 2D
206211

207-
double newRadius_ = std::sqrt(px * px + py * py);
212+
// Distance from origin is just |h| (absolute value of distance along axis)
213+
double newRadius_ = std::abs(h);
208214

209215
return Point{px, py, pz, newRadius_, true};
210216
}

0 commit comments

Comments
 (0)