Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@

`BatchNLPKernels.jl` provides [`KernelAbstractions.jl`](https://github.com/JuliaGPU/KernelAbstractions.jl) kernels for evaluating problem data from a (parametric) [`ExaModel`](https://github.com/exanauts/ExaModels.jl) for batches of solutions (and parameters). Currently the following functions (as well as their non-parametric variants) are exported:

- `obj_batch!(::BatchModel, X, Θ)`
- `grad_batch!(::BatchModel, X, Θ)`
- `cons_nln_batch!(::BatchModel, X, Θ)`
- `jac_coord_batch!(::BatchModel, X, Θ)`
- `hess_coord_batch!(::BatchModel, X, Θ, Y; obj_weight=1.0)`
- `jprod_nln_batch!(::BatchModel, X, Θ, V)`
- `jtprod_nln_batch!(::BatchModel, X, Θ, V)`
- `hprod_batch!(::BatchModel, X, Θ, Y, V; obj_weight=1.0)`

- `objective!(::BatchModel, X, Θ)`
- `objective_gradient!(::BatchModel, X, Θ)`
- `constraints!(::BatchModel, X, Θ)`
- `constraints_jacobian!(::BatchModel, X, Θ)`
- `lagrangian_hessian!(::BatchModel, X, Θ, Y; obj_weight=1.0)`
- `constraints_jprod!(::BatchModel, X, Θ, V)`
- `constraints_jtprod!(::BatchModel, X, Θ, V)`
- `lagrangian_hprod!(::BatchModel, X, Θ, Y, V; obj_weight=1.0)`
- `all_violations!(::BatchModel, X, Θ)`
- `constraint_violations!(::BatchModel, X, Θ)`
- `bound_violations!(::BatchModel, X)`

To use these functions, first wrap your `ExaModel` in a `BatchModel`:

Expand All @@ -30,7 +32,7 @@ This pre-allocates work and output buffers. By default, only the buffers to supp
Then, you can call the batch functions as follows:

```julia
objs = obj_batch!(bm, X, Θ)
objs = objective!(bm, X, Θ)
```

where `X` and `Θ` are (device) matrices with dimensions `(nvar, batch_size)` and `(nθ, batch_size)` respectively.
Expand Down
28 changes: 14 additions & 14 deletions ext/BNKChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ module BNKChainRulesCore
using BatchNLPKernels
using ChainRulesCore

function ChainRulesCore.rrule(::typeof(BatchNLPKernels.obj_batch!), bm::BatchModel, X, Θ)
y = BatchNLPKernels.obj_batch!(bm, X, Θ)
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.objective!), bm::BatchModel, X, Θ)
y = BatchNLPKernels.objective!(bm, X, Θ)

function obj_batch_pullback(Ȳ)
Ȳ = ChainRulesCore.unthunk(Ȳ)
gradients = BatchNLPKernels.grad_batch!(bm, X, Θ)
gradients = BatchNLPKernels.objective_gradient!(bm, X, Θ)

X̄ = gradients .* Ȳ'

Expand All @@ -17,12 +17,12 @@ function ChainRulesCore.rrule(::typeof(BatchNLPKernels.obj_batch!), bm::BatchMod

return y, obj_batch_pullback
end
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.obj_batch!), bm::BatchModel, X)
y = BatchNLPKernels.obj_batch!(bm, X)
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.objective!), bm::BatchModel, X)
y = BatchNLPKernels.objective!(bm, X)

function obj_batch_pullback(Ȳ)
Ȳ = ChainRulesCore.unthunk(Ȳ)
gradients = BatchNLPKernels.grad_batch!(bm, X)
gradients = BatchNLPKernels.objective_gradient!(bm, X)

X̄ = gradients .* Ȳ'

Expand All @@ -33,32 +33,32 @@ function ChainRulesCore.rrule(::typeof(BatchNLPKernels.obj_batch!), bm::BatchMod
end


function ChainRulesCore.rrule(::typeof(BatchNLPKernels.cons_nln_batch!), bm::BatchModel, X, Θ)
y = BatchNLPKernels.cons_nln_batch!(bm, X, Θ)
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.constraints!), bm::BatchModel, X, Θ)
y = BatchNLPKernels.constraints!(bm, X, Θ)

function cons_nln_batch_pullback(Ȳ)
Ȳ = ChainRulesCore.unthunk(Ȳ)
X̄ = BatchNLPKernels.jtprod_nln_batch!(bm, X, Θ, Ȳ)
X̄ = BatchNLPKernels.constraints_jtprod!(bm, X, Θ, Ȳ)
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), X̄, ChainRulesCore.NoTangent()
end

return y, cons_nln_batch_pullback
end
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.cons_nln_batch!), bm::BatchModel, X)
y = BatchNLPKernels.cons_nln_batch!(bm, X)
function ChainRulesCore.rrule(::typeof(BatchNLPKernels.constraints!), bm::BatchModel, X)
y = BatchNLPKernels.constraints!(bm, X)

function cons_nln_batch_pullback(Ȳ)
Ȳ = ChainRulesCore.unthunk(Ȳ)
X̄ = BatchNLPKernels.jtprod_nln_batch!(bm, X, Ȳ)
X̄ = BatchNLPKernels.constraints_jtprod!(bm, X, Ȳ)
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), X̄
end

return y, cons_nln_batch_pullback
end


function ChainRulesCore.rrule(::typeof(BatchNLPKernels.constraint_violations!), bm::BatchModel, V)
Vc = BatchNLPKernels.constraint_violations!(bm, V)
function ChainRulesCore.rrule(::typeof(BatchNLPKernels._constraint_violations!), bm::BatchModel, V)
Vc = BatchNLPKernels._constraint_violations!(bm, V)

function constraint_violations_pullback(Ȳ)
Ȳ = ChainRulesCore.unthunk(Ȳ)
Expand Down
9 changes: 5 additions & 4 deletions src/BatchNLPKernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ const KAExtension = ExaKA.KAExtension
include("interval.jl")
include("batch_model.jl")

const BOI = BatchNLPKernels
export BOI, BatchModel, BatchModelConfig
export obj_batch!, grad_batch!, cons_nln_batch!, jac_coord_batch!, hess_coord_batch!
export jprod_nln_batch!, jtprod_nln_batch!, hprod_batch!
const BNK = BatchNLPKernels
export BNK, BatchModel, BatchModelConfig
export objective!, objective_gradient!, constraints!, constraints_jacobian!, lagrangian_hessian!
export constraints_jprod!, constraints_jtprod!, lagrangian_hprod!
export all_violations!, constraint_violations!, bound_violations!

include("utils.jl")
include("kernels.jl")
Expand Down
26 changes: 13 additions & 13 deletions src/api/cons.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
"""
cons_nln_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
constraints!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)

Evaluate constraints for a batch of solutions and parameters.
"""
function cons_nln_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
function constraints!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
C = _maybe_view(bm, :cons_out, X)
cons_nln_batch!(bm, X, Θ, C)
constraints!(bm, X, Θ, C)
return C
end

"""
cons_nln_batch!(bm::BatchModel, X::AbstractMatrix)
constraints!(bm::BatchModel, X::AbstractMatrix)

Evaluate constraints for a batch of solutions.
"""
function cons_nln_batch!(bm::BatchModel, X::AbstractMatrix)
function constraints!(bm::BatchModel, X::AbstractMatrix)
Θ = _repeat_params(bm, X)
cons_nln_batch!(bm, X, Θ)
constraints!(bm, X, Θ)
end


function cons_nln_batch!(
function constraints!(
bm::BatchModel,
X::AbstractMatrix,
Θ::AbstractMatrix,
Expand All @@ -34,7 +34,7 @@
_assert_batch_size(batch_size, bm.batch_size)
backend = _get_backend(bm.model)

_cons_nln_batch!(backend, C, bm.model.cons, X, Θ)
_constraints!(backend, C, bm.model.cons, X, Θ)

conbuffers_batch = _maybe_view(bm, :cons_work, X)

Expand All @@ -53,17 +53,17 @@
return C
end

function _cons_nln_batch!(backend, C, con::ExaModels.Constraint, X, Θ)
function _constraints!(backend, C, con::ExaModels.Constraint, X, Θ)
if !isempty(con.itr)
batch_size = size(X, 2)
kerf_batch(backend)(C, con.f, con.itr, X, Θ; ndrange = (length(con.itr), batch_size))
end
_cons_nln_batch!(backend, C, con.inner, X, Θ)
_constraints!(backend, C, con.inner, X, Θ)
synchronize(backend)
end
function _cons_nln_batch!(backend, C, con::ExaModels.ConstraintNull, X, Θ) end
function _cons_nln_batch!(backend, C, con::ExaModels.ConstraintAug, X, Θ)
_cons_nln_batch!(backend, C, con.inner, X, Θ)
function _constraints!(backend, C, con::ExaModels.ConstraintNull, X, Θ) end

Check warning on line 64 in src/api/cons.jl

View check run for this annotation

Codecov / codecov/patch

src/api/cons.jl#L64

Added line #L64 was not covered by tests
function _constraints!(backend, C, con::ExaModels.ConstraintAug, X, Θ)
_constraints!(backend, C, con.inner, X, Θ)
end

function _conaugs_batch!(backend, conbuffers, con::ExaModels.ConstraintAug, X, Θ)
Expand Down
24 changes: 12 additions & 12 deletions src/api/grad.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
"""
grad_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)

Evaluate objective gradient for a batch of points.
"""
function grad_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
function objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix)
G = _maybe_view(bm, :grad_out, X)
grad_batch!(bm, X, Θ, G)
objective_gradient!(bm, X, Θ, G)
return G
end

"""
grad_batch!(bm::BatchModel, X::AbstractMatrix)
objective_gradient!(bm::BatchModel, X::AbstractMatrix)

Evaluate objective gradient for a batch of points.
"""
function grad_batch!(bm::BatchModel, X::AbstractMatrix)
function objective_gradient!(bm::BatchModel, X::AbstractMatrix)
Θ = _repeat_params(bm, X)
grad_batch!(bm, X, Θ)
objective_gradient!(bm, X, Θ)
end

function _grad_batch!(backend, grad_work, objs, X, Θ)
function _objective_gradient!(backend, grad_work, objs, X, Θ)
sgradient_batch!(backend, grad_work, objs, X, Θ, one(eltype(grad_work)))
_grad_batch!(backend, grad_work, objs.inner, X, Θ)
_objective_gradient!(backend, grad_work, objs.inner, X, Θ)
synchronize(backend)
end
function _grad_batch!(backend, grad_work, objs::ExaModels.ObjectiveNull, X, Θ) end
function _objective_gradient!(backend, grad_work, objs::ExaModels.ObjectiveNull, X, Θ) end

Check warning on line 27 in src/api/grad.jl

View check run for this annotation

Codecov / codecov/patch

src/api/grad.jl#L27

Added line #L27 was not covered by tests

function sgradient_batch!(
backend::B,
Expand All @@ -41,11 +41,11 @@
end

"""
grad_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, G::AbstractMatrix)
objective_gradient!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, G::AbstractMatrix)

Evaluate gradients for a batch of points with different parameters.
"""
function grad_batch!(
function objective_gradient!(
bm::BatchModel,
X::AbstractMatrix,
Θ::AbstractMatrix,
Expand All @@ -63,7 +63,7 @@
if !isempty(grad_work)
fill!(grad_work, zero(eltype(grad_work)))

_grad_batch!(backend, grad_work, bm.model.objs, X, Θ)
_objective_gradient!(backend, grad_work, bm.model.objs, X, Θ)

fill!(G, zero(eltype(G)))
compress_to_dense_batch(backend)(
Expand Down
30 changes: 15 additions & 15 deletions src/api/hess.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
"""
hess_coord_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
lagrangian_hessian!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)

Evaluate Hessian coordinates for a batch of points.
"""
function hess_coord_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
function lagrangian_hessian!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
H_view = _maybe_view(bm, :hprod_work, X)
hess_coord_batch!(bm, X, Θ, Y, H_view; obj_weight=obj_weight)
lagrangian_hessian!(bm, X, Θ, Y, H_view; obj_weight=obj_weight)
return H_view
end

"""
hess_coord_batch!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
lagrangian_hessian!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)

Evaluate Hessian coordinates for a batch of points.
"""
function hess_coord_batch!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)
function lagrangian_hessian!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix; obj_weight=1.0)

Check warning on line 17 in src/api/hess.jl

View check run for this annotation

Codecov / codecov/patch

src/api/hess.jl#L17

Added line #L17 was not covered by tests
Θ = _repeat_params(bm, X)
hess_coord_batch!(bm, X, Θ, Y; obj_weight=obj_weight)
lagrangian_hessian!(bm, X, Θ, Y; obj_weight=obj_weight)

Check warning on line 19 in src/api/hess.jl

View check run for this annotation

Codecov / codecov/patch

src/api/hess.jl#L19

Added line #L19 was not covered by tests
end

function hess_coord_batch!(
function lagrangian_hessian!(
bm::BatchModel,
X::AbstractMatrix,
Θ::AbstractMatrix,
Expand All @@ -37,24 +37,24 @@
backend = _get_backend(bm.model)

fill!(H, zero(eltype(H)))
_obj_hess_coord_batch!(backend, H, bm.model.objs, X, Θ, obj_weight)
_con_hess_coord_batch!(backend, H, bm.model.cons, X, Θ, Y)
_obj_lagrangian_hessian!(backend, H, bm.model.objs, X, Θ, obj_weight)
_con_lagrangian_hessian!(backend, H, bm.model.cons, X, Θ, Y)
return H
end

function _obj_hess_coord_batch!(backend, H, objs, X, Θ, obj_weight)
function _obj_lagrangian_hessian!(backend, H, objs, X, Θ, obj_weight)
shessian_batch!(backend, H, nothing, objs, X, Θ, obj_weight, zero(eltype(H)))
_obj_hess_coord_batch!(backend, H, objs.inner, X, Θ, obj_weight)
_obj_lagrangian_hessian!(backend, H, objs.inner, X, Θ, obj_weight)
synchronize(backend)
end
function _obj_hess_coord_batch!(backend, H, objs::ExaModels.ObjectiveNull, X, Θ, obj_weight) end
function _obj_lagrangian_hessian!(backend, H, objs::ExaModels.ObjectiveNull, X, Θ, obj_weight) end

Check warning on line 50 in src/api/hess.jl

View check run for this annotation

Codecov / codecov/patch

src/api/hess.jl#L50

Added line #L50 was not covered by tests

function _con_hess_coord_batch!(backend, H, cons, X, Θ, Y)
function _con_lagrangian_hessian!(backend, H, cons, X, Θ, Y)
shessian_batch!(backend, H, nothing, cons, X, Θ, Y, zero(eltype(H)))
_con_hess_coord_batch!(backend, H, cons.inner, X, Θ, Y)
_con_lagrangian_hessian!(backend, H, cons.inner, X, Θ, Y)
synchronize(backend)
end
function _con_hess_coord_batch!(backend, H, cons::ExaModels.ConstraintNull, X, Θ, Y) end
function _con_lagrangian_hessian!(backend, H, cons::ExaModels.ConstraintNull, X, Θ, Y) end

Check warning on line 57 in src/api/hess.jl

View check run for this annotation

Codecov / codecov/patch

src/api/hess.jl#L57

Added line #L57 was not covered by tests

function shessian_batch!(
backend::B,
Expand Down
16 changes: 8 additions & 8 deletions src/api/hprod.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
"""
hprod_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)

Evaluate Hessian-vector products for a batch of points.
"""
function hprod_batch!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
function lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Θ::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
Hv = _maybe_view(bm, :hprod_out, X)
hprod_batch!(bm, X, Θ, Y, V, Hv; obj_weight=obj_weight)
lagrangian_hprod!(bm, X, Θ, Y, V, Hv; obj_weight=obj_weight)
return Hv
end

"""
hprod_batch!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)

Evaluate Hessian-vector products for a batch of points.
"""
function hprod_batch!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
function lagrangian_hprod!(bm::BatchModel, X::AbstractMatrix, Y::AbstractMatrix, V::AbstractMatrix; obj_weight=1.0)
Θ = _repeat_params(bm, X)
hprod_batch!(bm, X, Θ, Y, V; obj_weight=obj_weight)
lagrangian_hprod!(bm, X, Θ, Y, V; obj_weight=obj_weight)
return Hv
end

function hprod_batch!(
function lagrangian_hprod!(
bm::BatchModel,
X::AbstractMatrix,
Θ::AbstractMatrix,
Expand All @@ -40,7 +40,7 @@ function hprod_batch!(

H_batch = _maybe_view(bm, :hprod_work, X)

hess_coord_batch!(bm, X, Θ, Y, H_batch; obj_weight=obj_weight)
lagrangian_hessian!(bm, X, Θ, Y, H_batch; obj_weight=obj_weight)

fill!(Hv, zero(eltype(Hv)))
kersyspmv_batch(backend)(
Expand Down
Loading