imgsearcher/app/api/baidu_image_search.py
2025-04-09 11:13:17 +08:00

227 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import base64
import json
import os
import requests
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
class BaiduImageSearch:
"""百度相似图片搜索API封装类"""
def __init__(self):
"""初始化获取API密钥"""
self.api_key = os.getenv('BAIDU_API_KEY')
self.secret_key = os.getenv('BAIDU_SECRET_KEY')
self.access_token = None
self.get_access_token()
def get_access_token(self):
"""获取百度API的access_token"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {
"grant_type": "client_credentials",
"client_id": self.api_key,
"client_secret": self.secret_key
}
response = requests.post(url, params=params)
if response.status_code == 200:
result = response.json()
self.access_token = result.get('access_token')
return self.access_token
else:
raise Exception(f"获取access_token失败: {response.text}")
def add_image(self, image_path=None, image_base64=None, url=None, brief=None, tags=None):
"""
添加图片到图库
Args:
image_path: 本地图片路径
image_base64: 图片的base64编码
url: 图片URL
brief: 图片摘要信息最长256B{"name":"图片名称", "id":"123"}
tags: 分类信息最多2个tag"1,2"
Returns:
dict: API返回的结果
"""
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/add?access_token={self.access_token}"
params = {}
# 设置brief信息必选
if brief:
if isinstance(brief, dict):
brief = json.dumps(brief, ensure_ascii=False)
params['brief'] = brief
else:
raise ValueError("brief参数是必须的")
# 设置图片信息(三选一)
if image_path:
with open(image_path, 'rb') as f:
image = base64.b64encode(f.read())
params['image'] = image
elif image_base64:
params['image'] = image_base64
elif url:
params['url'] = url
else:
raise ValueError("必须提供image_path、image_base64或url其中之一")
# 设置可选的tags
if tags:
params['tags'] = tags
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(request_url, data=params, headers=headers)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"添加图片失败: {response.text}")
def search_image(self, image_path=None, image_base64=None, url=None, tags=None, tag_logic=None, pn=0, rn=300):
"""
检索相似图片
Args:
image_path: 本地图片路径
image_base64: 图片的base64编码
url: 图片URL
tags: 分类信息过滤,如:"1,2"
tag_logic: 标签逻辑0表示逻辑and1表示逻辑or
pn: 分页起始位置默认0
rn: 返回结果数量默认300最大1000
Returns:
dict: API返回的结果
"""
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/search?access_token={self.access_token}"
params = {}
# 设置图片信息(三选一)
if image_path:
with open(image_path, 'rb') as f:
image = base64.b64encode(f.read())
params['image'] = image
elif image_base64:
params['image'] = image_base64
elif url:
params['url'] = url
else:
raise ValueError("必须提供image_path、image_base64或url其中之一")
# 设置可选参数
if tags:
params['tags'] = tags
if tag_logic is not None:
params['tag_logic'] = tag_logic
if pn is not None:
params['pn'] = pn
if rn is not None:
params['rn'] = rn
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(request_url, data=params, headers=headers)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"检索图片失败: {response.text}")
def delete_image(self, image_path=None, image_base64=None, url=None, cont_sign=None):
"""
删除图库中的图片
Args:
image_path: 本地图片路径
image_base64: 图片的base64编码
url: 图片URL
cont_sign: 图片签名,支持批量删除,格式如:"932301884,1068006219;316336521,553141152"
Returns:
dict: API返回的结果
"""
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/delete?access_token={self.access_token}"
params = {}
# 设置图片信息(三选一)
if image_path:
with open(image_path, 'rb') as f:
image = base64.b64encode(f.read())
params['image'] = image
elif image_base64:
params['image'] = image_base64
elif url:
params['url'] = url
elif cont_sign:
params['cont_sign'] = cont_sign
else:
raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一")
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(request_url, data=params, headers=headers)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"删除图片失败: {response.text}")
def update_image(self, image_path=None, image_base64=None, url=None, cont_sign=None, brief=None, tags=None):
"""
更新图库中图片的摘要和分类信息
Args:
image_path: 本地图片路径
image_base64: 图片的base64编码
url: 图片URL
cont_sign: 图片签名
brief: 更新的摘要信息最长256B
tags: 更新的分类信息最多2个tag"1,2"
Returns:
dict: API返回的结果
"""
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/update?access_token={self.access_token}"
params = {}
# 设置图片信息(三选一)
if image_path:
with open(image_path, 'rb') as f:
image = base64.b64encode(f.read())
params['image'] = image
elif image_base64:
params['image'] = image_base64
elif url:
params['url'] = url
elif cont_sign:
params['cont_sign'] = cont_sign
else:
raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一")
# 设置更新信息
if brief:
if isinstance(brief, dict):
brief = json.dumps(brief, ensure_ascii=False)
params['brief'] = brief
if tags:
params['tags'] = tags
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(request_url, data=params, headers=headers)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"更新图片失败: {response.text}")