imgsearcher/app.py

607 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import base64
import requests
from flask import Flask, render_template, request, jsonify, redirect, url_for, session, Response, stream_with_context
from flask_cors import CORS
from werkzeug.utils import secure_filename
from app.api.baidu_image_search import BaiduImageSearch
from app.api.image_utils import ImageUtils
from app.api.azure_openai import AzureOpenAI
from app.api.type_manager_mongo import TypeManagerMongo
from app.api.robot_manager import RobotManager
app = Flask(__name__, template_folder='app/templates', static_folder='app/static')
CORS(app) # 启用CORS支持跨域请求
# 设置会话密钥
app.secret_key = os.urandom(24)
# 配置上传文件夹
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# 允许的文件扩展名
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
# 初始化百度图像搜索API
image_search_api = BaiduImageSearch()
# 初始化Azure OpenAI API
try:
azure_openai_api = AzureOpenAI()
except ValueError as e:
print(f"警告: {str(e)}")
azure_openai_api = None
# 初始化类型管理器
type_manager = TypeManagerMongo()
# 初始化机器人角色管理器
robot_manager = RobotManager()
def allowed_file(filename):
"""检查文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/imgsearcherApi/')
def index():
"""首页"""
return render_template('index.html')
@app.route('/imgsearcherApi/upload', methods=['POST'])
def upload_image():
"""上传图片到图库"""
if 'file' not in request.files:
return jsonify({'error': '没有文件上传'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '没有选择文件'}), 400
if not allowed_file(file.filename):
return jsonify({'error': '不支持的文件类型'}), 400
# 获取表单数据
image_type = request.form.get('type', '')
description = request.form.get('description', '')
name = request.form.get('name', '')
tags = request.form.get('tags', '1') # 默认标签为1
# 将类型和描述信息保存到本地
type_manager.add_type(image_type, description)
# 创建brief信息不包含描述信息
brief = {
'type': image_type,
'name': name
}
# 保存文件
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
try:
# 调用API添加图片
result = image_search_api.add_image(
image_path=file_path,
brief=brief,
tags=tags
)
# 添加本地存储路径到结果
result['file_path'] = file_path
return jsonify(result)
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/search', methods=['POST'])
def search_image():
"""搜索相似图片"""
if 'file' not in request.files:
return jsonify({'error': '没有文件上传'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '没有选择文件'}), 400
if not allowed_file(file.filename):
return jsonify({'error': '不支持的文件类型'}), 400
# 获取表单数据
image_type = request.form.get('type', '')
tags = request.form.get('tags', '1') # 默认标签为1
tag_logic = request.form.get('tag_logic', '0')
# 保存文件
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
try:
# 调用API搜索图片
result = image_search_api.search_image(
image_path=file_path,
tags=tags,
tag_logic=tag_logic,
type_filter=image_type
)
# 从本地获取类型描述信息,并添加到结果中
if 'result' in result and result['result']:
for item in result['result']:
try:
brief = item.get('brief', '{}')
if isinstance(brief, str):
brief_dict = json.loads(brief)
item_type = brief_dict.get('type', '')
# 从本地获取描述信息
description = type_manager.get_description(item_type)
# 将描述信息添加到brief中
brief_dict['description'] = description
item['brief'] = json.dumps(brief_dict, ensure_ascii=False)
except Exception as e:
print(f"处理搜索结果项出错: {e}")
continue
return jsonify(result)
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/delete', methods=['POST'])
def delete_image():
"""删除图库中的图片"""
data = request.get_json()
if not data or 'cont_sign' not in data:
return jsonify({'error': '缺少必要参数'}), 400
cont_sign = data['cont_sign']
try:
# 调用API删除图片
result = image_search_api.delete_image(cont_sign=cont_sign)
return jsonify(result)
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/update', methods=['POST'])
def update_image():
"""更新图库中的图片信息"""
data = request.get_json()
if not data or 'cont_sign' not in data:
return jsonify({'error': '缺少必要参数'}), 400
cont_sign = data['cont_sign']
brief = data.get('brief')
tags = data.get('tags')
# 如果提供了brief信息将类型和描述信息保存到本地
if brief and isinstance(brief, dict):
image_type = brief.get('type', '')
description = brief.get('description', '')
# 如果有描述信息,则更新到本地
if image_type and description:
type_manager.add_type(image_type, description)
# 从要上传的brief中移除描述信息
brief.pop('description', None)
try:
# 调用API更新图片
result = image_search_api.update_image(
cont_sign=cont_sign,
brief=brief,
tags=tags
)
return jsonify(result)
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/api/token')
def get_token():
"""获取API访问令牌"""
try:
token = image_search_api.get_access_token()
return jsonify({'access_token': token})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/chat-with-image', methods=['GET'])
def chat_with_image_page():
"""图片对话页面"""
return render_template('chat.html')
@app.route('/imgsearcherApi/api/upload-chat-image', methods=['POST'])
def upload_chat_image():
"""上传图片用于对话"""
if 'file' not in request.files:
return jsonify({'error': '没有文件上传'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '没有选择文件'}), 400
if not allowed_file(file.filename):
return jsonify({'error': '不支持的文件类型'}), 400
# 获取选择的机器人角色ID
robot_id = request.form.get('robot_id', '')
# 保存文件
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
# 存储图片路径到会话
session['chat_image_path'] = file_path
# 存储机器人角色ID到会话
session['robot_id'] = robot_id
# 清空对话历史
session['conversation_history'] = []
try:
# 使用百度图像搜索API搜索图片
search_result = image_search_api.search_image(image_path=file_path)
# 使用方法三确定最可信的类型
reliable_types = image_search_api._determine_most_reliable_types(search_result.get('result', []))
image_type = reliable_types.get('method3', '')
# 从本地获取描述信息
description = type_manager.get_description(image_type)
# 如果本地没有描述信息,尝试从搜索结果中获取
if not description:
for item in search_result.get('result', []):
brief = item.get('brief', '{}')
if isinstance(brief, str):
try:
brief_dict = json.loads(brief)
if brief_dict.get('type') == image_type:
description = brief_dict.get('description', '')
if description:
# 将获取到的描述信息保存到本地
type_manager.add_type(image_type, description)
break
except:
continue
# 获取机器人角色信息
robot_info = None
if robot_id:
robot = robot_manager.get_robot(robot_id)
if robot:
robot_info = {
'name': robot.get('name', ''),
'background': robot.get('background', '')
}
return jsonify({
'success': True,
'image_path': file_path,
'image_type': image_type,
'description': description,
'robot': robot_info
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/api/chat', methods=['POST'])
def chat():
"""与图片进行对话"""
if not azure_openai_api:
return jsonify({'error': 'Azure OpenAI API未正确配置'}), 500
data = request.get_json()
if not data or 'message' not in data:
return jsonify({'error': '缺少消息内容'}), 400
message = data['message']
image_path = session.get('chat_image_path')
robot_id = session.get('robot_id', '')
if not image_path or not os.path.exists(image_path):
return jsonify({'error': '没有上传图片或图片已失效'}), 400
# 获取对话历史
conversation_history = session.get('conversation_history', [])
# 获取机器人角色信息
robot_info = None
if robot_id:
robot = robot_manager.get_robot(robot_id)
if robot:
robot_info = {
'name': robot.get('name', ''),
'background': robot.get('background', '')
}
try:
# 调用Azure OpenAI API进行对话
response = azure_openai_api.chat_with_image(
image_path=image_path,
message=message,
conversation_history=conversation_history,
robot_info=robot_info
)
# 提取回复内容
reply = response['choices'][0]['message']['content']
# 更新对话历史
conversation_history.append({"role": "user", "content": message})
conversation_history.append({"role": "assistant", "content": reply})
session['conversation_history'] = conversation_history
return jsonify({
'reply': reply,
'conversation_history': conversation_history
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/robots', methods=['GET'])
def robot_page():
"""机器人角色管理页面"""
return render_template('robots.html')
@app.route('/imgsearcherApi/api/robots', methods=['GET'])
def get_robots():
"""获取所有机器人角色"""
try:
robots = robot_manager.get_all_robots()
return jsonify({
'success': True,
'robots': robots
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/api/robots', methods=['POST'])
def add_robot():
"""添加新的机器人角色"""
try:
name = request.form.get('name')
background = request.form.get('background')
if not name or not background:
return jsonify({'error': '机器人名称和背景故事不能为空'}), 400
avatar_file = None
if 'avatar' in request.files:
avatar_file = request.files['avatar']
if avatar_file.filename == '':
avatar_file = None
elif not allowed_file(avatar_file.filename):
return jsonify({'error': '不支持的文件类型'}), 400
result = robot_manager.add_robot(name, background, avatar_file)
if 'error' in result:
return jsonify(result), 400
return jsonify({
'success': True,
'robot': result
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/api/robots/<robot_id>', methods=['GET'])
def get_robot(robot_id):
"""获取指定机器人角色"""
try:
robot = robot_manager.get_robot(robot_id)
if not robot:
return jsonify({'error': '找不到该机器人'}), 404
return jsonify({
'success': True,
'robot': robot
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/api/robots/<robot_id>', methods=['PUT'])
def update_robot(robot_id):
"""更新机器人角色"""
try:
name = request.form.get('name')
background = request.form.get('background')
avatar_file = None
if 'avatar' in request.files:
avatar_file = request.files['avatar']
if avatar_file.filename == '':
avatar_file = None
elif not allowed_file(avatar_file.filename):
return jsonify({'error': '不支持的文件类型'}), 400
result = robot_manager.update_robot(robot_id, name, background, avatar_file)
if 'error' in result:
return jsonify(result), 400
return jsonify({
'success': True,
'robot': result
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/imgsearcherApi/api/robots/<robot_id>', methods=['DELETE'])
def delete_robot(robot_id):
"""删除机器人角色"""
try:
success = robot_manager.delete_robot(robot_id)
if not success:
return jsonify({'error': '找不到该机器人或删除失败'}), 404
return jsonify({
'success': True
})
except Exception as e:
return jsonify({'error': str(e)}), 500
# ==================== Chat Stream ====================
@app.route('/imgsearcherApi/api/chat-stream', methods=['POST'])
def chat_stream():
"""流式返回文本增量 + 最终语音 (SSE)"""
if not azure_openai_api:
return jsonify({'error': 'Azure OpenAI API未正确配置'}), 500
data = request.get_json()
if not data or 'message' not in data:
return jsonify({'error': '缺少消息内容'}), 400
message = data['message']
image_path = session.get('chat_image_path')
robot_id = session.get('robot_id', '')
if not image_path or not os.path.exists(image_path):
return jsonify({'error': '没有上传图片或图片已失效'}), 400
conversation_history = session.get('conversation_history', [])
robot_info = None
if robot_id:
robot = robot_manager.get_robot(robot_id)
if robot:
robot_info = {'name': robot.get('name', ''), 'background': robot.get('background', '')}
# 语音音色选择
voice_wav = data.get('voice') or 'zh-CN-XiaoXiao-Assistant-Audio.wav'
# 与Azure流式对话
openai_resp = azure_openai_api.chat_with_image_stream(
image_path=image_path,
message=message,
conversation_history=conversation_history,
robot_info=robot_info
)
def event_stream():
accumulated_text = ""
pending = ""
punctuation = "。!?.!?"
def send_tts(sentence):
if not sentence.strip():
return
tts_params = {
'tts_text': sentence,
'prompt_wav': voice_wav,
'text_split_method': 'cut5'
}
try:
tts_resp = requests.get(
"https://cloud.infini-ai.com/AIStudio/v1/inference/api/te-c7zfd4hrdoj2vqmw/tts/CosyVoice/v1/zero_shot",
params=tts_params,
headers={'Authorization': f'Bearer {os.getenv("INFINI_API_KEY", "sk-daooolufaf7ienn6")}'},
stream=True,
timeout=30
)
if tts_resp.status_code == 200:
audio_bytes = bytearray()
for chunk in tts_resp.iter_content(chunk_size=8192):
if chunk:
audio_bytes.extend(chunk)
if audio_bytes:
b64_audio = base64.b64encode(audio_bytes).decode()
yield f"data: {{\"type\": \"audio\", \"content\": \"{b64_audio}\" }}\n\n"
except Exception as e:
print("TTS 调用失败", e)
# 读取 GPT 流
for line in openai_resp.iter_lines():
if not line:
continue
try:
if line.strip() == b'data: [DONE]':
break
if line.startswith(b'data:'):
content_json = json.loads(line[5:].strip())
choices = content_json.get('choices', [])
if not choices:
continue
delta = choices[0].get('delta', {}).get('content', '')
if not delta:
continue
accumulated_text += delta
pending += delta
# 立即发送文本增量
yield f"data: {{\"type\": \"text\", \"content\": {json.dumps(delta, ensure_ascii=False)} }}\n\n"
# 检测标点
while True:
idx = -1
for i, ch in enumerate(pending):
if ch in punctuation:
idx = i
break
if idx == -1:
break
sentence = pending[:idx+1]
pending = pending[idx+1:]
# 调用 TTS 并流式发送
for audio_evt in send_tts(sentence):
yield audio_evt
except Exception as e:
print("解析OpenAI流错误", e)
continue
# 处理剩余 pending
if pending.strip():
for audio_evt in send_tts(pending):
yield audio_evt
pending = ""
# 更新会话历史
conversation_history.append({"role": "user", "content": message})
conversation_history.append({"role": "assistant", "content": accumulated_text})
session['conversation_history'] = conversation_history
yield "event: end\ndata: {}\n\n"
# 保持请求上下文,避免在生成器中操作 session 报错
return Response(stream_with_context(event_stream()), content_type='text/event-stream')
# ==================== TTS ====================
@app.route('/imgsearcherApi/api/tts', methods=['POST'])
def tts():
"""文本转语音接口,返回音频数据流"""
data = request.get_json()
if not data or 'text' not in data:
return jsonify({'error': '缺少text参数'}), 400
text = data['text']
prompt_text = data.get('prompt_text') or "我是威震天,我只代表月亮消灭你"
prompt_wav = data.get('voice') or data.get('prompt_wav') or "zh-CN-XiaoXiao-Assistant-Audio.wav"
tts_base_url = "https://cloud.infini-ai.com/AIStudio/v1/inference/api/te-c7zfd4hrdoj2vqmw/tts/CosyVoice/v1/zero_shot"
api_key = os.getenv('INFINI_API_KEY', 'sk-daooolufaf7ienn6')
headers = {'Authorization': f'Bearer {api_key}'}
params = {
'tts_text': text,
'prompt_wav': prompt_wav,
'text_split_method': 'cut5'
}
try:
r = requests.get(tts_base_url, params=params, headers=headers, timeout=20)
if r.status_code == 200:
return Response(r.content, content_type=r.headers.get('Content-Type', 'audio/wav'))
else:
print(f"TTS remote error {r.status_code}: {r.text}")
return jsonify({'error': f'TTS服务错误: {r.text}'}), 502
except Exception as e:
return jsonify({'error': f'TTS请求失败: {str(e)}'}), 500
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5001)