feat: 添加聊天最大输出 Token 和流式输出选项,优化 AI 聊天体验
This commit is contained in:
364
app.py
364
app.py
@@ -221,6 +221,8 @@ AI_SETTINGS_DEFAULT = {
|
||||
),
|
||||
"api_key": os.environ.get("SILICONFLOW_API_KEY", ""),
|
||||
"timeout": int(os.environ.get("SILICONFLOW_TIMEOUT", "30") or "30"),
|
||||
"chat_max_tokens": int(os.environ.get("SILICONFLOW_CHAT_MAX_TOKENS", "4096") or "4096"),
|
||||
"chat_stream_enabled": True,
|
||||
"restock_threshold": LOW_STOCK_THRESHOLD,
|
||||
"restock_limit": 24,
|
||||
"lcsc_timeout": int(os.environ.get("LCSC_TIMEOUT", "20") or "20"),
|
||||
@@ -426,6 +428,11 @@ def _get_ai_settings() -> dict:
|
||||
except (TypeError, ValueError):
|
||||
settings["timeout"] = 30
|
||||
|
||||
try:
|
||||
settings["chat_max_tokens"] = max(256, int(settings.get("chat_max_tokens", 4096)))
|
||||
except (TypeError, ValueError):
|
||||
settings["chat_max_tokens"] = 4096
|
||||
|
||||
try:
|
||||
settings["restock_threshold"] = max(0, int(settings.get("restock_threshold", LOW_STOCK_THRESHOLD)))
|
||||
except (TypeError, ValueError):
|
||||
@@ -449,6 +456,7 @@ def _get_ai_settings() -> dict:
|
||||
settings["lcsc_app_id"] = (settings.get("lcsc_app_id") or "").strip()
|
||||
settings["lcsc_access_key"] = (settings.get("lcsc_access_key") or "").strip()
|
||||
settings["lcsc_secret_key"] = (settings.get("lcsc_secret_key") or "").strip()
|
||||
settings["chat_stream_enabled"] = bool(settings.get("chat_stream_enabled", True))
|
||||
settings["lock_storage_mode"] = bool(settings.get("lock_storage_mode", False))
|
||||
return settings
|
||||
|
||||
@@ -2928,6 +2936,7 @@ def _call_siliconflow_chat(
|
||||
model: str,
|
||||
api_key: str,
|
||||
timeout: int,
|
||||
max_tokens: int = 4096, # 默认 4096,SQL规划等短回复场景可传 700
|
||||
) -> str:
|
||||
api_key = (api_key or "").strip()
|
||||
if not api_key:
|
||||
@@ -2940,7 +2949,7 @@ def _call_siliconflow_chat(
|
||||
payload = {
|
||||
"model": model,
|
||||
"temperature": 0.2,
|
||||
"max_tokens": 700,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
@@ -2975,6 +2984,74 @@ def _call_siliconflow_chat(
|
||||
raise RuntimeError("AI 返回格式无法解析") from exc
|
||||
|
||||
|
||||
def _call_siliconflow_chat_stream(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
*,
|
||||
api_url: str,
|
||||
model: str,
|
||||
api_key: str,
|
||||
timeout: int,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""流式调用 SiliconFlow chat API,按 SSE 协议逐片 yield 文本内容。"""
|
||||
api_key = (api_key or "").strip()
|
||||
if not api_key:
|
||||
raise RuntimeError("SILICONFLOW_API_KEY 未配置")
|
||||
if not api_url:
|
||||
raise RuntimeError("AI API URL 未配置")
|
||||
if not model:
|
||||
raise RuntimeError("AI 模型名称未配置")
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"temperature": 0.2,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
}
|
||||
body = json.dumps(payload).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
api_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
try:
|
||||
resp = urllib.request.urlopen(req, timeout=timeout)
|
||||
except urllib.error.HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="ignore")
|
||||
raise RuntimeError(f"AI 服务返回 HTTP {exc.code}: {detail[:200]}") from exc
|
||||
except urllib.error.URLError as exc:
|
||||
raise RuntimeError(f"AI 服务连接失败: {exc.reason}") from exc
|
||||
except (TimeoutError, socket.timeout) as exc:
|
||||
raise RuntimeError(f"AI 服务连接超时(>{timeout}秒)") from exc
|
||||
|
||||
try:
|
||||
# HTTP 响应逐行读取,每行格式: "data: {...}" 或 "data: [DONE]"
|
||||
for line_bytes in resp:
|
||||
line = line_bytes.decode("utf-8").rstrip("\r\n")
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
try:
|
||||
chunk_obj = json.loads(data_str)
|
||||
delta = chunk_obj["choices"][0]["delta"].get("content") or ""
|
||||
if delta:
|
||||
yield delta
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
resp.close()
|
||||
|
||||
|
||||
DB_CHAT_ALLOWED_TABLES = {"boxes", "components", "inventory_events"}
|
||||
DB_CHAT_FORBIDDEN_KEYWORDS = {
|
||||
"insert",
|
||||
@@ -3113,6 +3190,7 @@ def _build_db_chat_sql_plan(
|
||||
model=settings["model"],
|
||||
api_key=settings["api_key"],
|
||||
timeout=settings["timeout"],
|
||||
max_tokens=700, # SQL规划只需短 JSON 输出
|
||||
)
|
||||
parsed = json.loads(_extract_json_object_text(raw_text))
|
||||
sql = str(parsed.get("sql", "") or "").strip()
|
||||
@@ -3320,6 +3398,7 @@ def _build_db_chat_answer(
|
||||
if web_context:
|
||||
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
|
||||
|
||||
chat_max_tokens = int(settings.get("chat_max_tokens", 4096) or 4096)
|
||||
return _call_siliconflow_chat(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
@@ -3327,6 +3406,51 @@ def _build_db_chat_answer(
|
||||
model=settings["model"],
|
||||
api_key=settings["api_key"],
|
||||
timeout=settings["timeout"],
|
||||
max_tokens=chat_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _build_db_chat_answer_stream(
|
||||
question: str,
|
||||
sql: str,
|
||||
sql_reason: str,
|
||||
rows: list[dict],
|
||||
web_context: list[dict],
|
||||
settings: dict,
|
||||
):
|
||||
"""流式版数据库查询回答,逐片 yield 文本内容。"""
|
||||
if not rows:
|
||||
yield "查询已执行,但没有匹配数据。你可以换个条件,例如指定料号、时间范围或盒型。"
|
||||
return
|
||||
|
||||
system_prompt = (
|
||||
"你是库存分析助手。"
|
||||
"请仅根据提供的 SQL 结果回答,禁止虚构不存在的数据。"
|
||||
"回答用简明中文,优先给结论,再给关键证据。"
|
||||
"若提供了联网参考,请明确标注为参考信息,不得当作数据库事实。"
|
||||
)
|
||||
user_prompt = (
|
||||
"用户问题:\n"
|
||||
+ question
|
||||
+ "\n\n执行SQL:\n"
|
||||
+ sql
|
||||
+ "\n\nSQL说明:\n"
|
||||
+ sql_reason
|
||||
+ "\n\n查询结果(JSON):\n"
|
||||
+ json.dumps(rows, ensure_ascii=False)
|
||||
)
|
||||
if web_context:
|
||||
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
|
||||
|
||||
chat_max_tokens = int(settings.get("chat_max_tokens", 4096) or 4096)
|
||||
yield from _call_siliconflow_chat_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
api_url=settings["api_url"],
|
||||
model=settings["model"],
|
||||
api_key=settings["api_key"],
|
||||
timeout=settings["timeout"],
|
||||
max_tokens=chat_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -3405,6 +3529,7 @@ def _build_general_chat_answer(
|
||||
if web_context:
|
||||
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
|
||||
|
||||
chat_max_tokens = int(settings.get("chat_max_tokens", 4096) or 4096)
|
||||
return _call_siliconflow_chat(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
@@ -3412,12 +3537,68 @@ def _build_general_chat_answer(
|
||||
model=settings["model"],
|
||||
api_key=settings["api_key"],
|
||||
timeout=settings["timeout"],
|
||||
max_tokens=chat_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _build_general_chat_answer_stream(
|
||||
question: str,
|
||||
history: list[dict],
|
||||
web_context: list[dict],
|
||||
memory_text: str,
|
||||
settings: dict,
|
||||
):
|
||||
"""流式版通用聊天回答,逐片 yield 文本内容。"""
|
||||
q = (question or "").strip()
|
||||
q_lower = q.lower()
|
||||
# 模型名询问直接返回,无需调 API
|
||||
if "模型" in q or "model" in q_lower:
|
||||
model_name = str(settings.get("model", "") or "未配置")
|
||||
yield f"当前系统配置的 AI 模型是:{model_name}。"
|
||||
return
|
||||
|
||||
history_lines = []
|
||||
for row in history[-8:]:
|
||||
role = str(row.get("role", "user") or "user")
|
||||
content = str(row.get("content", "") or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
history_lines.append(f"[{role}] {content[:220]}")
|
||||
|
||||
system_prompt = (
|
||||
"你是系统内置通用助手。"
|
||||
"请直接回答用户问题,不要把自己限制为仅库存问题。"
|
||||
"对于时效性问题(如天气、新闻、实时价格),若提供了联网参考就优先依据参考;"
|
||||
"若没有联网参考,需明确说明可能不够实时。"
|
||||
"回答保持简洁中文。"
|
||||
)
|
||||
user_prompt = (
|
||||
"对话历史:\n"
|
||||
+ "\n".join(history_lines or ["(无)"])
|
||||
+ "\n\n用户当前问题:\n"
|
||||
+ question
|
||||
)
|
||||
if memory_text:
|
||||
user_prompt += "\n\n本地记忆摘要:\n" + memory_text
|
||||
if web_context:
|
||||
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
|
||||
|
||||
chat_max_tokens = int(settings.get("chat_max_tokens", 4096) or 4096)
|
||||
yield from _call_siliconflow_chat_stream(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
api_url=settings["api_url"],
|
||||
model=settings["model"],
|
||||
api_key=settings["api_key"],
|
||||
timeout=settings["timeout"],
|
||||
max_tokens=chat_max_tokens,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/ai/chat")
|
||||
def ai_chat_page():
|
||||
return render_template("ai_chat.html")
|
||||
settings = _get_ai_settings()
|
||||
return render_template("ai_chat.html", settings=settings)
|
||||
|
||||
|
||||
@app.route("/ai/chat/memory/clear", methods=["POST"])
|
||||
@@ -3644,6 +3825,171 @@ def ai_chat_query():
|
||||
}
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# SSE 流式聊天接口 /ai/chat/stream
|
||||
# 与 ai_chat_query 执行相同的预处理逻辑,但最终 AI 回答改为流式 SSE 输出
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
@app.route("/ai/chat/stream", methods=["POST"])
|
||||
def ai_chat_stream():
|
||||
"""SSE 流式聊天接口,先同步完成预处理(联网/SQL),再以 SSE 逐片推送 AI 回答。"""
|
||||
payload = request.get_json(silent=True) or {}
|
||||
question = str(payload.get("question", "") or "").strip()
|
||||
history = payload.get("history", [])
|
||||
allow_web_search = _is_truthy_form_value(str(payload.get("allow_web_search", "")))
|
||||
allow_db_query = _is_truthy_form_value(str(payload.get("allow_db_query", "")))
|
||||
if not isinstance(history, list):
|
||||
history = []
|
||||
|
||||
_SSE_HEADERS = {
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # 防止 Nginx 缓冲
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
def _sse_err(msg: str):
|
||||
"""单条错误事件的 Response"""
|
||||
evt = json.dumps({"type": "error", "message": msg}, ensure_ascii=False)
|
||||
return Response(iter([f"data: {evt}\n\n"]), mimetype="text/event-stream", headers=_SSE_HEADERS)
|
||||
|
||||
if not question:
|
||||
return _sse_err("请输入问题")
|
||||
|
||||
username = (session.get("username") or "guest").strip() or "guest"
|
||||
memory_payload = _get_ai_chat_memory(username)
|
||||
memory_text = _memory_context_text(memory_payload)
|
||||
|
||||
settings = _get_ai_settings()
|
||||
if not settings.get("api_key") or not settings.get("api_url") or not settings.get("model"):
|
||||
return _sse_err("AI 参数不完整,请先到参数页配置")
|
||||
if not bool(settings.get("chat_stream_enabled", True)):
|
||||
return _sse_err("当前已关闭流式输出,请在 AI 参数页开启后重试")
|
||||
|
||||
# ── 预处理:与 ai_chat_query 相同逻辑 ────────────────────────────────────
|
||||
chat_mode = "general"
|
||||
sql = ""
|
||||
planner_reason = ""
|
||||
planner_raw = ""
|
||||
row_count = 0
|
||||
web_context = []
|
||||
serialized_rows = []
|
||||
answer_gen = None # 最终的流式生成器
|
||||
|
||||
if not allow_db_query:
|
||||
# 通用问答模式
|
||||
if allow_web_search:
|
||||
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
|
||||
if not web_context and ("天气" in question or "weather" in question.lower()):
|
||||
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
|
||||
if weather_ctx:
|
||||
web_context = [weather_ctx]
|
||||
planner_reason = "未启用数据库查询,按通用问答模式处理"
|
||||
chat_mode = "general"
|
||||
answer_gen = _build_general_chat_answer_stream(question, history, web_context, memory_text, settings)
|
||||
else:
|
||||
# 数据库查询模式
|
||||
try:
|
||||
sql_raw, planner_reason, planner_raw = _build_db_chat_sql_plan(question, history, memory_text, settings)
|
||||
except Exception as exc:
|
||||
_log_event(logging.WARNING, "ai_chat_stream_plan_error", error=str(exc), question=question)
|
||||
if not _is_general_chat_question(question):
|
||||
return _sse_err(f"SQL 规划失败: {exc}")
|
||||
# 回退通用
|
||||
if allow_web_search:
|
||||
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
|
||||
if not web_context and ("天气" in question or "weather" in question.lower()):
|
||||
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
|
||||
if weather_ctx:
|
||||
web_context = [weather_ctx]
|
||||
planner_reason = "已切换到通用对话模式"
|
||||
chat_mode = "general"
|
||||
answer_gen = _build_general_chat_answer_stream(question, history, web_context, memory_text, settings)
|
||||
else:
|
||||
safe, reject_reason, safe_sql = _is_safe_readonly_sql(sql_raw)
|
||||
if not safe:
|
||||
_log_event(logging.WARNING, "ai_chat_stream_sql_rejected", reason=reject_reason, sql=sql_raw)
|
||||
can_fallback = "未包含可识别的数据表" in reject_reason or _is_general_chat_question(question)
|
||||
if not can_fallback:
|
||||
return _sse_err(f"SQL 被安全策略拒绝: {reject_reason}")
|
||||
# 回退通用
|
||||
if allow_web_search:
|
||||
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
|
||||
if not web_context and ("天气" in question or "weather" in question.lower()):
|
||||
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
|
||||
if weather_ctx:
|
||||
web_context = [weather_ctx]
|
||||
planner_reason = "已切换到通用对话模式"
|
||||
chat_mode = "general"
|
||||
answer_gen = _build_general_chat_answer_stream(question, history, web_context, memory_text, settings)
|
||||
else:
|
||||
# 执行 SQL(需要在路由函数里做,生成器无 db.session 保证)
|
||||
final_sql = _ensure_query_limit(safe_sql, row_limit=80)
|
||||
sql = final_sql
|
||||
try:
|
||||
query_result = db.session.execute(db.text(final_sql))
|
||||
columns = list(query_result.keys())
|
||||
rows_data = query_result.fetchall()
|
||||
serialized_rows = _serialize_sql_rows(rows_data, columns)
|
||||
row_count = len(serialized_rows)
|
||||
except Exception as exc:
|
||||
_log_event(logging.ERROR, "ai_chat_stream_exec_error", error=str(exc), sql=final_sql)
|
||||
return _sse_err("SQL 执行失败,请调整问题后重试")
|
||||
|
||||
if allow_web_search:
|
||||
try:
|
||||
web_context = _build_db_chat_web_context(question, serialized_rows, timeout=settings.get("timeout", 30))
|
||||
except Exception:
|
||||
web_context = []
|
||||
|
||||
chat_mode = "db"
|
||||
answer_gen = _build_db_chat_answer_stream(question, final_sql, planner_reason, serialized_rows, web_context, settings)
|
||||
|
||||
# ── SSE 生成器 ────────────────────────────────────────────────────────────
|
||||
meta = {
|
||||
"type": "meta",
|
||||
"sql": sql,
|
||||
"planner_reason": planner_reason,
|
||||
"planner_raw": planner_raw,
|
||||
"row_count": row_count,
|
||||
"rows_preview": serialized_rows[:8],
|
||||
"web_context": web_context,
|
||||
"chat_mode": chat_mode,
|
||||
"allow_web_search": allow_web_search,
|
||||
"allow_db_query": allow_db_query,
|
||||
}
|
||||
|
||||
def _generate():
|
||||
# 先推送元数据(SQL、模式、联网来源等)
|
||||
yield f"data: {json.dumps(meta, ensure_ascii=False)}\n\n"
|
||||
|
||||
accumulated = []
|
||||
try:
|
||||
for chunk in answer_gen:
|
||||
accumulated.append(chunk)
|
||||
evt = json.dumps({"type": "chunk", "text": chunk}, ensure_ascii=False)
|
||||
yield f"data: {evt}\n\n"
|
||||
except Exception as exc:
|
||||
err_evt = json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False)
|
||||
yield f"data: {err_evt}\n\n"
|
||||
return
|
||||
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 流结束后保存记忆
|
||||
full_answer = "".join(accumulated)
|
||||
_append_ai_chat_memory(username, question, full_answer)
|
||||
_log_event(
|
||||
logging.INFO,
|
||||
"ai_chat_stream_success",
|
||||
question=question,
|
||||
chat_mode=chat_mode,
|
||||
allow_web_search=allow_web_search,
|
||||
allow_db_query=allow_db_query,
|
||||
web_sources=sum(len(item.get("sources", [])) for item in web_context),
|
||||
)
|
||||
|
||||
return Response(_generate(), mimetype="text/event-stream", headers=_SSE_HEADERS)
|
||||
|
||||
|
||||
def _is_safe_next_path(path: str) -> bool:
|
||||
candidate = (path or "").strip()
|
||||
if not candidate:
|
||||
@@ -5302,6 +5648,7 @@ def ai_settings_page():
|
||||
api_url = request.form.get("api_url", "").strip()
|
||||
model = request.form.get("model", "").strip()
|
||||
api_key = request.form.get("api_key", "").strip()
|
||||
chat_stream_enabled = _is_truthy_form_value(request.form.get("chat_stream_enabled", ""))
|
||||
lcsc_app_id = request.form.get("lcsc_app_id", "").strip()
|
||||
lcsc_access_key = request.form.get("lcsc_access_key", "").strip()
|
||||
lcsc_secret_key = request.form.get("lcsc_secret_key", "").strip()
|
||||
@@ -5315,6 +5662,15 @@ def ai_settings_page():
|
||||
error = "超时时间必须是大于等于 5 的整数"
|
||||
timeout = settings["timeout"]
|
||||
|
||||
try:
|
||||
chat_max_tokens = int((request.form.get("chat_max_tokens", "4096") or "4096").strip())
|
||||
if chat_max_tokens < 256:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
if not error:
|
||||
error = "聊天最大输出 token 必须是大于等于 256 的整数"
|
||||
chat_max_tokens = settings["chat_max_tokens"]
|
||||
|
||||
try:
|
||||
restock_threshold = int((request.form.get("restock_threshold", "5") or "5").strip())
|
||||
if restock_threshold < 0:
|
||||
@@ -5355,6 +5711,8 @@ def ai_settings_page():
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"timeout": timeout,
|
||||
"chat_max_tokens": chat_max_tokens,
|
||||
"chat_stream_enabled": chat_stream_enabled,
|
||||
"restock_threshold": restock_threshold,
|
||||
"restock_limit": restock_limit,
|
||||
"lcsc_base_url": LCSC_BASE_URL,
|
||||
@@ -5374,6 +5732,8 @@ def ai_settings_page():
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"timeout": timeout,
|
||||
"chat_max_tokens": chat_max_tokens,
|
||||
"chat_stream_enabled": chat_stream_enabled,
|
||||
"restock_threshold": restock_threshold,
|
||||
"restock_limit": restock_limit,
|
||||
"lcsc_base_url": LCSC_BASE_URL,
|
||||
|
||||
Reference in New Issue
Block a user