Skip to content

Commit 2818009

Browse files
committed
pid optimization, error checking
1 parent 6cdc0fa commit 2818009

2 files changed

Lines changed: 144 additions & 87 deletions

File tree

src/msrDynamics/_msrDynamics.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tqdm import tqdm
77
from symengine import Mul
88
import sys
9+
import gc
910

1011
MAX_INT = sys.maxsize
1112

@@ -537,6 +538,7 @@ def __init__(self, name: str = None, m: float = 0.0, scp: float = 0.0, W: float
537538
self.y_out = np.array([]) # solution data, to be populated by System
538539
self._linked_nodes = [] # list of linked nodes
539540
self._dydt_linked = 0.0 # sym. expressions for linked nodes
541+
self.in_system = False # flag to check if node has been added to system
540542

541543
@property
542544
def dydt(self):
@@ -570,6 +572,9 @@ def set_dTdt_advective(self, source):
570572
if self.dndt or self.dcdt or self.drdt:
571573
raise ValueError('''This node has already been assigned
572574
point-kinetic dynamics''')
575+
if not self.in_system:
576+
raise ValueError('''Node dynamics cannot be set until added to
577+
a System() object''')
573578

574579
#check that node has been added to the system
575580
if self.y:
@@ -592,6 +597,9 @@ def set_dTdt_internal(self, source: list, k: list):
592597
if self.dndt or self.dcdt or self.drdt:
593598
raise ValueError('''This node has already been assigned
594599
point-kinetic dynamics''')
600+
if not self.in_system:
601+
raise ValueError('''Node dynamics cannot be set until added to
602+
a System() object''')
595603

596604
#check that node has been added to the system
597605
if self.y:
@@ -614,6 +622,9 @@ def set_dTdt_convective(self, source: list, hA: list):
614622
if self.dndt or self.dcdt or self.drdt:
615623
raise ValueError('''This node has already been assigned
616624
point-kinetic dynamics''')
625+
if not self.in_system:
626+
raise ValueError('''Node dynamics cannot be set until added to
627+
a System() object''')
617628
if (self.m <= 0.0) or (self.scp <= 0.0):
618629
print(f'node mass: {self.m:.2f}')
619630
print(f'node specific heat capacity: {self.scp:.2f}')
@@ -652,6 +663,9 @@ def set_dndt(self, r: y, beta_eff: float, Lambda: float, lam: list, C: list):
652663
self.dcdt:
653664
raise ValueError('''This node has already been assigned
654665
incompatible dynamics''')
666+
if not self.in_system:
667+
raise ValueError('''Node dynamics cannot be set until added to
668+
a System() object''')
655669

656670
#check that node has been added to the system
657671
if self.y:
@@ -688,6 +702,9 @@ def set_dcdt(self, n: y, beta: float, Lambda: float, lam: float, flow: bool = Fa
688702
self.dndt:
689703
raise ValueError('''This node has already been assigned
690704
incompatible dynamics''')
705+
if not self.in_system:
706+
raise ValueError('''Node dynamics cannot be set until added to
707+
a System() object''')
691708
#check that node has been added to the system
692709
if self.y:
693710
# reset in case of update
@@ -719,6 +736,9 @@ def set_drdt(self, sources: list, coeffs: list):
719736
self.dcdt:
720737
raise ValueError('''This node has already been assigned
721738
incompatible dynamics''')
739+
if not self.in_system:
740+
raise ValueError('''Node dynamics cannot be set until added to
741+
a System() object''')
722742
#check that node has been added to the system
723743
if self.y:
724744
# reset in case of update
@@ -731,6 +751,9 @@ def set_drdt(self, sources: list, coeffs: list):
731751
raise ValueError("Nodes need to be added to a System() object before setting dynamics.")
732752

733753
def set_dndt_decay(self, n: y, n0: float, rel_yield: float, lam: float):
754+
if not self.in_system:
755+
raise ValueError('''Node dynamics cannot be set until added to
756+
a System() object''')
734757
#check that node has been added to the system
735758
if self.y:
736759
# reset in case of update
@@ -743,6 +766,9 @@ def set_dydt_node(self, nodes: list, coeffs: list = None):
743766
"""
744767
Add dynamics of another node
745768
"""
769+
if not self.in_system:
770+
raise ValueError('''Node dynamics cannot be set until added to
771+
a System() object''')
746772
if coeffs is None:
747773
coeffs = [1.0]*len(nodes)
748774
for idx, n in enumerate(nodes):

src/msrDynamics/_pid_loop.py

Lines changed: 118 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(self,
4646
bound: tuple = None,
4747
clegg_integrator: bool = False,
4848
min_reading: float = None,
49+
store_all_output: bool = False,
4950
) -> None:
5051
"""
5152
Initializes the PID_loop object.
@@ -83,96 +84,126 @@ def __init__(self,
8384
self.output_sym = Function(self.name)
8485
self._output_func = None
8586
self.cumsum = 0.0
86-
self.p_output = []
87-
self.i_output = []
88-
self.d_output = []
89-
self.output = []
90-
self.err = []
91-
self.dt = []
92-
self.times = []
87+
if store_all_output:
88+
self.times = []
89+
self.output = []
90+
self.state = []
91+
self.p_output = []
92+
self.i_output = []
93+
self.d_output = []
94+
self.err = []
95+
self.dt = []
96+
self.dedt = []
97+
self.integral = []
98+
else:
99+
self.times = None
100+
self.output = None
101+
self.state = None
102+
self.p_output = None
103+
self.i_output = None
104+
self.d_output = None
105+
self.err = None
106+
self.dt = None
107+
self.dedt = None
108+
self.integral = None
109+
93110
self.err_prev = None
94-
self.dedt = []
95-
self.state = []
96111
self.bound = bound
97112
self.clegg_integrator = clegg_integrator
98113
self.de_prev = None
99114
self.min_reading = min_reading
100-
self.integral = []
101-
102-
@property
103-
def output_func(self):
104-
"""
105-
Generates or returns the PID controller function.
106-
107-
Returns:
108-
callable: A function implementing the PID control logic.
109-
"""
110-
if self._output_func is None:
111-
def pid_func(y, state, t):
112-
"""
113-
PID control logic for calculating the output.
114-
115-
Args:
116-
y (float): Current output value.
117-
state (float): Current state value.
118-
t (float): Current time.
119-
120-
Returns:
121-
float: PID control output value.
122-
"""
123-
124-
dt = t - self.times[-1] if self.times else t
125-
if (self.min_reading) and (state < self.min_reading):
126-
p_out, i_out, d_out, out, err, dedt = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
127-
else:
128-
# p
129-
err = state - self.setpoint_value
130-
if self.err_prev and self.clegg_integrator:
131-
if np.sign(self.err_prev) != np.sign(err):
132-
self.cumsum = 0.0
133-
# i
134-
self.cumsum += err*dt
135-
# d
136-
de = err - self.err_prev if self.err_prev is not None else 0.0
137-
if dt == 0.0:
138-
out = self.output[-1] if self.output else 0.0
139-
return out
140-
dedt = de / dt
141-
142-
p_out = self.k_p*err
143-
i_out = self.k_i*self.cumsum
144-
d_out = self.k_d*dedt
145-
calc = p_out + d_out + i_out + self.base_value
146-
147-
if self.bound:
148-
out = max(self.bound[0], min(calc, self.bound[1]))
115+
self.store_all_output = store_all_output
116+
self.t_prev = None
117+
self.out_prev = None
118+
119+
def output_func(self):
120+
"""
121+
Generates or returns the PID controller function.
122+
123+
Returns:
124+
callable: A function implementing the PID control logic.
125+
"""
126+
if self._output_func is None:
127+
def pid_func(y, state, t):
128+
"""
129+
PID control logic for calculating the output.
130+
131+
Args:
132+
y (float): Current output value.
133+
state (float): Current state value.
134+
t (float): Current time.
135+
136+
Returns:
137+
float: PID control output value.
138+
"""
139+
140+
# unpack attributes for performance
141+
t_prev = self.t_prev
142+
min_reading = self.min_reading
143+
setpoint_value = self.setpoint_value
144+
err_prev = self.err_prev
145+
ci = self.clegg_integrator
146+
out_prev = self.out_prev
147+
k_p = self.k_p
148+
k_i = self.k_i
149+
k_d = self.k_d
150+
base_value = self.base_value
151+
bound = self.bound
152+
dt = t - t_prev if t_prev else t
153+
if (min_reading) and (state < min_reading):
154+
p_out, i_out, d_out, out, err, dedt = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
149155
else:
150-
out = calc
151-
152-
# store inputs/outputs
153-
self.err_prev = err
154-
self.state.append(state)
155-
self.times.append(t)
156-
self.p_output.append(p_out)
157-
self.i_output.append(i_out)
158-
self.d_output.append(d_out)
159-
self.output.append(out)
160-
self.err.append(err)
161-
self.dt.append(dt)
162-
self.dedt.append(dedt)
163-
self.integral.append(self.cumsum)
164-
return out
165-
166-
return pid_func
167-
else:
168-
return self._output_func
169-
170-
@output_func.setter
171-
def output_func(self, custom_output_func):
172-
"""
173-
Sets a custom function for the PID output logic.
174-
175-
Args:
176-
custom_output_func (callable): Custom PID logic function.
177-
"""
178-
self._output_func = custom_output_func
156+
# p
157+
err = state - setpoint_value
158+
if err_prev and ci:
159+
if np.sign(err_prev) != np.sign(err):
160+
self.cumsum = 0.0
161+
# i
162+
self.cumsum += err*dt
163+
# d
164+
de = err - err_prev if err_prev is not None else 0.0
165+
if dt == 0.0:
166+
out = out_prev if out_prev else 0.0
167+
return out
168+
dedt = de / dt
169+
170+
p_out = k_p*err
171+
i_out = k_i*self.cumsum
172+
d_out = k_d*dedt
173+
calc = p_out + d_out + i_out + base_value
174+
175+
if bound:
176+
out = max(bound[0], min(calc, bound[1]))
177+
else:
178+
out = calc
179+
180+
# store inputs/outputs
181+
self.err_prev = err
182+
self.t_prev = t
183+
self.out_prev = out
184+
if self.store_all_output:
185+
self.times.append(t)
186+
self.output.append(out)
187+
self.state.append(state)
188+
self.p_output.append(p_out)
189+
self.i_output.append(i_out)
190+
self.d_output.append(d_out)
191+
self.err.append(err)
192+
self.dt.append(dt)
193+
self.dedt.append(dedt)
194+
self.integral.append(self.cumsum)
195+
return out
196+
197+
return pid_func
198+
199+
self.output_func = output_func(self)
200+
201+
# @output_func.setter
202+
# def output_func(self, custom_output_func):
203+
# """
204+
# Sets a custom function for the PID output logic.
205+
206+
# Args:
207+
# custom_output_func (callable): Custom PID logic function.
208+
# """
209+
# self._output_func = custom_output_func

0 commit comments

Comments
 (0)