Skip to content

Commit 2783cef

Browse files
committed
Support FFTs
1 parent 4ccf061 commit 2783cef

File tree

9 files changed

+270
-0
lines changed

9 files changed

+270
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ docs/site/
2222
# committed for packages, but should be committed for applications that require a static
2323
# environment.
2424
Manifest.toml
25+
.DS_Store

Project.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name = "FastTransformsForwardDiff"
2+
uuid = "77fa7db0-1c81-401d-9fde-3592fc42b8bc"
3+
authors = ["Sheehan Olver <solver@mac.com>"]
4+
version = "0.0.1"
5+
6+
[deps]
7+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
8+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
10+
[compat]
11+
AbstractFFTs = "1"
12+
FFTW = "1"
13+
ForwardDiff = "0.10"
14+
15+
[extras]
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
18+
19+
[targets]
20+
test = ["Test", "FFTW"]

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,17 @@
11
# FastTransformsForwardDiff.jl
22
A Julia package to support forward-mode auto-differentiation for fast transforms
3+
4+
5+
[![Build Status](https://github.com/JuliaApproximation/FastTransformsForwardDiff.jl/workflows/CI/badge.svg)](https://github.com/JuliaApproximation/FastTransformsForwardDiff.jl/actions)
6+
[![codecov](https://codecov.io/gh/JuliaApproximation/FastTransformsForwardDiff.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaApproximation/FastTransformsForwardDiff.jl)
7+
8+
9+
A package for forward-mode auto-differentiation for fast transforms. Currently supports the fft:
10+
```julia
11+
julia> using FastTransformsForwardDiff: derivative
12+
13+
julia> θ = range(0,2π; length=n+1)[1:end-1];
14+
15+
julia> derivative-> fft(exp.(ω .* cos.(θ)))[1]/n, 1)
16+
0.5651591039924849 + 0.0im
17+
```

src/FastTransformsForwardDiff.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module FastTransformsForwardDiff
2+
using ForwardDiff
3+
import AbstractFFTs
4+
import ForwardDiff: value, partials, npartials, Dual, tagtype, derivative, jacobian, gradient
5+
6+
@inline tagtype(::Complex{T}) where T = tagtype(T)
7+
@inline tagtype(::Type{Complex{T}}) where T = tagtype(T)
8+
9+
include("fft.jl")
10+
11+
end # module FastTransformsForwardDiff

src/fft.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value)
2+
3+
partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n))
4+
5+
npartials(x::Complex{<:Dual{T,V,N}}) where {T,V,N} = N
6+
npartials(::Type{<:Complex{<:Dual{T,V,N}}}) where {T,V,N} = N
7+
8+
# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im)
9+
AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.complexfloat.(x)
10+
AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) + 0im
11+
12+
AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x)
13+
AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d)
14+
15+
for plan in [:plan_fft, :plan_ifft, :plan_bfft]
16+
@eval begin
17+
18+
AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
19+
AbstractFFTs.$plan(value.(x), region)
20+
21+
AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, region=1:ndims(x)) =
22+
AbstractFFTs.$plan(value.(x), region)
23+
24+
end
25+
end
26+
27+
# rfft only accepts real arrays
28+
AbstractFFTs.plan_rfft(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
29+
AbstractFFTs.plan_rfft(value.(x), region)
30+
31+
for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex?
32+
@eval begin
33+
34+
AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
35+
AbstractFFTs.$plan(value.(x), region)
36+
37+
AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, region=1:ndims(x)) =
38+
AbstractFFTs.$plan(value.(x), d, region)
39+
40+
end
41+
end
42+
43+
# for f in (:dct, :idct)
44+
# pf = Symbol("plan_", f)
45+
# @eval begin
46+
# AbstractFFTs.$f(x::AbstractArray{<:Dual}) = $pf(x) * x
47+
# AbstractFFTs.$f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x
48+
# AbstractFFTs.$pf(x::AbstractArray{<:Dual}, region; kws...) = $pf(value.(x), region; kws...)
49+
# AbstractFFTs.$pf(x::AbstractArray{<:Complex}, region; kws...) = $pf(value.(x), region; kws...)
50+
# end
51+
# end
52+
53+
54+
for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities
55+
@eval begin
56+
57+
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Dual}) =
58+
_apply_plan(p, x)
59+
60+
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:Dual}}) =
61+
_apply_plan(p, x)
62+
63+
end
64+
end
65+
66+
function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
67+
xtil = p * value.(x)
68+
dxtils = ntuple(npartials(eltype(x))) do n
69+
p * partials.(x, n)
70+
end
71+
__apply_plan(tagtype(eltype(x)), xtil, dxtils)
72+
end
73+
74+
function __apply_plan(T, xtil, dxtils)
75+
map(xtil, dxtils...) do val, parts...
76+
Complex(
77+
Dual{T}(real(val), map(real, parts)),
78+
Dual{T}(imag(val), map(imag, parts)),
79+
)
80+
end
81+
end

test/runtests.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using FastTransformsForwardDiff, FFTW, Test
2+
using ForwardDiff: Dual, valtype, value, partials, derivative
3+
using AbstractFFTs: complexfloat, realfloat
4+
5+
6+
@testset "fft and rfft" begin
7+
x1 = Dual.(1:4.0, 2:5, 3:6)
8+
9+
@test value.(x1) == 1:4
10+
@test partials.(x1, 1) == 2:5
11+
12+
@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im
13+
@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0)
14+
15+
@test fft(x1, 1)[1] isa Complex{<:Dual}
16+
17+
@testset "$f" for f in [fft, ifft, rfft, bfft]
18+
@test value.(f(x1)) == f(value.(x1))
19+
@test partials.(f(x1), 1) == f(partials.(x1, 1))
20+
end
21+
22+
f = x -> real(fft([x; 0; 0])[1])
23+
@test derivative(f,0.1) 1
24+
25+
r = x -> real(rfft([x; 0; 0])[1])
26+
@test derivative(r,0.1) 1
27+
28+
29+
n = 100
30+
θ = range(0,2π; length=n+1)[1:end-1]
31+
# emperical from Mathematical
32+
@test derivative-> fft(exp.(ω .* cos.(θ)))[1]/n, 1) 0.565159103992485
33+
34+
# c = x -> dct([x; 0; 0])[1]
35+
# @test derivative(c,0.1) ≈ 1
36+
end

workflows/CompatHelper.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: CompatHelper
2+
on:
3+
schedule:
4+
- cron: 0 0 * * *
5+
workflow_dispatch:
6+
permissions:
7+
contents: write
8+
pull-requests: write
9+
jobs:
10+
CompatHelper:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- name: Check if Julia is already available in the PATH
14+
id: julia_in_path
15+
run: which julia
16+
continue-on-error: true
17+
- name: Install Julia, but only if it is not already available in the PATH
18+
uses: julia-actions/setup-julia@v1
19+
with:
20+
version: '1'
21+
arch: ${{ runner.arch }}
22+
if: steps.julia_in_path.outcome != 'success'
23+
- name: "Add the General registry via Git"
24+
run: |
25+
import Pkg
26+
ENV["JULIA_PKG_SERVER"] = ""
27+
Pkg.Registry.add("General")
28+
shell: julia --color=yes {0}
29+
- name: "Install CompatHelper"
30+
run: |
31+
import Pkg
32+
name = "CompatHelper"
33+
uuid = "aa819f21-2bde-4658-8897-bab36330d9b7"
34+
version = "3"
35+
Pkg.add(; name, uuid, version)
36+
shell: julia --color=yes {0}
37+
- name: "Run CompatHelper"
38+
run: |
39+
import CompatHelper
40+
CompatHelper.main()
41+
shell: julia --color=yes {0}
42+
env:
43+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
44+
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
45+
# COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }}

workflows/TagBot.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name: TagBot
2+
on:
3+
issue_comment: # THIS BIT IS NEW
4+
types:
5+
- created
6+
workflow_dispatch:
7+
jobs:
8+
TagBot:
9+
# THIS 'if' LINE IS NEW
10+
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
11+
# NOTHING BELOW HAS CHANGED
12+
runs-on: ubuntu-latest
13+
steps:
14+
- uses: JuliaRegistries/TagBot@v1
15+
with:
16+
token: ${{ secrets.GITHUB_TOKEN }}
17+
ssh: ${{ secrets.DOCUMENTER_KEY }}

workflows/ci.yml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
name: CI
2+
on:
3+
push:
4+
branches:
5+
- master
6+
pull_request:
7+
jobs:
8+
test:
9+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
10+
runs-on: ${{ matrix.os }}
11+
strategy:
12+
fail-fast: false
13+
matrix:
14+
version:
15+
- '1.6'
16+
- '1' # Latest stable release
17+
os:
18+
- ubuntu-latest
19+
- macOS-latest
20+
- windows-latest
21+
arch:
22+
- x64
23+
steps:
24+
- uses: actions/checkout@v2
25+
- uses: julia-actions/setup-julia@v1
26+
with:
27+
version: ${{ matrix.version }}
28+
arch: ${{ matrix.arch }}
29+
- uses: actions/cache@v1
30+
env:
31+
cache-name: cache-artifacts
32+
with:
33+
path: ~/.julia/artifacts
34+
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
35+
restore-keys: |
36+
${{ runner.os }}-test-${{ env.cache-name }}-
37+
${{ runner.os }}-test-
38+
${{ runner.os }}-
39+
- uses: julia-actions/julia-buildpkg@v1
40+
- uses: julia-actions/julia-runtest@v1
41+
- uses: julia-actions/julia-processcoverage@v1
42+
- uses: codecov/codecov-action@v1
43+
with:
44+
file: lcov.info

0 commit comments

Comments
 (0)