imgsearcher/app/api/azure_openai.py
eust-w 487f3af948 feat: modifications based on team suggestions
- Add MongoDB type manager implementation (TypeManagerMongo)
- Update environment variables configuration to support MongoDB connection
- Add chat functionality
- Integrate Azure OpenAI API support
- Update dependencies and startup script
2025-04-11 12:00:32 +08:00

167 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import base64
import json
import requests
from dotenv import load_dotenv
from app.api.baidu_image_search import BaiduImageSearch
from app.api.image_utils import ImageUtils
from app.api.type_manager_mongo import TypeManagerMongo
# 加载环境变量
load_dotenv()
class AzureOpenAI:
"""Azure OpenAI API封装类用于与GPT-4o多模态模型交互"""
def __init__(self):
"""初始化获取API密钥和端点"""
self.api_key = os.getenv('AZURE_OPENAI_API_KEY')
self.endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
self.api_version = os.getenv('AZURE_OPENAI_API_VERSION')
self.deployment_name = os.getenv('AZURE_OPENAI_DEPLOYMENT_NAME')
self.baidu_image_search = BaiduImageSearch()
self.type_manager = TypeManagerMongo()
# 检查配置是否存在
if not self.api_key or not self.endpoint or not self.api_version or not self.deployment_name:
raise ValueError("Azure OpenAI配置不完整请检查.env文件")
def _encode_image(self, image_path):
"""
将图片编码为base64格式
Args:
image_path: 图片路径
Returns:
str: base64编码的图片数据
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def _get_image_type_description(self, image_path):
"""
获取图片类型和描述
Args:
image_path: 图片路径
Returns:
tuple: (类型, 描述)
"""
try:
# 使用百度图像搜索API搜索图片
search_result = self.baidu_image_search.search_image(image_path=image_path)
# 使用方法三确定最可信的类型
reliable_types = self.baidu_image_search._determine_most_reliable_types(search_result.get('result', []))
image_type = reliable_types.get('method3', '')
# 首先从本地类型管理器获取描述信息
description = self.type_manager.get_description(image_type)
# 如果本地没有描述信息,尝试从搜索结果中获取
if not description:
for item in search_result.get('result', []):
brief = item.get('brief', '{}')
if isinstance(brief, str):
try:
brief_dict = json.loads(brief)
if brief_dict.get('type') == image_type:
description = brief_dict.get('description', '')
if description:
# 将获取到的描述信息保存到本地
self.type_manager.add_type(image_type, description)
break
except:
continue
print(f"获取到的图片类型: {image_type}, 描述长度: {len(description)}")
return image_type, description
except Exception as e:
print(f"获取图片类型和描述失败: {str(e)}")
return "", ""
def chat_with_image(self, image_path, message, conversation_history=None):
"""
使用图片和消息与GPT-4o多模态模型进行对话
Args:
image_path: 图片路径
message: 用户消息
conversation_history: 对话历史记录
Returns:
dict: 模型响应
"""
if conversation_history is None:
conversation_history = []
# 获取图片类型和描述
image_type, description = self._get_image_type_description(image_path)
# 编码图片
base64_image = self._encode_image(image_path)
# 构建系统提示
system_message = "你是一个智能助手,能够分析图片并回答问题。"
if image_type and description:
system_message += f"\n\n这是一张{image_type}的图片。\n描述信息:{description}\n\n请基于这些信息和图片内容回答用户的问题。"
# 打印系统提示
print("\nimage_type and description:")
print(f"{image_type} - {description}")
print(f"系统提示: {system_message}...(总长度:{len(system_message)}")
# 构建消息历史
messages = [
{"role": "system", "content": system_message}
]
# 添加对话历史
for msg in conversation_history:
messages.append(msg)
# 添加当前消息和图片
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": message
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
})
# 构建API请求
headers = {
"Content-Type": "application/json",
"api-key": self.api_key
}
payload = {
"messages": messages,
"max_tokens": 2000,
"temperature": 0.7,
"top_p": 0.95,
"stream": False
}
# 发送请求
url = f"{self.endpoint}/openai/deployments/{self.deployment_name}/chat/completions?api-version={self.api_version}"
response = requests.post(url, headers=headers, json=payload)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Azure OpenAI API请求失败: {response.text}")