diff --git a/src/data_morph/shapes/points/scatter.py b/src/data_morph/shapes/points/scatter.py index 3b847ca5..5d7b02de 100644 --- a/src/data_morph/shapes/points/scatter.py +++ b/src/data_morph/shapes/points/scatter.py @@ -30,21 +30,16 @@ class Scatter(PointCollection): def __init__(self, dataset: Dataset) -> None: rng = np.random.default_rng(1) - center = (dataset.data.x.mean(), dataset.data.y.mean()) + morph_range = dataset.morph_bounds.range + center = dataset.morph_bounds.center points = [center] - max_radius = max(dataset.data.x.std(), dataset.data.y.std()) points.extend( [ ( - center[0] - + np.cos(angle) * radius - + rng.standard_normal() * max_radius, - center[1] - + np.sin(angle) * radius - + rng.standard_normal() * max_radius, + center[0] + np.cos(angle) * rng.uniform(0, morph_range[0] / 2), + center[1] + np.sin(angle) * rng.uniform(0, morph_range[1] / 2), ) - for radius in np.linspace(max_radius // 5, max_radius, num=5) - for angle in np.linspace(0, 360, num=50, endpoint=False) + for angle in np.linspace(0, 720, num=100, endpoint=False) ] ) super().__init__(*points)