diff --git a/cadquery/assembly.py b/cadquery/assembly.py index 65c024702..de343411f 100644 --- a/cadquery/assembly.py +++ b/cadquery/assembly.py @@ -17,7 +17,7 @@ from warnings import warn from .cq import Workplane -from .occ_impl.shapes import Shape, Compound, isSubshape +from .occ_impl.shapes import Shape, Compound, isSubshape, compound from .occ_impl.geom import Location from .occ_impl.assembly import Color from .occ_impl.solver import ( @@ -38,7 +38,7 @@ from .occ_impl.importers.assembly import importStep as _importStep, importXbf, importXml from .selectors import _expression_grammar as _selector_grammar -from .utils import deprecate +from .utils import deprecate, BiDict # type definitions AssemblyObjects = Union[Shape, Workplane, None] @@ -48,6 +48,8 @@ PATH_DELIM = "/" # entity selector grammar definition + + def _define_grammar(): from pyparsing import ( @@ -98,9 +100,9 @@ class Assembly(object): constraints: List[Constraint] # Allows metadata to be stored for exports - _subshape_names: dict[Shape, str] - _subshape_colors: dict[Shape, Color] - _subshape_layers: dict[Shape, str] + _subshape_names: BiDict[Shape, str] + _subshape_colors: BiDict[Shape, Color] + _subshape_layers: BiDict[Shape, str] _solve_result: Optional[Dict[str, Any]] @@ -147,9 +149,9 @@ def __init__( self._solve_result = None - self._subshape_names = {} - self._subshape_colors = {} - self._subshape_layers = {} + self._subshape_names = BiDict() + self._subshape_colors = BiDict() + self._subshape_layers = BiDict() def _copy(self) -> "Assembly": """ @@ -158,9 +160,9 @@ def _copy(self) -> "Assembly": rv = self.__class__(self.obj, self.loc, self.name, self.color, self.metadata) - rv._subshape_colors = dict(self._subshape_colors) - rv._subshape_names = dict(self._subshape_names) - rv._subshape_layers = dict(self._subshape_layers) + rv._subshape_colors = BiDict(self._subshape_colors) + rv._subshape_names = BiDict(self._subshape_names) + rv._subshape_layers = BiDict(self._subshape_layers) for ch in self.children: ch_copy = ch._copy() @@ -754,12 +756,22 @@ def addSubshape( return self - def __getitem__(self, name: str) -> "Assembly": + def __getitem__(self, name: str) -> Union["Assembly", Shape]: """ [] based access to children. + """ - return self.objects[name] + if name in self.objects: + return self.objects[name] + elif name[0] in self.objects: + rv = self.objects[name[0]] + + if name[1] in rv._subshape_names.inv: + rv = compound(self._subshape_names.inv[name[1]]) + return rv[0] if len(rv) == 1 else compound(rv) + + raise KeyError def _ipython_key_completions_(self) -> List[str]: """ @@ -772,22 +784,29 @@ def __contains__(self, name: str) -> bool: return name in self.objects - def __getattr__(self, name: str) -> "Assembly": + def __getattr__(self, name: str) -> Union["Assembly", Shape]: """ . based access to children. """ if name in self.objects: return self.objects[name] + elif name in self._subshape_names.inv: + rv = self._subshape_names.inv[name] + return rv[0] if len(rv) == 1 else compound(rv) - raise AttributeError + raise AttributeError(f"{name} is not an attribute of {self}") def __dir__(self): """ Modified __dir__ for autocompletion. """ - return list(self.__dict__) + list(ch.name for ch in self.children) + return ( + list(self.__dict__) + + list(ch.name for ch in self.children) + + list(self._subshape_names.inv.keys()) + ) def __getstate__(self): """ diff --git a/cadquery/fig.py b/cadquery/fig.py index a68a43e59..b0ae3a8a9 100644 --- a/cadquery/fig.py +++ b/cadquery/fig.py @@ -9,14 +9,14 @@ from threading import Thread from itertools import chain from webbrowser import open_new_tab +from uuid import uuid1 from typish import instance_of from trame.app import get_server from trame.app.core import Server -from trame.widgets import html, vtk as vtk_widgets, client -from trame.ui.html import DivLayout - +from trame.widgets import vtk as vtk_widgets, client, trame, vuetify3 as v3 +from trame.ui.vuetify3 import SinglePageWithDrawerLayout from . import Shape from .vis import style, Showable, ShapeLike, _split_showables @@ -43,7 +43,7 @@ class Figure: ren: vtkRenderer view: vtk_widgets.VtkRemoteView shapes: dict[ShapeLike, list[vtkProp3D]] - actors: list[vtkProp3D] + actors: dict[str, tuple[vtkProp3D, ...]] loop: AbstractEventLoop thread: Thread empty: bool @@ -107,24 +107,104 @@ def __init__(self, port: int = 18081): self.ren = renderer self.shapes = {} - self.actors = [] + self.actors = {} + self.active = None # server - server = get_server("CQ-server") - server.client_type = "vue3" + server = get_server("CQ-server", client_type="vue3") + self.server = server + + # state + self.state = self.server.state + + self.state.actors = [] + self.state.selected = None # layout - with DivLayout(server): + self.layout = SinglePageWithDrawerLayout(server, show_drawer=False) + with self.layout as layout: client.Style("body { margin: 0; }") - with html.Div(style=FULL_SCREEN): - self.view = vtk_widgets.VtkRemoteView( - win, interactive_ratio=1, interactive_quality=100 + layout.title.set_text("CQ viewer") + layout.footer.hide() + + with layout.toolbar: + + BSTYLE = "display: block;" + + v3.VBtn( + click=lambda: self._fit(), + flat=True, + density="compact", + icon="mdi-crop-free", + style=BSTYLE, + ) + + v3.VBtn( + click=lambda: self._view((0, 0, 0), (1, 1, 1), (0, 0, 1)), + flat=True, + density="compact", + icon="mdi-axis-arrow", + style=BSTYLE, + ) + + v3.VBtn( + click=lambda: self._view((0, 0, 0), (1, 0, 0), (0, 0, 1)), + flat=True, + density="compact", + icon="mdi-axis-x-arrow", + style=BSTYLE, + ) + + v3.VBtn( + click=lambda: self._view((0, 0, 0), (0, 1, 0), (0, 0, 1)), + flat=True, + density="compact", + icon="mdi-axis-y-arrow", + style=BSTYLE, + ) + + v3.VBtn( + click=lambda: self._view((0, 0, 0), (0, 0, 1), (0, 1, 0)), + flat=True, + density="compact", + icon="mdi-axis-z-arrow", + style=BSTYLE, + ) + + v3.VBtn( + click=lambda: self._pop(), + flat=True, + density="compact", + icon="mdi-file-document-remove-outline", + style=BSTYLE, + ) + + v3.VBtn( + click=lambda: self._clear([]), + flat=True, + density="compact", + icon="mdi-delete-outline", + style=BSTYLE, + ) + + with layout.content: + with v3.VContainer( + fluid=True, classes="pa-0 fill-height", + ): + self.view = vtk_widgets.VtkRemoteView( + win, interactive_ratio=1, interactive_quality=100 + ) + + with layout.drawer: + self.tree = trame.GitTree( + sources=("actors",), + visibility_change=(self.onVisibility, "[$event]"), + actives_change=(self.onSelection, "[$event]"), ) server.state.flush() - self.server = server self.loop = new_event_loop() def _run_loop(): @@ -159,7 +239,20 @@ def _run(self, coro) -> Future: return run_coroutine_threadsafe(coro, self.loop) - def show(self, *showables: Showable | vtkProp3D | list[vtkProp3D], **kwargs): + def _update_state(self, name: str): + async def _(): + + self.state.dirty(name) + self.state.flush() + + self._run(_()) + + def show( + self, + *showables: Showable | vtkProp3D | list[vtkProp3D], + name: Optional[str] = None, + **kwargs, + ): """ Show objects. """ @@ -170,6 +263,9 @@ def show(self, *showables: Showable | vtkProp3D | list[vtkProp3D], **kwargs): pts = style(vecs, **kwargs) axs = style(locs, **kwargs) + # to be added to state + new_actors = [] + for s in shapes: # do not show markers by default if "markersize" not in kwargs: @@ -181,14 +277,18 @@ def show(self, *showables: Showable | vtkProp3D | list[vtkProp3D], **kwargs): for actor in actors: self.ren.AddActor(actor) + new_actors.extend(actors) + for prop in chain(props, axs): - self.actors.append(prop) self.ren.AddActor(prop) + new_actors.append(prop) + if vecs: - self.actors.append(*pts) self.ren.AddActor(*pts) + new_actors.append(*pts) + # store to enable pop self.last = (shapes, axs, pts if vecs else None, props) @@ -202,76 +302,155 @@ async def _show(): self.fit() self.empty = False + # update actors + uuid = str(uuid1()) + self.state.actors.append( + { + "id": uuid, + "parent": "0", + "visible": 1, + "name": f"{name if name else type(showables[0]).__name__} at {id(showables[0]):x}", + } + ) + self._update_state("actors") + + self.actors[uuid] = tuple(new_actors) + return self + async def _fit(self): + self.ren.ResetCamera() + self.view.update() + def fit(self): """ Update view to fit all objects. """ - async def _show(): - self.ren.ResetCamera() - self.view.update() + self._run(self._fit()) - self._run(_show()) + return self + + async def _view(self, foc, pos, up): + + cam = self.ren.GetActiveCamera() + + cam.SetViewUp(*up) + cam.SetFocalPoint(*foc) + cam.SetPosition(*pos) + + self.ren.ResetCamera() + + self.view.update() + + def iso(self): + + self._run(self._view((0, 0, 0), (1, 1, 1), (0, 0, 1))) return self - def clear(self, *shapes: Shape | vtkProp3D): - """ - Clear specified objects. If no arguments are passed, clears all objects. - """ + def up(self): - async def _clear(): + self._run(self._view((0, 0, 0), (0, 0, 1), (0, 1, 0))) - if len(shapes) == 0: - self.ren.RemoveAllViewProps() + return self - self.actors.clear() - self.shapes.clear() + pass - for s in shapes: - if instance_of(s, ShapeLike): - for a in self.shapes[s]: - self.ren.RemoveActor(a) + def front(self): - del self.shapes[s] - else: - self.actors.remove(s) - self.ren.RemoveActor(s) + self._run(self._view((0, 0, 0), (1, 0, 0), (0, 0, 1))) - self.view.update() + return self + + def side(self): + + self._run(self._view((0, 0, 0), (0, 1, 0), (0, 0, 1))) + + return self + + async def _clear(self, shapes): + + if len(shapes) == 0: + self.ren.RemoveAllViewProps() + + self.actors.clear() + self.shapes.clear() + + self.state.actors = [] + + for s in shapes: + if instance_of(s, ShapeLike): + for a in self.shapes[s]: + self.ren.RemoveActor(a) + + del self.shapes[s] + else: + for k, v in self.actors.items(): + if s in v: + for el in self.actors.pop(k): + self.ren.RemoveActor(el) + + break + + self._update_state("actors") + self.view.update() + + def clear(self, *shapes: Shape | vtkProp3D): + """ + Clear specified objects. If no arguments are passed, clears all objects. + """ # reset last, bc we don't want to keep track of what was removed self.last = None - future = self._run(_clear()) + future = self._run(self._clear(shapes)) future.result() return self + async def _pop(self): + + if self.active is None: + self.active = self.actors[-1]["id"] + + if self.active in self.actors: + for act in self.actors[self.active]: + self.ren.RemoveActor(act) + + self.actors.pop(self.active) + + # update corresponding state + for i, el in enumerate(self.state.actors): + if el["id"] == self.active: + self.state.actors.pop(i) + self._update_state("actors") + break + + self.active = None + + else: + return + + self.view.update() + def pop(self): """ - Clear the last showable. + Clear the selected showable. """ - async def _pop(): + self._run(self._pop()) - (shapes, axs, pts, props) = self.last + return self - for s in shapes: - for act in self.shapes.pop(s): - self.ren.RemoveActor(act) + def onVisibility(self, event): - for act in chain(axs, props): - self.ren.RemoveActor(act) - self.actors.remove(act) + actors = self.actors[event["id"]] - if pts: - self.ren.RemoveActor(*pts) - self.actors.remove(*pts) + for act in actors: + act.SetVisibility(event["visible"]) - self.view.update() + self.view.update() - self._run(_pop()) + def onSelection(self, event): - return self + self.active = event[0] diff --git a/cadquery/occ_impl/assembly.py b/cadquery/occ_impl/assembly.py index e08c4a6c1..579ebbd7f 100644 --- a/cadquery/occ_impl/assembly.py +++ b/cadquery/occ_impl/assembly.py @@ -52,6 +52,7 @@ from .shapes import Shape, Solid, Compound from .exporters.vtk import toString from ..cq import Workplane +from ..utils import BiDict # type definitions AssemblyObjects = Union[Shape, Workplane, None] @@ -210,15 +211,15 @@ def children(self) -> Iterable["AssemblyProtocol"]: ... @property - def _subshape_names(self) -> Dict[Shape, str]: + def _subshape_names(self) -> BiDict[Shape, str]: ... @property - def _subshape_colors(self) -> Dict[Shape, Color]: + def _subshape_colors(self) -> BiDict[Shape, Color]: ... @property - def _subshape_layers(self) -> Dict[Shape, str]: + def _subshape_layers(self) -> BiDict[Shape, str]: ... @overload @@ -276,7 +277,7 @@ def __iter__( ) -> Iterator[Tuple[Shape, str, Location, Optional[Color]]]: ... - def __getitem__(self, name: str) -> Self: + def __getitem__(self, name: str) -> Self | Shape: ... def __contains__(self, name: str) -> bool: diff --git a/cadquery/occ_impl/geom.py b/cadquery/occ_impl/geom.py index a9e3e9dff..6a36440da 100644 --- a/cadquery/occ_impl/geom.py +++ b/cadquery/occ_impl/geom.py @@ -95,6 +95,9 @@ def __init__(self, *args): fV = gp_Vec(args[0].XYZ()) elif isinstance(args[0], gp_XYZ): fV = gp_Vec(args[0]) + elif hasattr(args[0], "__array__"): + tmp = args[0].ravel() + fV = gp_Vec(tmp[0], tmp[1], tmp[2]) else: raise TypeError("Expected three floats, OCC gp_, or 3-tuple") elif len(args) == 0: @@ -1173,7 +1176,7 @@ def __setstate__(self, data: BytesIO): ls = BinTools_LocationSet() ls.Read(data) - if ls.NbLocations() > 0: - self.wrapped = ls.Location(1) + if ls.NbLocations() == 0: + self.wrapped = TopLoc_Location() else: - self.wrapped = TopLoc_Location() # identity location + self.wrapped = ls.Location(1) diff --git a/cadquery/occ_impl/importers/assembly.py b/cadquery/occ_impl/importers/assembly.py index 1ae810e60..9648b6d55 100644 --- a/cadquery/occ_impl/importers/assembly.py +++ b/cadquery/occ_impl/importers/assembly.py @@ -253,7 +253,7 @@ def _process_label(lbl: TDF_Label, parent: AssemblyProtocol): parent.add(tmp) # change the current assy to handle subshape data - current = parent[comp_name] + current = cast(AssemblyProtocol, parent[comp_name]) # iterate over subshape and handle names, layers and colors subshape_labels = TDF_LabelSequence() @@ -326,6 +326,7 @@ def _process_label(lbl: TDF_Label, parent: AssemblyProtocol): assy.objects.pop(assy.name) assy.name = str(name_attr.Get().ToExtString()) assy.objects[assy.name] = assy + if cq_color: assy.color = cq_color @@ -350,7 +351,7 @@ def _process_label(lbl: TDF_Label, parent: AssemblyProtocol): # extras on successive round-trips. exportStepMeta does not add the extra top-level # node and so does not exhibit this behavior. if assy.name in imported_assy: - imported_assy = imported_assy[assy.name] + imported_assy = cast(AssemblyProtocol, imported_assy[assy.name]) # comp_labels = TDF_LabelSequence() # shape_tool.GetComponents_s(top_level_label, comp_labels) # comp_label = comp_labels.Value(1) diff --git a/cadquery/occ_impl/nurbs.py b/cadquery/occ_impl/nurbs.py new file mode 100644 index 000000000..4e8bc7fbf --- /dev/null +++ b/cadquery/occ_impl/nurbs.py @@ -0,0 +1,1991 @@ +# %% imports +import numpy as np +import scipy.sparse as sp + +from numba import njit as _njit + +from typing import NamedTuple, Optional, Tuple, List, Union, cast + +from math import comb + +from numpy.typing import NDArray +from numpy import linspace, ndarray + +from casadi import ldl, ldl_solve + +from OCP.Geom import Geom_BSplineCurve, Geom_BSplineSurface +from OCP.TColgp import TColgp_Array1OfPnt, TColgp_Array2OfPnt +from OCP.TColStd import ( + TColStd_Array1OfInteger, + TColStd_Array1OfReal, +) +from OCP.gp import gp_Pnt +from OCP.BRepBuilderAPI import BRepBuilderAPI_MakeEdge, BRepBuilderAPI_MakeFace + +from .shapes import Face, Edge + +from multimethod import multidispatch + +njit = _njit(cache=True, error_model="numpy", fastmath=True, nogil=True, parallel=False) + +njiti = _njit( + cache=True, inline="always", error_model="numpy", fastmath=True, parallel=False +) + + +# %% internal helpers + + +def _colPtsArray(pts: NDArray) -> TColgp_Array1OfPnt: + + rv = TColgp_Array1OfPnt(1, pts.shape[0]) + + for i, p in enumerate(pts): + rv.SetValue(i + 1, gp_Pnt(*p)) + + return rv + + +def _colPtsArray2(pts: NDArray) -> TColgp_Array2OfPnt: + + assert pts.ndim == 3 + + nu, nv, _ = pts.shape + + rv = TColgp_Array2OfPnt(1, len(pts), 1, len(pts[0])) + + for i, row in enumerate(pts): + for j, pt in enumerate(row): + rv.SetValue(i + 1, j + 1, gp_Pnt(*pt)) + + return rv + + +def _colRealArray(knots: NDArray) -> TColStd_Array1OfReal: + + rv = TColStd_Array1OfReal(1, len(knots)) + + for i, el in enumerate(knots): + rv.SetValue(i + 1, el) + + return rv + + +def _colIntArray(knots: NDArray) -> TColStd_Array1OfInteger: + + rv = TColStd_Array1OfInteger(1, len(knots)) + + for i, el in enumerate(knots): + rv.SetValue(i + 1, el) + + return rv + + +# %% vocabulary types + +Array = ndarray # NDArray[np.floating] +ArrayI = ndarray # NDArray[np.int_] + + +class COO(NamedTuple): + """ + COO sparse matrix container. + """ + + i: ArrayI + j: ArrayI + v: Array + + def coo(self): + + return sp.coo_matrix((self.v, (self.i, self.j))) + + def csc(self): + + return self.coo().tocsc() + + def csr(self): + + return self.coo().tocsr() + + +class Curve(NamedTuple): + """ + B-spline curve container. + """ + + pts: Array + knots: Array + order: int + periodic: bool + + def curve(self) -> Geom_BSplineCurve: + + if self.periodic: + mults = _colIntArray(np.ones_like(self.knots, dtype=int)) + knots = _colRealArray(self.knots) + else: + unique_knots, mults_arr = np.unique(self.knots, return_counts=True) + knots = _colRealArray(unique_knots) + mults = _colIntArray(mults_arr) + + return Geom_BSplineCurve( + _colPtsArray(self.pts), knots, mults, self.order, self.periodic, + ) + + def edge(self) -> Edge: + + return Edge(BRepBuilderAPI_MakeEdge(self.curve()).Shape()) + + @classmethod + def fromEdge(cls, e: Edge): + + assert ( + e.geomType() == "BSPLINE" + ), "B-spline geometry required, try converting first." + + g = e._geomAdaptor().BSpline() + + knots = np.repeat(list(g.Knots()), list(g.Multiplicities())) + pts = np.array([(p.X(), p.Y(), p.Z()) for p in g.Poles()]) + order = g.Degree() + periodic = g.IsPeriodic() + + return cls(pts, knots, order, periodic) + + def __call__(self, us: Array) -> Array: + + return nbCurve( + np.atleast_1d(us), self.order, self.knots, self.pts, self.periodic + ) + + def der(self, us: NDArray, dorder: int) -> NDArray: + + return nbCurveDer( + np.atleast_1d(us), self.order, dorder, self.knots, self.pts, self.periodic + ) + + +class Surface(NamedTuple): + """ + B-spline surface container. + """ + + pts: Array + uknots: Array + vknots: Array + uorder: int + vorder: int + uperiodic: bool + vperiodic: bool + + def surface(self) -> Geom_BSplineSurface: + + unique_knots, mults_arr = np.unique(self.uknots, return_counts=True) + uknots = _colRealArray(unique_knots) + umults = _colIntArray(mults_arr) + + unique_knots, mults_arr = np.unique(self.vknots, return_counts=True) + vknots = _colRealArray(unique_knots) + vmults = _colIntArray(mults_arr) + + return Geom_BSplineSurface( + _colPtsArray2(self.pts), + uknots, + vknots, + umults, + vmults, + self.uorder, + self.vorder, + self.uperiodic, + self.vperiodic, + ) + + def face(self, tol: float = 1e-3) -> Face: + + return Face(BRepBuilderAPI_MakeFace(self.surface(), tol).Shape()) + + @classmethod + def fromFace(cls, f: Face): + """ + Construct a surface from a face. + """ + + assert ( + f.geomType() == "BSPLINE" + ), "B-spline geometry required, try converting first." + + g = cast(Geom_BSplineSurface, f._geomAdaptor()) + + uknots = np.repeat(list(g.UKnots()), list(g.UMultiplicities())) + vknots = np.repeat(list(g.VKnots()), list(g.VMultiplicities())) + + tmp = [] + for i in range(1, g.NbUPoles() + 1): + tmp.append( + [ + [g.Pole(i, j).X(), g.Pole(i, j).Y(), g.Pole(i, j).Z(),] + for j in range(1, g.NbVPoles() + 1) + ] + ) + + pts = np.array(tmp) + + uorder = g.UDegree() + vorder = g.VDegree() + + uperiodic = g.IsUPeriodic() + vperiodic = g.IsVPeriodic() + + return cls(pts, uknots, vknots, uorder, vorder, uperiodic, vperiodic) + + def __call__(self, u: Array, v: Array) -> Array: + """ + Evaluate surface at (u,v) points. + """ + + return nbSurface( + np.atleast_1d(u), + np.atleast_1d(v), + self.uorder, + self.vorder, + self.uknots, + self.vknots, + self.pts, + self.uperiodic, + self.vperiodic, + ) + + def der(self, u: Array, v: Array, dorder: int) -> Array: + """ + Evaluate surface and derivatives at (u,v) points. + """ + + return nbSurfaceDer( + np.atleast_1d(u), + np.atleast_1d(v), + self.uorder, + self.vorder, + dorder, + self.uknots, + self.vknots, + self.pts, + self.uperiodic, + self.vperiodic, + ) + + def normal(self, u: Array, v: Array) -> Tuple[Array, Array]: + """ + Evaluate surface normals. + """ + + ders = self.der(u, v, 1) + + du = ders[:, 1, 0, :].squeeze() + dv = ders[:, 0, 1, :].squeeze() + + rv = np.atleast_2d(np.cross(du, dv)) + rv /= np.linalg.norm(rv, axis=1)[:, np.newaxis] + + return rv, ders[:, 0, 0, :].squeeze() + + +# %% basis functions + + +@njiti +def _preprocess( + u: Array, order: int, knots: Array, periodic: float +) -> Tuple[Array, Array, Optional[int], Optional[int], int]: + """ + Helper for handling peridocity. This function extends the knot vector, + wraps the parameters and calculates the delta span. + """ + + # handle periodicity + if periodic: + period = knots[-1] - knots[0] + u_ = u % period + knots_ext = extendKnots(order, knots) + minspan = 0 + maxspan = len(knots) - 1 + deltaspan = order - 1 + else: + u_ = u + knots_ext = knots + minspan = order + maxspan = knots.shape[0] - order - 1 + deltaspan = 0 + + return u_, knots_ext, minspan, maxspan, deltaspan + + +@njiti +def extendKnots(order: int, knots: Array) -> Array: + """ + Knot vector extension for periodic b-splines. + + Parameters + ---------- + order : int + B-spline order. + knots : Array + Knot vector. + + Returns + ------- + knots_ext : Array + Extended knots vector. + + """ + + return np.concat((knots[-order:-1] - knots[-1], knots, knots[-1] + knots[1:order])) + + +@njiti +def nbFindSpan( + u: float, + order: int, + knots: Array, + low: Optional[int] = None, + high: Optional[int] = None, +) -> int: + """ + NURBS book A2.1 with modifications to handle periodic usecases. + + Parameters + ---------- + u : float + Parameter value. + order : int + Spline order. + knots : ndarray + Knot vector. + + Returns + ------- + Span index. + + """ + + if low is None: + low = order + + if high is None: + high = knots.shape[0] - order - 1 + + mid = (low + high) // 2 + + if u >= knots[-1]: + return high - 1 # handle last span + elif u < knots[0]: + return low + + while u < knots[mid] or u >= knots[mid + 1]: + if u < knots[mid]: + high = mid + else: + low = mid + + mid = (low + high) // 2 + + return mid + + +@njiti +def nbBasis(i: int, u: float, order: int, knots: Array, out: Array): + """ + NURBS book A2.2 + + Parameters + ---------- + i : int + Span index. + u : float + Parameter value. + order : int + B-spline order. + knots : ndarray + Knot vector. + out : ndarray + B-spline basis function values. + + Returns + ------- + None. + + """ + + out[0] = 1.0 + + left = np.zeros_like(out) + right = np.zeros_like(out) + + for j in range(1, order + 1): + left[j] = u - knots[i + 1 - j] + right[j] = knots[i + j] - u + + saved = 0.0 + + for r in range(j): + temp = out[r] / (right[r + 1] + left[j - r]) + out[r] = saved + right[r + 1] * temp + saved = left[j - r] * temp + + out[j] = saved + + +@njiti +def nbBasisDer(i: int, u: float, order: int, dorder: int, knots: Array, out: Array): + """ + NURBS book A2.3 + + Parameters + ---------- + i : int + Span index. + u : float + Parameter value. + order : int + B-spline order. + dorder : int + Derivative order. + knots : ndarray + Knot vector. + out : ndarray + B-spline basis function and derivative values. + + Returns + ------- + None. + + """ + + ndu = np.zeros((order + 1, order + 1)) + + left = np.zeros(order + 1) + right = np.zeros(order + 1) + + a = np.zeros((2, order + 1)) + + ndu[0, 0] = 1 + + for j in range(1, order + 1): + left[j] = u - knots[i + 1 - j] + right[j] = knots[i + j] - u + + saved = 0.0 + + for r in range(j): + ndu[j, r] = right[r + 1] + left[j - r] + temp = ndu[r, j - 1] / ndu[j, r] + + ndu[r, j] = saved + right[r + 1] * temp + saved = left[j - r] * temp + + ndu[j, j] = saved + + # store the basis functions + out[0, :] = ndu[:, order] + + # calculate and store derivatives + + # loop over basis functions + for r in range(order + 1): + s1 = 0 + s2 = 1 + + a[0, 0] = 1 + + # loop over derivative orders + for k in range(1, dorder + 1): + d = 0.0 + rk = r - k + pk = order - k + + if r >= k: + a[s2, 0] = a[s1, 0] / ndu[pk + 1, rk] + d = a[s2, 0] * ndu[rk, pk] + + if rk >= -1: + j1 = 1 + else: + j1 = -rk + + if r - 1 <= pk: + j2 = k - 1 + else: + j2 = order - r + + for j in range(j1, j2 + 1): + a[s2, j] = (a[s1, j] - a[s1, j - 1]) / ndu[pk + 1, rk + j] + d += a[s2, j] * ndu[rk + j, pk] + + if r <= pk: + a[s2, k] = -a[s1, k - 1] / ndu[pk + 1, r] + d += a[s2, k] * ndu[r, pk] + + # store the kth derivative of rth basis + out[k, r] = d + + # switch + s1, s2 = s2, s1 + + # multiply recursively by the order + r = order + + for k in range(1, dorder + 1): + out[k, :] *= r + r *= order - k + + +# %% evaluation + + +@njit +def nbCurve( + u: Array, order: int, knots: Array, pts: Array, periodic: bool = False +) -> Array: + """ + NURBS book A3.1 with modifications to handle periodicity. + + Parameters + ---------- + u : Array + Parameter values. + order : int + B-spline order. + knots : Array + Knot vector. + pts : Array + Control points. + periodic : bool, optional + Periodicity flag. The default is False. + + Returns + ------- + Array + Curve values. + + """ + + # number of control points + nb = pts.shape[0] + + u_, knots_ext, minspan, maxspan, deltaspan = _preprocess(u, order, knots, periodic) + + # number of param values + nu = np.size(u) + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros(n) + + # initialize + out = np.zeros((nu, 3)) + + for i in range(nu): + ui = u_[i] + + # find span + span = nbFindSpan(ui, order, knots, minspan, maxspan) + deltaspan + + # evaluate chunk + nbBasis(span, ui, order, knots_ext, temp) + + # multiply by ctrl points + for j in range(order + 1): + out[i, :] += temp[j] * pts[(span - order + j) % nb, :] + + return out + + +@njit +def nbCurveDer( + u: Array, order: int, dorder: int, knots: Array, pts: Array, periodic: bool = False +) -> Array: + """ + NURBS book A3.2 with modifications to handle periodicity. + + Parameters + ---------- + u : Array + Parameter values. + order : int + B-spline order. + dorder : int + Derivative order. + knots : Array + Knot vector. + pts : Array + Control points. + periodic : bool, optional + Periodicity flag. The default is False. + + + Returns + ------- + Array + Curve values and derivatives. + + """ + # number of control points + nb = pts.shape[0] + + # handle periodicity + u_, knots_ext, minspan, maxspan, deltaspan = _preprocess(u, order, knots, periodic) + + # number of param values + nu = np.size(u) + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros((dorder + 1, n)) + + # initialize + out = np.zeros((nu, dorder + 1, 3)) + + for i in range(nu): + ui = u_[i] + + # find span + span = nbFindSpan(ui, order, knots, minspan, maxspan) + deltaspan + + # evaluate chunk + nbBasisDer(span, ui, order, dorder, knots_ext, temp) + + # multiply by ctrl points + for j in range(order + 1): + for k in range(dorder + 1): + out[i, k, :] += temp[k, j] * pts[(span - order + j) % nb, :] + + return out + + +@njit +def nbSurface( + u: Array, + v: Array, + uorder: int, + vorder: int, + uknots: Array, + vknots: Array, + pts: Array, + uperiodic: bool = False, + vperiodic: bool = False, +) -> Array: + """ + NURBS book A3.5 with modifications to handle periodicity. + + Parameters + ---------- + u : Array + U parameter values. + v : Array + V parameter values. + uorder : int + B-spline u order. + vorder : int + B-spline v order. + uknots : Array + U knot vector.. + vknots : Array + V knot vector.. + pts : Array + Control points. + uperiodic : bool, optional + U periodicity flag. The default is False. + vperiodic : bool, optional + V periodicity flag. The default is False. + + Returns + ------- + Array + Surface values. + + """ + + # number of control points + nub = pts.shape[0] + nvb = pts.shape[1] + + # handle periodicity + u_, uknots_ext, minspanu, maxspanu, deltaspanu = _preprocess( + u, uorder, uknots, uperiodic + ) + v_, vknots_ext, minspanv, maxspanv, deltaspanv = _preprocess( + v, vorder, vknots, vperiodic + ) + + # number of param values + nu = np.size(u) + + # chunck sizes + un = uorder + 1 + vn = vorder + 1 + + # temp chunck storage + utemp = np.zeros(un) + vtemp = np.zeros(vn) + + # initialize + out = np.zeros((nu, 3)) + + for i in range(nu): + ui = u_[i] + vi = v_[i] + + # find span + uspan = nbFindSpan(ui, uorder, uknots, minspanu, maxspanu) + deltaspanu + vspan = nbFindSpan(vi, vorder, vknots, minspanv, maxspanv) + deltaspanv + + # evaluate chunk + nbBasis(uspan, ui, uorder, uknots_ext, utemp) + nbBasis(vspan, vi, vorder, vknots_ext, vtemp) + + uind = uspan - uorder + temp = np.empty(3) + + # multiply by ctrl points: Nu.T*P*Nv + for j in range(vorder + 1): + + temp[:] = 0.0 + vind = vspan - vorder + j + + # calculate Nu.T*P + for k in range(uorder + 1): + temp += utemp[k] * pts[(uind + k) % nub, vind % nvb, :] + + # multiple by Nv + out[i, :] += vtemp[j] * temp + + return out + + +@njit +def nbSurfaceDer( + u: Array, + v: Array, + uorder: int, + vorder: int, + dorder: int, + uknots: Array, + vknots: Array, + pts: Array, + uperiodic: bool = False, + vperiodic: bool = False, +) -> Array: + """ + NURBS book A3.6 with modifications to handle periodicity. + + Parameters + ---------- + u : Array + U parameter values. + v : Array + V parameter values. + uorder : int + B-spline u order. + vorder : int + B-spline v order. + dorder : int + Maximum derivative order. + uknots : Array + U knot vector.. + vknots : Array + V knot vector.. + pts : Array + Control points. + uperiodic : bool, optional + U periodicity flag. The default is False. + vperiodic : bool, optional + V periodicity flag. The default is False. + + Returns + ------- + Array + Surface and derivative values. + + """ + + # max derivative orders + du = min(dorder, uorder) + dv = min(dorder, vorder) + + # number of control points + nub = pts.shape[0] + nvb = pts.shape[1] + + # handle periodicity + u_, uknots_ext, minspanu, maxspanu, deltaspanu = _preprocess( + u, uorder, uknots, uperiodic + ) + v_, vknots_ext, minspanv, maxspanv, deltaspanv = _preprocess( + v, vorder, vknots, vperiodic + ) + + # number of param values + nu = np.size(u) + + # chunck sizes + un = uorder + 1 + vn = vorder + 1 + + # temp chunck storage + + utemp = np.zeros((du + 1, un)) + vtemp = np.zeros((dv + 1, vn)) + + # initialize + out = np.zeros((nu, du + 1, dv + 1, 3)) + + for i in range(nu): + ui = u_[i] + vi = v_[i] + + # find span + uspan = nbFindSpan(ui, uorder, uknots, minspanu, maxspanu) + deltaspanu + vspan = nbFindSpan(vi, vorder, vknots, minspanv, maxspanv) + deltaspanv + + # evaluate chunk + nbBasisDer(uspan, ui, uorder, du, uknots_ext, utemp) + nbBasisDer(vspan, vi, vorder, dv, vknots_ext, vtemp) + + for k in range(du + 1): + + temp = np.zeros((vorder + 1, 3)) + + # Nu.T^(k)*pts + for s in range(vorder + 1): + for r in range(uorder + 1): + temp[s, :] += ( + utemp[k, r] + * pts[(uspan - uorder + r) % nub, (vspan - vorder + s) % nvb, :] + ) + + # ramaining derivative orders: dk + du <= dorder + dd = min(dorder - k, dv) + + # .. * Nv^(l) + for l in range(dd + 1): + for s in range(vorder + 1): + out[i, k, l, :] += vtemp[l, s] * temp[s, :] + + return out + + +# %% matrices + + +@njit +def designMatrix(u: Array, order: int, knots: Array, periodic: bool = False) -> COO: + """ + Create a sparse (possibly periodic) design matrix. + """ + + # extend the knots + u_, knots_ext, minspan, maxspan, deltaspan = _preprocess(u, order, knots, periodic) + + # number of param values + nu = len(u) + + # number of basis functions + nb = maxspan + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros(n) + + # initialize the empty matrix + rv = COO( + i=np.empty(n * nu, dtype=np.int64), + j=np.empty(n * nu, dtype=np.int64), + v=np.empty(n * nu), + ) + + # loop over param values + for i in range(nu): + ui = u_[i] + + # find the supporting span + span = nbFindSpan(ui, order, knots, minspan, maxspan) + deltaspan + + # evaluate non-zero functions + nbBasis(span, ui, order, knots_ext, temp) + + # update the matrix + rv.i[i * n : (i + 1) * n] = i + rv.j[i * n : (i + 1) * n] = ( + span - order + np.arange(n) + ) % nb # NB: this is due to peridicity + rv.v[i * n : (i + 1) * n] = temp + + return rv + + +@njit +def designMatrix2D( + u: Array, + v: Array, + uorder: int, + vorder: int, + uknots: Array, + vknots: Array, + uperiodic: bool = False, + vperiodic: bool = False, +) -> COO: + """ + Create a sparse tensor product design matrix. + """ + + # extend the knots and preprocess + u_, uknots_ext, minspanu, maxspanu, deltaspanu = _preprocess( + u, uorder, uknots, uperiodic + ) + v_, vknots_ext, minspanv, maxspanv, deltaspanv = _preprocess( + v, vorder, vknots, vperiodic + ) + + # number of param values + ni = len(u) + + # chunck size + nu = uorder + 1 + nv = vorder + 1 + nj = nu * nv + + # number of basis + nu_total = maxspanu + nv_total = maxspanv + + # temp chunck storage + utemp = np.zeros(nu) + vtemp = np.zeros(nv) + + # initialize the empty matrix + rv = COO( + i=np.empty(ni * nj, dtype=np.int64), + j=np.empty(ni * nj, dtype=np.int64), + v=np.empty(ni * nj), + ) + + # loop over param values + for i in range(ni): + ui, vi = u_[i], v_[i] + + # find the supporting span + uspan = nbFindSpan(ui, uorder, uknots, minspanu, maxspanu) + deltaspanu + vspan = nbFindSpan(vi, vorder, vknots, minspanv, maxspanv) + deltaspanv + + # evaluate non-zero functions + nbBasis(uspan, ui, uorder, uknots_ext, utemp) + nbBasis(vspan, vi, vorder, vknots_ext, vtemp) + + # update the matrix + rv.i[i * nj : (i + 1) * nj] = i + rv.j[i * nj : (i + 1) * nj] = ( + ((uspan - uorder + np.arange(nu)) % nu_total) * nv_total + + ((vspan - vorder + np.arange(nv)) % nv_total)[:, np.newaxis] + ).ravel() + rv.v[i * nj : (i + 1) * nj] = (utemp * vtemp[:, np.newaxis]).ravel() + + return rv + + +@njit +def periodicDesignMatrix(u: Array, order: int, knots: Array) -> COO: + """ + Create a sparse periodic design matrix. + """ + + return designMatrix(u, order, knots, periodic=True) + + +@njit +def derMatrix(u: Array, order: int, dorder: int, knots: Array) -> list[COO]: + """ + Create a sparse design matrix and corresponding derivative matrices. + """ + + # number of param values + nu = np.size(u) + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros((dorder + 1, n)) + + # initialize the empty matrix + rv = [] + + for _ in range(dorder + 1): + rv.append( + COO( + i=np.empty(n * nu, dtype=np.int64), + j=np.empty(n * nu, dtype=np.int64), + v=np.empty(n * nu), + ) + ) + + # loop over param values + for i in range(nu): + ui = u[i] + + # find the supporting span + span = nbFindSpan(ui, order, knots) + + # evaluate non-zero functions + nbBasisDer(span, ui, order, dorder, knots, temp) + + # update the matrices + for di in range(dorder + 1): + rv[di].i[i * n : (i + 1) * n] = i + rv[di].j[i * n : (i + 1) * n] = span - order + np.arange(n) + rv[di].v[i * n : (i + 1) * n] = temp[di, :] + + return rv + + +@njit +def periodicDerMatrix(u: Array, order: int, dorder: int, knots: Array) -> list[COO]: + """ + Create a sparse periodic design matrix and corresponding derivative matrices. + """ + + # extend the knots + knots_ext = np.concat( + (knots[-order:-1] - knots[-1], knots, knots[-1] + knots[1:order]) + ) + + # number of param values + nu = len(u) + + # number of basis functions + nb = len(knots) - 1 + + # chunck size + n = order + 1 + + # temp chunck storage + temp = np.zeros((dorder + 1, n)) + + # initialize the empty matrix + rv = [] + + for _ in range(dorder + 1): + rv.append( + COO( + i=np.empty(n * nu, dtype=np.int64), + j=np.empty(n * nu, dtype=np.int64), + v=np.empty(n * nu), + ) + ) + + # loop over param values + for i in range(nu): + ui = u[i] + + # find the supporting span + span = nbFindSpan(ui, order, knots, 0, nb) + order - 1 + + # evaluate non-zero functions + nbBasisDer(span, ui, order, dorder, knots_ext, temp) + + # update the matrices + for di in range(dorder + 1): + rv[di].i[i * n : (i + 1) * n] = i + rv[di].j[i * n : (i + 1) * n] = ( + span - order + np.arange(n) + ) % nb # NB: this is due to peridicity + rv[di].v[i * n : (i + 1) * n] = temp[di, :] + + return rv + + +@njit +def periodicDiscretePenalty(us: Array, order: int) -> COO: + + if order not in (1, 2): + raise ValueError( + f"Only 1st and 2nd order penalty is supported, requested order {order}" + ) + + # number of rows + nb = len(us) + + # number of elements per row + ne = order + 1 + + # initialize the penlaty matrix + rv = COO( + i=np.empty(nb * ne, dtype=np.int64), + j=np.empty(nb * ne, dtype=np.int64), + v=np.empty(nb * ne), + ) + + if order == 1: + for ix in range(nb): + rv.i[ne * ix] = ix + rv.j[ne * ix] = (ix - 1) % nb + rv.v[ne * ix] = -0.5 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = (ix + 1) % nb + rv.v[ne * ix + 1] = 0.5 + + elif order == 2: + for ix in range(nb): + rv.i[ne * ix] = ix + rv.j[ne * ix] = (ix - 1) % nb + rv.v[ne * ix] = 1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + rv.v[ne * ix + 1] = -2 + + rv.i[ne * ix + 2] = ix + rv.j[ne * ix + 2] = (ix + 1) % nb + rv.v[ne * ix + 2] = 1 + + return rv + + +@njit +def discretePenalty(us: Array, order: int, splineorder: int = 3) -> COO: + + if order not in (1, 2): + raise ValueError( + f"Only 1st and 2nd order penalty is supported, requested order {order}" + ) + + # number of rows + nb = len(us) + + # number of elements per row + ne = order + 1 + + # initialize the penlaty matrix + rv = COO( + i=np.empty(nb * ne, dtype=np.int64), + j=np.empty(nb * ne, dtype=np.int64), + v=np.empty(nb * ne), + ) + + if order == 1: + for ix in range(nb): + if ix == 0: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix + rv.v[ne * ix] = -1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + 1 + rv.v[ne * ix + 1] = 1 + elif ix < nb - 1: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix - 1 + rv.v[ne * ix] = -0.5 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + 1 + rv.v[ne * ix + 1] = 0.5 + else: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix - 1 + rv.v[ne * ix] = -1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + rv.v[ne * ix + 1] = 1 + + elif order == 2: + for ix in range(nb): + if ix == 0: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix + rv.v[ne * ix] = 1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + 1 + rv.v[ne * ix + 1] = -2 + + rv.i[ne * ix + 2] = ix + rv.j[ne * ix + 2] = ix + 2 + rv.v[ne * ix + 2] = 1 + elif ix < nb - 1: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix - 1 + rv.v[ne * ix] = 1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix + rv.v[ne * ix + 1] = -2 + + rv.i[ne * ix + 2] = ix + rv.j[ne * ix + 2] = ix + 1 + rv.v[ne * ix + 2] = 1 + else: + rv.i[ne * ix] = ix + rv.j[ne * ix] = ix - 2 + rv.v[ne * ix] = 1 + + rv.i[ne * ix + 1] = ix + rv.j[ne * ix + 1] = ix - 1 + rv.v[ne * ix + 1] = -2 + + rv.i[ne * ix + 2] = ix + rv.j[ne * ix + 2] = ix + rv.v[ne * ix + 2] = 1 + + return rv + + +@njit +def penaltyMatrix2D( + u: Array, + v: Array, + uorder: int, + vorder: int, + dorder: int, + uknots: Array, + vknots: Array, + uperiodic: bool = False, + vperiodic: bool = False, +) -> list[COO]: + """ + Create sparse tensor product 2D derivative matrices. + """ + + # extend the knots and preprocess + u_, uknots_ext, minspanu, maxspanu, deltaspanu = _preprocess( + u, uorder, uknots, uperiodic + ) + v_, vknots_ext, minspanv, maxspanv, deltaspanv = _preprocess( + v, vorder, vknots, vperiodic + ) + + # number of param values + ni = len(u) + + # chunck size + nu = uorder + 1 + nv = vorder + 1 + nj = nu * nv + + # number of basis + nu_total = maxspanu + nv_total = maxspanv + + # temp chunck storage + utemp = np.zeros((dorder + 1, nu)) + vtemp = np.zeros((dorder + 1, nv)) + + # initialize the emptry matrices + rv = [] + for i in range(dorder + 1): + rv.append( + COO( + i=np.empty(ni * nj, dtype=np.int64), + j=np.empty(ni * nj, dtype=np.int64), + v=np.empty(ni * nj), + ) + ) + + # loop over param values + for i in range(ni): + ui, vi = u_[i], v_[i] + + # find the supporting span + uspan = nbFindSpan(ui, uorder, uknots, minspanu, maxspanu) + deltaspanu + vspan = nbFindSpan(vi, vorder, vknots, minspanv, maxspanv) + deltaspanv + + # evaluate non-zero functions + nbBasisDer(uspan, ui, uorder, dorder, uknots_ext, utemp) + nbBasisDer(vspan, vi, vorder, dorder, vknots_ext, vtemp) + + # update the matrices - iterate over all derivative paris + for dv in range(dorder + 1): + + du = dorder - dv # NB: du + dv == dorder + + rv[dv].i[i * nj : (i + 1) * nj] = i + rv[dv].j[i * nj : (i + 1) * nj] = ( + ((uspan - uorder + np.arange(nu)) % nu_total) * nv_total + + ((vspan - vorder + np.arange(nv)) % nv_total)[:, np.newaxis] + ).ravel() + rv[dv].v[i * nj : (i + 1) * nj] = ( + utemp[du, :] * vtemp[dv, :, np.newaxis] + ).ravel() + + return rv + + +def uniformGrid( + uknots: Array, + vknots: Array, + uorder: int, + vorder: int, + uperiodic: bool, + vperiodic: bool, +) -> Tuple[Array, Array]: + """ + Create a uniform grid for evaluating penalties. + """ + + Up, Vp = np.meshgrid( + np.linspace( + uknots[0], uknots[-1], 2 * len(uknots) * uorder, endpoint=not uperiodic + ), + np.linspace( + vknots[0], vknots[-1], 2 * len(vknots) * vorder, endpoint=not vperiodic + ), + ) + up = Up.ravel() + vp = Vp.ravel() + + return up, vp + + +# %% construction + + +def parametrizeChord(data: Array) -> Array: + """ + Chord length parametrization. + """ + + dists = np.linalg.norm(data - np.roll(data, 1), axis=1) + params = np.cumulative_sum(dists) + + return params / params[-1] + + +@multidispatch +def periodicApproximate( + data: Array, + us: Optional[Array] = None, + knots: int | Array = 50, + order: int = 3, + penalty: int = 4, + lam: float = 0, +) -> Curve: + + npts = data.shape[0] + + # parametrize the points if needed + if us is None: + us = linspace(0, 1, npts, endpoint=False) + + # construct the knot vector + if isinstance(knots, int): + knots_ = linspace(0, 1, knots) + else: + knots_ = np.array(knots) + + # construct the design matrix + C = periodicDesignMatrix(us, order, knots_).csc() + CtC = C.T @ C + + # add the penalty if requested + if lam: + up = linspace(0, 1, order * npts, endpoint=False) + + assert penalty <= order + 2 + + # discrete + exact derivatives + if penalty > order: + Pexact = periodicDerMatrix(up, order, order - 1, knots_)[-1].csc() + Pdiscrete = periodicDiscretePenalty(up, penalty - order).csc() + + P = Pdiscrete @ Pexact + + # only exact derivatives + else: + P = periodicDerMatrix(up, order, penalty, knots_)[-1].csc() + + CtC += lam * P.T @ P + + # factorize + D, L, P = ldl(CtC, True) + + # invert + pts = ldl_solve(C.T @ data, D, L, P).toarray() + + # convert to an edge + rv = Curve(pts, knots_, order, periodic=True) + + return rv + + +@periodicApproximate.register +def periodicApproximate( + data: List[Array], + us: Optional[Array] = None, + knots: int | Array = 50, + order: int = 3, + penalty: int = 4, + lam: float = 0, +) -> List[Curve]: + + rv = [] + + npts = data[0].shape[0] + + # parametrize the points + us = linspace(0, 1, npts, endpoint=False) + + # construct the knot vector + if isinstance(knots, int): + knots_ = linspace(0, 1, knots) + else: + knots_ = np.array(knots) + + # construct the design matrix + C = periodicDesignMatrix(us, order, knots_).csc() + CtC = C.T @ C + + # add the penalty if requested + if lam: + up = linspace(0, 1, order * npts, endpoint=False) + + assert penalty <= order + 2 + + # discrete + exact derivatives + if penalty > order: + Pexact = periodicDerMatrix(up, order, order - 1, knots_)[-1].csc() + Pdiscrete = periodicDiscretePenalty(up, penalty - order).csc() + + P = Pdiscrete @ Pexact + + # only exact derivatives + else: + P = periodicDerMatrix(up, order, penalty, knots_)[-1].csc() + + CtC += lam * P.T @ P + + # factorize + D, L, P = ldl(CtC, True) + + # invert every dataset + for dataset in data: + pts = ldl_solve(C.T @ dataset, D, L, P).toarray() + + # convert to an edge and store + rv.append(Curve(pts, knots_, order, periodic=True)) + + return rv + + +@multidispatch +def approximate( + data: Array, + us: Optional[Array] = None, + knots: int | Array = 50, + order: int = 3, + penalty: int = 4, + lam: float = 0, + tangents: Optional[Tuple[Array, Array]] = None, +) -> Curve: + + npts = data.shape[0] + + # parametrize the points + us = linspace(0, 1, npts) + + # construct the knot vector + if isinstance(knots, int): + knots_ = np.concatenate( + (np.repeat(0, order), linspace(0, 1, knots), np.repeat(1, order)) + ) + else: + knots_ = np.array(knots) + + # construct the design matrix + C = designMatrix(us, order, knots_).csc() + CtC = C.T @ C + + # add a penalty term if requested + if lam: + up = linspace(0, 1, order * npts) + + assert penalty <= order + 2 + + # discrete + exact derivatives + if penalty > order: + Pexact = derMatrix(up, order, order - 1, knots_)[-1].csc() + Pdiscrete = discretePenalty(up, penalty - order, order).csc() + + P = Pdiscrete @ Pexact + + # only exact derivatives + else: + P = derMatrix(up, order, penalty, knots_)[-1].csc() + + CtC += lam * P.T @ P + + # clamp first and last point + Cc = C[[0, -1], :] + bc = data[[0, -1], :] + nc = 2 # number of constraints + + # handle tangent constraints if needed + if tangents: + nc += 2 + + Cc2 = derMatrix(us[[0, -1]], order, 1, knots_)[-1].csc() + + Cc = sp.vstack((Cc, Cc2)) + bc = np.vstack((bc, *tangents)) + + # final matrix and vector + Aug = sp.bmat([[CtC, Cc.T], [Cc, None]]) + data_aug = np.vstack((C.T @ data, bc)) + + # factorize + D, L, P = ldl(Aug, False) + + # invert + pts = ldl_solve(data_aug, D, L, P).toarray()[:-nc, :] + + # convert to an edge + rv = Curve(pts, knots_, order, periodic=False) + + return rv + + +@approximate.register +def approximate( + data: List[Array], + us: Optional[Array] = None, + knots: int | Array = 50, + order: int = 3, + penalty: int = 4, + lam: float = 0, + tangents: Optional[Union[Tuple[Array, Array], List[Tuple[Array, Array]]]] = None, +) -> List[Curve]: + + rv = [] + + npts = data[0].shape[0] + + # parametrize the points + us = linspace(0, 1, npts) + + # construct the knot vector + if isinstance(knots, int): + knots_ = np.concatenate( + (np.repeat(0, order), linspace(0, 1, knots), np.repeat(1, order)) + ) + else: + knots_ = np.array(knots) + + # construct the design matrix + C = designMatrix(us, order, knots_).csc() + CtC = C.T @ C + + # add a penalty term if requested + if lam: + up = linspace(0, 1, order * npts) + + assert penalty <= order + 2 + + # discrete + exact derivatives + if penalty > order: + Pexact = derMatrix(up, order, order - 1, knots_)[-1].csc() + Pdiscrete = discretePenalty(up, penalty - order, order).csc() + + P = Pdiscrete @ Pexact + + # only exact derivatives + else: + P = derMatrix(up, order, penalty, knots_)[-1].csc() + + CtC += lam * P.T @ P + + # clamp first and last point + Cc = C[[0, -1], :] + + nc = 2 # number of constraints + + # handle tangent constraints if needed + if tangents: + nc += 2 + Cc2 = derMatrix(us[[0, -1]], order, 1, knots_)[-1].csc() + Cc = sp.vstack((Cc, Cc2)) + + # final matrix and vector + Aug = sp.bmat([[CtC, Cc.T], [Cc, None]]) + + # factorize + D, L, P = ldl(Aug, False) + + # invert all datasets + for ix, dataset in enumerate(data): + bc = dataset[[0, -1], :] # first and last point for clamping + + if tangents: + if len(tangents) == len(data): + bc = np.vstack((bc, *tangents[ix])) + else: + bc = np.vstack((bc, *tangents)) + + # construct the LHS of the linear system + dataset_aug = np.vstack((C.T @ dataset, bc)) + + # actual solver + pts = ldl_solve(dataset_aug, D, L, P).toarray()[:-nc, :] + + # convert to an edge + rv.append(Curve(pts, knots_, order, periodic=False)) + + return rv + + +def approximate2D( + data: Array, + u: Array, + v: Array, + uorder: int, + vorder: int, + uknots: int | Array = 50, + vknots: int | Array = 50, + uperiodic: bool = False, + vperiodic: bool = False, + penalty: int = 3, + lam: float = 0, +) -> Surface: + """ + Simple 2D surface approximation (without any penalty). + """ + + # process the knots + uknots_ = uknots if isinstance(uknots, Array) else np.linspace(0, 1, uknots) + vknots_ = vknots if isinstance(vknots, Array) else np.linspace(0, 1, vknots) + + # create the desing matrix + C = designMatrix2D( + u, v, uorder, vorder, uknots_, vknots_, uperiodic, vperiodic + ).csc() + + # handle penalties if requested + if lam: + # construct the penalty grid + up, vp = uniformGrid(uknots_, vknots_, uorder, vorder, uperiodic, vperiodic) + + # construct the derivative matrices + penalties = penaltyMatrix2D( + up, vp, uorder, vorder, penalty, uknots_, vknots_, uperiodic, vperiodic, + ) + + # augment the design matrix + tmp = [comb(penalty, i) * penalties[i].csc() for i in range(penalty + 1)] + Lu = uknots_[-1] - uknots_[0] # v lenght of the parametric domain + Lv = vknots_[-1] - vknots_[0] # u lenght of the parametric domain + P = Lu * Lv / len(up) * sp.vstack(tmp) + + CtC = C.T @ C + lam * P.T @ P + else: + CtC = C.T @ C + + # solve normal equations + D, L, P = ldl(CtC, False) + pts = ldl_solve(C.T @ data, D, L, P).toarray() + + # construt the result + rv = Surface( + pts.reshape((len(uknots_) - int(uperiodic), len(vknots_) - int(vperiodic), 3)), + uknots_, + vknots_, + uorder, + vorder, + uperiodic, + vperiodic, + ) + + return rv + + +def fairPenalty(surf: Surface, penalty: int, lam: float) -> Surface: + """ + Penalty-based surface fairing. + """ + + uknots = surf.uknots + vknots = surf.vknots + pts = surf.pts.reshape((-1, 3)) + + # generate penalty grid + up, vp = uniformGrid( + uknots, vknots, surf.uorder, surf.vorder, surf.uperiodic, surf.vperiodic + ) + + # generate penalty matrix + penalties = penaltyMatrix2D( + up, + vp, + surf.uorder, + surf.vorder, + penalty, + surf.uknots, + surf.vknots, + surf.uperiodic, + surf.vperiodic, + ) + + tmp = [comb(penalty, i) * penalties[i].csc() for i in range(penalty + 1)] + Lu = uknots[-1] - uknots[0] # v lenght of the parametric domain + Lv = vknots[-1] - vknots[0] # u lenght of the parametric domain + P = Lu * Lv / len(up) * sp.vstack(tmp) + + # form and solve normal equations + CtC = sp.identity(pts.shape[0]) + lam * P.T @ P + + D, L, P = ldl(CtC, False) + pts_new = ldl_solve(pts, D, L, P).toarray() + + # construt the result + rv = Surface( + pts_new.reshape( + (len(uknots) - int(surf.uperiodic), len(vknots) - int(surf.vperiodic), 3) + ), + uknots, + vknots, + surf.uorder, + surf.vorder, + surf.uperiodic, + surf.vperiodic, + ) + + return rv + + +def periodicLoft(*curves: Curve, order: int = 3) -> Surface: + + nknots: int = len(curves) + 1 + + # collect control pts + pts = [el for el in np.stack([c.pts for c in curves]).swapaxes(0, 1)] + + # approximate + pts_new = [el.pts for el in periodicApproximate(pts, knots=nknots, order=order)] + + # construct the final surface + rv = Surface( + np.stack(pts_new).swapaxes(0, 1), + linspace(0, 1, nknots), + curves[0].knots, + order, + curves[0].order, + True, + curves[0].periodic, + ) + + return rv + + +def loft( + *curves: Curve, + order: int = 3, + lam: float = 1e-9, + penalty: int = 4, + tangents: Optional[List[Tuple[Array, Array]]] = None, +) -> Surface: + + nknots: int = len(curves) + + # collect control pts + pts = np.stack([c.pts for c in curves]) + + # approximate + pts_new = [] + + for j in range(pts.shape[1]): + pts_new.append( + approximate( + pts[:, j, :], + knots=nknots, + order=order, + lam=lam, + penalty=penalty, + tangents=tangents[j] if tangents else None, + ).pts + ) + + # construct the final surface + rv = Surface( + np.stack(pts_new).swapaxes(0, 1), + np.concatenate( + (np.repeat(0, order), linspace(0, 1, nknots), np.repeat(1, order)) + ), + curves[0].knots, + order, + curves[0].order, + False, + curves[0].periodic, + ) + + return rv + + +def reparametrize( + *curves: Curve, n: int = 100, knots: int = 100, w1: float = 1, w2: float = 1 +) -> List[Curve]: + + from scipy.optimize import fmin_l_bfgs_b + + n_curves = len(curves) + + u0_0 = np.linspace(0, 1, n, False) + u0 = np.tile(u0_0, n_curves) + + # scaling for the second cost term + scale = n * np.linalg.norm(curves[0](u0[0]) - curves[1](u0[n])) + + def cost(u: Array) -> float: + + rv1 = 0 + us = np.split(u, n_curves) + + pts = [] + + for i, ui in enumerate(us): + + # evaluate + pts.append(curves[i](ui)) + + # parametric distance between points on the same curve + rv1 += np.sum((ui[:-1] - ui[1:]) ** 2) + np.sum((ui[0] + 1 - ui[-1]) ** 2) + + rv2 = 0 + + for p1, p2 in zip(pts, pts[1:]): + + # geometric distance between points on adjecent curves + rv2 += np.sum(((p1 - p2) / scale) ** 2) + + return w1 * rv1 + w2 * rv2 + + def grad(u: Array) -> Array: + + rv1 = np.zeros_like(u) + us = np.split(u, n_curves) + + pts = [] + tgts = [] + + for i, ui in enumerate(us): + + # evaluate up to 1st derivative + tmp = curves[i].der(ui, 1) + + pts.append(tmp[:, 0, :].squeeze()) + tgts.append(tmp[:, 1, :].squeeze()) + + # parametric distance between points on the same curve + delta = np.roll(ui, -1) - ui + delta[-1] += 1 + delta *= -2 + delta -= np.roll(delta, 1) + + rv1[i * n : (i + 1) * n] = delta + + rv2 = np.zeros_like(u) + + for i, _ in enumerate(us): + # geometric distance between points on adjecent curves + + # first profile + if i == 0: + p1, p2, t = pts[i], pts[i + 1], tgts[i] + + rv2[i * n : (i + 1) * n] = (2 / scale ** 2 * (p1 - p2) * t).sum(1) + + # middle profile + elif i + 1 < n_curves: + p1, p2, t = pts[i], pts[i + 1], tgts[i] + p0 = pts[i - 1] + + rv2[i * n : (i + 1) * n] = (2 / scale ** 2 * (p1 - p2) * t).sum(1) + rv2[i * n : (i + 1) * n] += (-2 / scale ** 2 * (p0 - p1) * t).sum(1) + + # last profile + else: + p1, p2, t = pts[i - 1], pts[i], tgts[i] + + rv2[i * n : (i + 1) * n] = (-2 / scale ** 2 * (p1 - p2) * t).sum(1) + + return w1 * rv1 + w2 * rv2 + + usol, _, _ = fmin_l_bfgs_b(cost, u0, grad) + + us = np.split(usol, n_curves) + + return periodicApproximate( + [crv(u) for crv, u in zip(curves, us)], knots=knots, lam=0 + ) + + +def offset(surf: Surface, d: float, lam: float = 1e-3) -> Surface: + """ + Simple approximate offset. + """ + + # construct the knot grid + U, V = np.meshgrid( + np.linspace(surf.uknots[0], surf.uknots[-1], surf.uorder * len(surf.uknots)), + np.linspace(surf.vknots[0], surf.vknots[-1], surf.vorder * len(surf.uknots)), + ) + + us = U.ravel() + vs = V.ravel() + + # evaluate the normals + ns, pts = surf.normal(us, vs) + + # move the control points + pts += d * ns + + return approximate2D( + pts, + us, + vs, + surf.uorder, + surf.vorder, + surf.uknots, + surf.vknots, + surf.uperiodic, + surf.vperiodic, + lam=lam, + ) + + +# %% for removal? +@njit +def findSpan(v, knots): + + return np.searchsorted(knots, v, "right") - 1 + + +@njit +def findSpanLinear(v, knots): + + for rv in range(len(knots)): + if knots[rv] <= v and knots[rv + 1] > v: + return rv + + return -1 + + +@njit +def periodicKnots(degree: int, n_pts: int): + rv = np.arange(0.0, n_pts + degree + 1, 1.0) + rv /= rv[-1] + + return rv diff --git a/cadquery/occ_impl/shape_protocols.py b/cadquery/occ_impl/shape_protocols.py index c8c98afb2..5aab57b52 100644 --- a/cadquery/occ_impl/shape_protocols.py +++ b/cadquery/occ_impl/shape_protocols.py @@ -76,6 +76,9 @@ def Area(self) -> float: def BoundingBox(self, tolerance: Optional[float] = None) -> BoundBox: ... + def distance(self, other, tol: float = 1e-6) -> float: + ... + class Shape1DProtocol(ShapeProtocol, Protocol): def tangentAt( diff --git a/cadquery/occ_impl/shapes.py b/cadquery/occ_impl/shapes.py index 70b709883..fbad475d1 100644 --- a/cadquery/occ_impl/shapes.py +++ b/cadquery/occ_impl/shapes.py @@ -321,6 +321,9 @@ from OCP.OSD import OSD_ThreadPool +from OCP.Bnd import Bnd_OBB +from OCP.BRepBndLib import BRepBndLib + from math import pi, sqrt, inf, radians, cos import warnings @@ -1433,12 +1436,12 @@ def split(self, *splitters: "Shape") -> "Shape": return self._bool_op((self,), splitters, split_op) - def distance(self, other: "Shape") -> float: + def distance(self, other: "Shape", tol: float = 1e-6) -> float: """ Minimal distance between two shapes """ - dist_calc = BRepExtrema_DistShapeShape(self.wrapped, other.wrapped) + dist_calc = BRepExtrema_DistShapeShape(self.wrapped, other.wrapped, tol) dist_calc.SetMultiThread(True) return dist_calc.Value() @@ -2397,6 +2400,13 @@ def hasPCurve(self, f: "Face") -> bool: return ShapeAnalysis_Edge().HasPCurve(self.wrapped, f.wrapped) + def reversed(self) -> "Edge": + """ + Return a reversed version of self. + """ + + return self.__class__(self.wrapped.Reversed()) + @classmethod def makeCircle( cls, @@ -5467,6 +5477,10 @@ def shell( return shell(*s, tol=tol, manifold=manifold, ctx=ctx, history=history) +# add an alias +sew = shell + + @multimethod def solid( s1: Shape, *sn: Shape, tol: float = 1e-6, history: Optional[ShapeHistory] = None, @@ -6171,7 +6185,7 @@ def chamfer(s: Shape, e: Shape, d: float) -> Shape: return _compound_or_shape(builder.Shape()) -def extrude(s: Shape, d: VectorLike) -> Shape: +def extrude(s: Shape, d: VectorLike, both: bool = False) -> Shape: """ Extrude a shape. """ @@ -6180,7 +6194,13 @@ def extrude(s: Shape, d: VectorLike) -> Shape: for el in _get(s, ("Vertex", "Edge", "Wire", "Face")): - builder = BRepPrimAPI_MakePrism(el.wrapped, Vector(d).wrapped) + if both: + builder = BRepPrimAPI_MakePrism( + el.moved(-Vector(d)).wrapped, (2 * Vector(d)).wrapped + ) + else: + builder = BRepPrimAPI_MakePrism(el.wrapped, Vector(d).wrapped) + builder.Build() results.append(builder.Shape()) @@ -6258,6 +6278,30 @@ def _offset(t): return rv +def offset2D( + s: Shape, t: float, kind: Literal["arc", "intersection", "tangent"] = "arc" +) -> Shape: + """ + 2D Offset edges, wires or faces. + """ + + kind_dict = { + "arc": GeomAbs_JoinType.GeomAbs_Arc, + "intersection": GeomAbs_JoinType.GeomAbs_Intersection, + "tangent": GeomAbs_JoinType.GeomAbs_Tangent, + } + + bldr = BRepOffsetAPI_MakeOffset() + bldr.Init(kind_dict[kind]) + + for el in _get_wires(s): + bldr.AddWire(el.wrapped) + + bldr.Perform(t) + + return _compound_or_shape(bldr.Shape()) + + @multimethod def sweep( s: Shape, path: Shape, aux: Optional[Shape] = None, cap: bool = False @@ -6501,6 +6545,7 @@ def loft( return loft(s, cap, ruled, continuity, parametrization, degree, compat) +@multimethod def project( s: Shape, base: Shape, @@ -6524,6 +6569,26 @@ def project( return _compound_or_shape(bldr.Projection()) +@project.register +def project( + s: Shape, base: Shape, direction: VectorLike, +): + """ + Project s onto base using cylindrical projection. + """ + + results = [] + + for el in _get_wires(s): + bldr = BRepProj_Projection(el.wrapped, base.wrapped, Vector(direction).toDir()) + + while bldr.More(): + results.append(_compound_or_shape(bldr.Current())) + bldr.Next() + + return _normalize(compound(results)) + + #%% diagnotics @@ -6561,7 +6626,7 @@ def check( def isSubshape(s1: Shape, s2: Shape) -> bool: """ - Check if s1 is a subshape of s2. + Check is s1 is a subshape of s2. """ shape_map = TopTools_IndexedDataMapOfShapeListOfShape() @@ -6576,7 +6641,7 @@ def isSubshape(s1: Shape, s2: Shape) -> bool: #%% properties -def closest(s1: Shape, s2: Shape) -> Tuple[Vector, Vector]: +def closest(s1: Shape, s2: Shape, tol: float = 1e-6) -> Tuple[Vector, Vector]: """ Closest points between two shapes. """ @@ -6588,7 +6653,36 @@ def closest(s1: Shape, s2: Shape) -> Tuple[Vector, Vector]: ext.LoadS1(s1.wrapped) ext.LoadS2(s2.wrapped) + ext.SetDeflection(tol) + import OCP + + ext.SetAlgo(OCP.Extrema.Extrema_ExtAlgo.Extrema_ExtAlgo_Grad) + # perform - assert ext.Perform() + ext.Perform() return Vector(ext.PointOnShape1(1)), Vector(ext.PointOnShape2(1)) + + +def obb(s: Shape) -> Shape: + + # construct the OBB + bbox = Bnd_OBB() + BRepBndLib.AddOBB_s( + s.wrapped, bbox, theIsTriangulationUsed=False, theIsOptimal=True + ) + + # convert to a shape + center = Vector(bbox.Center()) + xdir = Vector(bbox.XDirection()) + ydir = Vector(bbox.YDirection()) + zdir = Vector(bbox.ZDirection()) + + dx = bbox.XHSize() + dy = bbox.YHSize() + dz = bbox.ZHSize() + + ax = gp_Ax2(center.toPnt(), zdir.toDir(), xdir.toDir()) + ax.SetLocation((center - dx * xdir - dy * ydir - dz * zdir).toPnt()) + + return Shape.cast(BRepPrimAPI_MakeBox(ax, 2.0 * dx, 2.0 * dy, 2.0 * dz).Shape()) diff --git a/cadquery/selectors.py b/cadquery/selectors.py index efc7f5e5e..d60f55927 100644 --- a/cadquery/selectors.py +++ b/cadquery/selectors.py @@ -86,6 +86,22 @@ def dist(tShape): return [min(objectList, key=dist)] +class NearestToShapeSelector(Selector): + """ + Selects object nearest the provided Shape. + + Applicability: All Types of Shapes + + """ + + def __init__(self, s: Shape): + self.shape = s + + def filter(self, objectList: Sequence[Shape]): + + return [min(objectList, key=lambda el: self.shape.distance(el))] + + class BoxSelector(Selector): """ Selects objects inside the 3D box defined by 2 points. @@ -870,3 +886,11 @@ def filter(self, objectList: Sequence[Shape]): Filter give object list through th already constructed complex selector object """ return self.mySelector.filter(objectList) + + +# %% Aliases + +NearestToPoint = NearestToPointSelector +NearestToShape = NearestToShapeSelector +Parallel = ParallelDirSelector +Perpendicular = PerpendicularDirSelector diff --git a/cadquery/utils.py b/cadquery/utils.py index c473db677..df743d771 100644 --- a/cadquery/utils.py +++ b/cadquery/utils.py @@ -2,6 +2,7 @@ from inspect import signature, isbuiltin from typing import TypeVar, Callable, cast from warnings import warn +from collections import UserDict from multimethod import multimethod, DispatchError @@ -83,3 +84,35 @@ def get_arity(f: TCallable) -> int: rv = f.__code__.co_argcount - n_defaults return rv + + +K = TypeVar("K") +V = TypeVar("V") + + +class BiDict(UserDict[K, V]): + """ + Bi-directional dictionary. + """ + + _inv: dict[V, list[K]] + + def __init__(self, *args, **kwargs): + + self._inv = {} + + super().__init__(*args, **kwargs) + + def __setitem__(self, k: K, v: V): + + super().__setitem__(k, v) + + if v in self._inv: + self._inv[v].append(k) + else: + self._inv[v] = [k] + + @property + def inv(self) -> dict[V, list[K]]: + + return self._inv diff --git a/conda/meta.yaml b/conda/meta.yaml index 44003ec4a..be3e6a6a9 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -26,6 +26,8 @@ requirements: - multimethod >=1.11,<2.0 - casadi - typish + - numba + - scipy - trame - trame-vtk diff --git a/environment.yml b/environment.yml index f30b0fbf5..7791baeef 100644 --- a/environment.yml +++ b/environment.yml @@ -25,6 +25,8 @@ dependencies: - pathspec - click - appdirs + - numba + - scipy - trame - trame-vtk - pip diff --git a/mypy.ini b/mypy.ini index 7bc958faf..fdd66dd63 100644 --- a/mypy.ini +++ b/mypy.ini @@ -37,6 +37,9 @@ ignore_missing_imports = True [mypy-casadi.*] ignore_missing_imports = True +[mypy-numba.*] +ignore_missing_imports = True + [mypy-trame.*] ignore_missing_imports = True diff --git a/tests/test_nurbs.py b/tests/test_nurbs.py new file mode 100644 index 000000000..9f57a8a46 --- /dev/null +++ b/tests/test_nurbs.py @@ -0,0 +1,276 @@ +from cadquery.occ_impl.nurbs import ( + designMatrix, + periodicDesignMatrix, + designMatrix2D, + nbFindSpan, + nbBasis, + nbBasisDer, + Curve, + Surface, + approximate, + periodicApproximate, + periodicLoft, + loft, + reparametrize, +) + +from cadquery.func import circle + +import numpy as np +import scipy.sparse as sp + +from pytest import approx, fixture, mark + + +@fixture +def circles() -> list[Curve]: + + # u,v periodic + c1 = circle(1).toSplines() + c2 = circle(5) + + cs = [ + Curve.fromEdge(c1.moved(loc)) + for loc in c2.locations(np.linspace(0, 1, 10, False)) + ] + + return cs + + +@fixture +def trimmed_circles() -> list[Curve]: + + c1 = circle(1).trim(0, 1).toSplines() + c2 = circle(5) + + cs = [ + Curve.fromEdge(c1.moved(loc)) + for loc in c2.locations(np.linspace(0, 1, 10, False)) + ] + + return cs + + +@fixture +def rotated_circles() -> list[Curve]: + + pts1 = np.array([v.toTuple() for v in circle(1).sample(100)[0]]) + pts2 = np.array([v.toTuple() for v in circle(1).moved(z=1, rz=90).sample(100)[0]]) + + c1 = periodicApproximate(pts1) + c2 = periodicApproximate(pts2) + + return [c1, c2] + + +def test_periodic_dm(): + + knots = np.linspace(0, 1, 5) + params = np.linspace(0, 1, 100) + order = 3 + + res = periodicDesignMatrix(params, order, knots) + + C = sp.coo_array((res.v, (res.i, res.j))) + + assert C.shape[0] == len(params) + assert C.shape[1] == len(knots) - 1 + + +def test_dm_2d(): + + uknots = np.array([0, 0, 0, 0, 0.25, 0.5, 0.75, 1, 1, 1, 1]) + uparams = np.linspace(0, 1, 100) + uorder = 3 + + vknots = np.array([0, 0, 0, 0.5, 1, 1, 1]) + vparams = np.linspace(0, 1, 100) + vorder = 2 + + res = designMatrix2D(uparams, vparams, uorder, vorder, uknots, vknots) + + C = res.coo() + + assert C.shape[0] == len(uparams) + assert C.shape[1] == (len(uknots) - uorder - 1) * (len(vknots) - vorder - 1) + + +def test_dm(): + + knots = np.array([0, 0, 0, 0, 0.25, 0.5, 0.75, 1, 1, 1, 1]) + params = np.linspace(0, 1, 100) + order = 3 + + res = designMatrix(params, order, knots) + + C = sp.coo_array((res.v, (res.i, res.j))) + + assert C.shape[0] == len(params) + assert C.shape[1] == len(knots) - order - 1 + + +def test_der(): + + knots = np.array([0, 0, 0, 0, 0.25, 0.5, 0.75, 1, 1, 1, 1]) + params = np.linspace(0, 1, 100) + order = 3 + + out_der = np.zeros((order + 1, order + 1)) + out = np.zeros(order + 1) + + for p in params: + nbBasisDer(nbFindSpan(p, order, knots), p, order, order - 1, knots, out_der) + nbBasis(nbFindSpan(p, order, knots), p, order, knots, out) + + # sanity check + assert np.allclose(out_der[0, :], out) + + +def test_periodic_curve(): + + knots = np.linspace(0, 1, 5) + pts = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 2], [0, 2, 0]]) + + crv = Curve(pts, knots, 3, True) + + # is it indeed periodic? + assert crv.curve().IsPeriodic() + + # convert to an edge + e = crv.edge() + + assert e.isValid() + assert e.ShapeType() == "Edge" + + +def test_curve(): + + knots = np.array([0, 0, 0, 0, 1, 1, 1, 1]) + pts = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 2], [0, 2, 0]]) + + crv = Curve(pts, knots, 3, False) + + # sanity check + assert not crv.curve().IsPeriodic() + + # convert to an edge + e = crv.edge() + + assert e.isValid() + assert e.ShapeType() == "Edge" + + # edge to curve + crv2 = Curve.fromEdge(e) + e2 = crv2.edge() + + assert e2.isValid() + + # check roundtrip + crv3 = Curve.fromEdge(e2) + + assert np.allclose(crv2.knots, crv3.knots) + assert np.allclose(crv2.pts, crv3.pts) + + +def test_surface(): + + uknots = vknots = np.array([0, 0, 1, 1]) + pts = np.array([[[0, 0, 0], [0, 1, 0]], [[1, 0, 0], [1, 1, 0]]]) + + srf = Surface(pts, uknots, vknots, 1, 1, False, False) + + # convert to a face + f = srf.face() + + assert f.isValid() + assert f.Area() == approx(1) + + # roundtrip + srf2 = Surface.fromFace(f) + + assert np.allclose(srf.uknots, srf2.uknots) + assert np.allclose(srf.vknots, srf2.vknots) + assert np.allclose(srf.pts, srf2.pts) + + +def test_approximate(): + + pts_ = circle(1).trim(0, 1).sample(100)[0] + pts = np.array([list(p) for p in pts_]) + + # regular approximate + crv = approximate(pts) + e = crv.edge() + + assert e.isValid() + assert e.Length() == approx(1) + + # approximate with a double penalty + crv = approximate(pts, penalty=4, lam=1e-9) + e = crv.edge() + + assert e.isValid() + assert e.Length() == approx(1) + + # approximate with a single penalty + crv = approximate(pts, penalty=2, lam=1e-9) + e = crv.edge() + + assert e.isValid() + assert e.Length() == approx(1) + + +def test_periodic_approximate(): + + pts_ = circle(1).sample(100)[0] + pts = np.array([list(p) for p in pts_]) + + crv = periodicApproximate(pts) + e = crv.edge() + + assert e.isValid() + assert e.Length() == approx(2 * np.pi) + + +def test_periodic_loft(circles, trimmed_circles): + + # u,v periodic + surf1 = periodicLoft(*circles) + + assert surf1.face().isValid() + + # u periodic + surf2 = periodicLoft(*trimmed_circles) + + assert surf2.face().isValid() + + +def test_loft(circles, trimmed_circles): + + # v periodic + surf1 = loft(*circles) + + assert surf1.face().isValid() + + # non-periodic + surf2 = loft(*trimmed_circles) + + assert surf2.face().isValid() + + +def test_reparametrize(rotated_circles): + + c1, c2 = rotated_circles + + # this surface will be twisted + surf = loft(c1, c2, order=2, lam=1e-6) + + # this should adjust the paramatrizations + c1r, c2r = reparametrize(c1, c2) + + # resulting loft should not be twisted + surfr = loft(c1r, c2r, order=2, lam=1e-6) + + # assert that the surface is indeed not twisted + assert surfr.face().Area() == approx(2 * np.pi, 1e-3) + assert surfr.face().Area() >= 1.01 * surf.face().Area()