211 lines
8.4 KiB
Python
211 lines
8.4 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import base64
|
||
import json
|
||
import requests
|
||
from dotenv import load_dotenv
|
||
from app.api.baidu_image_search import BaiduImageSearch
|
||
from app.api.image_utils import ImageUtils
|
||
from app.api.type_manager_mongo import TypeManagerMongo
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
class AzureOpenAI:
|
||
"""Azure OpenAI API封装类,用于与GPT-4o多模态模型交互"""
|
||
|
||
def __init__(self):
|
||
"""初始化,获取API密钥和端点"""
|
||
self.api_key = os.getenv('AZURE_OPENAI_API_KEY')
|
||
self.endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
|
||
self.api_version = os.getenv('AZURE_OPENAI_API_VERSION')
|
||
self.deployment_name = os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME')
|
||
self.baidu_image_search = BaiduImageSearch()
|
||
self.type_manager = TypeManagerMongo()
|
||
|
||
# 检查配置是否存在
|
||
if not self.api_key or not self.endpoint or not self.api_version or not self.deployment_name:
|
||
raise ValueError("Azure OpenAI配置不完整,请检查.env文件")
|
||
|
||
def _encode_image(self, image_path):
|
||
"""
|
||
将图片编码为base64格式
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
|
||
Returns:
|
||
str: base64编码的图片数据
|
||
"""
|
||
with open(image_path, "rb") as image_file:
|
||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||
|
||
def _get_image_type_description(self, image_path):
|
||
"""
|
||
获取图片类型和描述
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
|
||
Returns:
|
||
tuple: (类型, 描述)
|
||
"""
|
||
try:
|
||
# 使用百度图像搜索API搜索图片
|
||
search_result = self.baidu_image_search.search_image(image_path=image_path)
|
||
|
||
# 使用方法三确定最可信的类型
|
||
reliable_types = self.baidu_image_search._determine_most_reliable_types(search_result.get('result', []))
|
||
image_type = reliable_types.get('method3', '')
|
||
|
||
# 首先从本地类型管理器获取描述信息
|
||
description = self.type_manager.get_description(image_type)
|
||
|
||
# 如果本地没有描述信息,尝试从搜索结果中获取
|
||
if not description:
|
||
for item in search_result.get('result', []):
|
||
brief = item.get('brief', '{}')
|
||
if isinstance(brief, str):
|
||
try:
|
||
brief_dict = json.loads(brief)
|
||
if brief_dict.get('type') == image_type:
|
||
description = brief_dict.get('description', '')
|
||
if description:
|
||
# 将获取到的描述信息保存到本地
|
||
self.type_manager.add_type(image_type, description)
|
||
break
|
||
except:
|
||
continue
|
||
|
||
print(f"获取到的图片类型: {image_type}, 描述长度: {len(description)}")
|
||
return image_type, description
|
||
except Exception as e:
|
||
print(f"获取图片类型和描述失败: {str(e)}")
|
||
return "", ""
|
||
|
||
def chat_with_image(self, image_path, message, conversation_history=None, robot_info=None):
|
||
"""
|
||
使用图片和消息与GPT-4o多模态模型进行对话
|
||
|
||
Args:
|
||
image_path: 图片路径
|
||
message: 用户消息
|
||
conversation_history: 对话历史记录
|
||
robot_info: 机器人角色信息,包含name和background
|
||
|
||
Returns:
|
||
dict: 模型响应
|
||
"""
|
||
if conversation_history is None:
|
||
conversation_history = []
|
||
|
||
# 获取图片类型和描述
|
||
image_type, description = self._get_image_type_description(image_path)
|
||
|
||
# 编码图片
|
||
base64_image = self._encode_image(image_path)
|
||
|
||
# 构建系统提示
|
||
if robot_info and robot_info.get('name') and robot_info.get('background'):
|
||
# 如果有机器人角色信息,使用机器人角色
|
||
system_message = f"你是{robot_info['name']},一个能够分析图片并回答问题的角色。\n\n你的背景故事:{robot_info['background']}\n\n在对话中,你应该始终保持这个角色的身份和特点,用第一人称回答问题。"
|
||
else:
|
||
# 默认智能助手
|
||
system_message = "你是一个智能助手,能够分析图片并回答问题。"
|
||
|
||
# 添加图片类型和描述信息
|
||
if image_type and description:
|
||
system_message += f"\n\n这是一张{image_type}的图片。\n描述信息:{description}\n\n请基于这些信息和图片内容回答用户的问题。"
|
||
|
||
# 打印系统提示
|
||
print("\nimage_type and description:")
|
||
print(f"{image_type} - {description}")
|
||
print(f"系统提示: {system_message}...(总长度:{len(system_message)})")
|
||
|
||
# 构建消息历史
|
||
messages = [
|
||
{"role": "system", "content": system_message}
|
||
]
|
||
|
||
# 添加对话历史
|
||
for msg in conversation_history:
|
||
messages.append(msg)
|
||
|
||
# 添加当前消息和图片
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": message
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||
}
|
||
}
|
||
]
|
||
})
|
||
|
||
# 构建API请求
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"api-key": self.api_key
|
||
}
|
||
|
||
payload = {
|
||
"messages": messages,
|
||
"max_tokens": 2000,
|
||
"temperature": 0.7,
|
||
"top_p": 0.95,
|
||
"stream": False
|
||
}
|
||
|
||
# 发送请求
|
||
url = f"{self.endpoint}/openai/deployments/{self.deployment_name}/chat/completions?api-version={self.api_version}"
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
else:
|
||
raise Exception(f"Azure OpenAI API请求失败: {response.text}")
|
||
|
||
def chat_with_image_stream(self, image_path, message, conversation_history=None, robot_info=None):
|
||
"""流式输出 GPT-4o 回复,返回 requests.Response 对象 (stream=True)"""
|
||
if conversation_history is None:
|
||
conversation_history = []
|
||
image_type, description = self._get_image_type_description(image_path)
|
||
base64_image = self._encode_image(image_path)
|
||
|
||
if robot_info and robot_info.get('name') and robot_info.get('background'):
|
||
system_message = f"你是{robot_info['name']},一个能够分析图片并回答问题的角色。\n\n你的背景故事:{robot_info['background']}\n\n在对话中,你应该始终保持这个角色的身份和特点,用第一人称回答问题。"
|
||
else:
|
||
system_message = "你是一个智能助手,能够分析图片并回答问题。"
|
||
if image_type and description:
|
||
system_message += f"\n\n这是一张{image_type}的图片。\n描述信息:{description}\n\n请基于这些信息和图片内容回答用户的问题。"
|
||
|
||
messages = [{"role": "system", "content": system_message}]
|
||
for msg in conversation_history:
|
||
messages.append(msg)
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": message},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
||
]
|
||
})
|
||
|
||
headers = {"Content-Type": "application/json", "api-key": self.api_key}
|
||
payload = {
|
||
"messages": messages,
|
||
"max_tokens": 2000,
|
||
"temperature": 0.7,
|
||
"top_p": 0.95,
|
||
"stream": True
|
||
}
|
||
url = f"{self.endpoint}/openai/deployments/{self.deployment_name}/chat/completions?api-version={self.api_version}"
|
||
return requests.post(url, headers=headers, json=payload, stream=True)
|