diff --git a/scratch/__init__.py b/scratch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scratch/config.py b/scratch/config.py new file mode 100644 index 0000000..74ac073 --- /dev/null +++ b/scratch/config.py @@ -0,0 +1,42 @@ +""" +config.py + +@author: wronk + +Configuration file for training/testing shallow inverse model +""" + +import numpy as np + +# Entries in structurals and subjects must correspoond, +# i.e. structurals[i] === subjects[i]. + +# Structural MRI subject names +conf_structurals = ['AKCLEE_103', 'AKCLEE_104', 'AKCLEE_105', 'AKCLEE_106', + 'AKCLEE_107', 'AKCLEE_109', 'AKCLEE_110', 'AKCLEE_115', + 'AKCLEE_117', 'AKCLEE_118', 'AKCLEE_119', 'AKCLEE_121', + 'AKCLEE_125', 'AKCLEE_126', 'AKCLEE_131', 'AKCLEE_132'] + +# Experimental subject names +conf_subjects = ['eric_sps_03', 'eric_sps_04', 'eric_sps_05', 'eric_sps_06', + 'eric_sps_07', 'eric_sps_09', 'eric_sps_10', 'eric_sps_15', + 'eric_sps_17', 'eric_sps_18', 'eric_sps_19', 'eric_sps_21', + 'eric_sps_25', 'eric_sps_26', 'eric_sps_31', 'eric_sps_32'] + +# Model params for training/testing +common_params = dict(dt=0.001, + SNR=np.inf, + rho=0.1, + lam=1.) # Weighting of regularizer cost + +# Model training params +training_params = dict(n_training_times_noise=1000, # Number of noise data samples + n_training_times_sparse=0, # Number of sparse data samples + batch_size=100, + n_training_iters=int(1e3), # Number of training iterations + opt_lr=1e-4) # Learning rate for optimizer + +# Model evaluation params +eval_params = dict(n_avg_verts=25, # Number of verts to avg when determining est position + n_test_verts=1000, # Probably should be <= 1000 to avoid mem problems + linear_inv='MNE') # sLORETA or MNE diff --git a/scratch/eval_model.py b/scratch/eval_model.py new file mode 100644 index 0000000..c2e550f --- /dev/null +++ b/scratch/eval_model.py @@ -0,0 +1,139 @@ +"""eval_model.py + + Evaluate deep neural net on sensor space signals from known source space + signals. Tests localization error and point spread + + Usage: + eval_model.py [--subj=] + eval_model.py (-h | --help) + + Options: + --subj= Specify subject to process with structural name + -h, --help Show this screen +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as op + +import mne +from mne import SourceEstimate +from mne.simulation import simulate_sparse_stc, simulate_stc, simulate_evoked +from mne.minimum_norm import apply_inverse + +import tensorflow as tf +import numpy as np + +from shallow_fun import (load_subject_objects, gen_evoked_subject, + get_all_vert_positions, get_largest_dip_positions, + get_localization_metrics, eval_error_norm, + norm_transpose) +from shallow import make_tanh_network, sparse_objective +from config import common_params, eval_params, conf_structurals, conf_subjects + +# Removing eric_sps_32/AKCL_132 b/c of vertex issue +structurals = conf_structurals[:-1] +subjects = conf_subjects[:-1] + + +if __name__ == "__main__": + + from docopt import docopt + argv = docopt(__doc__) + + megdir = argv[''] + structdir = argv[''] + model_fname = argv[''] + + struct = None + subj = None + if argv['--subj']: + struct = argv['--subj'] + subj = subjects[structurals.index(struct)] + + # Number of verts to avg when determining est position + n_avg_verts = eval_params['n_avg_verts'] + # Probably should be <= 1000 to avoid mem problems + n_test_verts = eval_params['n_test_verts'] + + sess = tf.Session() + + # Get subject info and create data + subj, fwd, inv, cov, evoked_info = load_subject_objects(megdir, subj, + struct) + vert_list = [fwd['src'][0]['vertno'], fwd['src'][1]['vertno']] + n_verts = fwd['src'][0]['nuse'] + fwd['src'][1]['nuse'] + + print("Simulating and normalizing data") + sensor_dim = len(fwd['info']['ch_names']) + source_dim = fwd['src'][0]['nuse'] + fwd['src'][1]['nuse'] + + print("Reconstructing model and restoring saved weights") + # Reconstruct network + network_dims = [source_dim // 2, source_dim // 2, source_dim] + yhat, h_list, x_sensor = make_tanh_network(sensor_dim, source_dim, network_dims) + sparse_cost, y_source, tf_rho, tf_lam = sparse_objective(sensor_dim, source_dim, + yhat, h_list, + sess) + saver = tf.train.Saver() + saver.restore(sess, model_fname) + + print("\nEvaluating deep learning approach...\n") + + # Simulate unit dipole activations + + rand_verts = np.sort(np.random.choice(range(n_verts), n_test_verts, + replace=False)) + sim_vert_data = np.eye(n_verts)[:, rand_verts] + evoked, stc = gen_evoked_subject(sim_vert_data, fwd, cov, evoked_info, + common_params['dt'], + common_params['SNR']) + + # Normalize data and transpose so it's (n_observations x n_chan) + x_test = norm_transpose(evoked.data) + y_test = norm_transpose(stc.data) + + # Ground truth dipole positions + vert_positions = get_all_vert_positions(inv['src']) + true_act_positions = vert_positions[rand_verts, :] + + feed_dict = {x_sensor: x_test, y_source: y_test, + tf_rho: common_params['rho'], tf_lam: common_params['lam']} + src_est_dl = sess.run(yhat, feed_dict) + stc_dl = SourceEstimate(src_est_dl.T, vertices=vert_list, subject=struct, + tmin=0, tstep=common_params['dt']) + + # Calculate vector norm error + error_norm_dl = eval_error_norm(y_test, src_est_dl) + + # Get position of most active dipoles and calc accuracy metrics (in meters) + est_act_positions = get_largest_dip_positions(src_est_dl, n_avg_verts, + vert_positions) + accuracy_dl, point_spread_dl = get_localization_metrics(true_act_positions, + est_act_positions) + + print("\nEvaluating standard linear approach...\n") + # + # Evaluate standard MNE methods + # + stc_std = apply_inverse(evoked, inv, method=eval_params['linear_inv']) + src_est_std = stc_std.data.T + + # Calculate vector norm error + error_norm_std = eval_error_norm(y_test, src_est_std) + est_act_positions = get_largest_dip_positions(src_est_std, n_avg_verts, + vert_positions) + accuracy_std, point_spread_std = get_localization_metrics(true_act_positions, + est_act_positions) + + sess.close() + print('\bShallow; error norm average for {} verts: {:0.4f}'.format( + n_test_verts, np.mean(error_norm_dl))) + print('Linear method: error norm average for {} verts: {:0.4f}\n'.format( + n_test_verts, np.mean(error_norm_std))) + print('Shallow; Loc. accuracy: {:0.5f}, Avg. Point spread: {:0.5f}'.format( + accuracy_dl, np.mean(point_spread_dl))) + print('Linear method; Loc. accuracy: {:0.5f}, Avg. Point spread: {:0.5f}\n'.format( + accuracy_std, np.mean(point_spread_std))) diff --git a/scratch/shallow.py b/scratch/shallow.py index efa6456..e076b0e 100644 --- a/scratch/shallow.py +++ b/scratch/shallow.py @@ -18,99 +18,83 @@ import os import os.path as op +import tensorflow as tf +import numpy as np +from scipy.sparse import random as random_sparse + import mne from mne import SourceEstimate from mne.minimum_norm import read_inverse_operator from mne.simulation import simulate_sparse_stc, simulate_stc, simulate_evoked from mne.externals.h5io import read_hdf5, write_hdf5 -import tensorflow as tf -import numpy as np - - -# Entries in structurals and subjects must correspoond, -# i.e. structurals[i] === subjects[i]. -structurals = ['AKCLEE_103', 'AKCLEE_104', 'AKCLEE_105', 'AKCLEE_106', - 'AKCLEE_107', 'AKCLEE_109', 'AKCLEE_110', 'AKCLEE_115', - 'AKCLEE_117', 'AKCLEE_118', 'AKCLEE_119', 'AKCLEE_121', - 'AKCLEE_125', 'AKCLEE_126', 'AKCLEE_131', 'AKCLEE_132'] -subjects = ['eric_sps_03', 'eric_sps_04', 'eric_sps_05', 'eric_sps_06', - 'eric_sps_07', 'eric_sps_09', 'eric_sps_10', 'eric_sps_15', - 'eric_sps_17', 'eric_sps_18', 'eric_sps_19', 'eric_sps_21', - 'eric_sps_25', 'eric_sps_26', 'eric_sps_31', 'eric_sps_32'] +from shallow_fun import (load_subject_objects, gen_evoked_subject, + get_data_batch, get_all_vert_positions, + get_largest_dip_positions, get_localization_metrics, + eval_error_norm, norm_transpose) +from config import (common_params, training_params, conf_structurals, + conf_subjects) # Removing eric_sps_32/AKCL_132 b/c of vertex issue -structurals = structurals[:-1] -subjects = subjects[:-1] +structurals = conf_structurals[:-1] +subjects = conf_subjects[:-1] +seed = 0 -def load_subject_objects(megdatadir, subj, struct): +def weight_variable(shape, name=None): + init = tf.truncated_normal(shape, stddev=0.1) + if name is not None: + return tf.Variable(init, name=name) - print(" %s: -- loading meg objects" % subj) - - fname_fwd = op.join(megdatadir, subj, 'forward', - '%s-sss-fwd.fif' % subj) - fwd = mne.read_forward_solution(fname_fwd, force_fixed=True, surf_ori=True) - - fname_inv = op.join(megdatadir, subj, 'inverse', - '%s-55-sss-meg-eeg-fixed-inv.fif' % subj) - inv = read_inverse_operator(fname_inv) - - fname_epochs = op.join(megdatadir, subj, 'epochs', - 'All_55-sss_%s-epo.fif' % subj) - #epochs = mne.read_epochs(fname_epochs) - #evoked = epochs.average() - #evoked_info = evoked.info - evoked_info = mne.io.read_info(fname_epochs) - cov = inv['noise_cov'] - - print(" %s: -- finished loading meg objects" % subj) + return tf.Variable(init) - return subj, fwd, inv, cov, evoked_info +def bias_variable(shape, name=None): + init = tf.constant(0.1, shape=shape) -def gen_evoked_subject(signal, fwd, cov, evoked_info, dt, noise_snr, - seed=None): - """Function to generate evoked and stc from signal array""" + if name is not None: + return tf.Variable(init, name=name) - vertices = [fwd['src'][0]['vertno'], fwd['src'][1]['vertno']] - stc = SourceEstimate(sim_vert_data, vertices, tmin=0, tstep=dt) + return tf.Variable(init) - evoked = simulate_evoked(fwd, stc, evoked_info, cov, noise_snr, - random_state=seed) - evoked.add_eeg_average_proj() - return evoked, stc +def make_tanh_network(sensor_dim, source_dim, dims): + """Function to create neural network""" + x_sensor = tf.placeholder(tf.float32, shape=[None, sensor_dim], + name="x_sensor") -def weight_variable(shape): - init = tf.truncated_normal(shape, stddev=0.11) - return tf.Variable(init) + W_list, b_list, h_list = [], [], [] + dims.insert(0, sensor_dim) # Augment with input layer dim -def bias_variable(shape): - init = tf.constant(0.1, shape=shape) - return tf.Variable(init) + # Loop through and create network layer at each step + for di, (dim1, dim2) in enumerate(zip(dims[:-1], dims[1:])): + W_list.append(weight_variable([dim1, dim2], name='W%i' % di)) + b_list.append(bias_variable([dim2], name='b%i' % di)) -def make_tanh_network(sensor_dim, source_dim): + # Handle input layer separately + if di == 0: + h_list.append(tf.nn.tanh(tf.matmul(x_sensor, W_list[-1]) + + b_list[-1])) + else: + h_list.append(tf.nn.tanh(tf.matmul(h_list[-1], W_list[-1]) + + b_list[-1])) - x_sensor = tf.placeholder(tf.float32, shape=[None, sensor_dim], name="x_sensor") + # Attach histogram summaries to weight functions + tf.histogram_summary('W%i Hist' % di, W_list[-1]) - W1 = weight_variable([sensor_dim, source_dim//2]) - b1 = bias_variable([source_dim//2]) - h1 = tf.nn.tanh(tf.matmul(x_sensor, W1) + b1) + # Return y_hat (final h_list layer), rest of h_list, and data placeholder + return h_list[-1], h_list[:-1], x_sensor - W2 = weight_variable([source_dim//2, source_dim]) - b2 = bias_variable([source_dim]) - h2 = tf.nn.tanh(tf.matmul(h1, W2) + b2) - W3 = weight_variable([source_dim, source_dim]) - b3 = bias_variable([source_dim]) - yhat = tf.nn.tanh(tf.matmul(h2, W3) + b3) +def bernoulli(act, rho): + """Helper to calculate sparsity penalty based on KL divergence""" - return yhat, h1, h2, x_sensor + return (rho * (tf.log(rho) - tf.log(act + 1e-6)) + + (1 - rho) * (tf.log(1 - rho) - tf.log(1 - act + 1e-6))) -def sparse_objective(sensor_dim, source_dim, yhat, h1, h2, sess): +def sparse_objective(sensor_dim, source_dim, yhat, h_list, sess): y_source = tf.placeholder(tf.float32, shape=[None, source_dim], name="y_source") rho = tf.placeholder(tf.float32, shape=(), name="rho") @@ -120,20 +104,20 @@ def sparse_objective(sensor_dim, source_dim, yhat, h1, h2, sess): error = tf.reduce_sum(tf.squared_difference(y_source, yhat)) # Remap activations to [0,1] - a1 = 0.5*h1 + 0.5 - a2 = 0.5*h2 + 0.5 + act_list = [0.5 * h_obj + 0.5 for h_obj in h_list] - kl_bernoulli_h1 = (rho*(tf.log(rho) - tf.log(a1+1e-6) - + (1-rho)*(tf.log(1-rho) - tf.log(1-a1+1e-6)))) - kl_bernoulli_h2 = (rho*(tf.log(rho) - tf.log(a2+1e-6) - + (1-rho)*(tf.log(1-rho) - tf.log(1-a2+1e-6)))) - regularizer = (tf.reduce_sum(kl_bernoulli_h1) - + tf.reduce_sum(kl_bernoulli_h2)) + kl_bernoulli_list = [bernoulli(act, rho) for act in act_list] - cost = error + lam*regularizer + regularizer = sum([tf.reduce_sum(kl_bernoulli_h) + for kl_bernoulli_h in kl_bernoulli_list]) - return cost, y_source, rho, lam + cost = error + lam * regularizer + # Attach summaries + tf.scalar_summary('error', error) + tf.scalar_summary('cost function', cost) + + return cost, y_source, rho, lam if __name__ == "__main__": @@ -149,53 +133,83 @@ def sparse_objective(sensor_dim, source_dim, yhat, h1, h2, sess): struct = argv['--subj'] subj = subjects[structurals.index(struct)] - # Params - n_times = 250 - dt = 0.001 - noise_snr = np.inf + n_training_times = training_params['n_training_times_noise'] + \ + training_params['n_training_times_sparse'] + + n_training_iters = training_params['n_training_iters'] - niter = 100 - rho = 0.05 - lam = 1. + save_network = True + fpath_save = op.join('model_subj_{}_iters.meta'.format(n_training_iters)) + n_training_times = training_params['n_training_times_noise'] + \ + training_params['n_training_times_sparse'] # Get subject info and create data subj, fwd, inv, cov, evoked_info = load_subject_objects(megdir, subj, struct) - n_verts = fwd['src'][0]['nuse'] + fwd['src'][1]['nuse'] - sim_vert_data = np.random.randn(n_verts, n_times) - - print("applying forward operator") - evoked, stc = gen_evoked_subject(sim_vert_data, fwd, cov, evoked_info, dt, - noise_snr) - - print("building neural net") + sensor_dim = len(fwd['info']['ch_names']) + source_dim = fwd['src'][0]['nuse'] + fwd['src'][1]['nuse'] + network_dims = [source_dim // 2, source_dim // 2, source_dim] + + sparse_dens = 1. / source_dim # Density of non-zero vals in sparse training data + sparse_dist = np.random.randn + + noise_data = np.random.randn(source_dim, training_params['n_training_times_noise']) + sparse_data = random_sparse(source_dim, training_params['n_training_times_sparse'], + sparse_dens, random_state=seed, + dtype=np.float32, data_rvs=sparse_dist).toarray() + sim_train_data = np.concatenate((noise_data, sparse_data), axis=1) + + print("Simulating and normalizing training data") + evoked, stc = gen_evoked_subject(sim_train_data, fwd, cov, evoked_info, common_params['dt'], + common_params['SNR']) + # Normalize training data to lie between -1 and 1 + # XXX: Appropriate to do this? Maybe need to normalize src space only + # before generating sens data + x_train = norm_transpose(evoked.data) + y_train = norm_transpose(sim_train_data) + + print("Building neural net") sess = tf.Session() + yhat, h_list, x_sensor = make_tanh_network(sensor_dim, source_dim, network_dims) + sparse_cost, y_source, tf_rho, tf_lam = sparse_objective(sensor_dim, + source_dim, yhat, + h_list, sess) + merged_summaries = tf.merge_all_summaries() + train_writer = tf.train.SummaryWriter('./train_summaries', sess.graph) - # Create neural network - sensor_dim = evoked.data.shape[0] - source_dim = n_verts - - yhat, h1, h2, x_sensor = make_tanh_network(sensor_dim, source_dim) - sparse_cost, y_source, tf_rho, tf_lam = sparse_objective(sensor_dim, source_dim, - yhat, h1, h2, sess) - - train_step = tf.train.AdamOptimizer(1e-4).minimize(sparse_cost) + train_step = tf.train.AdamOptimizer( + training_params['opt_lr']).minimize(sparse_cost) + saver = tf.train.Saver() sess.run(tf.initialize_all_variables()) - - x_sens = np.ascontiguousarray(evoked.data.T) - x_sens /= np.max(np.abs(x_sens)) - y_src = np.ascontiguousarray(sim_vert_data.T) - y_src /= np.max(np.abs(y_src)) - - print("optimizing...") - niter = 100 - for i in xrange(niter): - _, obj = sess.run([train_step, sparse_cost], - feed_dict={x_sensor: x_sens, y_source: y_src, - tf_rho: rho, tf_lam: lam} - ) - - print(" it: %d i, cost: %.2f" % (i+1, obj)) - - # Evaluate net + print("\nSim params\n----------\nn_iter: {}\nn_training_times: {}\nSNR: \ + {}\nbatch_size: {}\n".format(n_training_iters, n_training_times, + str(common_params['SNR']), training_params['batch_size'])) + print("Optimizing...") + for ii in xrange(n_training_iters): + # Get random batch of data + x_sens_batch, y_src_batch = get_data_batch(x_train, y_train, + training_params['batch_size'], seed=ii) + + # Take training step + feed_dict = {x_sensor: x_sens_batch, y_source: y_src_batch, + tf_rho: common_params['rho'], + tf_lam: common_params['lam']} + + # Save summaries for tensorboard every 10 steps + if ii % 10 == 0: + _, obj, summary = sess.run([train_step, sparse_cost, + merged_summaries], feed_dict) + train_writer.add_summary(summary, ii) + print("\titer: %04i, cost: %.2f" % (ii, obj)) + else: + _, obj = sess.run([train_step, sparse_cost], feed_dict) + + if save_network: + save_fold = 'saved_models' + if not os.path.isdir(save_fold): + os.mkdir(save_fold) + + saver.save(sess, save_fold + '/model_{}'.format(struct)) + + sess.close() diff --git a/scratch/shallow_fun.py b/scratch/shallow_fun.py new file mode 100644 index 0000000..9ee2031 --- /dev/null +++ b/scratch/shallow_fun.py @@ -0,0 +1,194 @@ +""" +shallow_fun.py + +Collection of helper functions +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as op + +import numpy as np + +import mne +from mne import SourceEstimate +from mne.minimum_norm import read_inverse_operator +from mne.simulation import simulate_sparse_stc, simulate_stc, simulate_evoked + + +def load_subject_objects(megdatadir, subj, struct, verbose=False): + + print(" %s: -- loading meg objects" % subj) + + fname_fwd = op.join(megdatadir, subj, 'forward', + '%s-sss-fwd.fif' % subj) + fwd = mne.read_forward_solution(fname_fwd, force_fixed=True, surf_ori=True, + verbose=verbose) + + fname_inv = op.join(megdatadir, subj, 'inverse', + '%s-55-sss-meg-eeg-fixed-inv.fif' % subj) + inv = read_inverse_operator(fname_inv, verbose=verbose) + + fname_epochs = op.join(megdatadir, subj, 'epochs', + 'All_55-sss_%s-epo.fif' % subj) + #epochs = mne.read_epochs(fname_epochs) + #evoked = epochs.average() + #evoked_info = evoked.info + evoked_info = mne.io.read_info(fname_epochs, verbose=verbose) + cov = inv['noise_cov'] + + print(" %s: -- finished loading meg objects" % subj) + + return subj, fwd, inv, cov, evoked_info + + +def gen_evoked_subject(signal, fwd, cov, evoked_info, dt, noise_snr, + seed=None): + """Function to generate evoked and stc from signal array""" + + vertices = [fwd['src'][0]['vertno'], fwd['src'][1]['vertno']] + stc = SourceEstimate(signal, vertices, tmin=0, tstep=dt) + + evoked = simulate_evoked(fwd, stc, evoked_info, cov, noise_snr, + random_state=seed) + evoked.set_eeg_reference() + + return evoked, stc + + +def norm_transpose(data): + """Helper to transpose data and normalize by max val""" + #XXX Probably should switch to sklearn's standard scaler + # (0 mean, unit var, and saves the scaling if we need to apply it later) + data_fixed = np.ascontiguousarray(data.T) + data_fixed /= np.max(np.abs(data_fixed)) + + return data_fixed + + +def get_data_batch(x_data, y_label, batch_size, seed=None): + """Function to get a random sampling of an evoked and stc pair""" + + # Get random sampling of data, seed by batch num + np.random.seed(seed) + rand_inds = np.random.randint(x_data.shape[0], size=batch_size) + + return x_data[rand_inds, :], y_label[rand_inds, :] + + +def get_all_vert_positions(src): + """Function to get 3-space position of used dipoles + + Parameters + ---------- + src: SourceSpaces + Source space object for subject. Needed to get dipole positions + + Returns + ------- + dip_pos: np.array shape(n_src x 3) + 3-space positions of used dipoles + """ + # Get vertex numbers and positions that are in use + # (usually ~4k - 5k out of ~150k) + left_vertno = src[0]['vertno'] + right_vertno = src[1]['vertno'] + + vertnos = np.concatenate((left_vertno, right_vertno)) + dip_pos = np.concatenate((src[0]['rr'][left_vertno, :], + src[1]['rr'][right_vertno, :])) + + return dip_pos + + +def eval_error_norm(src_data_orig, src_data_est): + """Function to compute norm of the error vector at each dipole + + Parameters + ---------- + + src_data_orig: numpy matrix size (n_samples x n_src) + Ground truth source estimate used to generate sensor data + + src_data_est: numpy matrix (n_samples x n_src) + Source estimate of sensor data created using src_data_orig + + Returns + ------- + error_norm: np.array size(n_samples) + Norm of vector between true activation and estimated activation + + """ + + #TODO: might want to normalize by number of vertices since subject source + # spaces can have different number of dipoles + + error_norm = np.zeros((src_data_orig.shape[0])) + + for ri, (row_orig, row_est) in enumerate(zip(src_data_orig, src_data_est)): + error_norm[ri] = np.linalg.norm(row_orig - row_est) + + return error_norm + + +def get_largest_dip_positions(data, n_verts, dip_pos): + """Function to get spatial centroid of highest activated dipoles + + Parameters + ---------- + data: np.array shape(n_times x n_src) + Source estimate data + n_verts: int + Number of vertices to use when computing maximum activation centroid + dip_pos: np.array shape(n_src x 3) + 3-space positions of all dipoles in source space . + + Returns + ------- + avg_pos: np.array shape(n_times x 3) + Euclidean centroid of activation for largest `n_verts` activations + """ + + #TODO: How to handle negative current vals? Use abs? + + # Initialize + largest_dip_pos = np.zeros((data.shape[0], n_verts, 3)) + + # Find largest `n_verts` dipoles at each time point and get position + for ti in range(data.shape[0]): + largest_dip_inds = data[ti, :].argsort()[-n_verts:] + largest_dip_pos[ti, :, :] = dip_pos[largest_dip_inds, :] + + return largest_dip_pos + + +def get_localization_metrics(true_pos, largest_dip_pos): + """Helper to get accuracy and point spread + + Parameters + ---------- + true_pos: np.array shape(n_times, 3) + 3D position of dipole that was simulated active + largest_dip_pos: np.array shape(n_times, n_dipoles, 3) + 3D positions of top `n_dipoles` dipoles with highest activation + + Returns + ------- + accuracy: np.array + + point_spread: float + + """ + + centroids = np.mean(largest_dip_pos, axis=1) + accuracy = np.linalg.norm(true_pos - centroids) + + # Calculate difference in x/y/z positions from true activation to each src + point_distance = np.subtract(largest_dip_pos, true_pos[:, np.newaxis, :]) + + # Calculate Euclidean distance (w/ norm) and take mean over all dipoles + point_spread = np.mean(np.linalg.norm(point_distance, axis=-1), axis=-1) + + return accuracy, point_spread