Skip to content
14 changes: 14 additions & 0 deletions src/dctkit/dec/flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,17 @@ def flat_DPP(c: C.CochainD0V | C.CochainD0T) -> C.CochainP1:
flat_matrix = c.complex.flat_DPP_weights

return flat(c, flat_matrix, C.CochainP1(c.complex, primal_edges))


def flat_PDP(c: C.CochainP0V | C.CochainP0T) -> C.CochainP1:
"""Implements the flat PDP operator for primal 0-cochains.

Args:
c: a primal 0-cochain.
Returns:
the primal 1-cochain resulting from the application of the flat operator.
"""
primal_edges = c.complex.primal_edges_vectors[:, :c.coeffs.shape[1]]
flat_matrix = c.complex.flat_PDP_weights

return flat(c, flat_matrix, C.CochainP1(c.complex, primal_edges))
171 changes: 171 additions & 0 deletions src/dctkit/dec/wedge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import itertools
import jax.numpy as jnp
from jax import Array, vmap
from dctkit.dec import cochain as C
from scipy.special import factorial
from functools import partial
from typing import List


def compute_permutation_vectors(n: int) -> Array:
"""Computes all permutation vectors of length n.

Args:
n: The number of elements to permute.

Returns:
a JAX array of shape (n!, n) containing all permutations
of the integers from 0 to n-1.
"""
perms = list(itertools.permutations(range(n)))
perm_array = jnp.array(perms)
return perm_array


@vmap
def permutation_sign(p: Array) -> Array:
"""Computes the sign of a permutation.

The sign is +1 for even permutations and -1 for odd permutations.
It is computed as the determinant of the corresponding permutation matrix.

Args:
p: A 1D array representing a permutation of integers.

Returns:
the sign of the permutations.
"""
n = len(p)
# Permutation matrix
perm_matrix = jnp.eye(n)[p]
return round(jnp.linalg.det(perm_matrix))


@partial(vmap, in_axes=(0, None))
def find_simplex_idx(s: Array, S: Array) -> Array:
"""Finds the index of a given simplex in a set of simplices.

Args:
s: A 1D array representing a simplex (e.g., a set of vertex indices).
S: A 2D array where each row is a simplex.

Returns:
the index of the simplex `s` in `S`. If `s` is not found,
returns -1.
"""
# Broadcast and compare all rows to the given row
matches = jnp.all(S == s, axis=1)
# Find the index where all elements match
simplex_idx = jnp.where(matches, size=1, fill_value=-1)[0][0]
return simplex_idx


@partial(vmap, in_axes=(0, None, None, None, None, None, None))
def compute_wedge_coeffs(simplex: Array,
S_list: List[Array],
c_1: C.Cochain,
c_2: C.Cochain,
perm_vec: Array,
sgn_perm_vec: Array,
weight: Array) -> Array:
"""Computes the coefficients of the wedge product for a simplex.

This function computes the weighted wedge product of two cochains over
a simplex, taking into account permutations and orientation signs. It
is vectorized over the first argument (`simplex`) using `jax.vmap`.

Args:
simplex: a 1D array representing the indices of a simplex.
S_list: a list of arrays, where each array contains all simplices of
a given dimension.
c_1: the first cochain object.
c_2: the second cochain object.
perm_vec: an array representing a permutation of the simplex indices.
sgn_perm_vec: the sign (+1 or -1) associated with the permutation.
weight: a scalar weight to apply to the wedge product.

Returns:
the weighted sum of the wedge product contributions for the
permuted simplex.
"""
# perm the simplex idx vector
perm_simplex = simplex[perm_vec]
# split the perm simplices in vector of indices compatible
# with c_1 and c_2
perm_simplex_c_1 = perm_simplex[:, :c_1.dim+1]
perm_simplex_c_2 = perm_simplex[:, c_1.dim:]

# since the perm simplices may not have the same orientations
# as the original one, we need to account for that
perm_ord_c_1 = jnp.argsort(perm_simplex_c_1, axis=1)
perm_ord_c_2 = jnp.argsort(perm_simplex_c_2, axis=1)
sign_orientations_c_1 = permutation_sign(perm_ord_c_1)
sign_orientations_c_2 = permutation_sign(perm_ord_c_2)

# compute the indexes for every (ordered) perm_simplex
ord_simplex_c_1 = jnp.take_along_axis(perm_simplex_c_1, perm_ord_c_1, axis=1)
ord_simplex_c_2 = jnp.take_along_axis(perm_simplex_c_2, perm_ord_c_2, axis=1)
perm_idx_c_1 = find_simplex_idx(ord_simplex_c_1, S_list[c_1.dim])
perm_idx_c_2 = find_simplex_idx(ord_simplex_c_2, S_list[c_2.dim])

# compute the value of the cup product
cup_prod_no_sign = c_1.coeffs[perm_idx_c_1]*c_2.coeffs[perm_idx_c_2]
cup_prod = cup_prod_no_sign.ravel()*sign_orientations_c_1*sign_orientations_c_2
wedge_vec = sgn_perm_vec*cup_prod
# FIXME: fix this part of the code
if c_1.dim + c_2.dim > 1:
weight = weight[0]
return weight*jnp.sum(wedge_vec)

weight_coeffs = weight[perm_idx_c_2[0]]

weighted_wedge_vec = weight_coeffs*wedge_vec[0] + (1-weight_coeffs)*wedge_vec[1]

# compute wedge entry

# print(weighted_wedge_vec)
# assert False
return weighted_wedge_vec


def wedge(c_1: C.Cochain, c_2: C.Cochain, weight: Array = None) -> C.Cochain:
"""Computes the wedge product of two cochains.

Args:
c_1: the first cochain.
c_2: the second cochain.

Returns:
a new cochain representing the wedge product of `c_1` and `c_2`.

Raises:
Exception: If attempting a primal-dual wedge product, which is
undefined.
AssertionError: If computing a dual wedge product with dimension
greater than 1, which is not defined.
"""
wedge_coch_dim = c_1.dim + c_2.dim
S = c_1.complex
# extract the matrix of indices of the wedge_coch_dim+1-simplices (primal/dual)
if c_1.is_primal and c_2.is_primal:
# primal wedge
S_list = S.S
elif (not c_1.is_primal) and not (c_2.is_primal):
# dual wedge is only defined for wedge_coch_dim <=1
assert wedge_coch_dim <= 1
S_list = S.S_dual
else:
raise Exception("The primal-dual wedge product is not defined.")
num_c_2_dim_simplex = S_list[c_2.dim].shape[0]
if weight is None:
# standard definition
weight = 1/(wedge_coch_dim+1)*jnp.ones(num_c_2_dim_simplex)
weight *= 1/factorial(wedge_coch_dim, True)
simplices = S_list[wedge_coch_dim]
# generate the permutation vectors and compute its signs
perm_vec = compute_permutation_vectors(wedge_coch_dim+1)
sgn_perm_vec = permutation_sign(perm_vec)
# compute wedge coeffs
wedge_coch_coeffs = compute_wedge_coeffs(
simplices, S_list, c_1, c_2, perm_vec, sgn_perm_vec, weight)
return C.Cochain(wedge_coch_dim, c_1.is_primal, S, wedge_coch_coeffs)
84 changes: 73 additions & 11 deletions src/dctkit/mesh/simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,39 @@ def get_boundary_operators(self):

def get_complex_boundary_faces_indices(self):
"""Find the IDs of the boundary faces of the complex, i.e. the row indices of
the boundary faces in the matrix S[dim-1].
the boundary faces in the matrix S[dim-1]. (fix the docs)
"""
# boundary faces IDs appear only once in the matrix simplices_faces[dim]
unique_elements, counts = np.unique(
self.simplices_faces[self.dim], return_counts=True)
self.bnd_faces_indices = np.sort(unique_elements[counts == 1])
# FIXME: this routine is tested only for dim-1 boundary simplices. Fix it!
self.boundary_simplices = sl.ShiftedList([None] * (self.dim + 1), -1)

# For k = 0 to dim - 1, use boundary matrix ∂_{k+1}
for k in range(self.dim):
boundary_mat = self.boundary[k + 1] # COO format
rows = boundary_mat[0] # Each row corresponds to a k-simplex

unique, counts = np.unique(rows, return_counts=True)
bnd_k = unique[counts == 1]
self.boundary_simplices[k] = bnd_k

# For k = dim (top-level), look for top-simplices having at least one
# (dim-1)-face on the boundary

# Find boundary (dim-1)-simplices
boundary_faces = self.boundary_simplices[self.dim - 1]
# shape: (num_top_simplices, dim+1)
simplices_faces = self.simplices_faces[self.dim]
is_boundary_simplex = np.any(np.isin(simplices_faces, boundary_faces), axis=1)

self.boundary_simplices[self.dim] = np.nonzero(is_boundary_simplex)[0]

def get_tets_containing_a_boundary_face(self):
"""Compute a list in which the i-th element is the index of the top-level
simplex in which the i-th boundary face belongs."""
if not hasattr(self, "bnd_faces_indices"):
if not hasattr(self, "boundary_simplices"):
self.get_complex_boundary_faces_indices()
dim = self.dim - 1
self.tets_cont_bnd_face = get_cofaces(
self.bnd_faces_indices, dim, self)
self.boundary_simplices[dim], dim, self)

def get_circumcenters(self):
"""Compute all the circumcenters."""
Expand All @@ -105,6 +123,48 @@ def get_circumcenters(self):
self.circ[p] = C
self.bary_circ[p] = B

def get_S_dual(self):
"""
Compute S_dual[k] for all k = 0.1
Each S_dual[k] is a matrix where each row contains the indices of dual nodes
(i.e., circumcenters of top-dimensional simplices) that form a dual k-simplex.

Stores the result in self.S_dual[k].
"""
# FIXME: test properly this routine!
if not hasattr(self, "boundary_simplices"):
self.get_complex_boundary_faces_indices()
dim = self.dim
self.S_dual = [None]*2

# store dual 0-simplices
self.S_dual[0] = np.arange(
self.S[dim].shape[0], dtype=dctkit.int_dtype).reshape(-1, 1)

# dual 0-simplices are the circumcenters of top-dimensional primal simplices
# These are not stored in S_dual but are the "nodes" for the dual complex

num_codim_1 = self.S[dim - 1].shape[0]

S_dual_interior_k = []

for idx in range(num_codim_1):
# Find all top-simplices (of dim) that contain this codim-k simplex
cofaces = np.nonzero(self.simplices_faces[dim] == idx)[0]
if len(cofaces) == 2:
S_dual_interior_k.append(cofaces)

S_dual_interior_k = np.array(S_dual_interior_k, dtype=dctkit.int_dtype)
S_dual_bnd_k_idx = self.boundary_simplices[dim-1]
S_dual_interior_k_idx = np.setdiff1d(
np.arange(num_codim_1), S_dual_bnd_k_idx)
S_dual_k = np.empty((num_codim_1, 2), dtype=dctkit.int_dtype)
# set placeholder for the boundary
S_dual_k[S_dual_bnd_k_idx] = 0.
# set correct value for the interior
S_dual_k[S_dual_interior_k_idx] = S_dual_interior_k
self.S_dual[1] = S_dual_k

def get_primal_volumes(self):
"""Compute all the primal volumes."""
self.primal_volumes = [None]*(self.dim + 1)
Expand Down Expand Up @@ -221,7 +281,8 @@ def get_dual_edge_vectors(self):
else:
circ_faces = self.circ[dim-1]
circ_bnd_faces = np.zeros(circ_faces.shape, dtype=dctkit.float_dtype)
circ_bnd_faces[self.bnd_faces_indices] = circ_faces[self.bnd_faces_indices]
circ_bnd_faces[self.boundary_simplices[dim-1]
] = circ_faces[self.boundary_simplices[dim-1]]

# adjust the signs based on the appropriate entries of the dual coboundary
# NOTE: here we take the values of the boundary matrix, we fix their signs later
Expand All @@ -243,7 +304,7 @@ def get_dual_edge_vectors(self):
# coboundary matrix
sign = -vals[boundary_rows_idx][:, None]*(-1)**dim
complement = circ_bnd_faces
complement[self.bnd_faces_indices] *= sign
complement[self.boundary_simplices[dim-1]] *= sign

self.dual_edges_vectors += complement

Expand Down Expand Up @@ -310,9 +371,10 @@ def get_flat_PDP_weights(self):
num_nodes = self.num_nodes
self.flat_PDP_weights = np.zeros(
(num_edges, num_nodes), dtype=dctkit.float_dtype)
# FIXME: check if it is possible or not to optimize this routine
# FIXME: optimize this routine with jax.vmap
for i in range(num_edges):
self.flat_PDP_weights[i, self.S[1][i]] = self.primal_volumes[1][i]/2
self.flat_PDP_weights[i, self.S[1][i]] = 1/2
self.flat_PDP_weights = self.flat_PDP_weights.T

def get_current_covariant_basis(self, node_coords: npt.NDArray | Array) -> Array:
"""Compute the current covariant basis of each face of a 2D simplicial complex.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_linear_elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_linear_elasticity_pure_tension(setup_test, is_primal, energy_formulatio

ref_node_coords = S.node_coords

bnd_edges_idx = S.bnd_faces_indices
bnd_edges_idx = S.boundary_simplices[S.dim-1]
left_bnd_nodes_idx = util.get_nodes_for_physical_group(mesh, 1, "left")
right_bnd_nodes_idx = util.get_nodes_for_physical_group(mesh, 1, "right")
left_bnd_edges_idx = util.get_edges_for_physical_group(S, mesh, "left")
Expand Down
16 changes: 8 additions & 8 deletions tests/test_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_simplicial_complex_1(setup_test, space_dim: int):
assert np.allclose(S.hodge_star[i], hodge_true[i])
assert np.allclose(S.hodge_star_inverse[i], hodge_inv_true[i])

assert np.allclose(S.bnd_faces_indices, bnd_faces_indices_true)
assert np.allclose(S.boundary_simplices[S.dim - 1], bnd_faces_indices_true)
assert np.allclose(S.tets_cont_bnd_face, tets_cont_bnd_face_true)
assert np.allclose(S.primal_edges_vectors, primal_edges_true)
assert np.allclose(S.dual_edges_vectors, dual_edges_true)
Expand Down Expand Up @@ -235,13 +235,13 @@ def test_simplicial_complex_2(setup_test, space_dim):
flat_DPP_weights_true = flat_DPD_weights_true
flat_PDP_weights_true = np.array([[0.5, 0.5, 0., 0., 0.],
[0.5, 0., 0., 0.5, 0.],
[0.35355339, 0., 0., 0., 0.35355339],
[0.5, 0., 0., 0., 0.5],
[0., 0.5, 0.5, 0., 0.],
[0., 0.35355339, 0., 0., 0.35355339],
[0., 0.5, 0., 0., 0.5],
[0., 0., 0.5, 0.5, 0.],
[0., 0., 0.35355339, 0., 0.35355339],
[0., 0., 0., 0.35355339, 0.35355339]],
dtype=dctkit.float_dtype)
[0., 0., 0.5, 0., 0.5],
[0., 0., 0., 0.5, 0.5]],
dtype=dctkit.float_dtype).T

# define true reference metric
metric_true = np.stack([np.identity(2)]*4)
Expand All @@ -255,7 +255,7 @@ def test_simplicial_complex_2(setup_test, space_dim):
assert np.allclose(S.hodge_star[i], hodge_true[i])

# test bnd faces indices
assert np.allclose(S.bnd_faces_indices, bnd_faces_indices_true)
assert np.allclose(S.boundary_simplices[S.dim - 1], bnd_faces_indices_true)

# test tets containing boundary face
assert np.allclose(S.tets_cont_bnd_face, tets_cont_bnd_face_true)
Expand Down Expand Up @@ -401,7 +401,7 @@ def test_simplicial_complex_3(setup_test, space_dim):
assert np.all(S.boundary[3][i] == boundary_true[3][i])

# test bnd faces indices
assert np.allclose(S.bnd_faces_indices, bnd_faces_indices_true)
assert np.allclose(S.boundary_simplices[S.dim - 1], bnd_faces_indices_true)

# test tets containing boundary face
assert np.allclose(S.tets_cont_bnd_face, tets_cont_bnd_face_true)
Expand Down
Loading