-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpureembedding.py
More file actions
105 lines (94 loc) · 7.95 KB
/
pureembedding.py
File metadata and controls
105 lines (94 loc) · 7.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np # 引入 numpy 用于向量计算
from numpy.linalg import norm # 用于计算向量范数
import os # 引入 os 用于文件检查
# 纯embedding 余弦相似度匹配
def load_patterns(file_path):
"""
读取词条和Embedding,返回 (原始词条, embedding向量) 的元组列表.
文件格式可以为 "词条名称|向量..." 或 "部分1|部分2|向量..."。
向量前的所有部分被视为一个整体词条名。
"""
patterns_data = []
if not os.path.exists(file_path):
print(f"错误:词典文件未找到: {file_path}")
return patterns_data
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f):
line = line.strip()
if not line:
continue
parts = line.rsplit('|', 1)
if len(parts) == 2:
original_pattern = parts[0].strip()
embedding_str = parts[1].strip()
try:
embedding_vector = np.fromstring(embedding_str, sep=' ', dtype=np.float32)
if original_pattern:
patterns_data.append((original_pattern, embedding_vector))
else:
print(f"警告:文件 {file_path} 第 {line_num + 1} 行解析出的词条名称为空。")
except ValueError:
print(f"警告:无法解析文件 {file_path} 第 {line_num + 1} 行的Embedding: '{embedding_str}'")
else:
print(f"警告:文件 {file_path} 第 {line_num + 1} 行格式错误,无法按 '|' 分离出向量部分。行内容: '{line}'")
except Exception as e:
print(f"读取词典文件 {file_path} 时发生意外错误: {e}")
return patterns_data
def cosine_similarity(vec_a, vec_b):
"""计算两个向量之间的余弦相似度"""
if vec_a is None or vec_b is None or vec_a.size == 0 or vec_b.size == 0:
return -2.0 # 返回一个明确表示错误的值,区别于有效相似度
if vec_a.shape != vec_b.shape:
try:
vec_b = vec_b.reshape(vec_a.shape)
except ValueError:
print(f"错误:无法匹配向量维度 {vec_a.shape} vs {vec_b.shape}")
return -2.0
if vec_a.shape != vec_b.shape:
print(f"错误:无法匹配向量维度 {vec_a.shape} vs {vec_b.shape}")
return -2.0
norm_a = norm(vec_a)
norm_b = norm(vec_b)
if norm_a == 0 or norm_b == 0:
return 0.0
epsilon = 1e-8
similarity = np.dot(vec_a, vec_b) / (norm_a * norm_b + epsilon)
return np.clip(similarity, -1.0, 1.0)
# 移除 find_closest_pattern 函数,因为主逻辑在 main 中实现
if __name__ == "__main__":
dict_file = "attackpattern_embeddings.txt" # 要读取的文件名
top_n_to_show = 100 # 要显示的最高相似度条目数量
# 解析输入向量 (这里仍然使用之前例子中的向量)
input_vector_str = "-0.144820 0.065548 -0.600861 -0.153625 0.089813 -0.109741 -0.112292 -0.091874 0.217631 0.275176 -0.021560 -0.019503 0.721424 -0.295027 -0.202383 -0.240452 -0.333778 -0.148215 0.122652 -0.079071 -0.146889 -0.278756 0.020142 -0.015602 0.018453 0.147108 -0.255867 -0.241694 -0.477927 -0.464876 0.202725 -0.040899 0.060236 0.157702 0.462763 -0.212102 -0.256463 -0.078805 0.151989 0.121381 0.029166 0.095806 -0.243763 -0.070208 0.049676 -0.004198 -0.149807 0.163636 -0.250134 0.203135 -0.181013 0.218106 -0.146606 0.151513 -0.180215 -0.064386 0.060377 0.096040 -0.325038 -0.052464 0.289355 0.218500 -0.306790 0.035398 -0.070002 0.422087 0.324049 0.046197 0.351199 -0.306682 -0.234246 0.174341 -0.389343 0.220956 -0.193955 -0.102264 -0.249018 0.179854 0.130489 -0.050503 -0.084504 0.304664 -0.040066 -0.098755 -0.102910 -0.040308 0.180854 0.039847 0.162105 -0.280922 0.113717 -0.222711 0.267084 0.041831 -0.054989 -0.106423 -0.298047 0.231791 0.457714 1.294187 -0.302475 0.153918 -0.185550 -0.187438 0.068898 -0.156385 0.591874 -0.087013 0.061587 -0.130047 0.041219 0.122747 0.201285 -0.157583 0.291229 0.167601 -0.220790 -0.008287 -0.450951 0.402498 -0.038640 0.023099 0.230809 0.363636 0.054768 -0.150516 -0.222419 0.021246 0.105147 -0.282037 0.165069 -0.154344 -0.190550 0.249799 -0.048917 -0.391530 -0.226325 0.049610 -0.042397 0.044014 -0.021086 -0.194049 0.125340 0.300763 -0.198892 0.630934 0.168318 -0.333051 -0.004321 0.240483 0.144884 -0.209747 0.219995 0.233820 -0.348865 -0.039827 0.271932 0.133562 -0.070429 -0.404245 0.156540 0.089586 -0.178018 0.110618 0.084201 -0.350635 0.095982 -0.037210 0.287156 0.008164 0.134575 -0.018384 -0.225881 -0.352525 -0.115055 -0.120432 0.095733 -0.202494 0.090298 -0.158642 0.014019 -0.201444 -0.178822 0.242450 0.144619 -0.284491 -0.113471 0.111357 0.071273 0.073867 -0.328825 0.061805 0.181553 -0.160747 -0.218722 0.063351 -0.031255 -0.019151 -0.160904 -0.232111 0.007824 0.099044 -0.159014 -0.294597 0.042182 -0.183868 -0.019738 0.258573 -0.148923 -0.264688 -0.267380 0.148241 -0.434820 0.122159 0.195685 -0.218912 -0.086031 0.212410 0.161564 0.000308 0.086573 -0.217004 0.304690 -0.219592 -0.073335 0.141120 -0.242910 0.085586 0.100350 0.067089 0.201230 0.065903 -0.077217 0.003316 -0.179237 0.149166 -0.040958 -0.254475 -0.011164 -0.269732 0.200748 0.044622 -0.260597 -0.011343 -0.175607 -0.103322 -0.449325 0.103831 0.245419 0.106078 -0.332137 -0.356736 0.233364 0.216993 0.119345 0.106316 -0.054492 0.160198 0.016703 -0.134365 0.215763 -0.384037 -0.051453 -0.072110 -0.085095 -0.007655 -0.065900 0.269769 -0.148769 0.311556 -0.134349 0.288912 0.603764 0.050451 0.226609 0.039185 0.321784 -0.187419 0.121260 0.280291 -0.076132 -0.032083 0.148014 0.058734 0.167625 0.109758 -0.442004 -0.038028 0.003863 0.073480 0.044744 0.136258 0.019985 -0.072736 0.408221 -0.522739 -0.128427 -0.041622 0.062005 0.040775 -0.783812 -0.176879 0.127332 0.217924 0.168247 -0.030973 -0.070785 0.093790 0.497498 -0.096935 0.116922 -0.453464 -0.143181 -0.137324 -0.211777 -0.321468 -0.106098 0.229011 -0.160774 -0.564785 0.268416 -0.232299 -0.002226 0.677005 0.436654 -0.057344 -0.143667 0.453290 0.218914 0.001706 0.071355 -0.286742 -0.399693 -0.338367 0.471329 -0.311500 0.153438 -0.183239 0.085723 -0.313507 0.083238 0.117130 -0.147364 -0.171759 0.304897 0.044022 0.363251 0.616331 -0.009404 0.284844 0.381188 -0.050236 0.133500 0.234174 -0.171405 0.305686 0.163267 0.246391 -0.219697 -0.126274 0.399859 0.154139 -0.069083 0.170848 -0.598348 0.065538 -0.111276 0.229887 0.533292 -0.053299 -0.013589 -0.474127 0.010863 0.124984 0.049252 0.026583 0.236116 -0.186822 0.352161 0.272370 0.033416 0.029649 -0.119141 -0.077214"
try:
input_embedding_vector = np.fromstring(input_vector_str, sep=' ', dtype=np.float32)
except ValueError:
print("错误:无法解析输入向量字符串。")
input_embedding_vector = None
# 加载词典数据
patterns_data = load_patterns(dict_file)
if input_embedding_vector is not None and patterns_data:
# 计算所有相似度
similarity_results = []
for term, vector in patterns_data:
similarity = cosine_similarity(input_embedding_vector, vector)
if similarity > -1.1: # 过滤掉计算错误的情况 (-2.0)
similarity_results.append((term, similarity))
# 按相似度降序排序
similarity_results.sort(key=lambda item: item[1], reverse=True)
# 打印 Top N 结果
print(f"\n与输入向量最相似的前 {top_n_to_show} 个词条:")
count = 0
for term, similarity in similarity_results:
if count < top_n_to_show:
print(f" - '{term}': {similarity:.6f}")
count += 1
else:
break # 只显示前 N 个
if count == 0:
print(" (未找到任何有效相似度结果)")
elif not patterns_data:
print(f"错误:未能从 {dict_file} 加载任何数据。")
else:
print("错误:输入向量无效。")