diff --git a/pyba/Camera.py b/pyba/Camera.py index c7c3072..75237f8 100644 --- a/pyba/Camera.py +++ b/pyba/Camera.py @@ -142,17 +142,38 @@ def summarize(self): def plot_2d( self, img_id: int, - points2d: Optional[np.ndarray] = None, + points: Optional[np.ndarray] = None, bones: Optional[np.ndarray] = None, colors: Optional[List[Tuple]] = None, ) -> np.ndarray: + """ + Parameters + ---------- + ... + + points: either points2d to plot directly, or points3d in which case they will + be projected for plotting. + + """ img = self.get_image(img_id) - points2d = self.points2d[img_id] if points2d is None else points2d + if points is None: + + points = self.points2d[img_id] + + if points.shape[-1] == 3: + # points are given as 3d coords + points3d = points[img_id] # (n_joints, 3) + points2d = self.project(points3d[np.newaxis, ...]).squeeze() # project works only in batches + elif points.shape[-1] == 2: + points2d = points + else: + raise ValueError(f"Expected points to have shape (..., 2) or (..., 3), but got shape {points.shape}") + # bones if bones is not None: for idx, b in enumerate(bones): - if self.can_see(img_id, b[0]): + if self.can_see(img_id, b[0]) and self.can_see(img_id, b[1]): img = cv2.line( img, tuple(points2d[b[0]].astype(int)), @@ -164,7 +185,7 @@ def plot_2d( for jid in range(self.get_njoints()): if self.can_see(img_id, jid): img = cv2.circle( - img, tuple(points2d[jid].astype(int)), 5, [0, 0, 128], 5 + img, tuple(points2d[jid].T.astype(int)), 5, [0, 0, 128], 5 ) return img