diff --git a/Project.toml b/Project.toml index 7fdfcb0..4f4d295 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SourceCodeMcCormick" uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960" authors = ["Robert Gottlieb "] -version = "0.5.0" +version = "0.5.1" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/kernel_writer/kernel_write.jl b/src/kernel_writer/kernel_write.jl index 3e8eb46..16e4c1a 100644 --- a/src/kernel_writer/kernel_write.jl +++ b/src/kernel_writer/kernel_write.jl @@ -10,7 +10,7 @@ kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwr kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic) function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool) # Create a hash of the expression and check if the function already exists - expr_hash = string(hash(num+sum(gradlist)), base=62) + expr_hash = string(hash(string(num)*string(gradlist)), base=62) if (overwrite==false) && (isfile(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))) try func_name = eval(Meta.parse("f_"*expr_hash)) return func_name @@ -102,9 +102,6 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons elseif splitting==:high # Formerly default split_point = 1500 max_size = 2000 - # elseif splitting==:high # More splitting - # split_point = 1000 - # max_size = 1200 elseif splitting==:max # Extremely small split_point = 500 max_size = 750 @@ -116,7 +113,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons sparsity = detect_sparsity(factored, gradlist) # Decide if the kernel needs to be split - if (n_vars[end] < 31) && (n_lines[end] <= max_size) + if (n_vars[end] < 31) && ((n_lines[end] <= max_size) || (findfirst(x -> x > split_point, n_lines)==length(n_lines))) # Complexity is fairly low; only a single kernel needed create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored, sparsity) push!(kernel_nums, 1) @@ -130,7 +127,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons while !complete # Determine which line to break at line_ID = findfirst(x -> x > split_point, n_lines) - vars_ID = findfirst(x -> x == 31, n_vars) + vars_ID = findfirst(x -> (x == 30) || (x == 31), n_vars) if isnothing(vars_ID) new_ID = line_ID elseif isnothing(line_ID) @@ -188,7 +185,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons n_lines = complexity(factored) n_vars = var_counts(factored) - # If the total number of lines (not including the final line) is below 2000 + # If the total number of lines (not including the final line) is below the max size # and the number of variables is below 32, we can make the final kernel and be done if (n_vars[end] < 32) && (all(n_lines[1:end-1] .<= max_size)) create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored, sparsity) @@ -328,7 +325,12 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a") # Put in the preamble. - write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist))) + if isempty(vars) + write(file, preamble_string(expr_hash, ["OUT";], 1, 1, length(gradlist))) + else + write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist))) + end + # Depending on the format of the expression, compose the kernel differently if typeof(expr) <: Real @@ -360,9 +362,9 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num end end else # There must be two elements in the dictionary - binary_vars = string.(get_name.(keys(key.dict))) + binary_vars = string.(get_name.(keys(expr.dict))) binary_vars = binary_vars[sort_vars(binary_vars)] - write(file, SCMC_quadaff_binary(vars..., expr.coeff, varlist)) + write(file, SCMC_quadaff_binary(binary_vars..., expr.coeff, varlist)) end elseif exprtype(expr)==ADD @@ -394,7 +396,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num # EAGO already does this and bypasses the need to calculate relaxations. # But, for compatibility with McCormick-style relaxations in ParBB, # it's easier to simply calculate what ParBB is expecting.) - write(file, postamble_quadaff(string.(vars), varlist)) + if isempty(varlist) + write(file, postamble_quadaff(String[], String[])) + elseif isempty(vars) + write(file, postamble_quadaff(String[], varlist)) + else + write(file, postamble_quadaff(string.(vars), varlist)) + end close(file) # Include this kernel so SCMC knows what it is @@ -403,7 +411,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num # Add onto the file the "main" CPU function that calls the kernel blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT)) file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a") - write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist))) + if isempty(gradlist) + write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, Symbol[])) + elseif isempty(vars) + write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, get_name.(gradlist))) + else + write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist))) + end close(file) # Include the file again to get the final kernel @@ -731,6 +745,7 @@ end # 7) log(inv(x1)) = -log(x1) [EAGO paper] # 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1 # 9) 1 / (1 + exp(-x)) = Sigmoid(x) +# 10) sin(x) = cos(x - pi/2) # # Forms that aren't relevant yet: # 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers) @@ -826,7 +841,7 @@ function perform_substitutions(old_factored::Vector{Equation}) end end # Create a factorization of this new expr - new_factorization = factor(new_expr) + new_factorization = factor(new_expr, split_div=true) # Scan through the new factorization to see if we can merge elements # with the original factored list done = false @@ -1191,7 +1206,7 @@ function perform_substitutions(old_factored::Vector{Equation}) new_expr *= arg end # Create a factorization of this new expr - new_factorization = factor(new_expr) + new_factorization = factor(new_expr, split_div=true) # Scan through the new factorization to see if we can merge elements @@ -1315,6 +1330,38 @@ function perform_substitutions(old_factored::Vector{Equation}) end end end + + # 10) sin(x) = cos(x - pi/2) + if exprtype(factored[index0].rhs)==TERM + if factored[index0].rhs.f==sin + # We found sin(arg). Check if (arg - pi/2) exists, + # and if so, also check if cos(arg - pi/2) exists. + scan_flag = true + index1 = findfirst(x -> isequal(x.rhs, arguments(factored[index0].rhs)[] - pi/2), factored) + if !isnothing(index1) + index2 = findfirst(x -> isequal(x.rhs, cos(factored[index1].lhs)), factored) + if !isnothing(index2) + # cos(arg - pi/2) exists already (index2). Remove all reference to index0 and replace with index2 + for i in eachindex(factored) + @eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$index2].lhs)) + end + deleteat!(factored, index0) + else + # arg - pi/2 exists already (index1), but not cos(arg - pi/2). Change + # index0 to be cos of index1.lhs instead of sin of arg + @eval $factored[$index0] = $factored[$index0].lhs ~ cos($factored[$index1].lhs) + end + else + # (arg - pi/2) doesn't exist, so we need to create it + newsym = gensym(:aux) + newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end]) + newvar = genvar(newsym) + insert!(factored, index0, Equation(Symbolics.value(newvar), arguments(factored[index0].rhs)[] - pi/2)) + @eval $factored[$index0+1] = $factored[$index0+1].lhs ~ cos($newvar) + end + break + end + end end end @@ -1511,6 +1558,8 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto write(file, SCMC_sigmoid_kernel(inputs..., gradlist, sparsity)) elseif RHS.f==sqrt write(file, SCMC_float_power_kernel(inputs..., 0.5, gradlist, sparsity)) + elseif RHS.f==cos + write(file, SCMC_cos_kernel(inputs..., gradlist, sparsity)) else close(file) error("Some function was used that we can't handle yet ($RHS)") @@ -1845,6 +1894,10 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star else total_lines += 190 end + new_ID = findfirst(x -> isequal(x.lhs, RHS.base), factorized) + if !isnothing(new_ID) + total_lines += _complexity(complexity, factorized, new_ID) + end elseif exprtype(RHS) == TERM if RHS.f==exp total_lines += 212 # Ranges from 212--310 @@ -1866,6 +1919,16 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star end elseif RHS.f==sqrt total_lines += 190 + new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized) + if !isnothing(new_ID) + total_lines += _complexity(complexity, factorized, new_ID) + end + elseif RHS.f==cos || RHS.f==sin + total_lines += 300 + new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized) + if !isnothing(new_ID) + total_lines += _complexity(complexity, factorized, new_ID) + end else error("Unknown function") end diff --git a/src/kernel_writer/math_kernels.jl b/src/kernel_writer/math_kernels.jl index 103952c..bc7a8f7 100644 --- a/src/kernel_writer/math_kernels.jl +++ b/src/kernel_writer/math_kernels.jl @@ -6,6 +6,12 @@ # these same functions, but in buffer/string form for the purposes of writing # new kernels. +# NOTE: These kernels might all be faster if we flip the ordering of indices. +# I.e., instead of having each row be a unique point to evaluate, make +# each column a unique point to evaluate. Preliminary checking on my +# workstation says this could be ~25% faster (tried for multiplication, +# 100000 unique points) + #= Unitary Rules =# @@ -4455,6 +4461,183 @@ function SCMC_large_float_power_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix, c return nothing end +# Cosine (argument should be in radians) +# NOTE: Sine can be cos(x - pi/2) +function SCMC_cos_kernel(OUT, x) +# function SCMC_cos_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix) + idx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x + stride = blockDim().x * gridDim().x + colmax = Int32((size(OUT,2)-4)/2) + + while idx <= Int32(size(OUT,1)) + # Reset the column counter + col = Int32(1) + + # Get lower and upper bounds from the interval + if (x[idx,4] - x[idx,3]) >= 2.0*pi + OUT[idx,3] = -1.0 + OUT[idx,4] = 1.0 + else + lo_quadrant, lo = quadrant(x[idx,3]) + hi_quadrant, hi = quadrant(x[idx,4]) + + if lo_quadrant == hi_quadrant + if x[idx,4] - x[idx,3] > 3.141592653589793 + OUT[idx,3] = -1.0 + OUT[idx,4] = 1.0 + elseif lo_quadrant==2 || lo_quadrant==3 + OUT[idx,3] = cos(lo) + OUT[idx,4] = cos(hi) + else + OUT[idx,3] = cos(hi) + OUT[idx,4] = cos(lo) + end + elseif lo_quadrant==2 && hi_quadrant==3 + OUT[idx,3] = cos(lo) + OUT[idx,4] = cos(hi) + elseif lo_quadrant==0 && hi_quadrant==1 + OUT[idx,3] = cos(hi) + OUT[idx,4] = cos(lo) + elseif (lo_quadrant==2 || lo_quadrant==3) && (hi_quadrant==0 || hi_quadrant==1) + OUT[idx,3] = min(cos(lo), cos(hi)) + OUT[idx,4] = 1.0 + elseif (lo_quadrant==0 || lo_quadrant==1) && (hi_quadrant==2 || hi_quadrant==3) + OUT[idx,3] = -1.0 + OUT[idx,4] = max(cos(lo), cos(hi)) + else + OUT[idx,3] = -1.0 + OUT[idx,4] = 1.0 + end + end + + + # get eps_min and eps_max + kL = Base.ceil(-0.5 - x[idx,3]/(2.0*pi)) + xL1 = x[idx,3] + 2.0*pi*kL + xU1 = x[idx,4] + 2.0*pi*kL + if (xL1 < -pi) || (xL1 > pi) + eps_min = NaN + eps_max = NaN + elseif xL1 <= 0.0 + if xU1 <= 0.0 + eps_min = x[idx,3] + eps_max = x[idx,4] + elseif xU1 >= pi + eps_min = pi - 2.0*pi*kL + eps_max = -2.0*pi*kL + else + eps_min = (cos(xL1) <= cos(xU1)) ? x[idx,3] : x[idx,4] + eps_max = -2.0*pi*kL + end + elseif xU1 <= pi + eps_min = x[idx,4] + eps_max = x[idx,3] + elseif xU1 >= 2.0*pi + eps_min = pi - 2.0*pi*kL + eps_max = 2.0*pi - 2.0*pi*kL + else + eps_min = pi - 2.0*pi*kL + eps_max = (cos(xL1) >= cos(xU1)) ? x[idx,3] : x[idx,4] + end + + midcv, cv_id, midcc, cc_id = midvals(x[idx,1], x[idx,2], eps_min, eps_max) + + # Call cv normally + cv, dcv = SCMC_cv_cos(midcv, x[idx,3], x[idx,4]) + OUT[idx,1] = cv + + # Call cc by shifting and negating the cv path + neg_cc, neg_dcc = SCMC_cv_cos(midcc - pi, x[idx,3] - pi, x[idx,4] - pi) + OUT[idx,2] = -neg_cc + dcc = -neg_dcc + + # Now we need mid_grad things... + if cv_id==1 + if cc_id==1 + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + elseif cv_id==2 + if cc_id==1 + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + while col <= colmax + OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + else + if cc_id==1 + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc + col += Int32(1) + end + elseif cc_id==2 + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc + col += Int32(1) + end + else + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + end + + # Perform the cut operation + if OUT[idx,1] < OUT[idx,3] + OUT[idx,1] = OUT[idx,3] + col = Int32(1) + while col <= colmax + OUT[idx,end-2*colmax+col] = 0.0 + col += Int32(1) + end + end + if OUT[idx,2] > OUT[idx,4] + OUT[idx,2] = OUT[idx,4] + col = Int32(1) + while col <= colmax + OUT[idx,end-1*colmax+col] = 0.0 + col += Int32(1) + end + end + + idx += stride + end + return nothing +end + + #= Binary Rules =# @@ -4862,6 +5045,269 @@ function SCMC_add_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix, y::CuDeviceMatr end +################## +# Helper functions for some kernels to use +function midvals(xcv::Float64, xcc::Float64, eps_min::Float64, eps_max::Float64) + if xcc >= xcv + if xcv == xcc + midcc = xcv + cc_id = Int32(2) + midcv = xcv + cv_id = Int32(2) + elseif xcv >= eps_max + if xcv >= eps_min + midcc = xcv + cc_id = Int32(2) + midcv = xcv + cv_id = Int32(2) + elseif eps_min >= xcc + midcc = xcv + cc_id = Int32(2) + midcv = xcc + cv_id = Int32(1) + else + midcc = xcv + cc_id = Int32(2) + midcv = eps_min + cv_id = Int32(3) + end + elseif eps_max >= xcc + if xcv >= eps_min + midcc = xcc + cc_id = Int32(1) + midcv = xcv + cv_id = Int32(2) + elseif eps_min >= xcc + midcc = xcc + cc_id = Int32(1) + midcv = xcc + cv_id = Int32(1) + else + midcc = xcc + cc_id = Int32(1) + midcv = eps_min + cv_id = Int32(3) + end + else + if xcv >= eps_min + midcc = eps_max + cc_id = Int32(3) + midcv = xcv + cv_id = Int32(2) + elseif eps_min >= xcc + midcc = eps_max + cc_id = Int32(3) + midcv = xcc + cv_id = Int32(1) + else + midcc = eps_max + cc_id = Int32(3) + midcv = eps_min + cv_id = Int32(3) + end + end + elseif eps_max >= xcv + if eps_min >= xcv + midcc = xcv + cc_id = Int32(2) + midcv = xcv + cv_id = Int32(2) + elseif xcc >= eps_min + midcc = xcv + cc_id = Int32(2) + midcv = xcc + cv_id = Int32(1) + else + midcc = xcv + cc_id = Int32(2) + midcv = eps_min + cv_id = Int32(3) + end + elseif xcc >= eps_max + if eps_min >= xcv + midcc = xcc + cc_id = Int32(1) + midcv = xcv + cv_id = Int32(2) + elseif xcc >= eps_min + midcc = xcc + cc_id = Int32(1) + midcv = xcc + cv_id = Int32(1) + else + midcc = xcc + cc_id = Int32(1) + midcv = eps_min + cv_id = Int32(3) + end + else + if eps_min >= xcv + midcc = eps_max + cc_id = Int32(3) + midcv = xcv + cv_id = Int32(2) + elseif xcc >= eps_min + midcc = eps_max + cc_id = Int32(3) + midcv = xcc + cv_id = Int32(1) + else + midcc = eps_max + cc_id = Int32(3) + midcv = eps_min + cv_id = Int32(3) + end + end + return midcv, cv_id, midcc, cc_id +end + +@inline function SCMC_cv_cos(x::Float64, xL::Float64, xU::Float64) + kL = Base.ceil(-0.5 - xL/(2.0*pi)) + if x <= (pi - 2.0*pi*kL) + xL1 = xL + 2.0*pi*kL + if xL1 >= pi/2.0 + return cos(x), -sin(x) + end + xU1 = min(xU + 2.0*pi*kL, pi) + if (xL1 >= -pi/2) && (xU1 <= pi/2) + if abs(xU - xL) < 1E-10 + return cos(xL), 0.0 + else + return cos(xL) + (x - xL)*(cos(xU) - cos(xL))/(xU - xL), + (cos(xU) - cos(xL))/(xU - xL) + end + end + return SCMC_cv_cosin(x + 2.0*pi*kL, xL1, xU1) + end + kU = Base.floor(0.5 - xU/(2.0*pi)) + if (x >= -pi - 2.0*pi*kU) + xU2 = xU + 2.0*pi*kU + if xU2 <= -pi/2.0 + return cos(x), -sin(x) + end + return SCMC_cv_cosin(x + 2.0*pi*kU, max(xL + 2.0*pi*kU, -pi), xU2) + end + return -1.0, 0.0 +end + +# Needs to return only two things (inlining to make comparisons with x) +@inline function SCMC_cv_cosin(x::Float64, xL::Float64, xU::Float64) + if abs(xL) <= abs(xU) + left = false + x0 = xU + xm = xL + else + left = true + x0 = xL + xm = xU + end + xj = cos_newton_or_golden_section(x0, xL, xU, xm) + if (left && (x <= xj)) || (~left && (x >= xj)) + return cos(x), -sin(x) + else + if abs(xm - xj) < 1e-10 + return cos(xm), 0.0 + else + return cos(xm) + (x - xm)*(cos(xm) - cos(xj))/(xm - xj), (cos(xm) - cos(xj))/(xm - xj) + end + end +end + +function cos_newton_or_golden_section(x0::Float64, xL::Float64, xU::Float64, envp::Float64) + dfk = 0.0 + xk = max(xL, min(x0, xU)) + fk = (xk - envp)*sin(xk) + cos(xk) - cos(envp) + iter = Int32(1) + while iter <= Int32(100) + dfk = (xk - envp)*cos(xk) + if abs(fk) < 1e-10 + return xk + end + if iszero(dfk) + xk = 0.0 + break # Need to do golden section + end + if (xk == xL) && (fk/dfk > 0.0) + return xk + elseif (xk == xU) && (fk/dfk < 0.0) + return xk + end + xk = max(xL, min(xU, xk - fk/dfk)) + fk = (xk - envp)*sin(xk) + cos(xk) - cos(envp) + iter += Int32(1) + end + + # If flag, we need to do golden section instead + a_golden = xL + fa_golden = (a_golden - envp)*sin(a_golden) + cos(a_golden) - cos(envp) + c_golden = xU + fc_golden = (c_golden - envp)*sin(c_golden) + cos(c_golden) - cos(envp) + + if fa_golden*fc_golden > 0 + xk = NaN + return xk + end + + b_golden = xU - (2.0 - Base.MathConstants.golden)*(xU - xL) + fb_golden = (b_golden - envp)*sin(b_golden) + cos(b_golden) - cos(envp) + + iter = Int32(1) + while iter <= Int32(100) + if (c_golden - b_golden > b_golden - a_golden) + x_golden = b_golden + (2.0 - Base.MathConstants.golden)*(c_golden - b_golden) + if abs(c_golden-a_golden) < 1.0e-10*(abs(b_golden) + abs(x_golden)) || iter == Int32(100) + xk = (c_golden + a_golden)/2.0 + return xk + end + iter += Int32(1) + fx_golden = (x_golden - envp)*sin(x_golden) + cos(x_golden) - cos(envp) + if fa_golden*fx_golden < 0.0 + c_golden = x_golden + fc_golden = fx_golden + else + a_golden = b_golden + fa_golden = fb_golden + b_golden = x_golden + fb_golden = fx_golden + end + else + x_golden = b_golden - (2.0 - Base.MathConstants.golden)*(b_golden - a_golden) + if abs(c_golden-a_golden) < 1.0e-10*(abs(b_golden) + abs(x_golden)) || iter == Int32(100) + xk = (c_golden + a_golden)/2.0 + return xk + end + iter += Int32(1) + fx_golden = (x_golden - envp)*sin(x_golden) + cos(x_golden) - cos(envp) + if fa_golden*fb_golden < 0.0 + c_golden = b_golden + fc_golden = fb_golden + b_golden = x_golden + fb_golden = fx_golden + else + a_golden = x_golden + fa_golden = fx_golden + end + end + end + + # Should never get to this point, but for completeness... + return xk +end + +# Directly from IntervalArithmetic.jl +function quadrant(x::Float64) + x_mod2pi = rem2pi(x, RoundNearest) + + x_mod2pi < -(pi/2.0) && return (Int32(2), x_mod2pi) + x_mod2pi < 0 && return (Int32(3), x_mod2pi) + x_mod2pi <= (pi/2.0) && return (Int32(0), x_mod2pi) + + return Int32(1), x_mod2pi +end + + + + ################## # Some templates that are useful for writing new kernels. diff --git a/src/kernel_writer/string_math_kernels.jl b/src/kernel_writer/string_math_kernels.jl index 869b6e1..5de8ab8 100644 --- a/src/kernel_writer/string_math_kernels.jl +++ b/src/kernel_writer/string_math_kernels.jl @@ -3803,7 +3803,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") write(buffer, " while col <= colmax\n") if sparsity_case == 1 @@ -4021,7 +4021,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cc))\n") write(buffer, " while col <= colmax\n") if sparsity_case == 1 @@ -4167,7 +4167,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_hi))\n") write(buffer, " while col <= colmax\n") write(buffer, " $OUT_cvgrad $eq 0.0\n") @@ -4314,7 +4314,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cv))\n") write(buffer, " while col <= colmax\n") if sparsity_case == 1 @@ -4478,7 +4478,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_cc))\n") write(buffer, " while col <= colmax\n") if sparsity_case == 1 @@ -4624,7 +4624,7 @@ function SCMC_sigmoid_kernel(OUT::String, v1::String, varlist::Vector{String}, s write(buffer, " col += Int32(1)\n") write(buffer, " end\n") write(buffer, " else\n") - write(buffer, " $OUT_cv $eq ((1.0/(1.0 + exp(-$v1_lo)))*($v1_hi - $v1_lo) + (1.0/(1.0 + exp(-$v1_hi)))*($v1_lo - $v1_lo))/($v1_hi - $v1_lo)\n") + write(buffer, " $OUT_cv $eq 1.0/(1.0 + exp(-$v1_lo))\n") write(buffer, " $OUT_cc $eq 1.0/(1.0 + exp(-$v1_hi))\n") write(buffer, " while col <= colmax\n") write(buffer, " $OUT_cvgrad $eq 0.0\n") @@ -9697,6 +9697,528 @@ function SCMC_float_power_kernel(OUT::String, v1::String, POW::T, varlist::Vecto return String(take!(buffer)) end +# Cos +# max threads: ??? +function SCMC_cos_kernel(OUT::String, v1::String, varlist::Vector{String}, sparsity::Vector{Int}) + if startswith(v1, "temp") + v1_cv = "$(v1)_cv" + v1_cc = "$(v1)_cc" + v1_lo = "$(v1)_lo" + v1_hi = "$(v1)_hi" + v1_cvgrad = "$(v1)_cvgrad[col]" + v1_ccgrad = "$(v1)_ccgrad[col]" + elseif startswith(v1, "aux") + v1_cv = "$(v1)[idx,1]" + v1_cc = "$(v1)[idx,2]" + v1_lo = "$(v1)[idx,3]" + v1_hi = "$(v1)[idx,4]" + v1_cvgrad = "$(v1)[idx,end-2*colmax+col]" + v1_ccgrad = "$(v1)[idx,end-1*colmax+col]" + else + v1_cv = "$(v1)[idx,1]" + v1_cc = "$(v1)[idx,1]" + v1_lo = "$(v1)[idx,2]" + v1_hi = "$(v1)[idx,3]" + end + if startswith(OUT, "temp") + OUT_cv = "$(OUT)_cv" + OUT_cc = "$(OUT)_cc" + OUT_lo = "$(OUT)_lo" + OUT_hi = "$(OUT)_hi" + OUT_cvgrad = "$(OUT)_cvgrad[col]" + OUT_ccgrad = "$(OUT)_ccgrad[col]" + else + OUT_cv = "$(OUT)[idx,1]" + OUT_cc = "$(OUT)[idx,2]" + OUT_lo = "$(OUT)[idx,3]" + OUT_hi = "$(OUT)[idx,4]" + OUT_cvgrad = "$(OUT)[idx,end-2*colmax+col]" + OUT_ccgrad = "$(OUT)[idx,end-1*colmax+col]" + end + + # Get the anti-sparsity list (elements NOT being used) + antisparsity = collect(1:length(varlist)) + antisparsity = antisparsity[antisparsity .∉ Ref(sparsity)] + + # Determine the sparsity case: + # 1) Use sparsity list + # 2) Use antisparsity list (because it's shorter than the sparsity list) + # 3) Don't use either, simply calculate all elements + if length(sparsity) <= length(antisparsity) + sparsity_case = 1 + sparsity_string = join(["col == Int32($(x))" for x in sparsity], " || ") + elseif length(antisparsity) > 0 + antisparsity_string = join(["col == Int32($(x))" for x in antisparsity], " || ") + sparsity_case = 2 + else + sparsity_case = 3 + end + + # Create the buffer that we will write to + buffer = Base.IOBuffer() + + # Write all the lines to the buffer + if startswith(v1, r"aux|temp") + write(buffer, " ##############################\n") + write(buffer, " ## Cosine (Or Shifted Sine) ##\n") + write(buffer, " ##############################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " if ($v1_hi - $v1_lo) >= 2.0*pi\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " else\n") + write(buffer, " lo_quadrant, lo = SourceCodeMcCormick.quadrant($v1_lo)\n") + write(buffer, " hi_quadrant, hi = SourceCodeMcCormick.quadrant($v1_hi)\n") + write(buffer, "\n") + write(buffer, " if lo_quadrant == hi_quadrant\n") + write(buffer, " if $v1_hi - $v1_lo > 3.141592653589793\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " elseif lo_quadrant==2 || lo_quadrant==3\n") + write(buffer, " $OUT_lo = cos(lo)\n") + write(buffer, " $OUT_hi = cos(hi)\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = cos(hi)\n") + write(buffer, " $OUT_hi = cos(lo)\n") + write(buffer, " end\n") + write(buffer, " elseif lo_quadrant==2 && hi_quadrant==3\n") + write(buffer, " $OUT_lo = cos(lo)\n") + write(buffer, " $OUT_hi = cos(hi)\n") + write(buffer, " elseif lo_quadrant==0 && hi_quadrant==1\n") + write(buffer, " $OUT_lo = cos(hi)\n") + write(buffer, " $OUT_hi = cos(lo)\n") + write(buffer, " elseif (lo_quadrant==2 || lo_quadrant==3) && (hi_quadrant==0 || hi_quadrant==1)\n") + write(buffer, " $OUT_lo = min(cos(lo), cos(hi))\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " elseif (lo_quadrant==0 || lo_quadrant==1) && (hi_quadrant==2 || hi_quadrant==3)\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = max(cos(lo), cos(hi))\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get eps_min and eps_max\n") + write(buffer, " kL = Base.ceil(-0.5 - $v1_lo/(2.0*pi))\n") + write(buffer, " xL1 = $v1_lo + 2.0*pi*kL\n") + write(buffer, " xU1 = $v1_hi + 2.0*pi*kL\n") + write(buffer, " if (xL1 < -pi) || (xL1 > pi)\n") + write(buffer, " eps_min = NaN\n") + write(buffer, " eps_max = NaN\n") + write(buffer, " elseif xL1 <= 0.0\n") + write(buffer, " if xU1 <= 0.0\n") + write(buffer, " eps_min = $v1_lo\n") + write(buffer, " eps_max = $v1_hi\n") + write(buffer, " elseif xU1 >= pi\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = -2.0*pi*kL\n") + write(buffer, " else\n") + write(buffer, " eps_min = (cos(xL1) <= cos(xU1)) ? $v1_lo : $v1_hi\n") + write(buffer, " eps_max = -2.0*pi*kL\n") + write(buffer, " end\n") + write(buffer, " elseif xU1 <= pi\n") + write(buffer, " eps_min = $v1_hi\n") + write(buffer, " eps_max = $v1_lo\n") + write(buffer, " elseif xU1 >= 2.0*pi\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = 2.0*pi - 2.0*pi*kL\n") + write(buffer, " else\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = (cos(xL1) >= cos(xU1)) ? $v1_lo : $v1_hi\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get midcv and midcc by finding the middle values of (cv, cc, eps_min), (cv, cc, eps_max)\n") + write(buffer, " midcv, cv_id, midcc, cc_id = SourceCodeMcCormick.midvals($v1_cv, $v1_cc, eps_min, eps_max)\n") + write(buffer, "\n") + write(buffer, " # Call the SCMC_cv_cos function for both cv and cc\n") + write(buffer, " cv, dcv = SourceCodeMcCormick.SCMC_cv_cos(midcv, $v1_lo, $v1_hi)\n") + write(buffer, " neg_cc, neg_dcc = SourceCodeMcCormick.SCMC_cv_cos(midcc - pi, $v1_lo - pi, $v1_hi - pi)\n") + write(buffer, " $OUT_cv = cv\n") + write(buffer, " $OUT_cc = -neg_cc\n") + write(buffer, " dcc = -neg_dcc\n") + write(buffer, "\n") + write(buffer, " # Calculate subgradients\n") + write(buffer, " if cv_id==1\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_ccgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " elseif cv_id==2\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = $v1_cvgrad*dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " if cc_id==1\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_ccgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " elseif cc_id==2\n") + write(buffer, " while col <= colmax\n") + if sparsity_case == 1 + write(buffer, " if $sparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + elseif sparsity_case == 2 + write(buffer, " if $antisparsity_string\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + write(buffer, " end\n") + else + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = $v1_cvgrad*dcc\n") + end + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + else + ID = findfirst(==(v1), varlist) + isnothing(ID) && error("Empty varlist") + write(buffer, " ##############################\n") + write(buffer, " ## Cosine (Or Shifted Sine) ##\n") + write(buffer, " ##############################\n") + write(buffer, "\n") + write(buffer, " # Reset the column counter\n") + write(buffer, " col = Int32(1)\n") + write(buffer, "\n") + write(buffer, " # Begin rule\n") + write(buffer, " if ($v1_hi - $v1_lo) >= 2.0*pi\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " else\n") + write(buffer, " lo_quadrant, lo = SourceCodeMcCormick.quadrant($v1_lo)\n") + write(buffer, " hi_quadrant, hi = SourceCodeMcCormick.quadrant($v1_hi)\n") + write(buffer, "\n") + write(buffer, " if lo_quadrant == hi_quadrant\n") + write(buffer, " if $v1_hi - $v1_lo > 3.141592653589793\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " elseif lo_quadrant==2 && hi_quadrant==3\n") + write(buffer, " $OUT_lo = cos(lo)\n") + write(buffer, " $OUT_hi = cos(hi)\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = cos(hi)\n") + write(buffer, " $OUT_hi = cos(lo)\n") + write(buffer, " end\n") + write(buffer, " elseif lo_quadrant==2 && hi_quadrant==3\n") + write(buffer, " $OUT_lo = cos(lo)\n") + write(buffer, " $OUT_hi = cos(hi)\n") + write(buffer, " elseif lo_quadrant==0 && hi_quadrant==1\n") + write(buffer, " $OUT_lo = cos(hi)\n") + write(buffer, " $OUT_hi = cos(lo)\n") + write(buffer, " elseif (lo_quadrant==2 || lo_quadrant==3) && (hi_quadrant==0 || hi_quadrant==1)\n") + write(buffer, " $OUT_lo = min(cos(lo), cos(hi))\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " elseif (lo_quadrant==0 || lo_quadrant==1) && (hi_quadrant==2 || hi_quadrant==3)\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = max(cos(lo), cos(hi))\n") + write(buffer, " else\n") + write(buffer, " $OUT_lo = -1.0\n") + write(buffer, " $OUT_hi = 1.0\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get eps_min and eps_max\n") + write(buffer, " kL = Base.ceil(-0.5 - $v1_lo/(2.0*pi))\n") + write(buffer, " xL1 = $v1_lo + 2.0*pi*kL\n") + write(buffer, " xU1 = $v1_hi + 2.0*pi*kL\n") + write(buffer, " if (xL1 < -pi) || (xL1 > pi)\n") + write(buffer, " eps_min = NaN\n") + write(buffer, " eps_max = NaN\n") + write(buffer, " elseif xL1 <= 0.0\n") + write(buffer, " if xU1 <= 0.0\n") + write(buffer, " eps_min = $v1_lo\n") + write(buffer, " eps_max = $v1_hi\n") + write(buffer, " elseif xU1 >= pi\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = -2.0*pi*kL\n") + write(buffer, " else\n") + write(buffer, " eps_min = (cos(xL1) <= cos(xU1)) ? $v1_lo : $v1_hi\n") + write(buffer, " eps_max = -2.0*pi*kL\n") + write(buffer, " end\n") + write(buffer, " elseif xU1 <= pi\n") + write(buffer, " eps_min = $v1_hi\n") + write(buffer, " eps_max = $v1_lo\n") + write(buffer, " elseif xU1 >= 2.0*pi\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = 2.0*pi - 2.0*pi*kL\n") + write(buffer, " else\n") + write(buffer, " eps_min = pi - 2.0*pi*kL\n") + write(buffer, " eps_max = (cos(xL1) >= cos(xU1)) ? $v1_lo : $v1_hi\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Get midcv and midcc by finding the middle values of (cv, cc, eps_min), (cv, cc, eps_max)\n") + write(buffer, " midcv, cv_id, midcc, cc_id = SourceCodeMcCormick.midvals($v1_cv, $v1_cc, eps_min, eps_max)\n") + write(buffer, "\n") + write(buffer, " # Call the SCMC_cv_cos function for both cv and cc\n") + write(buffer, " cv, dcv = SourceCodeMcCormick.SCMC_cv_cos(midcv, $v1_lo, $v1_hi)\n") + write(buffer, " neg_cc, neg_dcc = SourceCodeMcCormick.SCMC_cv_cos(midcc - pi, $v1_lo - pi, $v1_hi - pi)\n") + write(buffer, " $OUT_cv = cv\n") + write(buffer, " $OUT_cc = -neg_cc\n") + write(buffer, " dcc = -neg_dcc\n") + write(buffer, "\n") + write(buffer, " # Calculate subgradients\n") + write(buffer, " if cv_id==1 || cv_id==2\n") + write(buffer, " if cc_id==1 || cc_id==2\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = dcv\n") + write(buffer, " $OUT_ccgrad = dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = dcv\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " if cc_id==1 || cc_id==2\n") + write(buffer, " while col <= colmax\n") + write(buffer, " if col == Int32($ID)\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = dcc\n") + write(buffer, " else\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " end\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " else\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, "\n") + write(buffer, " # Cut\n") + write(buffer, " if $OUT_cv < $OUT_lo\n") + write(buffer, " $OUT_cv = $OUT_lo\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_cvgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + write(buffer, " if $OUT_cc > $OUT_hi\n") + write(buffer, " $OUT_cc = $OUT_hi\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " $OUT_ccgrad = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, " end\n") + end + return String(take!(buffer)) +end + #= Binary Rules =# @@ -14107,6 +14629,23 @@ function SCMC_quadaff_initialize(CONST::Real) # Create the buffer that we will write to buffer = Base.IOBuffer() + # Reset subgradients to 0, since they only get added to in quadaff expressions + write(buffer, " ##################################\n") + write(buffer, " ## Reset Terms and Subgradients ##\n") + write(buffer, " ##################################\n") + write(buffer, "\n") + write(buffer, " temp1_cv = 0.0\n") + write(buffer, " temp1_cc = 0.0\n") + write(buffer, " temp1_lo = 0.0\n") + write(buffer, " temp1_hi = 0.0\n") + write(buffer, " col = Int32(1)\n") + write(buffer, " while col <= colmax\n") + write(buffer, " temp1_cvgrad[col] = 0.0\n") + write(buffer, " temp1_ccgrad[col] = 0.0\n") + write(buffer, " col += Int32(1)\n") + write(buffer, " end\n") + write(buffer, "\n") + # Write the initialization of the quadratic constants # to the buffer write(buffer, " #############################\n") @@ -14116,6 +14655,7 @@ function SCMC_quadaff_initialize(CONST::Real) write(buffer, " intercept_cv = $(Float64(CONST))\n") write(buffer, " intercept_cc = $(Float64(CONST))\n") write(buffer, "\n") + return String(take!(buffer)) end diff --git a/src/transform/utilities.jl b/src/transform/utilities.jl index c7c9d25..ebbc545 100644 --- a/src/transform/utilities.jl +++ b/src/transform/utilities.jl @@ -284,6 +284,9 @@ function pull_vars(eqns::Vector{Equation}) end return vars end +function pull_vars(eqn::T) where T<:Real + return Num[] +end # Sorts variables in a more logical ordering, to be consistent # with McCormick.jl organization.