-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_embedding.py
More file actions
52 lines (45 loc) · 1.75 KB
/
evaluate_embedding.py
File metadata and controls
52 lines (45 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV, KFold, StratifiedKFold
from sklearn.svm import SVC, LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
def evaluate_embedding_with_split(train_emb, train_labels, test_emb, test_labels, eval_mode='svc'):
"""
使用训练集训练分类器,在测试集上评估
Args:
train_emb: 训练集嵌入向量
train_labels: 训练集标签
test_emb: 测试集嵌入向量
test_labels: 测试集标签
eval_mode: 评估模式,可选'svc', 'randomforest', 'linearsvc', 'logistic'
Returns:
tuple: (测试集准确率, 测试集预测结果)
"""
if eval_mode == 'svc':
from sklearn.svm import SVC
clf = SVC(probability=True)
elif eval_mode == 'randomforest':
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(n_estimators=100)
elif eval_mode == 'linearsvc':
from sklearn.svm import LinearSVC
clf = LinearSVC(max_iter=10000)
elif eval_mode == 'logistic':
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(max_iter=10000)
else:
raise ValueError(f"Unknown evaluation mode: {eval_mode}")
clf.fit(train_emb, train_labels)
preds = clf.predict(test_emb)
acc = (preds == test_labels).mean()
return acc, preds