Skip to content

Commit 4557455

Browse files
committed
wip: working printing for assignment operator
1 parent 4cefb1a commit 4557455

File tree

6 files changed

+90
-24
lines changed

6 files changed

+90
-24
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ using DispatchDoctor: @stable, @unstable
1212
include("NodePreallocation.jl")
1313
include("Strings.jl")
1414
include("Evaluate.jl")
15-
include("SpecialOperators.jl")
1615
include("EvaluateDerivative.jl")
1716
include("ChainRules.jl")
1817
include("EvaluationHelpers.jl")
1918
include("Simplify.jl")
2019
include("OperatorEnumConstruction.jl")
2120
include("Expression.jl")
2221
include("ExpressionAlgebra.jl")
22+
include("SpecialOperators.jl")
2323
include("Random.jl")
2424
include("Parse.jl")
2525
include("ParametricExpression.jl")

src/Evaluate.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ import ..NodeUtilsModule: is_constant
1010
import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded
1111
import ..ValueInterfaceModule: is_valid, is_valid_array
1212

13+
# Overloaded by SpecialOperators.jl:
14+
function any_special_operators(_)
15+
return false
16+
end
17+
function special_operator end
18+
function deg2_eval_special end
19+
function deg1_eval_special end
20+
1321
const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
1422

1523
macro return_on_nonfinite_val(eval_options, val, X)
@@ -268,7 +276,7 @@ function _eval_tree_array(
268276
# we can just return the constant result.
269277
if tree.degree == 0
270278
return deg0_eval(tree, cX, eval_options)
271-
elseif is_constant(tree)
279+
elseif !any_special_operators(operators) && is_constant(tree)
272280
# Speed hack for constant trees.
273281
const_result = dispatch_constant_tree(tree, operators)::ResultOk{T}
274282
!const_result.ok &&
@@ -330,6 +338,7 @@ end
330338
eval_options::EvalOptions,
331339
) where {T}
332340
nbin = get_nbin(operators)
341+
special_operators = any_special_operators(operators)
333342
long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
334343
if long_compilation_time
335344
return quote
@@ -352,15 +361,15 @@ end
352361
i -> let op = operators.binops[i]
353362
if special_operator(op)
354363
deg2_eval_special(tree, cX, op, eval_options)
355-
elseif tree.l.degree == 0 && tree.r.degree == 0
364+
elseif !$(special_operators) && tree.l.degree == 0 && tree.r.degree == 0
356365
deg2_l0_r0_eval(tree, cX, op, eval_options)
357-
elseif tree.r.degree == 0
366+
elseif !$(special_operators) && tree.r.degree == 0
358367
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
359368
!result_l.ok && return result_l
360369
@return_on_nonfinite_array(eval_options, result_l.x)
361370
# op(x, y), where y is a constant or variable but x is not.
362371
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
363-
elseif tree.l.degree == 0
372+
elseif !$(special_operators) && tree.l.degree == 0
364373
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
365374
!result_r.ok && return result_r
366375
@return_on_nonfinite_array(eval_options, result_r.x)
@@ -393,7 +402,8 @@ end
393402
if long_compilation_time
394403
return quote
395404
op = operators.unaops[op_idx]
396-
special_operator(op) && return deg1_eval_special(tree, cX, op, eval_options)
405+
special_operator(op) &&
406+
return deg1_eval_special(tree, cX, op, eval_options, operators)
397407
result = _eval_tree_array(tree.l, cX, operators, eval_options)
398408
!result.ok && return result
399409
@return_on_nonfinite_array(eval_options, result.x)
@@ -408,8 +418,8 @@ end
408418
i -> i == op_idx,
409419
i -> let op = operators.unaops[i]
410420
if special_operator(op)
411-
deg1_eval_special(tree, cX, op, eval_options)
412-
elseif !special_operators &&
421+
deg1_eval_special(tree, cX, op, eval_options, operators)
422+
elseif !$(special_operators) &&
413423
tree.l.degree == 2 &&
414424
tree.l.l.degree == 0 &&
415425
tree.l.r.degree == 0
@@ -418,7 +428,7 @@ end
418428
dispatch_deg1_l2_ll0_lr0_eval(
419429
tree, cX, op, l_op_idx, operators.binops, eval_options
420430
)
421-
elseif !special_operators && tree.l.degree == 1 && tree.l.l.degree == 0
431+
elseif !$(special_operators) && tree.l.degree == 1 && tree.l.l.degree == 0
422432
# op(op2(x)), where x is a constant or variable.
423433
l_op_idx = tree.l.op
424434
dispatch_deg1_l1_ll0_eval(
@@ -941,10 +951,4 @@ end
941951
end
942952
end
943953

944-
# Overloaded by SpecialOperators.jl:
945-
function any_special_operators end
946-
function special_operator end
947-
function deg2_eval_special end
948-
function deg1_eval_special end
949-
950954
end

src/SpecialOperators.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,32 @@ module SpecialOperatorsModule
22

33
using ..OperatorEnumModule: OperatorEnum
44
using ..EvaluateModule: _eval_tree_array, @return_on_nonfinite_array, deg2_eval
5+
using ..ExpressionModule: AbstractExpression
6+
using ..ExpressionAlgebraModule: @declare_expression_operator
57

68
import ..EvaluateModule:
79
special_operator, deg2_eval_special, deg1_eval_special, any_special_operators
10+
import ..StringsModule: get_op_name
811

9-
function any_special_operators(::Type{OperatorEnum{B,U}}) where {B,U}
12+
function any_special_operators(::Union{O,Type{O}}) where {B,U,O<:OperatorEnum{B,U}}
1013
return any(special_operator, B.types) || any(special_operator, U.types)
1114
end
1215

1316
# Use this to customize evaluation behavior for operators:
1417
@inline special_operator(::Type) = false
1518
@inline special_operator(f) = special_operator(typeof(f))
1619

17-
# Base.@kwdef struct WhileOperator <: Function
18-
# max_iters::Int = 100
19-
# end
2020
Base.@kwdef struct AssignOperator <: Function
2121
target_register::Int
2222
end
23-
24-
# @inline special_operator(::Type{WhileOperator}) = true
23+
@declare_expression_operator((op::AssignOperator), 1)
2524
@inline special_operator(::Type{AssignOperator}) = true
25+
get_op_name(o::AssignOperator) = "[{FEATURE_" * string(o.target_register) * "} =]"
2626

27+
# Base.@kwdef struct WhileOperator <: Function
28+
# max_iters::Int = 100
29+
# end
30+
# @inline special_operator(::Type{WhileOperator}) = true
2731
# function deg2_eval_special(tree, cX, op::WhileOperator, eval_options)
2832
# cond = tree.l
2933
# body = tree.r
@@ -43,7 +47,7 @@ end
4347
# end
4448
# TODO: Need to void any instance of buffer when using while loop.
4549

46-
function deg1_eval_special(tree, cX, op::AssignOperator, eval_options)
50+
function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators)
4751
result = _eval_tree_array(tree.l, cX, operators, eval_options)
4852
!result.ok && return result
4953
@return_on_nonfinite_array(eval_options, result.x)
@@ -54,4 +58,4 @@ function deg1_eval_special(tree, cX, op::AssignOperator, eval_options)
5458
return result
5559
end
5660

57-
end
61+
end

src/Strings.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ end
5656
end
5757
end
5858

59+
const FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH = length("{FEATURE_")
60+
function replace_feature_placeholders(s::String, f_variable::Function, variable_names)
61+
return replace(
62+
s,
63+
r"\{FEATURE_(\d+)\}" =>
64+
m -> f_variable(
65+
parse(Int, m[(begin + FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH):(end - 1)]),
66+
variable_names,
67+
),
68+
)
69+
end
70+
5971
# Can overload these for custom behavior:
6072
needs_brackets(val::Real) = false
6173
needs_brackets(val::AbstractArray) = false
@@ -179,7 +191,9 @@ function string_tree(
179191
c
180192
end,
181193
)
182-
return String(strip_brackets(raw_output))
194+
string_output = String(strip_brackets(raw_output))
195+
string_output = replace_feature_placeholders(string_output, f_variable, variable_names)
196+
return string_output
183197
end
184198

185199
# Print an equation

test/test_special_operators.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using TestItems: @testitem
2+
3+
@testitem "AssignOperator basic functionality" begin
4+
using DynamicExpressions
5+
using DynamicExpressions.SpecialOperatorsModule: AssignOperator
6+
using DynamicExpressions.EvaluateModule: eval_tree_array
7+
using Test
8+
using Random
9+
10+
# Define operators and variable names
11+
assign_op2 = AssignOperator(; target_register=2)
12+
operators = OperatorEnum(;
13+
binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_op2]
14+
)
15+
variable_names = ["x1", "x2", "x3", "x4", "x5"]
16+
17+
# Test data
18+
X = zeros(Float64, 2, 3)
19+
X[1, :] .= [1.0, 2.0, 3.0]
20+
X[2, :] .= [0.5, 1.5, 2.5]
21+
22+
# 1. Basic register assignment - assign constant to register 2,
23+
# and then add the return to `x2` (which should now be 3.0!)
24+
x1 = Expression(Node(; feature=1); operators, variable_names)
25+
x2 = Expression(Node(; feature=2); operators, variable_names)
26+
assign_expr = assign_op2(0.0 * x1 + 3.0) + x2
27+
28+
@test string_tree(assign_expr) == "[x2 =]((0.0 * x1) + 3.0) + x2"
29+
30+
# We should see that x2 will become 3.0 _before_ adding
31+
result, completed = eval_tree_array(assign_expr, X)
32+
@test completed == true
33+
@test all(==(6.0), result)
34+
35+
# We should also see that X is not changed by this
36+
@test X[2, :] == [0.5, 1.5, 2.5]
37+
38+
# But, with the reverse order, we get the x2 _before_ it was reassigned
39+
assign_expr_reverse = x2 + assign_op2(0.0 * x1 + 3.0)
40+
result, completed = eval_tree_array(assign_expr_reverse, X)
41+
@test completed == true
42+
@test result == [3.5, 4.5, 5.5]
43+
end

test/unittest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,4 @@ include("test_expression_math.jl")
133133
include("test_structured_expression.jl")
134134
include("test_readonlynode.jl")
135135
include("test_zygote_gradient_wrapper.jl")
136+
include("test_special_operators.jl")

0 commit comments

Comments
 (0)