Skip to content

Commit e1494ce

Browse files
committed
feat(rerank): upgrade from manual features to neural network learning
1 parent b83f7ed commit e1494ce

6 files changed

Lines changed: 839 additions & 138 deletions

File tree

spx-algorithm/database/user_feedback/models.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,45 +56,96 @@ class PairwiseTrainingSample:
5656
feedback_id: int # 原始反馈记录ID
5757

5858
def get_feature_vector(self) -> List[float]:
59-
"""提取特征向量用于训练
59+
"""提取神经网络特征向量用于训练
6060
6161
特征包括:
62-
- query向量与better图片向量的余弦相似度
63-
- query向量与worse图片向量的余弦相似度
64-
- better图片向量与worse图片向量的余弦相似度
65-
- 向量间的点积、欧氏距离等交互特征
62+
- 原始向量特征: query_vec, better_vec, worse_vec
63+
- 交互特征: element-wise乘积和差值
64+
- 对比特征: better vs worse直接对比
65+
- 统计特征: 余弦相似度和欧氏距离
6666
"""
6767
import numpy as np
6868

6969
query_vec = np.array(self.query_vector)
7070
better_vec = np.array(self.pic_vector_better)
7171
worse_vec = np.array(self.pic_vector_worse)
7272

73-
return compute_pairwise_features(query_vec, better_vec, worse_vec)
73+
return compute_neural_network_features(query_vec, better_vec, worse_vec)
7474

7575

76-
def compute_pairwise_features(query_vec, better_vec, worse_vec) -> List[float]:
76+
def compute_neural_network_features(query_vec, better_vec, worse_vec) -> List[float]:
7777
"""
78-
统一的pair-wise特征计算函数,供训练和预测阶段共用
78+
统一的神经网络特征计算函数,供训练和预测阶段共用
7979
8080
设计说明:
81-
- 该函数计算query、better、worse三个向量之间的pair-wise特征
81+
- 该函数计算query、better、worse三个向量之间的神经网络特征
8282
- 训练时:better是用户选择的图片,worse是未选择的图片
8383
- 预测时:better是当前候选图片,worse是动态参考向量(候选集合的平均向量)
8484
8585
特征设计理由:
86-
1. 相似度特征:衡量query与两个图片的语义匹配程度
87-
2. 距离特征:衡量向量空间中的几何距离关系
88-
3. 差异特征:直接比较better vs worse的相对优劣
89-
4. 长度特征:向量的模长反映了特征的激活强度
86+
1. 原始向量特征:保留完整信息,让神经网络自动学习
87+
2. 交互特征:捕捉query与候选图片间的特征对应关系
88+
3. 对比特征:直接比较better vs worse的相对优劣
89+
4. 统计特征:传统手工特征作为补充
9090
9191
Args:
92-
query_vec: 查询向量 (numpy array)
93-
better_vec: 更好的图片向量 (numpy array)
94-
worse_vec: 更差的图片向量 (numpy array)
92+
query_vec: 查询向量 (numpy array, 长度d)
93+
better_vec: 更好的图片向量 (numpy array, 长度d)
94+
worse_vec: 更差的图片向量 (numpy array, 长度d)
9595
9696
Returns:
97-
10维特征向量列表
97+
(6d+6)维特征向量列表
98+
"""
99+
import numpy as np
100+
101+
features = []
102+
103+
# 1. 原始向量特征 (3d维)
104+
features.extend(query_vec.tolist()) # d维
105+
features.extend(better_vec.tolist()) # d维
106+
features.extend(worse_vec.tolist()) # d维
107+
108+
# 2. 交互特征 (2d维)
109+
# element-wise乘积 - 捕捉特征对应关系
110+
element_wise_product = query_vec * better_vec
111+
features.extend(element_wise_product.tolist()) # d维
112+
113+
# element-wise差值 - 体现语义gap
114+
element_wise_diff = query_vec - better_vec
115+
features.extend(element_wise_diff.tolist()) # d维
116+
117+
# 3. 对比特征 (d维)
118+
# better vs worse直接对比
119+
better_worse_diff = better_vec - worse_vec
120+
features.extend(better_worse_diff.tolist()) # d维
121+
122+
# 4. 统计特征 (6维) - 传统手工特征作为补充
123+
# 余弦相似度
124+
cosine_sim_query_better = np.dot(query_vec, better_vec) / (np.linalg.norm(query_vec) * np.linalg.norm(better_vec) + 1e-8)
125+
cosine_sim_query_worse = np.dot(query_vec, worse_vec) / (np.linalg.norm(query_vec) * np.linalg.norm(worse_vec) + 1e-8)
126+
cosine_sim_better_worse = np.dot(better_vec, worse_vec) / (np.linalg.norm(better_vec) * np.linalg.norm(worse_vec) + 1e-8)
127+
128+
# 欧氏距离
129+
l2_dist_query_better = np.linalg.norm(query_vec - better_vec)
130+
l2_dist_query_worse = np.linalg.norm(query_vec - worse_vec)
131+
l2_dist_better_worse = np.linalg.norm(better_vec - worse_vec)
132+
133+
features.extend([
134+
cosine_sim_query_better,
135+
cosine_sim_query_worse,
136+
cosine_sim_better_worse,
137+
l2_dist_query_better,
138+
l2_dist_query_worse,
139+
l2_dist_better_worse
140+
])
141+
142+
return features
143+
144+
145+
# 保留原始函数以保持向后兼容(可选)
146+
def compute_pairwise_features(query_vec, better_vec, worse_vec) -> List[float]:
147+
"""
148+
旧版本pair-wise特征计算函数(保留以保持向后兼容)
98149
"""
99150
import numpy as np
100151

spx-algorithm/services/reranking/feature_extractor.py

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from typing import List, Dict, Any, Optional, Tuple
88
from ..image_matching.clip_service import CLIPService
9-
from database.user_feedback.models import UserFeedback, PairwiseTrainingSample, TrainingDataset, compute_pairwise_features
9+
from database.user_feedback.models import UserFeedback, PairwiseTrainingSample, TrainingDataset, compute_neural_network_features
1010

1111
logger = logging.getLogger(__name__)
1212

@@ -152,14 +152,14 @@ def _get_query_vector(self, query_text: str) -> Optional[np.ndarray]:
152152
def extract_ranking_features(self, query_text: str,
153153
candidates: List[Dict[str, Any]]) -> List[List[float]]:
154154
"""
155-
为排序预测提取pair-wise特征,使用动态参考向量方案
155+
为排序预测提取神经网络特征,使用动态参考向量方案
156156
157157
Args:
158158
query_text: 查询文本
159159
candidates: 候选结果列表,每个包含id, vector等字段
160160
161161
Returns:
162-
每个候选结果的pair-wise特征向量列表(10维
162+
每个候选结果的神经网络特征向量列表(6d+6维
163163
"""
164164
try:
165165
query_vector = self._get_query_vector(query_text)
@@ -187,7 +187,9 @@ def extract_ranking_features(self, query_text: str,
187187

188188
if not candidate_vectors:
189189
logger.error("所有候选结果都缺少向量数据")
190-
return [[0.0] * 10] * len(candidates)
190+
# 假设向量维度为512,神经网络特征维度为6*512+6=3078
191+
feature_dim = 6 * len(query_vector) + 6 if query_vector is not None else 3078
192+
return [[0.0] * feature_dim] * len(candidates)
191193

192194
# 计算动态参考向量:候选集合的平均向量
193195
# 这个平均向量代表当前候选集合的"平均质量水平"
@@ -196,57 +198,73 @@ def extract_ranking_features(self, query_text: str,
196198
query_vec = np.array(query_vector)
197199
features_list = []
198200

199-
# 为每个候选图片计算pair-wise特征
201+
# 为每个候选图片计算神经网络特征
200202
for i, candidate in enumerate(candidates):
201203
pic_vector = candidate.get('vector', [])
202204
if len(pic_vector) == 0:
203205
# 对于缺少向量的候选,填充零特征向量
204-
features_list.append([0.0] * 10)
206+
feature_dim = 6 * len(query_vector) + 6
207+
features_list.append([0.0] * feature_dim)
205208
continue
206209

207210
pic_vec = np.array(pic_vector)
208211

209-
# 使用统一的pair-wise特征计算函数
212+
# 使用统一的神经网络特征计算函数
210213
# candidate作为"better",reference_vector作为"worse"
211-
features = compute_pairwise_features(query_vec, pic_vec, reference_vector)
214+
features = compute_neural_network_features(query_vec, pic_vec, reference_vector)
212215
features_list.append(features)
213216

214-
logger.debug(f"成功提取{len(features_list)}个候选结果的pair-wise特征")
217+
logger.debug(f"成功提取{len(features_list)}个候选结果的神经网络特征")
215218
return features_list
216219

217220
except Exception as e:
218221
logger.error(f"排序特征提取失败: {e}")
219-
return [[0.0] * 10] * len(candidates)
222+
# 如果出错,返回合适维度的零向量
223+
try:
224+
feature_dim = 6 * len(query_vector) + 6 if query_vector is not None else 3078
225+
return [[0.0] * feature_dim] * len(candidates)
226+
except:
227+
return [[0.0] * 3078] * len(candidates)
220228

221229
def get_feature_names(self) -> List[str]:
222-
"""获取特征名称列表"""
223-
return [
224-
'query_better_similarity',
225-
'query_worse_similarity',
226-
'better_worse_similarity',
227-
'similarity_difference',
228-
'query_better_distance',
229-
'query_worse_distance',
230-
'distance_difference',
231-
'query_norm',
232-
'better_norm',
233-
'worse_norm'
234-
]
230+
"""获取神经网络特征名称列表"""
231+
# 假设向量维度为512(从CLIP模型配置获取)
232+
d = 512
233+
feature_names = []
234+
235+
# 原始向量特征 (3d维)
236+
for i in range(d):
237+
feature_names.append(f'query_vec_{i}')
238+
for i in range(d):
239+
feature_names.append(f'better_vec_{i}')
240+
for i in range(d):
241+
feature_names.append(f'worse_vec_{i}')
242+
243+
# 交互特征 (2d维)
244+
for i in range(d):
245+
feature_names.append(f'element_wise_product_{i}')
246+
for i in range(d):
247+
feature_names.append(f'element_wise_diff_{i}')
248+
249+
# 对比特征 (d+6维)
250+
for i in range(d):
251+
feature_names.append(f'better_worse_diff_{i}')
252+
253+
# 统计特征 (6维)
254+
feature_names.extend([
255+
'cosine_sim_query_better',
256+
'cosine_sim_query_worse',
257+
'cosine_sim_better_worse',
258+
'l2_dist_query_better',
259+
'l2_dist_query_worse',
260+
'l2_dist_better_worse'
261+
])
262+
263+
return feature_names
235264

236265
def get_ranking_feature_names(self) -> List[str]:
237-
"""获取排序特征名称列表(与训练时保持一致的pair-wise特征)"""
238-
return [
239-
'query_better_similarity', # query与候选图片相似度
240-
'query_worse_similarity', # query与参考向量相似度
241-
'better_worse_similarity', # 候选图片与参考向量相似度
242-
'similarity_difference', # 相似度差异(候选-参考)
243-
'query_better_distance', # query与候选图片距离
244-
'query_worse_distance', # query与参考向量距离
245-
'distance_difference', # 距离差异(参考-候选)
246-
'query_norm', # query向量长度
247-
'better_norm', # 候选图片向量长度
248-
'worse_norm' # 参考向量长度
249-
]
266+
"""获取排序特征名称列表(与训练时保持一致的神经网络特征)"""
267+
return self.get_feature_names()
250268

251269
def clear_cache(self):
252270
"""清空向量缓存"""

spx-algorithm/services/reranking/ltr_model.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
LTR重排序模型:基于LightGBM的pair-wise学习排序
2+
LTR重排序模型:支持LightGBM和神经网络的统一接口
33
"""
44

55
import logging
@@ -11,21 +11,23 @@
1111

1212

1313
class LTRModel:
14-
"""Learning to Rank 模型类"""
14+
"""Learning to Rank 模型类:支持LightGBM和神经网络"""
1515

16-
def __init__(self, model_path: Optional[str] = None):
16+
def __init__(self, model_path: Optional[str] = None, model_type: str = "neural_network"):
1717
"""
1818
初始化LTR模型
1919
2020
Args:
2121
model_path: 模型文件路径
22+
model_type: 模型类型,'lightgbm' 或 'neural_network'
2223
"""
2324
self.model_path = model_path or "models/ltr_model.pkl"
24-
self.trainer = LTRTrainer(self.model_path)
25+
self.model_type = model_type
26+
self.trainer = LTRTrainer(self.model_path, model_type)
2527
self.feature_extractor = None # 需要在初始化时注入
2628
self.is_trained = False
2729

28-
logger.info(f"LTR模型初始化完成,模型路径: {self.model_path}")
30+
logger.info(f"LTR模型初始化完成,模型类型: {model_type}, 模型路径: {self.model_path}")
2931

3032
def set_feature_extractor(self, feature_extractor: LTRFeatureExtractor):
3133
"""设置特征提取器"""
@@ -74,14 +76,14 @@ def predict_ranking_scores(self, query_text: str,
7476
logger.error("特征提取器未设置")
7577
return [0.0] * len(candidates)
7678

77-
# 提取排序特征
79+
# 提取排序特征(神经网络或传统特征)
7880
features = self.feature_extractor.extract_ranking_features(query_text, candidates)
7981

8082
if not features:
8183
logger.warning("特征提取失败,返回默认分数")
8284
return [candidate.get('similarity', 0.0) for candidate in candidates]
8385

84-
# 预测排序分数
86+
# 预测排序分数(自动适配模型类型)
8587
scores = self.trainer.predict_ranking_scores(features)
8688

8789
if not scores:
@@ -168,6 +170,7 @@ def get_model_info(self) -> Dict[str, Any]:
168170
return {
169171
**trainer_info,
170172
'model_path': self.model_path,
173+
'model_type': self.model_type,
171174
'has_feature_extractor': self.feature_extractor is not None,
172175
'is_ready': self.is_trained and self.feature_extractor is not None
173176
}

0 commit comments

Comments
 (0)