feat: 添加聊天最大输出 Token 和流式输出选项,优化 AI 聊天体验

This commit is contained in:
2026-03-14 12:48:31 +08:00
parent 21ad22a105
commit f97fad81e6
5 changed files with 694 additions and 50 deletions

364
app.py
View File

@@ -221,6 +221,8 @@ AI_SETTINGS_DEFAULT = {
),
"api_key": os.environ.get("SILICONFLOW_API_KEY", ""),
"timeout": int(os.environ.get("SILICONFLOW_TIMEOUT", "30") or "30"),
"chat_max_tokens": int(os.environ.get("SILICONFLOW_CHAT_MAX_TOKENS", "4096") or "4096"),
"chat_stream_enabled": True,
"restock_threshold": LOW_STOCK_THRESHOLD,
"restock_limit": 24,
"lcsc_timeout": int(os.environ.get("LCSC_TIMEOUT", "20") or "20"),
@@ -426,6 +428,11 @@ def _get_ai_settings() -> dict:
except (TypeError, ValueError):
settings["timeout"] = 30
try:
settings["chat_max_tokens"] = max(256, int(settings.get("chat_max_tokens", 4096)))
except (TypeError, ValueError):
settings["chat_max_tokens"] = 4096
try:
settings["restock_threshold"] = max(0, int(settings.get("restock_threshold", LOW_STOCK_THRESHOLD)))
except (TypeError, ValueError):
@@ -449,6 +456,7 @@ def _get_ai_settings() -> dict:
settings["lcsc_app_id"] = (settings.get("lcsc_app_id") or "").strip()
settings["lcsc_access_key"] = (settings.get("lcsc_access_key") or "").strip()
settings["lcsc_secret_key"] = (settings.get("lcsc_secret_key") or "").strip()
settings["chat_stream_enabled"] = bool(settings.get("chat_stream_enabled", True))
settings["lock_storage_mode"] = bool(settings.get("lock_storage_mode", False))
return settings
@@ -2928,6 +2936,7 @@ def _call_siliconflow_chat(
model: str,
api_key: str,
timeout: int,
max_tokens: int = 4096, # 默认 4096SQL规划等短回复场景可传 700
) -> str:
api_key = (api_key or "").strip()
if not api_key:
@@ -2940,7 +2949,7 @@ def _call_siliconflow_chat(
payload = {
"model": model,
"temperature": 0.2,
"max_tokens": 700,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
@@ -2975,6 +2984,74 @@ def _call_siliconflow_chat(
raise RuntimeError("AI 返回格式无法解析") from exc
def _call_siliconflow_chat_stream(
system_prompt: str,
user_prompt: str,
*,
api_url: str,
model: str,
api_key: str,
timeout: int,
max_tokens: int = 4096,
):
"""流式调用 SiliconFlow chat API按 SSE 协议逐片 yield 文本内容。"""
api_key = (api_key or "").strip()
if not api_key:
raise RuntimeError("SILICONFLOW_API_KEY 未配置")
if not api_url:
raise RuntimeError("AI API URL 未配置")
if not model:
raise RuntimeError("AI 模型名称未配置")
payload = {
"model": model,
"temperature": 0.2,
"max_tokens": max_tokens,
"stream": True,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
}
body = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
api_url,
data=body,
method="POST",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
)
try:
resp = urllib.request.urlopen(req, timeout=timeout)
except urllib.error.HTTPError as exc:
detail = exc.read().decode("utf-8", errors="ignore")
raise RuntimeError(f"AI 服务返回 HTTP {exc.code}: {detail[:200]}") from exc
except urllib.error.URLError as exc:
raise RuntimeError(f"AI 服务连接失败: {exc.reason}") from exc
except (TimeoutError, socket.timeout) as exc:
raise RuntimeError(f"AI 服务连接超时(>{timeout}秒)") from exc
try:
# HTTP 响应逐行读取,每行格式: "data: {...}" 或 "data: [DONE]"
for line_bytes in resp:
line = line_bytes.decode("utf-8").rstrip("\r\n")
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
data_str = line[6:]
try:
chunk_obj = json.loads(data_str)
delta = chunk_obj["choices"][0]["delta"].get("content") or ""
if delta:
yield delta
except Exception:
pass
finally:
resp.close()
DB_CHAT_ALLOWED_TABLES = {"boxes", "components", "inventory_events"}
DB_CHAT_FORBIDDEN_KEYWORDS = {
"insert",
@@ -3113,6 +3190,7 @@ def _build_db_chat_sql_plan(
model=settings["model"],
api_key=settings["api_key"],
timeout=settings["timeout"],
max_tokens=700, # SQL规划只需短 JSON 输出
)
parsed = json.loads(_extract_json_object_text(raw_text))
sql = str(parsed.get("sql", "") or "").strip()
@@ -3320,6 +3398,7 @@ def _build_db_chat_answer(
if web_context:
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
chat_max_tokens = int(settings.get("chat_max_tokens", 4096) or 4096)
return _call_siliconflow_chat(
system_prompt,
user_prompt,
@@ -3327,6 +3406,51 @@ def _build_db_chat_answer(
model=settings["model"],
api_key=settings["api_key"],
timeout=settings["timeout"],
max_tokens=chat_max_tokens,
)
def _build_db_chat_answer_stream(
question: str,
sql: str,
sql_reason: str,
rows: list[dict],
web_context: list[dict],
settings: dict,
):
"""流式版数据库查询回答,逐片 yield 文本内容。"""
if not rows:
yield "查询已执行,但没有匹配数据。你可以换个条件,例如指定料号、时间范围或盒型。"
return
system_prompt = (
"你是库存分析助手。"
"请仅根据提供的 SQL 结果回答,禁止虚构不存在的数据。"
"回答用简明中文,优先给结论,再给关键证据。"
"若提供了联网参考,请明确标注为参考信息,不得当作数据库事实。"
)
user_prompt = (
"用户问题:\n"
+ question
+ "\n\n执行SQL:\n"
+ sql
+ "\n\nSQL说明:\n"
+ sql_reason
+ "\n\n查询结果(JSON):\n"
+ json.dumps(rows, ensure_ascii=False)
)
if web_context:
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
chat_max_tokens = int(settings.get("chat_max_tokens", 4096) or 4096)
yield from _call_siliconflow_chat_stream(
system_prompt,
user_prompt,
api_url=settings["api_url"],
model=settings["model"],
api_key=settings["api_key"],
timeout=settings["timeout"],
max_tokens=chat_max_tokens,
)
@@ -3405,6 +3529,7 @@ def _build_general_chat_answer(
if web_context:
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
chat_max_tokens = int(settings.get("chat_max_tokens", 4096) or 4096)
return _call_siliconflow_chat(
system_prompt,
user_prompt,
@@ -3412,12 +3537,68 @@ def _build_general_chat_answer(
model=settings["model"],
api_key=settings["api_key"],
timeout=settings["timeout"],
max_tokens=chat_max_tokens,
)
def _build_general_chat_answer_stream(
question: str,
history: list[dict],
web_context: list[dict],
memory_text: str,
settings: dict,
):
"""流式版通用聊天回答,逐片 yield 文本内容。"""
q = (question or "").strip()
q_lower = q.lower()
# 模型名询问直接返回,无需调 API
if "模型" in q or "model" in q_lower:
model_name = str(settings.get("model", "") or "未配置")
yield f"当前系统配置的 AI 模型是:{model_name}"
return
history_lines = []
for row in history[-8:]:
role = str(row.get("role", "user") or "user")
content = str(row.get("content", "") or "").strip()
if not content:
continue
history_lines.append(f"[{role}] {content[:220]}")
system_prompt = (
"你是系统内置通用助手。"
"请直接回答用户问题,不要把自己限制为仅库存问题。"
"对于时效性问题(如天气、新闻、实时价格),若提供了联网参考就优先依据参考;"
"若没有联网参考,需明确说明可能不够实时。"
"回答保持简洁中文。"
)
user_prompt = (
"对话历史:\n"
+ "\n".join(history_lines or ["(无)"])
+ "\n\n用户当前问题:\n"
+ question
)
if memory_text:
user_prompt += "\n\n本地记忆摘要:\n" + memory_text
if web_context:
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
chat_max_tokens = int(settings.get("chat_max_tokens", 4096) or 4096)
yield from _call_siliconflow_chat_stream(
system_prompt,
user_prompt,
api_url=settings["api_url"],
model=settings["model"],
api_key=settings["api_key"],
timeout=settings["timeout"],
max_tokens=chat_max_tokens,
)
@app.route("/ai/chat")
def ai_chat_page():
return render_template("ai_chat.html")
settings = _get_ai_settings()
return render_template("ai_chat.html", settings=settings)
@app.route("/ai/chat/memory/clear", methods=["POST"])
@@ -3644,6 +3825,171 @@ def ai_chat_query():
}
# ──────────────────────────────────────────────────────────────────────────────
# SSE 流式聊天接口 /ai/chat/stream
# 与 ai_chat_query 执行相同的预处理逻辑,但最终 AI 回答改为流式 SSE 输出
# ──────────────────────────────────────────────────────────────────────────────
@app.route("/ai/chat/stream", methods=["POST"])
def ai_chat_stream():
"""SSE 流式聊天接口,先同步完成预处理(联网/SQL再以 SSE 逐片推送 AI 回答。"""
payload = request.get_json(silent=True) or {}
question = str(payload.get("question", "") or "").strip()
history = payload.get("history", [])
allow_web_search = _is_truthy_form_value(str(payload.get("allow_web_search", "")))
allow_db_query = _is_truthy_form_value(str(payload.get("allow_db_query", "")))
if not isinstance(history, list):
history = []
_SSE_HEADERS = {
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # 防止 Nginx 缓冲
"Connection": "keep-alive",
}
def _sse_err(msg: str):
"""单条错误事件的 Response"""
evt = json.dumps({"type": "error", "message": msg}, ensure_ascii=False)
return Response(iter([f"data: {evt}\n\n"]), mimetype="text/event-stream", headers=_SSE_HEADERS)
if not question:
return _sse_err("请输入问题")
username = (session.get("username") or "guest").strip() or "guest"
memory_payload = _get_ai_chat_memory(username)
memory_text = _memory_context_text(memory_payload)
settings = _get_ai_settings()
if not settings.get("api_key") or not settings.get("api_url") or not settings.get("model"):
return _sse_err("AI 参数不完整,请先到参数页配置")
if not bool(settings.get("chat_stream_enabled", True)):
return _sse_err("当前已关闭流式输出,请在 AI 参数页开启后重试")
# ── 预处理:与 ai_chat_query 相同逻辑 ────────────────────────────────────
chat_mode = "general"
sql = ""
planner_reason = ""
planner_raw = ""
row_count = 0
web_context = []
serialized_rows = []
answer_gen = None # 最终的流式生成器
if not allow_db_query:
# 通用问答模式
if allow_web_search:
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
if not web_context and ("天气" in question or "weather" in question.lower()):
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
if weather_ctx:
web_context = [weather_ctx]
planner_reason = "未启用数据库查询,按通用问答模式处理"
chat_mode = "general"
answer_gen = _build_general_chat_answer_stream(question, history, web_context, memory_text, settings)
else:
# 数据库查询模式
try:
sql_raw, planner_reason, planner_raw = _build_db_chat_sql_plan(question, history, memory_text, settings)
except Exception as exc:
_log_event(logging.WARNING, "ai_chat_stream_plan_error", error=str(exc), question=question)
if not _is_general_chat_question(question):
return _sse_err(f"SQL 规划失败: {exc}")
# 回退通用
if allow_web_search:
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
if not web_context and ("天气" in question or "weather" in question.lower()):
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
if weather_ctx:
web_context = [weather_ctx]
planner_reason = "已切换到通用对话模式"
chat_mode = "general"
answer_gen = _build_general_chat_answer_stream(question, history, web_context, memory_text, settings)
else:
safe, reject_reason, safe_sql = _is_safe_readonly_sql(sql_raw)
if not safe:
_log_event(logging.WARNING, "ai_chat_stream_sql_rejected", reason=reject_reason, sql=sql_raw)
can_fallback = "未包含可识别的数据表" in reject_reason or _is_general_chat_question(question)
if not can_fallback:
return _sse_err(f"SQL 被安全策略拒绝: {reject_reason}")
# 回退通用
if allow_web_search:
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
if not web_context and ("天气" in question or "weather" in question.lower()):
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
if weather_ctx:
web_context = [weather_ctx]
planner_reason = "已切换到通用对话模式"
chat_mode = "general"
answer_gen = _build_general_chat_answer_stream(question, history, web_context, memory_text, settings)
else:
# 执行 SQL需要在路由函数里做生成器无 db.session 保证)
final_sql = _ensure_query_limit(safe_sql, row_limit=80)
sql = final_sql
try:
query_result = db.session.execute(db.text(final_sql))
columns = list(query_result.keys())
rows_data = query_result.fetchall()
serialized_rows = _serialize_sql_rows(rows_data, columns)
row_count = len(serialized_rows)
except Exception as exc:
_log_event(logging.ERROR, "ai_chat_stream_exec_error", error=str(exc), sql=final_sql)
return _sse_err("SQL 执行失败,请调整问题后重试")
if allow_web_search:
try:
web_context = _build_db_chat_web_context(question, serialized_rows, timeout=settings.get("timeout", 30))
except Exception:
web_context = []
chat_mode = "db"
answer_gen = _build_db_chat_answer_stream(question, final_sql, planner_reason, serialized_rows, web_context, settings)
# ── SSE 生成器 ────────────────────────────────────────────────────────────
meta = {
"type": "meta",
"sql": sql,
"planner_reason": planner_reason,
"planner_raw": planner_raw,
"row_count": row_count,
"rows_preview": serialized_rows[:8],
"web_context": web_context,
"chat_mode": chat_mode,
"allow_web_search": allow_web_search,
"allow_db_query": allow_db_query,
}
def _generate():
# 先推送元数据SQL、模式、联网来源等
yield f"data: {json.dumps(meta, ensure_ascii=False)}\n\n"
accumulated = []
try:
for chunk in answer_gen:
accumulated.append(chunk)
evt = json.dumps({"type": "chunk", "text": chunk}, ensure_ascii=False)
yield f"data: {evt}\n\n"
except Exception as exc:
err_evt = json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False)
yield f"data: {err_evt}\n\n"
return
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
# 流结束后保存记忆
full_answer = "".join(accumulated)
_append_ai_chat_memory(username, question, full_answer)
_log_event(
logging.INFO,
"ai_chat_stream_success",
question=question,
chat_mode=chat_mode,
allow_web_search=allow_web_search,
allow_db_query=allow_db_query,
web_sources=sum(len(item.get("sources", [])) for item in web_context),
)
return Response(_generate(), mimetype="text/event-stream", headers=_SSE_HEADERS)
def _is_safe_next_path(path: str) -> bool:
candidate = (path or "").strip()
if not candidate:
@@ -5302,6 +5648,7 @@ def ai_settings_page():
api_url = request.form.get("api_url", "").strip()
model = request.form.get("model", "").strip()
api_key = request.form.get("api_key", "").strip()
chat_stream_enabled = _is_truthy_form_value(request.form.get("chat_stream_enabled", ""))
lcsc_app_id = request.form.get("lcsc_app_id", "").strip()
lcsc_access_key = request.form.get("lcsc_access_key", "").strip()
lcsc_secret_key = request.form.get("lcsc_secret_key", "").strip()
@@ -5315,6 +5662,15 @@ def ai_settings_page():
error = "超时时间必须是大于等于 5 的整数"
timeout = settings["timeout"]
try:
chat_max_tokens = int((request.form.get("chat_max_tokens", "4096") or "4096").strip())
if chat_max_tokens < 256:
raise ValueError
except ValueError:
if not error:
error = "聊天最大输出 token 必须是大于等于 256 的整数"
chat_max_tokens = settings["chat_max_tokens"]
try:
restock_threshold = int((request.form.get("restock_threshold", "5") or "5").strip())
if restock_threshold < 0:
@@ -5355,6 +5711,8 @@ def ai_settings_page():
"model": model,
"api_key": api_key,
"timeout": timeout,
"chat_max_tokens": chat_max_tokens,
"chat_stream_enabled": chat_stream_enabled,
"restock_threshold": restock_threshold,
"restock_limit": restock_limit,
"lcsc_base_url": LCSC_BASE_URL,
@@ -5374,6 +5732,8 @@ def ai_settings_page():
"model": model,
"api_key": api_key,
"timeout": timeout,
"chat_max_tokens": chat_max_tokens,
"chat_stream_enabled": chat_stream_enabled,
"restock_threshold": restock_threshold,
"restock_limit": restock_limit,
"lcsc_base_url": LCSC_BASE_URL,