Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 88 additions & 23 deletions library/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ def calculate_n(self, batch_vectorizer):
'''
self.theta = self.model.transform(batch_vectorizer)
self.pwd = np.dot(self.phi.values, self.theta.values)
self.nwd = np.zeros((self.phi.shape[0], self.theta.shape[1]))
phi_index = self.phi.index
is_phi_multiindex = isinstance(phi_index, pd.core.indexes.multi.MultiIndex)
self.nwd = pd.DataFrame(self.nwd, phi_index, self.theta.columns)
self.nwd = pd.DataFrame(
np.zeros((self.phi.shape[0], self.theta.shape[1])),
phi_index, self.theta.columns
)
print(self.nwd.shape)
phi_index_set = set(phi_index)

doc2token = {}
for batch_id in range(len(batch_vectorizer._batches_list)):
Expand All @@ -65,15 +69,21 @@ def calculate_n(self, batch_vectorizer):
for token_id, token_weight in zip(item.token_id, item.token_weight):
token = batch.token[token_id]
modality = batch.class_id[token_id]
token_key = (modality, token) if is_phi_multiindex else token


doc2token[theta_item_id]['tokens'].append(token)
doc2token[theta_item_id]['weights'].append(token_weight)
if token_key in phi_index_set:
doc2token[theta_item_id]['tokens'].append(token_key)
doc2token[theta_item_id]['weights'].append(token_weight)
self.nwd.loc[token_key, theta_item_id] += token_weight

'''
if is_phi_multiindex:
if phi_index.isin([(modality, token)]).any():
self.nwd.loc[(modality, token), theta_item_id] += token_weight
elif phi_index.isin([token]).any():
self.nwd.loc[token, theta_item_id] += token_weight
'''

previous_num_document_passes = self.model._num_document_passes
self.model._num_document_passes = 10
Expand All @@ -87,10 +97,14 @@ def calculate_n(self, batch_vectorizer):
tokens = functools.reduce(operator.iconcat, tokens, [])

ndw = np.concatenate([np.array(doc2token[doc_id]['weights']) for doc_id in docs_unique])

self._ndw = np.tile(ndw, (self.ptdw.shape[0], 1))
print(self._ndw.shape)

self.ptdw.columns = pd.MultiIndex.from_arrays([docs, tokens], names=('doc', 'token'))
self.ntdw = self.ptdw * self._ndw
# self.nwd = pd.DataFrame(data=ndw, index=self.ntdw.columns).T
print(self.nwd.shape)

self.ntd = self.ntdw.groupby(level=0, axis=1).sum()
self.nwt = self.ntdw.groupby(level=1, axis=1).sum().T
Expand Down Expand Up @@ -125,12 +139,7 @@ def calculate_s_t(self, batch_vectorizer, alpha=None, use_ptdw=None, calculate_n
if calculate_n:
self.calculate_n(batch_vectorizer)

if alpha is not None:
model_loss = (self.pwd < alpha / self.nd).astype(int)
else:
model_loss = self.nwd / self.nd / self.pwd
model_loss[np.isnan(model_loss)] = 1
model_loss = np.log(model_loss)
model_loss = self.calc_model_loss()

s_t = np.zeros(self.ntd.shape[0])
for t in range(s_t.shape[0]):
Expand Down Expand Up @@ -177,12 +186,7 @@ def calculate_s_td(self, batch_vectorizer, alpha=None, use_ptdw=False, calculate
if calculate_n:
self.calculate_n(batch_vectorizer)

if alpha is not None:
model_loss = (self.pwd < alpha / self.nd).astype(int)
else:
model_loss = self.nwd / self.nd / self.pwd
model_loss[np.isnan(model_loss)] = 1
model_loss = np.log(model_loss)
model_loss = self.calc_model_loss()

s_td = np.zeros(self.ntd.shape)
for t in range(s_td.shape[0]):
Expand All @@ -202,13 +206,7 @@ def calculate_s_wt(self, batch_vectorizer, alpha=None, use_ptdw=False, calculate
if calculate_n:
self.calculate_n(batch_vectorizer)

if alpha is not None:
model_loss = (self.pwd < alpha / self.nd).astype(int)
else:
model_loss = self.nwd / self.nd / self.pwd
model_loss[np.isnan(model_loss)] = 1
model_loss = np.log(model_loss)

model_loss = self.calc_model_loss()
s_wt = np.zeros(self.nwt.shape)
for t in range(s_wt.shape[1]):
ptwd_t = np.matmul(self.phi.iloc[:, t].values.reshape(-1, 1), self.theta.iloc[t, :].values.reshape(1, -1))
Expand All @@ -221,6 +219,14 @@ def calculate_s_wt(self, batch_vectorizer, alpha=None, use_ptdw=False, calculate

return s_wt

def calc_model_loss(self, alpha=None):
if alpha is not None:
model_loss = (self.pwd < alpha / self.nd).astype(int)
else:
model_loss = self.nwd / self.nd / self.pwd
model_loss[np.isnan(model_loss)] = 1
model_loss = np.log(model_loss)

def calculate_topic_statistics(self, batch_vectorizer, alpha=1, recalculate_n=True, calculate_n=False):
'''
Calculates topic semantic heterogenity and topic impurity
Expand All @@ -243,3 +249,62 @@ def calculate_topic_statistics(self, batch_vectorizer, alpha=1, recalculate_n=Tr
bin_ptdw_imp_t = self.calculate_imp_t(batch_vectorizer, binary_loss=True, use_ptdw=True)

return s_t, bin_s_t, ptdw_s_t, bin_ptdw_s_t, imp_t, bin_imp_t, ptdw_imp_t, bin_ptdw_imp_t


'''
def select_nonzeros(some_series):
return some_series[some_series.nonzero()[0]]



def compute_all_stats(model, demo_data, modality="@lemmatized"):
n_tdw, n_td, n_wt, n_t, n_dw = tn_calculate_n(model._model, demo_data.get_batch_vectorizer(), modality)

phi = model.get_phi()
theta = model.get_theta(dataset=demo_data)

predicted_p_wd = np.dot(phi, theta)
predicted_p_wd = pd.DataFrame(data=predicted_p_wd, index=phi.index, columns=theta.columns).loc[modality]

observed_p_wd = np.zeros_like(predicted_p_wd)
observed_p_wd = pd.DataFrame(data=observed_p_wd, index=phi.loc[modality].index, columns=theta.columns)

observed_pdw_series = n_dw.loc[0]

to_iter = observed_pdw_series.index.levels[0].unique()

for doc in tqdm(to_iter, total=to_iter.shape[0]):
observed_p_wd[doc] = observed_p_wd[doc].add(observed_pdw_series.loc[doc], fill_value=0)

observed_p_wd = observed_p_wd / observed_p_wd.sum(axis=0)
model_loss = np.log(observed_p_wd / predicted_p_wd + (observed_p_wd == 0).astype(int))

tmp = model_loss.T.values.flatten()
loss_series = pd.Series(data=tmp,
index=pd.MultiIndex.from_product(
[list(model_loss.columns), list(model_loss.index)],
names=['doc', 'token'])
)
loss_series = select_nonzeros(loss_series)

s_t = np.zeros(n_td.shape[0])
s_td = np.zeros(n_td.shape)
s_wt = np.zeros(n_wt.shape)

for t, topic in enumerate(tqdm(model.topic_names)):
topical_series = select_nonzeros(n_tdw.loc[topic])
assert sum(topical_series == 0) == 0

product_series = topical_series * loss_series #.loc[topical_series.index]


s_t[t] = product_series.sum() / topical_series.sum()
s_td[t, :] = product_series.sum(level=0) / topical_series.sum(level=0)
s_wt[:, t] = product_series.sum(level=1) / topical_series.sum(level=1)

s_t = pd.DataFrame(data=s_t, index=model.topic_names)
s_td = pd.DataFrame(data=s_td, index=n_td.index, columns=n_td.columns)
s_wt = pd.DataFrame(data=s_wt, index=n_wt.index, columns=n_wt.columns).fillna(0)

return s_t, s_td, s_wt
'''