Skip to content

Commit 4960068

Browse files
committed
feat(match): add image-to-image matching API with L2 distance metric
1 parent 32f9e47 commit 4960068

6 files changed

Lines changed: 241 additions & 4 deletions

File tree

spx-algorithm/README.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
## 主要特性
1515

16-
-**智能图像搜索**:基于语义理解的图像检索
16+
-**智能图像搜索**:基于语义理解的文本搜索图片
17+
-**图像相似度匹配**:基于视觉特征的图片匹配图片
1718
-**资源管理**:统一的资源添加、查询、删除接口
1819
-**SVG 图片支持**:专门优化的 SVG 图片处理流程
1920
-**批量操作**:支持批量添加和处理资源
@@ -168,7 +169,7 @@ Content-Type: application/json
168169
}
169170
```
170171

171-
#### 搜索资源
172+
#### 搜索资源(文本搜索图片)
172173
```http
173174
POST /v1/resource/search
174175
Content-Type: application/json
@@ -180,6 +181,18 @@ Content-Type: application/json
180181
}
181182
```
182183

184+
#### 匹配资源(图片匹配图片)
185+
```http
186+
POST /v1/resource/match
187+
Content-Type: application/json
188+
189+
{
190+
"svg_content": "<svg><circle cx=\"50\" cy=\"50\" r=\"40\" fill=\"red\" /></svg>",
191+
"top_k": 10,
192+
"threshold": 1.0
193+
}
194+
```
195+
183196
### 用户反馈和模型训练
184197

185198
#### 提交用户反馈
@@ -394,7 +407,8 @@ gunicorn -w 4 -b 0.0.0.0:5000 \
394407
### 对外接口(`/v1/resource/`
395408
- **添加资源**`POST /v1/resource/add` - 添加单个资源
396409
- **批量添加**`POST /v1/resource/batch` - 批量添加资源
397-
- **搜索资源**`POST /v1/resource/search` - 基于语义的资源搜索
410+
- **搜索资源**`POST /v1/resource/search` - 基于语义的文本搜索图片
411+
- **匹配资源**`POST /v1/resource/match` - 基于视觉相似度的图片匹配图片
398412

399413
### 内部接口(`/v1/internal/`
400414
- **系统调试**:向量数据查看、详细搜索调试
@@ -408,8 +422,9 @@ gunicorn -w 4 -b 0.0.0.0:5000 \
408422

409423
### 设计原则
410424
- **统一性**:所有资源相关操作都在 `/v1/resource/` 下,避免接口分散
411-
- **简洁性**对外只暴露核心的添加和搜索功能,隐藏复杂的管理操作
425+
- **简洁性**对外只暴露核心的添加、搜索和匹配功能,隐藏复杂的管理操作
412426
- **职责分离**:业务功能与调试维护功能严格分离
427+
- **算法多样性**:支持文本搜索(语义匹配)和图片匹配(视觉相似度)两种不同的检索模式
413428
- **用户驱动**:通过用户反馈持续优化搜索体验
414429

415430
## 故障排除

spx-algorithm/api/routes/resource_routes.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,70 @@ def search_resources(data: dict):
240240
'details': str(e)
241241
}), 500
242242

243+
244+
@resource_bp.route('/match', methods=['POST'])
245+
@validate_json_request(['svg_content'], {
246+
'svg_content': lambda x: isinstance(x, str) and len(x.strip()) > 0,
247+
'top_k': lambda x: isinstance(x, int) and x > 0,
248+
'threshold': lambda x: isinstance(x, (int, float)) and x >= 0 # 距离阈值无上限
249+
})
250+
@ensure_coordinator_initialized
251+
def match_resources(data: dict):
252+
"""
253+
图片匹配图片(基于向量距离)
254+
255+
接受JSON数据:
256+
{
257+
"svg_content": "<svg>...</svg>", // 输入的SVG图片内容
258+
"top_k": 10, // 可选,返回结果数量,默认10
259+
"threshold": 0.0 // 可选,最大距离阈值,默认0.0(距离越小越相似)
260+
}
261+
262+
注意:与文本搜索不同,此接口使用L2距离度量图片相似度
263+
- threshold表示最大允许距离(0表示不过滤)
264+
- 返回结果按距离排序,距离越小表示越相似
265+
"""
266+
try:
267+
svg_content = data['svg_content'].strip()
268+
top_k = data.get('top_k', 10)
269+
threshold = data.get('threshold', 0.0)
270+
271+
logger.info(f"收到图片匹配请求: top_k={top_k}, threshold={threshold}, svg_length={len(svg_content)}")
272+
273+
# 执行图片匹配
274+
result = coordinator.match_by_image(
275+
svg_content=svg_content,
276+
top_k=top_k,
277+
threshold=threshold
278+
)
279+
280+
if result['success']:
281+
return jsonify({
282+
'success': True,
283+
'top_k': top_k,
284+
'threshold': threshold,
285+
'results_count': result['results_count'],
286+
'results': result['results']
287+
})
288+
else:
289+
return jsonify({
290+
'success': False,
291+
'error': result.get('error', '图片匹配失败'),
292+
'code': 'MATCH_FAILED'
293+
}), 500
294+
295+
except ValueError as e:
296+
logger.error(f"图片匹配参数异常: {e}")
297+
return jsonify({
298+
'error': str(e),
299+
'code': 'INVALID_PARAMETER'
300+
}), 400
301+
302+
except Exception as e:
303+
logger.error(f"图片匹配系统异常: {e}")
304+
return jsonify({
305+
'error': '内部服务器错误',
306+
'code': 'INTERNAL_ERROR',
307+
'details': str(e)
308+
}), 500
309+

spx-algorithm/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def index():
6161
"add_resource": "/v1/resource/add (POST)",
6262
"batch_add_resources": "/v1/resource/batch (POST)",
6363
"search_resources": "/v1/resource/search (POST)",
64+
"match_resources": "/v1/resource/match (POST)",
6465
"submit_feedback": "/v1/feedback/submit (POST)",
6566
"train_ltr_model": "/v1/feedback/train (POST)",
6667
"internal_debug": "/v1/internal/* (仅用于调试和管理)"

spx-algorithm/coordinator/search_coordinator.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,63 @@ def get_rerank_service(self) -> RerankService:
321321
"""获取重排序服务实例(用于API层调用)"""
322322
return self.rerank_service
323323

324+
def match_by_image(self, svg_content: str,
325+
top_k: int = 10,
326+
threshold: float = 0.0,
327+
**kwargs) -> Dict[str, Any]:
328+
"""
329+
执行图片匹配图片:使用SVG内容匹配相似图片(基于L2距离)
330+
331+
Args:
332+
svg_content: 输入的SVG图片内容
333+
top_k: 返回结果数量
334+
threshold: 最大距离阈值(距离越小越相似,0表示不过滤)
335+
**kwargs: 其他参数
336+
337+
Returns:
338+
匹配结果(包含distance和similarity字段)
339+
"""
340+
try:
341+
logger.info(f"开始图片匹配流程: top_k={top_k}, threshold={threshold}")
342+
343+
# 通过图文匹配服务执行图片匹配(跳过重排序)
344+
matching_results = self.image_matching_service.search_by_image(
345+
svg_content, top_k, threshold
346+
)
347+
348+
if not matching_results:
349+
logger.info("图片匹配没有返回结果")
350+
return {
351+
'success': True,
352+
'top_k': top_k,
353+
'threshold': threshold,
354+
'results_count': 0,
355+
'results': [],
356+
'pipeline_stages': ['image_matching']
357+
}
358+
359+
logger.info(f"图片匹配完成,返回 {len(matching_results)} 个结果")
360+
361+
return {
362+
'success': True,
363+
'top_k': top_k,
364+
'threshold': threshold,
365+
'results_count': len(matching_results),
366+
'results': matching_results,
367+
'pipeline_stages': ['image_matching']
368+
}
369+
370+
except Exception as e:
371+
logger.error(f"图片匹配流程异常: {e}")
372+
return {
373+
'success': False,
374+
'error': str(e),
375+
'top_k': top_k,
376+
'threshold': threshold,
377+
'results_count': 0,
378+
'results': []
379+
}
380+
324381
def health_check(self) -> Dict[str, Any]:
325382
"""健康检查"""
326383
if not self.image_matching_service:

spx-algorithm/database/resource_vector/operations.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,57 @@ def search_by_vector(self, query_vector: List[float], limit: int = 10) -> List[D
252252
logger.error(f"向量搜索失败: {e}")
253253
return []
254254

255+
def search_by_vector_distance(self, query_vector: List[float], limit: int = 10) -> List[Dict[str, Any]]:
256+
"""
257+
通过向量距离搜索相似数据(用于图片匹配图片)
258+
259+
Args:
260+
query_vector: 查询向量
261+
limit: 返回结果数量
262+
263+
Returns:
264+
搜索结果列表,按距离排序(距离越小越相似)
265+
"""
266+
try:
267+
# 使用L2距离进行图片相似度计算
268+
search_params = {
269+
"metric_type": "L2", # 使用欧氏距离而不是余弦相似度
270+
"params": {"nprobe": 10} # 可以根据需要调整
271+
}
272+
273+
search_result = self.collection.search(
274+
data=[query_vector],
275+
anns_field="vector",
276+
param=search_params,
277+
limit=limit,
278+
output_fields=["id", "url", "added_at", "updated_at"]
279+
)
280+
281+
results = []
282+
for i, hit in enumerate(search_result[0]):
283+
result = {
284+
'rank': i + 1,
285+
'distance': float(hit.score), # L2距离,越小越相似
286+
'similarity': 1.0 / (1.0 + float(hit.score)), # 转换为相似度分数
287+
'id': hit.entity.get('id'),
288+
'url': hit.entity.get('url'),
289+
'added_at': hit.entity.get('added_at')
290+
}
291+
292+
# 添加更新时间(如果存在且不同于添加时间)
293+
updated_at = hit.entity.get('updated_at')
294+
if updated_at and updated_at != result['added_at']:
295+
result['updated_at'] = updated_at
296+
297+
results.append(result)
298+
299+
logger.info(f"向量距离搜索完成,返回 {len(results)} 个结果")
300+
return results
301+
302+
except Exception as e:
303+
logger.error(f"向量距离搜索失败: {e}")
304+
return []
305+
255306
def get_all_data(self, include_vectors: bool = False, limit: Optional[int] = None,
256307
offset: int = 0) -> List[Dict[str, Any]]:
257308
"""

spx-algorithm/services/image_matching/matching_service.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,52 @@ def search_by_text(self, query_text: str, top_k: int = 10,
235235
logger.error(f"文本搜索异常: {e}")
236236
return []
237237

238+
def search_by_image(self, svg_content: str, top_k: int = 10,
239+
threshold: float = 0.0) -> List[Dict[str, Any]]:
240+
"""
241+
基于图片内容搜索相似图片
242+
243+
Args:
244+
svg_content: SVG图片内容
245+
top_k: 返回结果数量
246+
threshold: 相似度阈值
247+
248+
Returns:
249+
搜索结果列表
250+
"""
251+
try:
252+
if not svg_content.strip():
253+
logger.error("SVG内容不能为空")
254+
return []
255+
256+
logger.info(f"开始图片搜索: top_k={top_k}, threshold={threshold}, svg_length={len(svg_content)}")
257+
258+
# 向量化SVG图片内容
259+
query_vector = self.vector_service.vectorize_image_from_svg_content(svg_content)
260+
if query_vector is None:
261+
logger.error("SVG图片向量化失败")
262+
return []
263+
264+
# 在向量数据库中搜索(使用距离度量)
265+
search_results = self.milvus_ops.search_by_vector_distance(
266+
query_vector.tolist(),
267+
limit=top_k
268+
)
269+
270+
# 应用距离阈值过滤(注意:距离越小越相似,所以是小于等于阈值)
271+
if threshold > 0:
272+
before_count = len(search_results)
273+
# 对于距离度量,threshold表示最大允许距离
274+
search_results = [r for r in search_results if r['distance'] <= threshold]
275+
logger.info(f"距离阈值过滤: {before_count} -> {len(search_results)} 个结果(最大距离: {threshold})")
276+
277+
logger.info(f"图片搜索完成: 找到 {len(search_results)} 个结果")
278+
return search_results
279+
280+
except Exception as e:
281+
logger.error(f"图片搜索异常: {e}")
282+
return []
283+
238284
def delete_image(self, id: int) -> bool:
239285
"""
240286
删除图片

0 commit comments

Comments
 (0)