From 23c60bc7970271a8391b43bbc7ab0ea4923ec7cb Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 17 Jun 2025 08:34:03 +0100 Subject: [PATCH 1/2] WIP: time dependent DirichletBCs --- asQ/allatonce/form.py | 31 ++++++---- tests/integration/test_paradiag.py | 94 +++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 13 deletions(-) diff --git a/asQ/allatonce/form.py b/asQ/allatonce/form.py index e6a1e755..57f6b045 100644 --- a/asQ/allatonce/form.py +++ b/asQ/allatonce/form.py @@ -97,17 +97,28 @@ def _set_bcs(self, field_bcs): is_mixed_element = isinstance(aaofunc.field_function_space.ufl_element(), fd.MixedElement) bcs_all = [] + for bc in field_bcs: - for step in range(aaofunc.nlocal_timesteps): - if is_mixed_element: - cpt = bc.function_space().index - else: - cpt = 0 - index = aaofunc._component_indices(step)[cpt] - bc_all = fd.DirichletBC(aaofunc.function_space.sub(index), - bc.function_arg, - bc.sub_domain) - bcs_all.append(bc_all) + + if isinstance(bc, Callable): + for step in range(aaofunc.nlocal_timesteps): + Vs = [aaofunc.function_space.sub(i) + for i in aaofunc._component_indices(step)] + us = aaofunc[step] + t = self.time[step] + bcs_step = bc(*Vs, *us, t) + bcs_all.extend(bcs_step) + else: + for step in range(aaofunc.nlocal_timesteps): + if is_mixed_element: + cpt = bc.function_space().index + else: + cpt = 0 + index = aaofunc._component_indices(step)[cpt] + bc_cpt = fd.DirichletBC(aaofunc.function_space.sub(index), + bc.function_arg, + bc.sub_domain) + bcs_all.append(bc_cpt) return bcs_all diff --git a/tests/integration/test_paradiag.py b/tests/integration/test_paradiag.py index 527b502e..1cc88b87 100644 --- a/tests/integration/test_paradiag.py +++ b/tests/integration/test_paradiag.py @@ -13,8 +13,12 @@ @pytest.mark.parallel(nprocs=4) -def test_Nitsche_BCs(): - # test the linear equation u_t - Delta u = 0, with u_ex = exp(0.5*x + y + 1.25*t) and weakly imposing Dirichlet BCs +def test_nitsche_bcs(): + """ + Test the linear equation u_t - Delta u = 0, + with u_ex = exp(0.5*x + y + 1.25*t) + and weakly imposed Dirichlet BCs. + """ nspatial_domains = 2 degree = 1 nx = 10 @@ -88,7 +92,91 @@ def form_function(q, phi, t): @pytest.mark.parallel(nprocs=4) -def test_Nitsche_heat_timeseries(): +def test_time_dependent_bcs(): + """ + Test the linear equation u_t - Delta u = 0, + with u_ex = exp(0.5*x + y + 1.25*t) + and strongly imposed Dirichlet BCs. + """ + nspatial_domains = 2 + degree = 1 + nx = 10 + h = 1/nx + dt = h + slice_length = 2 + nslices = fd.COMM_WORLD.size//nspatial_domains + + time_partition = [slice_length for _ in range(nslices)] + + ensemble = fd.Ensemble(fd.COMM_WORLD, nspatial_domains) + mesh = fd.UnitSquareMesh( + nx, nx, quadrilateral=False, comm=ensemble.comm, + distribution_parameters={'partitioner_type': 'simple'}) + + x, y = fd.SpatialCoordinate(mesh) + n = fd.FacetNormal(mesh) + V = fd.FunctionSpace(mesh, "CG", degree) + + def uexact(t): + return fd.exp(0.5*x + y + 1.25*t) + + def form_mass(q, phi): + return phi*q*fd.dx + + def form_function(q, phi, t): + return (fd.inner(fd.grad(q), fd.grad(phi))*fd.dx + - fd.inner(phi, fd.inner(fd.grad(q), n))*fd.ds + - fd.inner(q-uexact(t), fd.inner(fd.grad(phi), n))*fd.ds + + 20*nx*fd.inner(q-uexact(t), phi)*fd.ds) + + def form_bcs(V, q, t): + return [fd.DirichletBC(V, uexact(t), "on_boundary")] + + # Parameters for the diag + solver_parameters_diag = { + 'snes_type': 'ksponly', + 'ksp_rtol': 1e-8, + 'mat_type': 'matfree', + 'ksp_type': 'gmres', + 'pc_type': 'python', + 'pc_python_type': 'asQ.CirculantPC', + 'circulant_block': { + 'ksp_type': 'preonly', + 'pc_type': 'lu', + 'pc_factor_mat_solver_type': 'mumps' + }, + } + + w0 = fd.Function(V).interpolate(uexact(0)) + + pdg = asQ.Paradiag(ensemble=ensemble, + form_function=form_function, + form_mass=form_mass, ics=w0, + dt=dt, theta=0.5, + time_partition=time_partition, + solver_parameters=solver_parameters_diag) + pdg.solve() + q_exact = fd.Function(V) + qp = fd.Function(V) + errors = asQ.SharedArray(time_partition, comm=ensemble.ensemble_comm) + + for step in range(pdg.ntimesteps): + if pdg.aaoform.layout.is_local(step): + local_step = pdg.aaofunc.transform_index(step, from_range='window') + t = pdg.aaoform.time[local_step] + q_exact.interpolate(uexact(t)) + qp.assign(pdg.aaofunc[local_step]) + + errors.dlocal[local_step] = fd.errornorm(qp, q_exact) + + errors.synchronise() + + for step in range(pdg.ntimesteps): + assert (errors.dglobal[step] < h**(3/2)), "Error from analytical solution should be close to discretisation error" + + +@pytest.mark.parallel(nprocs=4) +def test_nitsche_heat_timeseries(): from utils.serial import ComparisonMiniapp nwindows = 1 From 3a083c95bec445090999b0aa43645129629f9f62 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 17 Jun 2025 17:22:07 +0100 Subject: [PATCH 2/2] time dependent BCs in allatonceform/jacobian --- asQ/allatonce/form.py | 6 +++--- asQ/allatonce/jacobian.py | 5 ++--- tests/integration/test_paradiag.py | 25 +++++++++++-------------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/asQ/allatonce/form.py b/asQ/allatonce/form.py index 57f6b045..bede2c24 100644 --- a/asQ/allatonce/form.py +++ b/asQ/allatonce/form.py @@ -100,11 +100,11 @@ def _set_bcs(self, field_bcs): for bc in field_bcs: - if isinstance(bc, Callable): + if callable(bc): for step in range(aaofunc.nlocal_timesteps): Vs = [aaofunc.function_space.sub(i) for i in aaofunc._component_indices(step)] - us = aaofunc[step] + us = aaofunc[step].subfunctions t = self.time[step] bcs_step = bc(*Vs, *us, t) bcs_all.extend(bcs_step) @@ -156,7 +156,7 @@ def assemble(self, func=None, tensor=None): # Set the current state if func is not None: - self.aaofunc.assign(func, update_halos=False) + self.aaofunc.assign(func, update_halos=False, blocking=True) # Assembly stage # The residual on the DirichletBC nodes is set to zero, diff --git a/asQ/allatonce/jacobian.py b/asQ/allatonce/jacobian.py index 55a3a02c..672ca348 100644 --- a/asQ/allatonce/jacobian.py +++ b/asQ/allatonce/jacobian.py @@ -141,7 +141,7 @@ def mult(self, mat, X, Y): :arg X: a PETSc Vec to apply the action on. :arg Y: a PETSc Vec for the result. """ - # we could use nonblocking here and overlap comms with assembling form + # Delay updating the halos until after we enforce the bcs self.x.assign(X, update_halos=True, blocking=True) # We use the same strategy as the implicit matrix context in firedrake @@ -159,6 +159,7 @@ def mult(self, mat, X, Y): # Zero the boundary nodes on the input so that A_ib = A_01 = 0 for bc in self.bcs: bc.zero(self.x.function) + self.x.update_time_halos() # assembly stage fd.assemble(self.action, bcs=self.bcs, @@ -166,8 +167,6 @@ def mult(self, mat, X, Y): if self._useprev: # repeat for the halo part of the matrix action - for bc in self.field_bcs: - bc.zero(self.x.uprev) fd.assemble(self.action_prev, bcs=self.bcs, tensor=self.Fprev) self.F.cofunction += self.Fprev diff --git a/tests/integration/test_paradiag.py b/tests/integration/test_paradiag.py index 1cc88b87..032ce0b5 100644 --- a/tests/integration/test_paradiag.py +++ b/tests/integration/test_paradiag.py @@ -114,7 +114,6 @@ def test_time_dependent_bcs(): distribution_parameters={'partitioner_type': 'simple'}) x, y = fd.SpatialCoordinate(mesh) - n = fd.FacetNormal(mesh) V = fd.FunctionSpace(mesh, "CG", degree) def uexact(t): @@ -124,10 +123,7 @@ def form_mass(q, phi): return phi*q*fd.dx def form_function(q, phi, t): - return (fd.inner(fd.grad(q), fd.grad(phi))*fd.dx - - fd.inner(phi, fd.inner(fd.grad(q), n))*fd.ds - - fd.inner(q-uexact(t), fd.inner(fd.grad(phi), n))*fd.ds - + 20*nx*fd.inner(q-uexact(t), phi)*fd.ds) + return fd.inner(fd.grad(q), fd.grad(phi))*fd.dx def form_bcs(V, q, t): return [fd.DirichletBC(V, uexact(t), "on_boundary")] @@ -138,21 +134,22 @@ def form_bcs(V, q, t): 'ksp_rtol': 1e-8, 'mat_type': 'matfree', 'ksp_type': 'gmres', - 'pc_type': 'python', - 'pc_python_type': 'asQ.CirculantPC', - 'circulant_block': { - 'ksp_type': 'preonly', - 'pc_type': 'lu', - 'pc_factor_mat_solver_type': 'mumps' - }, + 'pc_type': 'none', + # 'pc_type': 'python', + # 'pc_python_type': 'asQ.CirculantPC', + # 'circulant_block': { + # 'ksp_type': 'preonly', + # 'pc_type': 'lu', + # 'pc_factor_mat_solver_type': 'mumps' + # }, } w0 = fd.Function(V).interpolate(uexact(0)) pdg = asQ.Paradiag(ensemble=ensemble, form_function=form_function, - form_mass=form_mass, ics=w0, - dt=dt, theta=0.5, + form_mass=form_mass, bcs=[form_bcs], + dt=dt, theta=0.5, ics=w0, time_partition=time_partition, solver_parameters=solver_parameters_diag) pdg.solve()