Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions asQ/allatonce/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,28 @@ def _set_bcs(self, field_bcs):
is_mixed_element = isinstance(aaofunc.field_function_space.ufl_element(), fd.MixedElement)

bcs_all = []

for bc in field_bcs:
for step in range(aaofunc.nlocal_timesteps):
if is_mixed_element:
cpt = bc.function_space().index
else:
cpt = 0
index = aaofunc._component_indices(step)[cpt]
bc_all = fd.DirichletBC(aaofunc.function_space.sub(index),
bc.function_arg,
bc.sub_domain)
bcs_all.append(bc_all)

if callable(bc):
for step in range(aaofunc.nlocal_timesteps):
Vs = [aaofunc.function_space.sub(i)
for i in aaofunc._component_indices(step)]
us = aaofunc[step].subfunctions
t = self.time[step]
bcs_step = bc(*Vs, *us, t)
bcs_all.extend(bcs_step)
else:
for step in range(aaofunc.nlocal_timesteps):
if is_mixed_element:
cpt = bc.function_space().index
else:
cpt = 0
index = aaofunc._component_indices(step)[cpt]
bc_cpt = fd.DirichletBC(aaofunc.function_space.sub(index),
bc.function_arg,
bc.sub_domain)
bcs_all.append(bc_cpt)

return bcs_all

Expand Down Expand Up @@ -145,7 +156,7 @@ def assemble(self, func=None, tensor=None):

# Set the current state
if func is not None:
self.aaofunc.assign(func, update_halos=False)
self.aaofunc.assign(func, update_halos=False, blocking=True)

# Assembly stage
# The residual on the DirichletBC nodes is set to zero,
Expand Down
5 changes: 2 additions & 3 deletions asQ/allatonce/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def mult(self, mat, X, Y):
:arg X: a PETSc Vec to apply the action on.
:arg Y: a PETSc Vec for the result.
"""
# we could use nonblocking here and overlap comms with assembling form
# Delay updating the halos until after we enforce the bcs
self.x.assign(X, update_halos=True, blocking=True)

# We use the same strategy as the implicit matrix context in firedrake
Expand All @@ -159,15 +159,14 @@ def mult(self, mat, X, Y):
# Zero the boundary nodes on the input so that A_ib = A_01 = 0
for bc in self.bcs:
bc.zero(self.x.function)
self.x.update_time_halos()

# assembly stage
fd.assemble(self.action, bcs=self.bcs,
tensor=self.F.cofunction)

if self._useprev:
# repeat for the halo part of the matrix action
for bc in self.field_bcs:
bc.zero(self.x.uprev)
fd.assemble(self.action_prev, bcs=self.bcs,
tensor=self.Fprev)
self.F.cofunction += self.Fprev
Expand Down
91 changes: 88 additions & 3 deletions tests/integration/test_paradiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@


@pytest.mark.parallel(nprocs=4)
def test_Nitsche_BCs():
# test the linear equation u_t - Delta u = 0, with u_ex = exp(0.5*x + y + 1.25*t) and weakly imposing Dirichlet BCs
def test_nitsche_bcs():
"""
Test the linear equation u_t - Delta u = 0,
with u_ex = exp(0.5*x + y + 1.25*t)
and weakly imposed Dirichlet BCs.
"""
nspatial_domains = 2
degree = 1
nx = 10
Expand Down Expand Up @@ -88,7 +92,88 @@ def form_function(q, phi, t):


@pytest.mark.parallel(nprocs=4)
def test_Nitsche_heat_timeseries():
def test_time_dependent_bcs():
"""
Test the linear equation u_t - Delta u = 0,
with u_ex = exp(0.5*x + y + 1.25*t)
and strongly imposed Dirichlet BCs.
"""
nspatial_domains = 2
degree = 1
nx = 10
h = 1/nx
dt = h
slice_length = 2
nslices = fd.COMM_WORLD.size//nspatial_domains

time_partition = [slice_length for _ in range(nslices)]

ensemble = fd.Ensemble(fd.COMM_WORLD, nspatial_domains)
mesh = fd.UnitSquareMesh(
nx, nx, quadrilateral=False, comm=ensemble.comm,
distribution_parameters={'partitioner_type': 'simple'})

x, y = fd.SpatialCoordinate(mesh)
V = fd.FunctionSpace(mesh, "CG", degree)

def uexact(t):
return fd.exp(0.5*x + y + 1.25*t)

def form_mass(q, phi):
return phi*q*fd.dx

def form_function(q, phi, t):
return fd.inner(fd.grad(q), fd.grad(phi))*fd.dx

def form_bcs(V, q, t):
return [fd.DirichletBC(V, uexact(t), "on_boundary")]

# Parameters for the diag
solver_parameters_diag = {
'snes_type': 'ksponly',
'ksp_rtol': 1e-8,
'mat_type': 'matfree',
'ksp_type': 'gmres',
'pc_type': 'none',
# 'pc_type': 'python',
# 'pc_python_type': 'asQ.CirculantPC',
# 'circulant_block': {
# 'ksp_type': 'preonly',
# 'pc_type': 'lu',
# 'pc_factor_mat_solver_type': 'mumps'
# },
}

w0 = fd.Function(V).interpolate(uexact(0))

pdg = asQ.Paradiag(ensemble=ensemble,
form_function=form_function,
form_mass=form_mass, bcs=[form_bcs],
dt=dt, theta=0.5, ics=w0,
time_partition=time_partition,
solver_parameters=solver_parameters_diag)
pdg.solve()
q_exact = fd.Function(V)
qp = fd.Function(V)
errors = asQ.SharedArray(time_partition, comm=ensemble.ensemble_comm)

for step in range(pdg.ntimesteps):
if pdg.aaoform.layout.is_local(step):
local_step = pdg.aaofunc.transform_index(step, from_range='window')
t = pdg.aaoform.time[local_step]
q_exact.interpolate(uexact(t))
qp.assign(pdg.aaofunc[local_step])

errors.dlocal[local_step] = fd.errornorm(qp, q_exact)

errors.synchronise()

for step in range(pdg.ntimesteps):
assert (errors.dglobal[step] < h**(3/2)), "Error from analytical solution should be close to discretisation error"


@pytest.mark.parallel(nprocs=4)
def test_nitsche_heat_timeseries():
from utils.serial import ComparisonMiniapp

nwindows = 1
Expand Down
Loading