diff --git a/src/dnn/libcudnn.jl b/src/dnn/libcudnn.jl index a16ff41f..8e2bb987 100644 --- a/src/dnn/libcudnn.jl +++ b/src/dnn/libcudnn.jl @@ -126,6 +126,20 @@ function cudnnSetConvolutionMathType(convDesc, mathType) convDesc, mathType) end +function cudnnSetConvolutionGroupCount(convDesc,groupCount) + @check ccall((:cudnnSetConvolutionGroupCount,libcudnn), + cudnnStatus_t, + (cudnnConvolutionDescriptor_t,Cint), + convDesc,groupCount) +end + +function cudnnGetConvolutionGroupCount(convDesc,groupCount) + @check ccall((:cudnnGetConvolutionGroupCount,libcudnn), + cudnnStatus_t, + (cudnnConvolutionDescriptor_t,Ptr{Cint}), + convDesc,groupCount) +end + function cudnnCreatePoolingDescriptor(poolingDesc) @check ccall((:cudnnCreatePoolingDescriptor,libcudnn), cudnnStatus_t,