#!/usr/bin/env python # -*- coding: utf-8 -*- import base64 import json import os import requests from dotenv import load_dotenv # 加载环境变量 load_dotenv() class BaiduImageSearch: """百度相似图片搜索API封装类""" def __init__(self): """初始化,获取API密钥""" self.api_key = os.getenv('BAIDU_API_KEY') self.secret_key = os.getenv('BAIDU_SECRET_KEY') self.access_token = None self.get_access_token() def get_access_token(self): """获取百度API的access_token""" url = "https://aip.baidubce.com/oauth/2.0/token" params = { "grant_type": "client_credentials", "client_id": self.api_key, "client_secret": self.secret_key } response = requests.post(url, params=params) if response.status_code == 200: result = response.json() self.access_token = result.get('access_token') return self.access_token else: raise Exception(f"获取access_token失败: {response.text}") def add_image(self, image_path=None, image_base64=None, url=None, brief=None, tags=None): """ 添加图片到图库 Args: image_path: 本地图片路径 image_base64: 图片的base64编码 url: 图片URL brief: 图片摘要信息,最长256B,如:{"name":"图片名称", "id":"123"} tags: 分类信息,1 - 65535范围内的整数 tag间以逗号分隔,最多2个tag,2个tag无层级关系,检索时支持逻辑运算。样例:"100,11" ;检索时可圈定分类维度进行检索 Returns: dict: API返回的结果 """ request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/add?access_token={self.access_token}" params = {} # 设置brief信息(必选) if brief: if isinstance(brief, dict): brief = json.dumps(brief, ensure_ascii=False) params['brief'] = brief else: raise ValueError("brief参数是必须的") # 设置图片信息(三选一) if image_path: with open(image_path, 'rb') as f: image = base64.b64encode(f.read()) params['image'] = image elif image_base64: params['image'] = image_base64 elif url: params['url'] = url else: raise ValueError("必须提供image_path、image_base64或url其中之一") # 设置可选的tags if tags: params['tags'] = tags headers = {'Content-Type': 'application/x-www-form-urlencoded'} response = requests.post(request_url, data=params, headers=headers) print(f"添加图片请求: {params}") print(f"添加图片响应: {response.text}") if response.status_code == 200: return response.json() else: raise Exception(f"添加图片失败: {response.text}") def search_image(self, image_path=None, image_base64=None, url=None, tags=None, tag_logic=None, type_filter=None, pn=0, rn=300): """ 检索相似图片 Args: image_path: 本地图片路径 image_base64: 图片的base64编码 url: 图片URL tags: 分类信息过滤,如:"1,2" tag_logic: 标签逻辑,0表示逻辑and,1表示逻辑or type_filter: 类型过滤,用于按图片类型进行筛选 pn: 分页起始位置,默认0 rn: 返回结果数量,默认300,最大1000 Returns: dict: API返回的结果 """ request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/search?access_token={self.access_token}" params = {} # 设置图片信息(三选一) if image_path: with open(image_path, 'rb') as f: image = base64.b64encode(f.read()) params['image'] = image elif image_base64: params['image'] = image_base64 elif url: params['url'] = url else: raise ValueError("必须提供image_path、image_base64或url其中之一") # 设置可选参数 if tags: params['tags'] = tags if tag_logic is not None: params['tag_logic'] = tag_logic if pn is not None: params['pn'] = pn if rn is not None: params['rn'] = rn # 处理搜索结果的类型过滤 self.type_filter = type_filter headers = {'Content-Type': 'application/x-www-form-urlencoded'} # print(f"搜索图片请求: {params}") response = requests.post(request_url, data=params, headers=headers) # print(f"搜索图片响应: {response.text}") if response.status_code == 200: result = response.json() # 如果有类型过滤,对结果进行过滤 if self.type_filter and 'result' in result and result['result']: filtered_results = [] for item in result['result']: try: brief = item.get('brief', '{"type":""}') # 确保brief是字符串 if not isinstance(brief, str): brief = str(brief) try: brief_info = json.loads(brief) except json.JSONDecodeError as e: print(f"JSON解析错误: {e}, brief: {brief}") continue if brief_info.get('type', '') == self.type_filter: filtered_results.append(item) except Exception as e: print(f"处理搜索结果项出错: {e}") continue result['result'] = filtered_results result['result_num'] = len(filtered_results) # 分析结果,确定最可信的类型 if 'result' in result and result['result']: result['most_reliable_types'] = self._determine_most_reliable_types(result['result']) return result else: raise Exception(f"检索图片失败: {response.text}") def delete_image(self, image_path=None, image_base64=None, url=None, cont_sign=None): """ 删除图库中的图片 Args: image_path: 本地图片路径 image_base64: 图片的base64编码 url: 图片URL cont_sign: 图片签名,支持批量删除,格式如:"932301884,1068006219;316336521,553141152" Returns: dict: API返回的结果 """ request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/delete?access_token={self.access_token}" params = {} # 设置图片信息(三选一) if image_path: with open(image_path, 'rb') as f: image = base64.b64encode(f.read()) params['image'] = image elif image_base64: params['image'] = image_base64 elif url: params['url'] = url elif cont_sign: params['cont_sign'] = cont_sign else: raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一") headers = {'Content-Type': 'application/x-www-form-urlencoded'} response = requests.post(request_url, data=params, headers=headers) if response.status_code == 200: return response.json() else: raise Exception(f"删除图片失败: {response.text}") def _determine_most_reliable_types(self, results): """ 使用多种方法确定最可信的类型 只使用相似度前十的结果进行计算 Args: results: 搜索结果列表 Returns: dict: 包含不同方法确定的最可信类型 """ # 防止空结果 if not results: return { "method1": "", "method2": "", "method3": "", "method4": "", "method5": "" } # 只使用相似度前十的结果 # 按相似度降序排序并取前10个 sorted_results = sorted(results, key=lambda x: x.get('score', 0), reverse=True) top_results = sorted_results[:min(10, len(sorted_results))] print(f"使用前{len(top_results)}个结果进行最可信类型分析") # 提取前十个结果的类型和分数 type_scores = [] for item in top_results: try: score = item.get('score', 0) brief = item.get('brief', '{"type":""}') # 确保brief是字符串 if not isinstance(brief, str): brief = str(brief) try: brief_info = json.loads(brief) except json.JSONDecodeError as e: print(f"JSON解析错误: {e}, brief: {brief}") continue item_type = brief_info.get('type', '') if item_type: # 只考虑有类型的结果 type_scores.append((item_type, score)) except Exception as e: print(f"处理搜索结果项出错: {e}, item: {item}") continue if not type_scores: return { "method1": "", "method2": "", "method3": "", "method4": "", "method5": "" } # 方法一:基于最高相似度分数选择类型 method1_type = "" max_score = -1 for t, s in type_scores: if s > max_score: max_score = s method1_type = t # 方法二:基于加权投票选择类型 type_weight = {} for t, s in type_scores: if t not in type_weight: type_weight[t] = 0 type_weight[t] += s try: method2_type = max(type_weight.items(), key=lambda x: x[1])[0] if type_weight else "" except ValueError: method2_type = "" # 方法三:基于阈值过滤和加权投票选择类型 # 只考虑相似度大于0.6的结果 threshold = 0.6 filtered_type_weight = {} for t, s in type_scores: if s >= threshold: if t not in filtered_type_weight: filtered_type_weight[t] = 0 filtered_type_weight[t] += s try: method3_type = max(filtered_type_weight.items(), key=lambda x: x[1])[0] if filtered_type_weight else "" except ValueError: method3_type = "" # 方法四:基于多数投票选择类型 type_count = {} for t, _ in type_scores: if t not in type_count: type_count[t] = 0 type_count[t] += 1 try: method4_type = max(type_count.items(), key=lambda x: x[1])[0] if type_count else "" except ValueError: method4_type = "" # 方法五:基于加权多数投票选择类型 # 权重 = 计数 * 平均分数 type_weighted_count = {} type_total_score = {} # 初始化类型计数和总分数 for t, s in type_scores: if t not in type_total_score: type_total_score[t] = 0 type_total_score[t] += s # 计算加权分数 for t in type_count.keys(): if type_count[t] > 0: avg_score = type_total_score.get(t, 0) / type_count[t] type_weighted_count[t] = type_count[t] * avg_score try: method5_type = max(type_weighted_count.items(), key=lambda x: x[1])[0] if type_weighted_count else "" except ValueError: method5_type = "" return { "method1": method1_type, # 最高分数法 "method2": method2_type, # 加权投票法 "method3": method3_type, # 阈值过滤加权投票法 "method4": method4_type, # 多数投票法 "method5": method5_type # 加权多数投票法 } def update_image(self, image_path=None, image_base64=None, url=None, cont_sign=None, brief=None, tags=None): """ 更新图库中图片的摘要和分类信息 Args: image_path: 本地图片路径 image_base64: 图片的base64编码 url: 图片URL cont_sign: 图片签名 brief: 更新的摘要信息,最长256B tags: 更新的分类信息,最多2个tag,如:"1,2" Returns: dict: API返回的结果 """ request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/update?access_token={self.access_token}" params = {} # 设置图片信息(三选一) if image_path: with open(image_path, 'rb') as f: image = base64.b64encode(f.read()) params['image'] = image elif image_base64: params['image'] = image_base64 elif url: params['url'] = url elif cont_sign: params['cont_sign'] = cont_sign else: raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一") # 设置更新信息 if brief: if isinstance(brief, dict): brief = json.dumps(brief, ensure_ascii=False) params['brief'] = brief if tags: params['tags'] = tags headers = {'Content-Type': 'application/x-www-form-urlencoded'} response = requests.post(request_url, data=params, headers=headers) if response.status_code == 200: return response.json() else: raise Exception(f"更新图片失败: {response.text}")