forked from Moemu/Muice-Chatbot
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSuperMuice.py
More file actions
180 lines (154 loc) · 6.69 KB
/
SuperMuice.py
File metadata and controls
180 lines (154 loc) · 6.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import sqlite3
import json
from math import sqrt
import jieba
import random
import logging
import time
import re
from snownlp import SnowNLP
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn
logger = logging.getLogger("SuperMuice")
class SuperMuice:
muice: object = None
image_db: object = None
db_path: str = "./data/Muice_Chatbot_Plugin/supermuice.db"
def __init__(self, muice: object, image_db: object, db_path: str):
self.muice = muice
self.image_db = image_db
self.progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
transient=True,
)
self.db_path = db_path
logger.info("初始化 SuperMuice 数据库...")
self.load_data()
logger.info("SuperMuice 初始化完成!")
def load_data(self):
if not os.path.exists(os.path.dirname(self.db_path)):
os.makedirs(os.path.dirname(self.db_path))
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self.cursor = self.conn.cursor()
with self.conn:
self.cursor.execute('''CREATE TABLE IF NOT EXISTS forbidden_words (
id INTEGER PRIMARY KEY AUTOINCREMENT,
answer TEXT,
action TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)'''
)
def generate_random_vector(self, dimensions=128):
v = [random.gauss(0, 1) for _ in range(dimensions)]
magnitude = sqrt(sum(x**2 for x in v))
return [x / magnitude for x in v]
def simhash(self, content):
dimensions = 128
simhash_vector = [0] * dimensions
words = list(jieba.cut(content, cut_all=False))
feature_vectors = {word: self.generate_random_vector() for word in set(words)}
for word in words:
vector = feature_vectors[word]
for i in range(dimensions):
simhash_vector[i] += vector[i]
final_hash = ['1' if value >= 0 else '0' for value in simhash_vector]
return ''.join(final_hash)
def similarity(self, a, b):
count = sum(a[i] == b[i] for i in range(len(a)))
return float(count) / len(a)
def search_forbidden_content(self, content):
hash_value = self.simhash(content)
cursor = self.conn.execute("SELECT id, answer, action FROM forbidden_words")
results = cursor.fetchall()
similar_contents = [
(result[2], self.similarity(hash_value, result[1]))
for result in results
]
if similar_contents:
similar_contents = sorted(similar_contents, key=lambda x: x[1], reverse=True)
return max(similar_contents, key=lambda x: x[1])
else:
return None, None
def insert_content(self, answer, action):
hash_value = self.simhash(answer)
with self.conn:
self.cursor.execute("INSERT INTO forbidden_words (answer, action) VALUES (?,?)", (hash_value, action))
def classify_emotion(self, content: str) -> str:
s = SnowNLP(content)
return s.sentiments
async def process_message(self, answer, user_id: int, group_id: int) -> dict:
task = self.progress.add_task("[bold green][超级沐雪] 正在思考", total=None)
with self.progress:
logger.info(f"回复:{str(answer)}")
if isinstance(answer, list):
answer = random.choice(answer)
# 检查是否是图片消息段
# MessageSegment是dataclass对象,检查type属性
if hasattr(answer, 'type') and answer.type == 'image':
if random.randint(1, 100) <= 50:
if random.randint(0, 1) == 0:
self.muice.refresh()
message = self.muice.ask("(创造一个新话题)", user_id, group_id)
elif random.randint(0, 4) == 0:
message = self.muice.ask_bypass("(分享一下你的一些想法)")
else:
self.progress.remove_task(task)
return None
self.progress.remove_task(task)
return {
"action": "send_msg",
"message": message,
"sender_user_id": user_id,
"group_id": group_id
}
else:
self.progress.remove_task(task)
return None
if (emotion := self.classify_emotion(answer)) > 0.5:
self.progress.remove_task(task)
return None
logger.info(f"情感值:{emotion}")
try:
similar_content = self.search_forbidden_content(answer)
except Exception as e:
logger.error(e)
self.progress.remove_task(task)
return None
if not similar_content or not similar_content[1]:
self.progress.remove_task(task)
return None
delete_message = [
"你什么都没看到,对吧",
"不许看!你什么都没看到!",
"你看不见我,看不见我",
"emmm你就当我没说过那句话吧",
"大家就当作无事发生"
]
if similar_content[1] >= 0.8 and random.randint(1, 100) <= 50:
self.progress.remove_task(task)
return {
"action": "delete_msg"
} if similar_content[2] == "delete_msg" else {
"action": similar_content[0],
"message": random.choice(delete_message),
"sender_user_id": user_id,
"group_id": group_id
}
self.progress.remove_task(task)
return None
def add_forbidden_content(self, *args):
command_args = ' '.join(args)
if not command_args:
return "Usage: !supermuice_forbidden <content> [action]"
try:
if len(command_args.split()) == 1:
content = command_args
action = "delete_msg"
else:
content, action = command_args.split(' ', 1)
except ValueError:
return "Invalid command format."
self.insert_content(content, action)
return "已添加禁言内容。"