227 lines
7.8 KiB
Python
227 lines
7.8 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import base64
|
||
import json
|
||
import os
|
||
import requests
|
||
from dotenv import load_dotenv
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
class BaiduImageSearch:
|
||
"""百度相似图片搜索API封装类"""
|
||
|
||
def __init__(self):
|
||
"""初始化,获取API密钥"""
|
||
self.api_key = os.getenv('BAIDU_API_KEY')
|
||
self.secret_key = os.getenv('BAIDU_SECRET_KEY')
|
||
self.access_token = None
|
||
self.get_access_token()
|
||
|
||
def get_access_token(self):
|
||
"""获取百度API的access_token"""
|
||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||
params = {
|
||
"grant_type": "client_credentials",
|
||
"client_id": self.api_key,
|
||
"client_secret": self.secret_key
|
||
}
|
||
response = requests.post(url, params=params)
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
self.access_token = result.get('access_token')
|
||
return self.access_token
|
||
else:
|
||
raise Exception(f"获取access_token失败: {response.text}")
|
||
|
||
def add_image(self, image_path=None, image_base64=None, url=None, brief=None, tags=None):
|
||
"""
|
||
添加图片到图库
|
||
|
||
Args:
|
||
image_path: 本地图片路径
|
||
image_base64: 图片的base64编码
|
||
url: 图片URL
|
||
brief: 图片摘要信息,最长256B,如:{"name":"图片名称", "id":"123"}
|
||
tags: 分类信息,最多2个tag,如:"1,2"
|
||
|
||
Returns:
|
||
dict: API返回的结果
|
||
"""
|
||
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/add?access_token={self.access_token}"
|
||
|
||
params = {}
|
||
|
||
# 设置brief信息(必选)
|
||
if brief:
|
||
if isinstance(brief, dict):
|
||
brief = json.dumps(brief, ensure_ascii=False)
|
||
params['brief'] = brief
|
||
else:
|
||
raise ValueError("brief参数是必须的")
|
||
|
||
# 设置图片信息(三选一)
|
||
if image_path:
|
||
with open(image_path, 'rb') as f:
|
||
image = base64.b64encode(f.read())
|
||
params['image'] = image
|
||
elif image_base64:
|
||
params['image'] = image_base64
|
||
elif url:
|
||
params['url'] = url
|
||
else:
|
||
raise ValueError("必须提供image_path、image_base64或url其中之一")
|
||
|
||
# 设置可选的tags
|
||
if tags:
|
||
params['tags'] = tags
|
||
|
||
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
||
response = requests.post(request_url, data=params, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
else:
|
||
raise Exception(f"添加图片失败: {response.text}")
|
||
|
||
def search_image(self, image_path=None, image_base64=None, url=None, tags=None, tag_logic=None, pn=0, rn=300):
|
||
"""
|
||
检索相似图片
|
||
|
||
Args:
|
||
image_path: 本地图片路径
|
||
image_base64: 图片的base64编码
|
||
url: 图片URL
|
||
tags: 分类信息过滤,如:"1,2"
|
||
tag_logic: 标签逻辑,0表示逻辑and,1表示逻辑or
|
||
pn: 分页起始位置,默认0
|
||
rn: 返回结果数量,默认300,最大1000
|
||
|
||
Returns:
|
||
dict: API返回的结果
|
||
"""
|
||
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/search?access_token={self.access_token}"
|
||
|
||
params = {}
|
||
|
||
# 设置图片信息(三选一)
|
||
if image_path:
|
||
with open(image_path, 'rb') as f:
|
||
image = base64.b64encode(f.read())
|
||
params['image'] = image
|
||
elif image_base64:
|
||
params['image'] = image_base64
|
||
elif url:
|
||
params['url'] = url
|
||
else:
|
||
raise ValueError("必须提供image_path、image_base64或url其中之一")
|
||
|
||
# 设置可选参数
|
||
if tags:
|
||
params['tags'] = tags
|
||
if tag_logic is not None:
|
||
params['tag_logic'] = tag_logic
|
||
if pn is not None:
|
||
params['pn'] = pn
|
||
if rn is not None:
|
||
params['rn'] = rn
|
||
|
||
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
||
response = requests.post(request_url, data=params, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
else:
|
||
raise Exception(f"检索图片失败: {response.text}")
|
||
|
||
def delete_image(self, image_path=None, image_base64=None, url=None, cont_sign=None):
|
||
"""
|
||
删除图库中的图片
|
||
|
||
Args:
|
||
image_path: 本地图片路径
|
||
image_base64: 图片的base64编码
|
||
url: 图片URL
|
||
cont_sign: 图片签名,支持批量删除,格式如:"932301884,1068006219;316336521,553141152"
|
||
|
||
Returns:
|
||
dict: API返回的结果
|
||
"""
|
||
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/delete?access_token={self.access_token}"
|
||
|
||
params = {}
|
||
|
||
# 设置图片信息(三选一)
|
||
if image_path:
|
||
with open(image_path, 'rb') as f:
|
||
image = base64.b64encode(f.read())
|
||
params['image'] = image
|
||
elif image_base64:
|
||
params['image'] = image_base64
|
||
elif url:
|
||
params['url'] = url
|
||
elif cont_sign:
|
||
params['cont_sign'] = cont_sign
|
||
else:
|
||
raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一")
|
||
|
||
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
||
response = requests.post(request_url, data=params, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
else:
|
||
raise Exception(f"删除图片失败: {response.text}")
|
||
|
||
def update_image(self, image_path=None, image_base64=None, url=None, cont_sign=None, brief=None, tags=None):
|
||
"""
|
||
更新图库中图片的摘要和分类信息
|
||
|
||
Args:
|
||
image_path: 本地图片路径
|
||
image_base64: 图片的base64编码
|
||
url: 图片URL
|
||
cont_sign: 图片签名
|
||
brief: 更新的摘要信息,最长256B
|
||
tags: 更新的分类信息,最多2个tag,如:"1,2"
|
||
|
||
Returns:
|
||
dict: API返回的结果
|
||
"""
|
||
request_url = f"https://aip.baidubce.com/rest/2.0/image-classify/v1/realtime_search/similar/update?access_token={self.access_token}"
|
||
|
||
params = {}
|
||
|
||
# 设置图片信息(三选一)
|
||
if image_path:
|
||
with open(image_path, 'rb') as f:
|
||
image = base64.b64encode(f.read())
|
||
params['image'] = image
|
||
elif image_base64:
|
||
params['image'] = image_base64
|
||
elif url:
|
||
params['url'] = url
|
||
elif cont_sign:
|
||
params['cont_sign'] = cont_sign
|
||
else:
|
||
raise ValueError("必须提供image_path、image_base64、url或cont_sign其中之一")
|
||
|
||
# 设置更新信息
|
||
if brief:
|
||
if isinstance(brief, dict):
|
||
brief = json.dumps(brief, ensure_ascii=False)
|
||
params['brief'] = brief
|
||
|
||
if tags:
|
||
params['tags'] = tags
|
||
|
||
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
||
response = requests.post(request_url, data=params, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
else:
|
||
raise Exception(f"更新图片失败: {response.text}")
|