From e29eced691d7aab8d8fec88d2ff1d831236999c2 Mon Sep 17 00:00:00 2001 From: eust-w Date: Wed, 18 Jun 2025 11:28:39 +0800 Subject: [PATCH] add chat stream and audio response by tts --- app.py | 125 ++++++++++++++++++++++++++++++++++++++-- app/api/azure_openai.py | 36 ++++++++++++ app/static/js/chat.js | 113 ++++++++++++++++++++---------------- app/templates/chat.html | 15 +++++ 4 files changed, 234 insertions(+), 55 deletions(-) diff --git a/app.py b/app.py index 07dc7af..56efbc1 100644 --- a/app.py +++ b/app.py @@ -5,7 +5,7 @@ import os import json import base64 import requests -from flask import Flask, render_template, request, jsonify, redirect, url_for, session, Response +from flask import Flask, render_template, request, jsonify, redirect, url_for, session, Response, stream_with_context from flask_cors import CORS from werkzeug.utils import secure_filename from app.api.baidu_image_search import BaiduImageSearch @@ -457,6 +457,120 @@ def delete_robot(robot_id): except Exception as e: return jsonify({'error': str(e)}), 500 +# ==================== Chat Stream ==================== +@app.route('/imgsearcherApi/api/chat-stream', methods=['POST']) +def chat_stream(): + """流式返回文本增量 + 最终语音 (SSE)""" + if not azure_openai_api: + return jsonify({'error': 'Azure OpenAI API未正确配置'}), 500 + + data = request.get_json() + if not data or 'message' not in data: + return jsonify({'error': '缺少消息内容'}), 400 + + message = data['message'] + image_path = session.get('chat_image_path') + robot_id = session.get('robot_id', '') + if not image_path or not os.path.exists(image_path): + return jsonify({'error': '没有上传图片或图片已失效'}), 400 + + conversation_history = session.get('conversation_history', []) + robot_info = None + if robot_id: + robot = robot_manager.get_robot(robot_id) + if robot: + robot_info = {'name': robot.get('name', ''), 'background': robot.get('background', '')} + + # 语音音色选择 + voice_wav = data.get('voice') or 'zh-CN-XiaoXiao-Assistant-Audio.wav' + # 与Azure流式对话 + openai_resp = azure_openai_api.chat_with_image_stream( + image_path=image_path, + message=message, + conversation_history=conversation_history, + robot_info=robot_info + ) + + def event_stream(): + accumulated_text = "" + pending = "" + punctuation = "。!?.!?" + def send_tts(sentence): + if not sentence.strip(): + return + tts_params = { + 'tts_text': sentence, + 'prompt_wav': voice_wav, + 'text_split_method': 'cut5' + } + try: + tts_resp = requests.get( + "https://cloud.infini-ai.com/AIStudio/v1/inference/api/te-c7zfd4hrdoj2vqmw/tts/CosyVoice/v1/zero_shot", + params=tts_params, + headers={'Authorization': f'Bearer {os.getenv("INFINI_API_KEY", "sk-daooolufaf7ienn6")}'}, + stream=True, + timeout=30 + ) + if tts_resp.status_code == 200: + audio_bytes = bytearray() + for chunk in tts_resp.iter_content(chunk_size=8192): + if chunk: + audio_bytes.extend(chunk) + if audio_bytes: + b64_audio = base64.b64encode(audio_bytes).decode() + yield f"data: {{\"type\": \"audio\", \"content\": \"{b64_audio}\" }}\n\n" + except Exception as e: + print("TTS 调用失败", e) + # 读取 GPT 流 + for line in openai_resp.iter_lines(): + if not line: + continue + try: + if line.strip() == b'data: [DONE]': + break + if line.startswith(b'data:'): + content_json = json.loads(line[5:].strip()) + choices = content_json.get('choices', []) + if not choices: + continue + delta = choices[0].get('delta', {}).get('content', '') + if not delta: + continue + accumulated_text += delta + pending += delta + # 立即发送文本增量 + yield f"data: {{\"type\": \"text\", \"content\": {json.dumps(delta, ensure_ascii=False)} }}\n\n" + # 检测标点 + while True: + idx = -1 + for i, ch in enumerate(pending): + if ch in punctuation: + idx = i + break + if idx == -1: + break + sentence = pending[:idx+1] + pending = pending[idx+1:] + # 调用 TTS 并流式发送 + for audio_evt in send_tts(sentence): + yield audio_evt + except Exception as e: + print("解析OpenAI流错误", e) + continue + # 处理剩余 pending + if pending.strip(): + for audio_evt in send_tts(pending): + yield audio_evt + pending = "" + # 更新会话历史 + conversation_history.append({"role": "user", "content": message}) + conversation_history.append({"role": "assistant", "content": accumulated_text}) + session['conversation_history'] = conversation_history + yield "event: end\ndata: {}\n\n" + + # 保持请求上下文,避免在生成器中操作 session 报错 + return Response(stream_with_context(event_stream()), content_type='text/event-stream') + # ==================== TTS ==================== @app.route('/imgsearcherApi/api/tts', methods=['POST']) def tts(): @@ -467,18 +581,19 @@ def tts(): text = data['text'] prompt_text = data.get('prompt_text') or "我是威震天,我只代表月亮消灭你" - prompt_wav = data.get('prompt_wav') or "data_workspace/data/workspace_170/我是威震天.wav" + prompt_wav = data.get('voice') or data.get('prompt_wav') or "zh-CN-XiaoXiao-Assistant-Audio.wav" - tts_base_url = "http://180.76.186.85:20099/CosyVoice/v1/zero_shot" + tts_base_url = "https://cloud.infini-ai.com/AIStudio/v1/inference/api/te-c7zfd4hrdoj2vqmw/tts/CosyVoice/v1/zero_shot" + api_key = os.getenv('INFINI_API_KEY', 'sk-daooolufaf7ienn6') + headers = {'Authorization': f'Bearer {api_key}'} params = { 'tts_text': text, - 'prompt_text': prompt_text, 'prompt_wav': prompt_wav, 'text_split_method': 'cut5' } try: - r = requests.get(tts_base_url, params=params, timeout=20) + r = requests.get(tts_base_url, params=params, headers=headers, timeout=20) if r.status_code == 200: return Response(r.content, content_type=r.headers.get('Content-Type', 'audio/wav')) else: diff --git a/app/api/azure_openai.py b/app/api/azure_openai.py index 8bfa8b4..6b7a2a8 100644 --- a/app/api/azure_openai.py +++ b/app/api/azure_openai.py @@ -172,3 +172,39 @@ class AzureOpenAI: return response.json() else: raise Exception(f"Azure OpenAI API请求失败: {response.text}") + + def chat_with_image_stream(self, image_path, message, conversation_history=None, robot_info=None): + """流式输出 GPT-4o 回复,返回 requests.Response 对象 (stream=True)""" + if conversation_history is None: + conversation_history = [] + image_type, description = self._get_image_type_description(image_path) + base64_image = self._encode_image(image_path) + + if robot_info and robot_info.get('name') and robot_info.get('background'): + system_message = f"你是{robot_info['name']},一个能够分析图片并回答问题的角色。\n\n你的背景故事:{robot_info['background']}\n\n在对话中,你应该始终保持这个角色的身份和特点,用第一人称回答问题。" + else: + system_message = "你是一个智能助手,能够分析图片并回答问题。" + if image_type and description: + system_message += f"\n\n这是一张{image_type}的图片。\n描述信息:{description}\n\n请基于这些信息和图片内容回答用户的问题。" + + messages = [{"role": "system", "content": system_message}] + for msg in conversation_history: + messages.append(msg) + messages.append({ + "role": "user", + "content": [ + {"type": "text", "text": message}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} + ] + }) + + headers = {"Content-Type": "application/json", "api-key": self.api_key} + payload = { + "messages": messages, + "max_tokens": 2000, + "temperature": 0.7, + "top_p": 0.95, + "stream": True + } + url = f"{self.endpoint}/openai/deployments/{self.deployment_name}/chat/completions?api-version={self.api_version}" + return requests.post(url, headers=headers, json=payload, stream=True) diff --git a/app/static/js/chat.js b/app/static/js/chat.js index 963ce62..795dcb2 100644 --- a/app/static/js/chat.js +++ b/app/static/js/chat.js @@ -178,28 +178,76 @@ function initChatForm() { chatMessages.scrollTop = chatMessages.scrollHeight; try { - const response = await fetch('/imgsearcherApi/api/chat', { + // 使用流式接口 + const response = await fetch('/imgsearcherApi/api/chat-stream', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ - message: message + message: message, + voice: document.getElementById('voiceSelect') ? document.getElementById('voiceSelect').value : '' }) }); - - const data = await response.json(); - - // 移除正在输入指示器 - chatMessages.removeChild(typingIndicator); - - if (data.error) { - alert(`对话失败: ${data.error}`); + + if (!response.ok) { + chatMessages.removeChild(typingIndicator); + alert('对话失败'); return; } - - // 添加AI回复到聊天界面 - addMessage('assistant', data.reply); + + // 创建助手消息占位 + const assistantElement = document.createElement('div'); + assistantElement.className = 'message message-assistant'; + const contentDiv = document.createElement('div'); + contentDiv.className = 'message-content'; + assistantElement.appendChild(contentDiv); + const timeDiv = document.createElement('div'); + timeDiv.className = 'message-time'; + timeDiv.textContent = getCurrentTime(); + assistantElement.appendChild(timeDiv); + chatMessages.appendChild(assistantElement); + chatMessages.scrollTop = chatMessages.scrollHeight; + + const reader = response.body.getReader(); + const decoder = new TextDecoder('utf-8'); + let buf = ''; + let streamEnded = false; + while (!streamEnded) { + const { value, done } = await reader.read(); + if (done) break; + buf += decoder.decode(value, { stream: true }); + + let pos; + while ((pos = buf.indexOf('\n\n')) !== -1) { + const raw = buf.slice(0, pos).trim(); + buf = buf.slice(pos + 2); + if (!raw) continue; + + if (raw.startsWith('event: end')) { + chatMessages.removeChild(typingIndicator); + streamEnded = true; + break; + } + + if (!raw.startsWith('data:')) continue; + const payload = raw.slice(5).trim(); + if (!payload) continue; + try { + const obj = JSON.parse(payload); + if (obj.type === 'text') { + contentDiv.innerHTML += formatMessage(obj.content); + chatMessages.scrollTop = chatMessages.scrollHeight; + } else if (obj.type === 'audio') { + const audioBytes = Uint8Array.from(atob(obj.content), c => c.charCodeAt(0)); + const audioBlob = new Blob([audioBytes], { type: 'audio/wav' }); + new Audio(URL.createObjectURL(audioBlob)).play(); + } + } catch (e) { + console.error('解析事件失败', e, payload); + } + } + } } catch (error) { // 移除正在输入指示器 @@ -226,43 +274,8 @@ async function addMessage(role, content) { chatMessages.appendChild(messageElement); - // 如果是AI回复,调用TTS并播放 - if (role === 'assistant') { - try { - const ttsResp = await fetch('/imgsearcherApi/api/tts', { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - text: content - }) - }); - if (ttsResp.ok) { - const blob = await ttsResp.blob(); - const audioUrl = URL.createObjectURL(blob); - const audio = document.createElement('audio'); - audio.src = audioUrl; - audio.autoplay = true; - // 可选:提供播放控制按钮 - const controls = document.createElement('div'); - controls.className = 'audio-controls mt-1'; - const playBtn = document.createElement('button'); - playBtn.type = 'button'; - playBtn.className = 'btn btn-sm btn-outline-secondary'; - playBtn.textContent = '🔊 播放'; - playBtn.addEventListener('click', () => { - audio.play(); - }); - controls.appendChild(playBtn); - messageElement.appendChild(controls); - } else { - console.error('TTS 接口错误', await ttsResp.text()); - } - } catch (err) { - console.error('TTS 播放失败', err); - } - } + + // 滚动到底部 chatMessages.scrollTop = chatMessages.scrollHeight; diff --git a/app/templates/chat.html b/app/templates/chat.html index 7096b06..64bbd99 100644 --- a/app/templates/chat.html +++ b/app/templates/chat.html @@ -44,6 +44,21 @@ + +
+
+
+ + +
+
+
管理机器人角色