Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 62 additions & 4 deletions catgrad/bidirectional/operation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
from dataclasses import dataclass
from abc import abstractmethod
from open_hypergraphs import OpenHypergraph, FiniteFunction, IndexedCoproduct, FrobeniusFunctor
from open_hypergraphs import OpenHypergraph, FiniteFunction

from catgrad.signature import NdArrayType, obj, op, sigma_0, sigma_1
from catgrad.signature import NdArrayType, obj, op
import catgrad.core.operation as ops
from catgrad.special.definition import Definition
from catgrad.combinators import *
from catgrad.combinators import identity, twist, permutation, canonical

class Optic:
@abstractmethod
Expand Down Expand Up @@ -280,7 +280,6 @@ def __post_init__(self):
def arrow(self):
# here we write a morphism in *core*!
T = self.T
U = NdArrayType((), T.dtype)

full1 = op(ops.Constant(T, 1))

Expand Down Expand Up @@ -324,6 +323,65 @@ def rev(self):

sigmoid = canonical(lambda T: op(Sigmoid(T)))

@dataclass(frozen=True)
class Tanh(Definition, Lens):
T: NdArrayType
def source(self): return obj(self.T)
def target(self): return obj(self.T)

def __post_init__(self):
if not self.T.dtype.is_floating():
raise ValueError("Tanh is not defined for non-floating-point dtypes")

########################################
# Tanh as a Core definition

# The definition of the Tanh function in terms of Core ops
# tanh(x) = 2*sigmoid(2x) - 1
#
def arrow(self):
# here we write a morphism in *core*!
T = self.T

# The constants -1 and 2
minus1 = op(ops.Constant(T, -1))
two = op(ops.Constant(T, 2))

# 2*x
twox = ((two @ identity(obj(T))) >> op(ops.Multiply(T)))

# sigmoid(2*x)
sig = twox >> sigmoid(obj(T))

# 2 * sigmoid(2*x)
sig = (two @ sig) >> op(ops.Multiply(T))

# 2 * sigmoid(2*x) - 1
return (sig @ minus1) >> op(ops.Add(T))

########################################
# Tanh as an Optic

# we want this to appear as a Definition in core, so we just return the op
# as a singleton diagram.
def to_core(self):
return op(self)

# The forward map is like Lens, but we copy the *output*, not the input.
def fwd(self):
return op(self) >> copy(self.source())

# The reverse map is 1 - tanh(x)^2
# FIXME
def rev(self):
T = obj(self.T)
pow = exponentiate(2)(T) >> tanh(T) # tanh(x)^2
grad = (constant(1)(T) @ pow) >> subtract(T) # 1 - tanh(x)^2
return (grad @ identity(T)) >> multiply(T) # (1 - tanh(x)^2) * dy


tanh = canonical(lambda T: op(Tanh(T)))

def relu(X):
return copy(X) >> (gt_constant(0)(X) @ identity(X)) >> multiply(X)

Expand Down
3 changes: 3 additions & 0 deletions catgrad/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def linear(A: NdArrayType, B: NdArrayType, C: NdArrayType):
def bias(A: NdArrayType):
return (parameter(obj(A)) @ identity(obj(A))) >> add(obj(A))

def linear_with_bias(A: NdArrayType, B: NdArrayType, C: NdArrayType):
return linear(A, B, C) >> bias(A+C)

sigmoid = canonical(lambda T: op(Sigmoid(T)))

def dense(A: NdArrayType, B: NdArrayType, C: NdArrayType, activation=sigmoid):
Expand Down
3 changes: 1 addition & 2 deletions catgrad/signature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List, Tuple, Any, Callable, Protocol
from abc import ABC, abstractmethod
from typing import Tuple, Any, Protocol
from dataclasses import dataclass
from enum import Enum, auto

Expand Down