imgsearcher/app/api/baidu_image_search.py
eust-w 487f3af948 feat: modifications based on team suggestions
- Add MongoDB type manager implementation (TypeManagerMongo)
- Update environment variables configuration to support MongoDB connection
- Add chat functionality
- Integrate Azure OpenAI API support
- Update dependencies and startup script
2025-04-11 12:00:32 +08:00

403 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import base64
import json
import os
import requests
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
class BaiduImageSearch:
"""百度相似图片搜索API封装类"""
def __init__(self):
"""初始化获取API密钥"""
self.api_key = os.getenv('BAIDU_API_KEY')
self.secret_key = os.getenv('BAIDU_SECRET_KEY')
self.access_token = None
self.get_access_token()
def get_access_token(self):
"""获取百度API的access_token"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {
"grant_type": "client_credentials",
"client_id": self.api_key,
"client_secret": self.secret_key
}
response = requests.post(url, params=params)
if response.status_code == 200:
result = response.json()
self.access_token = result.get('access_token')
return self.access_token
else:
raise Exception(f"获取access_token失败: {response.text}")
def add_image(self, image_path=None, image_base64=None, url=None, brief=None, tags=None):
"""
添加图片到图库
Args:
image_path: 本地图片路径
image_base64: 图片的base64编码
url: 图片URL
brief: 图片摘要信息最长256B{"name":"图片名称", "id":"123"}
tags: 分类信息1 - 65535范围内的整数 tag间以逗号分隔最多2个tag2个tag无层级关系检索时支持逻辑运算。样例"100,11" ;检索时可圈定分类维度进行检索
Returns:
dict: API返回的结果
"""
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/add?access_token={self.access_token}"
params = {}
# 设置brief信息必选
if brief:
if isinstance(brief, dict):
brief = json.dumps(brief, ensure_ascii=False)
params['brief'] = brief
else:
raise ValueError("brief参数是必须的")
# 设置图片信息(三选一)
if image_path:
with open(image_path, 'rb') as f:
image = base64.b64encode(f.read())
params['image'] = image
elif image_base64:
params['image'] = image_base64
elif url:
params['url'] = url
else:
raise ValueError("必须提供image_path、image_base64或url其中之一")
# 设置可选的tags
if tags:
params['tags'] = tags
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(request_url, data=params, headers=headers)
print(f"添加图片请求: {params}")
print(f"添加图片响应: {response.text}")
if response.status_code == 200:
return response.json()
else:
raise Exception(f"添加图片失败: {response.text}")
def search_image(self, image_path=None, image_base64=None, url=None, tags=None, tag_logic=None, type_filter=None, pn=0, rn=300):
"""
检索相似图片
Args:
image_path: 本地图片路径
image_base64: 图片的base64编码
url: 图片URL
tags: 分类信息过滤,如:"1,2"
tag_logic: 标签逻辑0表示逻辑and1表示逻辑or
type_filter: 类型过滤,用于按图片类型进行筛选
pn: 分页起始位置默认0
rn: 返回结果数量默认300最大1000
Returns:
dict: API返回的结果
"""
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/search?access_token={self.access_token}"
params = {}
# 设置图片信息(三选一)
if image_path:
with open(image_path, 'rb') as f:
image = base64.b64encode(f.read())
params['image'] = image
elif image_base64:
params['image'] = image_base64
elif url:
params['url'] = url
else:
raise ValueError("必须提供image_path、image_base64或url其中之一")
# 设置可选参数
if tags:
params['tags'] = tags
if tag_logic is not None:
params['tag_logic'] = tag_logic
if pn is not None:
params['pn'] = pn
if rn is not None:
params['rn'] = rn
# 处理搜索结果的类型过滤
self.type_filter = type_filter
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
# print(f"搜索图片请求: {params}")
response = requests.post(request_url, data=params, headers=headers)
# print(f"搜索图片响应: {response.text}")
if response.status_code == 200:
result = response.json()
# 如果有类型过滤,对结果进行过滤
if self.type_filter and 'result' in result and result['result']:
filtered_results = []
for item in result['result']:
try:
brief = item.get('brief', '{"type":""}')
# 确保brief是字符串
if not isinstance(brief, str):
brief = str(brief)
try:
brief_info = json.loads(brief)
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}, brief: {brief}")
continue
if brief_info.get('type', '') == self.type_filter:
filtered_results.append(item)
except Exception as e:
print(f"处理搜索结果项出错: {e}")
continue
result['result'] = filtered_results
result['result_num'] = len(filtered_results)
# 分析结果,确定最可信的类型
if 'result' in result and result['result']:
result['most_reliable_types'] = self._determine_most_reliable_types(result['result'])
return result
else:
raise Exception(f"检索图片失败: {response.text}")
def delete_image(self, image_path=None, image_base64=None, url=None, cont_sign=None):
"""
删除图库中的图片
Args:
image_path: 本地图片路径
image_base64: 图片的base64编码
url: 图片URL
cont_sign: 图片签名,支持批量删除,格式如:"932301884,1068006219;316336521,553141152"
Returns:
dict: API返回的结果
"""
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/delete?access_token={self.access_token}"
params = {}
# 设置图片信息(三选一)
if image_path:
with open(image_path, 'rb') as f:
image = base64.b64encode(f.read())
params['image'] = image
elif image_base64:
params['image'] = image_base64
elif url:
params['url'] = url
elif cont_sign:
params['cont_sign'] = cont_sign
else:
raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一")
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(request_url, data=params, headers=headers)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"删除图片失败: {response.text}")
def _determine_most_reliable_types(self, results):
"""
使用多种方法确定最可信的类型
只使用相似度前十的结果进行计算
Args:
results: 搜索结果列表
Returns:
dict: 包含不同方法确定的最可信类型
"""
# 防止空结果
if not results:
return {
"method1": "",
"method2": "",
"method3": "",
"method4": "",
"method5": ""
}
# 只使用相似度前十的结果
# 按相似度降序排序并取前10个
sorted_results = sorted(results, key=lambda x: x.get('score', 0), reverse=True)
top_results = sorted_results[:min(10, len(sorted_results))]
print(f"使用前{len(top_results)}个结果进行最可信类型分析")
# 提取前十个结果的类型和分数
type_scores = []
for item in top_results:
try:
score = item.get('score', 0)
brief = item.get('brief', '{"type":""}')
# 确保brief是字符串
if not isinstance(brief, str):
brief = str(brief)
try:
brief_info = json.loads(brief)
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}, brief: {brief}")
continue
item_type = brief_info.get('type', '')
if item_type: # 只考虑有类型的结果
type_scores.append((item_type, score))
except Exception as e:
print(f"处理搜索结果项出错: {e}, item: {item}")
continue
if not type_scores:
return {
"method1": "",
"method2": "",
"method3": "",
"method4": "",
"method5": ""
}
# 方法一:基于最高相似度分数选择类型
method1_type = ""
max_score = -1
for t, s in type_scores:
if s > max_score:
max_score = s
method1_type = t
# 方法二:基于加权投票选择类型
type_weight = {}
for t, s in type_scores:
if t not in type_weight:
type_weight[t] = 0
type_weight[t] += s
try:
method2_type = max(type_weight.items(), key=lambda x: x[1])[0] if type_weight else ""
except ValueError:
method2_type = ""
# 方法三:基于阈值过滤和加权投票选择类型
# 只考虑相似度大于0.6的结果
threshold = 0.6
filtered_type_weight = {}
for t, s in type_scores:
if s >= threshold:
if t not in filtered_type_weight:
filtered_type_weight[t] = 0
filtered_type_weight[t] += s
try:
method3_type = max(filtered_type_weight.items(), key=lambda x: x[1])[0] if filtered_type_weight else ""
except ValueError:
method3_type = ""
# 方法四:基于多数投票选择类型
type_count = {}
for t, _ in type_scores:
if t not in type_count:
type_count[t] = 0
type_count[t] += 1
try:
method4_type = max(type_count.items(), key=lambda x: x[1])[0] if type_count else ""
except ValueError:
method4_type = ""
# 方法五:基于加权多数投票选择类型
# 权重 = 计数 * 平均分数
type_weighted_count = {}
type_total_score = {}
# 初始化类型计数和总分数
for t, s in type_scores:
if t not in type_total_score:
type_total_score[t] = 0
type_total_score[t] += s
# 计算加权分数
for t in type_count.keys():
if type_count[t] > 0:
avg_score = type_total_score.get(t, 0) / type_count[t]
type_weighted_count[t] = type_count[t] * avg_score
try:
method5_type = max(type_weighted_count.items(), key=lambda x: x[1])[0] if type_weighted_count else ""
except ValueError:
method5_type = ""
return {
"method1": method1_type, # 最高分数法
"method2": method2_type, # 加权投票法
"method3": method3_type, # 阈值过滤加权投票法
"method4": method4_type, # 多数投票法
"method5": method5_type # 加权多数投票法
}
def update_image(self, image_path=None, image_base64=None, url=None, cont_sign=None, brief=None, tags=None):
"""
更新图库中图片的摘要和分类信息
Args:
image_path: 本地图片路径
image_base64: 图片的base64编码
url: 图片URL
cont_sign: 图片签名
brief: 更新的摘要信息最长256B
tags: 更新的分类信息最多2个tag"1,2"
Returns:
dict: API返回的结果
"""
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/update?access_token={self.access_token}"
params = {}
# 设置图片信息(三选一)
if image_path:
with open(image_path, 'rb') as f:
image = base64.b64encode(f.read())
params['image'] = image
elif image_base64:
params['image'] = image_base64
elif url:
params['url'] = url
elif cont_sign:
params['cont_sign'] = cont_sign
else:
raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一")
# 设置更新信息
if brief:
if isinstance(brief, dict):
brief = json.dumps(brief, ensure_ascii=False)
params['brief'] = brief
if tags:
params['tags'] = tags
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(request_url, data=params, headers=headers)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"更新图片失败: {response.text}")