-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmvp.py
More file actions
149 lines (126 loc) · 3.98 KB
/
mvp.py
File metadata and controls
149 lines (126 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Try to rewrite this code with Python
source: https://blog.janestreet.com/computations-that-differentiate-debug-and-document-themselves/
"""
from typing import List, Optional
"""
(** A computation involving some set of variables. It can be evaluated, and
the partial derivative of each variable will be automatically computed. *)
type t
module Variable : sig
(** An identifier used to name a variable. *)
module ID : String_id.S
(** A variable in a computation. *)
type t
val create : id:ID.t -> initial_value:float -> t
(** Returns the current value of this variable. *)
val get_value : t -> float
(** Returns the current partial derivative of this variable. *)
val get_derivative : t -> float
(** Sets the current value of this variable. *)
val set_value : t -> float -> unit
end
"""
class Variable:
def __init__(self, id, init_val, parents:Optional[List]= None) -> None:
self.ID = id
self.v = init_val
self.parents: List[Variable] = parents if parents is not None else []
def __repr__(self) -> str:
return '{}: {}'.format(self.ID, self.v)
def get_value(self):
return self.v
def get_derivative(self):
...
def set_value(self, v) -> None:
self.v = v
"""
(** Constructs a computation representing a constant value. *)
val constant : float -> t
"""
def constant(id, v) -> Variable:
return Variable(id, v)
"""
(** Constructs a computation representing a single variable. *)
val variable : Variable.t -> t
"""
def variable(id, v) -> Variable:
return Variable(id, v)
"""
(** Constructs a computation representing the sum over some [t]s. *)
val sum : t list -> t
"""
def vsum(vars: List[Variable]) -> Variable:
return Variable('SUM', sum([var.v for var in vars]), parents=vars)
"""
(** Constructs a computation representing the square of [t]. *)
val square : t -> t
"""
def vsquare(var) -> Variable:
return Variable('SQUARE', var.v**2, parents=[var])
def vproduct(var1, var2) -> Variable:
return Variable('PRODUCT', var1.v * var2.v, parents=[var1, var2])
def vinner_product(vars1, vars2) -> Variable:
assert len(vars1) == len(vars2)
products = [vproduct(var1, var2) for var1, var2 in zip(vars1, vars2)]
sum_of_products = vsum(products)
return Variable('INNERPRODUCT', sum_of_products.v, parents=[sum_of_products])
"""
(** [evaluate t] evaluates the computation [t] and returns the result, and
updates the derivative information in the variables in [t]. *)
val evaluate : t -> float
"""
def evaluate(var):
return ...
def expand(var, max_lvl=None, _lvl=0):
if _lvl == max_lvl: return
print(' ' * _lvl, var)
for parent in var.parents:
expand(parent, max_lvl=max_lvl, _lvl=_lvl+1)
"""
```
let computation =
let var name initial_value =
variable (Variable.create ~id:(Variable.ID.of_string name) ~initial_value)
in
let x = var "x" 2. in
let y = var "y" 4. in
square (sum [x; square (sum [ y; constant 1.0 ])])
;;
```
"""
def expr1():
one = constant('one', 1)
two = constant('two', 2)
y = variable('y', 4)
x = variable('x', 2)
computation = vsum([y, one])
computation = vsquare(computation)
computation = vsum([x, computation])
computation = vsquare(computation)
computation = vsum([computation, two])
# print(computation, computation.parents)
expand(computation); print()
expand(computation, max_lvl=2); print()
expand(computation, max_lvl=3); print()
def expr2():
computation = constant('zero', 0)
for i in range(5):
if i % 2 == 0:
computation = vsum([computation, constant('i', i)])
expand(computation)
def expr3():
aums = [constant('april aum', 10_000_000),
constant('may aum', 12_000_000),
constant('june aum', 12_500_000),]
weights = [constant('april weight', 0.772),
constant('may weight', 0.765),
constant('june weight', 0.758),]
computation = vinner_product(aums, weights)
expand(computation); print()
expand(computation, max_lvl=2); print()
expand(computation, max_lvl=3)
if __name__ == '__main__':
# expr1()
# expr2()
expr3()