- Add MongoDB type manager implementation (TypeManagerMongo) - Update environment variables configuration to support MongoDB connection - Add chat functionality - Integrate Azure OpenAI API support - Update dependencies and startup script
403 lines
15 KiB
Python
403 lines
15 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import base64
|
||
import json
|
||
import os
|
||
import requests
|
||
from dotenv import load_dotenv
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
class BaiduImageSearch:
|
||
"""百度相似图片搜索API封装类"""
|
||
|
||
def __init__(self):
|
||
"""初始化,获取API密钥"""
|
||
self.api_key = os.getenv('BAIDU_API_KEY')
|
||
self.secret_key = os.getenv('BAIDU_SECRET_KEY')
|
||
self.access_token = None
|
||
self.get_access_token()
|
||
|
||
def get_access_token(self):
|
||
"""获取百度API的access_token"""
|
||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||
params = {
|
||
"grant_type": "client_credentials",
|
||
"client_id": self.api_key,
|
||
"client_secret": self.secret_key
|
||
}
|
||
response = requests.post(url, params=params)
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
self.access_token = result.get('access_token')
|
||
return self.access_token
|
||
else:
|
||
raise Exception(f"获取access_token失败: {response.text}")
|
||
|
||
def add_image(self, image_path=None, image_base64=None, url=None, brief=None, tags=None):
|
||
"""
|
||
添加图片到图库
|
||
|
||
Args:
|
||
image_path: 本地图片路径
|
||
image_base64: 图片的base64编码
|
||
url: 图片URL
|
||
brief: 图片摘要信息,最长256B,如:{"name":"图片名称", "id":"123"}
|
||
tags: 分类信息,1 - 65535范围内的整数 tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算。样例:"100,11" ;检索时可圈定分类维度进行检索
|
||
|
||
Returns:
|
||
dict: API返回的结果
|
||
"""
|
||
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/add?access_token={self.access_token}"
|
||
|
||
params = {}
|
||
|
||
# 设置brief信息(必选)
|
||
if brief:
|
||
if isinstance(brief, dict):
|
||
brief = json.dumps(brief, ensure_ascii=False)
|
||
params['brief'] = brief
|
||
else:
|
||
raise ValueError("brief参数是必须的")
|
||
|
||
# 设置图片信息(三选一)
|
||
if image_path:
|
||
with open(image_path, 'rb') as f:
|
||
image = base64.b64encode(f.read())
|
||
params['image'] = image
|
||
elif image_base64:
|
||
params['image'] = image_base64
|
||
elif url:
|
||
params['url'] = url
|
||
else:
|
||
raise ValueError("必须提供image_path、image_base64或url其中之一")
|
||
|
||
# 设置可选的tags
|
||
if tags:
|
||
params['tags'] = tags
|
||
|
||
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
||
response = requests.post(request_url, data=params, headers=headers)
|
||
print(f"添加图片请求: {params}")
|
||
print(f"添加图片响应: {response.text}")
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
else:
|
||
raise Exception(f"添加图片失败: {response.text}")
|
||
|
||
def search_image(self, image_path=None, image_base64=None, url=None, tags=None, tag_logic=None, type_filter=None, pn=0, rn=300):
|
||
"""
|
||
检索相似图片
|
||
|
||
Args:
|
||
image_path: 本地图片路径
|
||
image_base64: 图片的base64编码
|
||
url: 图片URL
|
||
tags: 分类信息过滤,如:"1,2"
|
||
tag_logic: 标签逻辑,0表示逻辑and,1表示逻辑or
|
||
type_filter: 类型过滤,用于按图片类型进行筛选
|
||
pn: 分页起始位置,默认0
|
||
rn: 返回结果数量,默认300,最大1000
|
||
|
||
Returns:
|
||
dict: API返回的结果
|
||
"""
|
||
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/search?access_token={self.access_token}"
|
||
|
||
params = {}
|
||
|
||
# 设置图片信息(三选一)
|
||
if image_path:
|
||
with open(image_path, 'rb') as f:
|
||
image = base64.b64encode(f.read())
|
||
params['image'] = image
|
||
elif image_base64:
|
||
params['image'] = image_base64
|
||
elif url:
|
||
params['url'] = url
|
||
else:
|
||
raise ValueError("必须提供image_path、image_base64或url其中之一")
|
||
|
||
# 设置可选参数
|
||
if tags:
|
||
params['tags'] = tags
|
||
if tag_logic is not None:
|
||
params['tag_logic'] = tag_logic
|
||
if pn is not None:
|
||
params['pn'] = pn
|
||
if rn is not None:
|
||
params['rn'] = rn
|
||
|
||
# 处理搜索结果的类型过滤
|
||
self.type_filter = type_filter
|
||
|
||
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
||
# print(f"搜索图片请求: {params}")
|
||
response = requests.post(request_url, data=params, headers=headers)
|
||
# print(f"搜索图片响应: {response.text}")
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
|
||
# 如果有类型过滤,对结果进行过滤
|
||
if self.type_filter and 'result' in result and result['result']:
|
||
filtered_results = []
|
||
for item in result['result']:
|
||
try:
|
||
brief = item.get('brief', '{"type":""}')
|
||
|
||
# 确保brief是字符串
|
||
if not isinstance(brief, str):
|
||
brief = str(brief)
|
||
|
||
try:
|
||
brief_info = json.loads(brief)
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON解析错误: {e}, brief: {brief}")
|
||
continue
|
||
|
||
if brief_info.get('type', '') == self.type_filter:
|
||
filtered_results.append(item)
|
||
except Exception as e:
|
||
print(f"处理搜索结果项出错: {e}")
|
||
continue
|
||
result['result'] = filtered_results
|
||
result['result_num'] = len(filtered_results)
|
||
|
||
# 分析结果,确定最可信的类型
|
||
if 'result' in result and result['result']:
|
||
result['most_reliable_types'] = self._determine_most_reliable_types(result['result'])
|
||
|
||
return result
|
||
else:
|
||
raise Exception(f"检索图片失败: {response.text}")
|
||
|
||
def delete_image(self, image_path=None, image_base64=None, url=None, cont_sign=None):
|
||
"""
|
||
删除图库中的图片
|
||
|
||
Args:
|
||
image_path: 本地图片路径
|
||
image_base64: 图片的base64编码
|
||
url: 图片URL
|
||
cont_sign: 图片签名,支持批量删除,格式如:"932301884,1068006219;316336521,553141152"
|
||
|
||
Returns:
|
||
dict: API返回的结果
|
||
"""
|
||
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/delete?access_token={self.access_token}"
|
||
|
||
params = {}
|
||
|
||
# 设置图片信息(三选一)
|
||
if image_path:
|
||
with open(image_path, 'rb') as f:
|
||
image = base64.b64encode(f.read())
|
||
params['image'] = image
|
||
elif image_base64:
|
||
params['image'] = image_base64
|
||
elif url:
|
||
params['url'] = url
|
||
elif cont_sign:
|
||
params['cont_sign'] = cont_sign
|
||
else:
|
||
raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一")
|
||
|
||
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
||
response = requests.post(request_url, data=params, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
else:
|
||
raise Exception(f"删除图片失败: {response.text}")
|
||
|
||
def _determine_most_reliable_types(self, results):
|
||
"""
|
||
使用多种方法确定最可信的类型
|
||
只使用相似度前十的结果进行计算
|
||
|
||
Args:
|
||
results: 搜索结果列表
|
||
|
||
Returns:
|
||
dict: 包含不同方法确定的最可信类型
|
||
"""
|
||
# 防止空结果
|
||
if not results:
|
||
return {
|
||
"method1": "",
|
||
"method2": "",
|
||
"method3": "",
|
||
"method4": "",
|
||
"method5": ""
|
||
}
|
||
|
||
# 只使用相似度前十的结果
|
||
# 按相似度降序排序并取前10个
|
||
sorted_results = sorted(results, key=lambda x: x.get('score', 0), reverse=True)
|
||
top_results = sorted_results[:min(10, len(sorted_results))]
|
||
|
||
print(f"使用前{len(top_results)}个结果进行最可信类型分析")
|
||
# 提取前十个结果的类型和分数
|
||
type_scores = []
|
||
for item in top_results:
|
||
try:
|
||
score = item.get('score', 0)
|
||
brief = item.get('brief', '{"type":""}')
|
||
|
||
# 确保brief是字符串
|
||
if not isinstance(brief, str):
|
||
brief = str(brief)
|
||
|
||
try:
|
||
brief_info = json.loads(brief)
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON解析错误: {e}, brief: {brief}")
|
||
continue
|
||
|
||
item_type = brief_info.get('type', '')
|
||
if item_type: # 只考虑有类型的结果
|
||
type_scores.append((item_type, score))
|
||
except Exception as e:
|
||
print(f"处理搜索结果项出错: {e}, item: {item}")
|
||
continue
|
||
|
||
if not type_scores:
|
||
return {
|
||
"method1": "",
|
||
"method2": "",
|
||
"method3": "",
|
||
"method4": "",
|
||
"method5": ""
|
||
}
|
||
|
||
# 方法一:基于最高相似度分数选择类型
|
||
method1_type = ""
|
||
max_score = -1
|
||
for t, s in type_scores:
|
||
if s > max_score:
|
||
max_score = s
|
||
method1_type = t
|
||
|
||
# 方法二:基于加权投票选择类型
|
||
type_weight = {}
|
||
for t, s in type_scores:
|
||
if t not in type_weight:
|
||
type_weight[t] = 0
|
||
type_weight[t] += s
|
||
|
||
try:
|
||
method2_type = max(type_weight.items(), key=lambda x: x[1])[0] if type_weight else ""
|
||
except ValueError:
|
||
method2_type = ""
|
||
|
||
# 方法三:基于阈值过滤和加权投票选择类型
|
||
# 只考虑相似度大于0.6的结果
|
||
threshold = 0.6
|
||
filtered_type_weight = {}
|
||
for t, s in type_scores:
|
||
if s >= threshold:
|
||
if t not in filtered_type_weight:
|
||
filtered_type_weight[t] = 0
|
||
filtered_type_weight[t] += s
|
||
|
||
try:
|
||
method3_type = max(filtered_type_weight.items(), key=lambda x: x[1])[0] if filtered_type_weight else ""
|
||
except ValueError:
|
||
method3_type = ""
|
||
|
||
# 方法四:基于多数投票选择类型
|
||
type_count = {}
|
||
for t, _ in type_scores:
|
||
if t not in type_count:
|
||
type_count[t] = 0
|
||
type_count[t] += 1
|
||
|
||
try:
|
||
method4_type = max(type_count.items(), key=lambda x: x[1])[0] if type_count else ""
|
||
except ValueError:
|
||
method4_type = ""
|
||
|
||
# 方法五:基于加权多数投票选择类型
|
||
# 权重 = 计数 * 平均分数
|
||
type_weighted_count = {}
|
||
type_total_score = {}
|
||
|
||
# 初始化类型计数和总分数
|
||
for t, s in type_scores:
|
||
if t not in type_total_score:
|
||
type_total_score[t] = 0
|
||
type_total_score[t] += s
|
||
|
||
# 计算加权分数
|
||
for t in type_count.keys():
|
||
if type_count[t] > 0:
|
||
avg_score = type_total_score.get(t, 0) / type_count[t]
|
||
type_weighted_count[t] = type_count[t] * avg_score
|
||
|
||
try:
|
||
method5_type = max(type_weighted_count.items(), key=lambda x: x[1])[0] if type_weighted_count else ""
|
||
except ValueError:
|
||
method5_type = ""
|
||
|
||
return {
|
||
"method1": method1_type, # 最高分数法
|
||
"method2": method2_type, # 加权投票法
|
||
"method3": method3_type, # 阈值过滤加权投票法
|
||
"method4": method4_type, # 多数投票法
|
||
"method5": method5_type # 加权多数投票法
|
||
}
|
||
|
||
def update_image(self, image_path=None, image_base64=None, url=None, cont_sign=None, brief=None, tags=None):
|
||
"""
|
||
更新图库中图片的摘要和分类信息
|
||
|
||
Args:
|
||
image_path: 本地图片路径
|
||
image_base64: 图片的base64编码
|
||
url: 图片URL
|
||
cont_sign: 图片签名
|
||
brief: 更新的摘要信息,最长256B
|
||
tags: 更新的分类信息,最多2个tag,如:"1,2"
|
||
|
||
Returns:
|
||
dict: API返回的结果
|
||
"""
|
||
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/update?access_token={self.access_token}"
|
||
|
||
params = {}
|
||
|
||
# 设置图片信息(三选一)
|
||
if image_path:
|
||
with open(image_path, 'rb') as f:
|
||
image = base64.b64encode(f.read())
|
||
params['image'] = image
|
||
elif image_base64:
|
||
params['image'] = image_base64
|
||
elif url:
|
||
params['url'] = url
|
||
elif cont_sign:
|
||
params['cont_sign'] = cont_sign
|
||
else:
|
||
raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一")
|
||
|
||
# 设置更新信息
|
||
if brief:
|
||
if isinstance(brief, dict):
|
||
brief = json.dumps(brief, ensure_ascii=False)
|
||
params['brief'] = brief
|
||
|
||
if tags:
|
||
params['tags'] = tags
|
||
|
||
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
||
response = requests.post(request_url, data=params, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
else:
|
||
raise Exception(f"更新图片失败: {response.text}")
|