diff --git a/examples/eg004r__fitting_JR_example.py b/examples/eg004r__fitting_JR_example.py index d8a43f53..7934d4d7 100644 --- a/examples/eg004r__fitting_JR_example.py +++ b/examples/eg004r__fitting_JR_example.py @@ -23,7 +23,7 @@ # whobpyt stuff import whobpyt from whobpyt.datatypes import par, Recording -from whobpyt.models.JansenRit import RNNJANSEN, ParamsJR +from whobpyt.models.JansenRit import RNNJANSEN, ParamsJR, JansenRit_np from whobpyt.optimization.custom_cost_JR import CostsJR from whobpyt.run import Model_fitting @@ -168,4 +168,26 @@ ax[1].set_title('Test') ax[2].plot(eeg_data.T) ax[2].set_title('empirical') -plt.show() \ No newline at end of file +plt.show() + + +# %% +# Modified JR Validation Model +# --------------------------------------------------- +# +# The modified JR model + +val_sim_len = 20 # Simulation length in secs +model_validate = JansenRit_np(model.node_size, model.step_size, model.output_size, model.tr, model.sc, model.lm.detach().numpy(), model.dist.detach().numpy(), model.params) + +state_hist, hE = model_validate.forward(external = u, hx = model_validate.createIC(ver = 0), hE = np.zeros((model.node_size,500)), sim_len=val_sim_len) +# %% +# Plot the EEG +plt.figure(figsize = (16, 8)) +plt.title("M") +for n in range(model.node_size): + plt.plot(state_hist[0:2000, n, 0:1], label = "M Node = " + str(n)) # Plotting EEG window + #plt.plot(state_hist[0:200, n, 1:2] - state_hist[0:200, n, 2:3], label = "EEG Node = " + str(n)) # plotting E-I + +plt.xlabel('Time Steps (multiply by step_size to get msec), step_size = ' + str(step_size)) +#plt.legend() diff --git a/whobpyt/models/JansenRit/__init__.py b/whobpyt/models/JansenRit/__init__.py index 213b8224..56957fd1 100644 --- a/whobpyt/models/JansenRit/__init__.py +++ b/whobpyt/models/JansenRit/__init__.py @@ -1,2 +1,3 @@ from .jansen_rit import RNNJANSEN -from .ParamsJR import ParamsJR \ No newline at end of file +from .ParamsJR import ParamsJR +from .jansen_rit_validate import JansenRit_np diff --git a/whobpyt/models/JansenRit/jansen_rit_validate.py b/whobpyt/models/JansenRit/jansen_rit_validate.py new file mode 100644 index 00000000..7f5f2a81 --- /dev/null +++ b/whobpyt/models/JansenRit/jansen_rit_validate.py @@ -0,0 +1,165 @@ +# Simulate JR with numpy code for validation +# Sorenza Bastiaens +import numpy as np + +class JansenRit_np(): + + def __init__(self, node_size, step_size, output_size, tr, sc, lm, dist, params): + + + # Initialize the JR Model + # + # INPUT + # num_regions: Int - Number of nodes in network to model + # params: Params_JR - The parameters that all nodes in the network will share + # Con_Mtx: Tensor [num_regions, num_regions] - With connectivity (eg. structural connectivity) + # step_size=0.1 + self.step_size = step_size + self.tr = tr # tr ms (integration step 0.1 ms) + self.sc = sc # structural connectivity factor + self.node_size = node_size # num of ROI + self.output_size = output_size # num of EEG channels + self.params = params + self.lm = lm # leadfield matrix + self.dist = dist # distance between nodes + self.state_size = 6 # + + def createIC(self, ver): + state_lb = -0.5 + state_ub = 0.5 + + return np.random.uniform(state_lb, state_ub, (self.node_size, self.state_size)) + + def forward(self, external, hx, hE, sim_len): + + # Runs the JR model + + # Defining JR parameters as numpy + A = self.params.A.npValue() + a = self.params.a.npValue() + B = self.params.B.npValue() + b = self.params.b.npValue() + g = self.params.g.npValue() + c1 = self.params.c1.npValue() + c2 = self.params.c2.npValue() + c3 = self.params.c3.npValue() + c4 = self.params.c4.npValue() + std_in = self.params.std_in.npValue() + vmax = self.params.vmax.npValue() + v0 = self.params.v0.npValue() + r = self.params.r.npValue() + y0 = self.params.y0.npValue() + mu = self.params.mu.npValue() + k = self.params.k.npValue() + cy0 = self.params.cy0.npValue() + ki = self.params.ki.npValue() + + g_f = self.params.g_f.npValue() + g_b = self.params.g_b.npValue() + + # Sigmoid function + def sigmoid(x, vmax, v0, r): + return vmax / (1 + np.exp(r * (v0 - x))) + + init_state = hx + sim_len = sim_len + step_size = self.step_size + + state_hist = np.zeros((int(sim_len/step_size), self.node_size, 7)) + M = init_state[:, 0:1] + E = init_state[:, 1:2] + I = init_state[:, 2:3] + Mv = init_state[:, 3:4] + Ev = init_state[:, 4:5] + Iv = init_state[:, 5:6] + + num_steps = int(sim_len/step_size) + dt = step_size + self.w_bb = np.zeros((self.node_size, self.node_size)) + self.w_ff = np.zeros((self.node_size, self.node_size)) + self.w_ll = np.zeros((self.node_size, self.node_size)) + # Update the Laplacian based on the updated connection gains w_bb. + w_b = np.exp(self.w_bb) * np.array(self.sc) + w_n_b = w_b / np.linalg.norm(w_b) + self.sc_m_b = w_n_b + dg_b = -np.diag(np.sum(w_n_b, axis=1)) + + # Update the Laplacian based on the updated connection gains w_ff. + w_f = np.exp(self.w_ff) * np.array(self.sc) + w_n_f = w_f / np.linalg.norm(w_f) + self.sc_m_f = w_n_f + dg_f = -np.diag(np.sum(w_n_f, axis=1)) + + # Update the Laplacian based on the updated connection gains w_ll. + w_l = np.exp(self.w_ll) * np.array(self.sc) + w_n_l = (0.5 * (w_l + np.transpose(w_l, (1, 0)))) / np.linalg.norm( + 0.5 * (w_l + np.transpose(w_l, (1, 0)))) + self.sc_fitted = w_n_l + dg_l = -np.diag(np.sum(w_n_l, axis=1)) + + + + self.delays = (self.dist / mu).astype(int) + + # TODO currently single node, need to add all the connections and make it multiple nodes + for i in range(num_steps): + + # LEd is to include the delays from other nodes + # con_1 = 1 + # Don't include boundaries so no k_lb for example and no m(x) stuff + + # Basically rM inludes (LEd_l + 1 * torch.matmul(dg_l, M)) + # Calculate the derivatives + # Lateral is P-P + # Forward is P-E + # Backward is P-I + Ed = np.zeros((self.node_size, self.node_size)) + hE_new = hE.copy() + Ed = hE_new[1,self.delays] #hE_new.gather(1,self.delays) + LEd_b = np.reshape(np.sum(w_n_b * np.transpose(Ed, (1, 0)), 1), (self.node_size, 1)) # Not sure if this needs to be included in validation + LEd_f = np.reshape(np.sum(w_n_f * np.transpose(Ed, (1, 0)), 1), (self.node_size, 1)) + LEd_l = np.reshape(np.sum(w_n_l * np.transpose(Ed, (1, 0)), 1), (self.node_size, 1)) + u_tms = 200 # Need to only had within a certain time frame, test again with 0 + rM = k * ki * u_tms + std_in*np.random.randn(self.node_size, 1) + g * (LEd_l + 1 * np.matmul(dg_l, M)) + rE = std_in*np.random.randn(self.node_size, 1) + g_f * (LEd_f + 1 * np.matmul(dg_f, E - I)) + rI = std_in*np.random.randn(self.node_size, 1) + g_b * (-LEd_b - 1 * np.matmul(dg_b, E - I)) + + dM = dt * Mv + dE = dt * Ev + dI = dt * Iv + dMv = dt * (A*a*( rM + sigmoid(vmax,v0,r, E - I))- (2*a*Mv) - (a**(2)*M)) # BE CAREGUL rM in code has the sigmoid so only take everything else from original code + dEv = dt * (A*a*(mu + rE + (c2*sigmoid(vmax,v0,r,(c1*M)))) - (2*a*Ev) - (a**(2)*E)) + dIv = dt * (B*b*(rI + c4*sigmoid(vmax,v0,r,(c3*M))) - (2*b*Iv) - (b**(2)*I)) + + # Update the state + dM = dM.detach().numpy() + dE = dE.detach().numpy() + dI = dI.detach().numpy() + dMv = dMv.detach().numpy() + dEv = dEv.detach().numpy() + dIv = dIv.detach().numpy() + M = M + dM + E = E + dE + I = I + dI + Mv = Mv + dMv + Ev = Ev + dEv + Iv = Iv + dIv + hE = np.concatenate((M, hE[:, :-1]), axis=1) #np.cat([M, hE[:, :-1]], axis=1) # update placeholders for pyramidal buffer + + state_hist[i, :, 0:1] = M + state_hist[i, :, 1:2] = E + state_hist[i, :, 2:3] = I + state_hist[i, :, 3:4] = Mv + state_hist[i, :, 4:5] = Ev + state_hist[i, :, 5:6] = Iv + + # Capture the states at every step . + #lm_t = (self.lm.T / np.sqrt(self.lm ** 2).sum(1)).T + #self.lm_t = (lm_t - 1 / self.output_size * np.matmul(np.ones((1, self.output_size)), lm_t)) + #temp = cy0 * np.matmul(self.lm_t, M[:self.node_size, :]) - 1 * y0 + #state_hist[i, :, 6:7] = temp # eeg_window + + # Should then downsample the state_hist to the sampling rate of the EEG + return state_hist, hE + +