特性:添加支持数据库查询的 AI 聊天功能

- 实现了一个新的 AI 聊天页面,用于自然语言查询,该页面会生成用于库存数据的只读 SQL 查询。
- 添加了本地内存存储,用于用户交互,允许 AI 记住最近的对话和笔记。
- 增强了聊天界面,增加了网络搜索和数据库查询执行选项。
- 更新了 README,包含了关于新 AI 聊天功能和其使用方法的详细信息。
- 引入了新的 CSS 样式以改善聊天界面的用户体验。
- 修改了现有模板以集成新的聊天功能,并提供从库存概览页面轻松访问。
This commit is contained in:
2026-03-14 01:34:29 +08:00
parent dc7efb8ff8
commit 21ad22a105
7 changed files with 1206 additions and 1 deletions

769
app.py
View File

@@ -39,6 +39,7 @@ DB_DIR = os.path.join(BASE_DIR, "data")
os.makedirs(DB_DIR, exist_ok=True)
DB_PATH = os.path.join(DB_DIR, "inventory.db")
APP_LOG_PATH = os.path.join(DB_DIR, "app.log")
AI_CHAT_MEMORY_PATH = os.path.join(DB_DIR, "ai_chat_memory.json")
# Flask 和 SQLAlchemy 基础初始化。
app = Flask(__name__)
@@ -100,6 +101,105 @@ def _read_log_lines(limit: int = 200) -> list[str]:
return [line.rstrip("\n") for line in file_obj.readlines()[-limit:]]
def _load_ai_chat_memory_store() -> dict:
if not os.path.exists(AI_CHAT_MEMORY_PATH):
return {}
try:
with open(AI_CHAT_MEMORY_PATH, "r", encoding="utf-8") as file_obj:
payload = json.load(file_obj)
return payload if isinstance(payload, dict) else {}
except (OSError, json.JSONDecodeError):
return {}
def _save_ai_chat_memory_store(payload: dict) -> None:
safe_payload = payload if isinstance(payload, dict) else {}
with open(AI_CHAT_MEMORY_PATH, "w", encoding="utf-8") as file_obj:
json.dump(safe_payload, file_obj, ensure_ascii=False, indent=2)
def _normalize_memory_text(text: str, max_len: int = 220) -> str:
cleaned = re.sub(r"\s+", " ", str(text or "")).strip()
if len(cleaned) > max_len:
return cleaned[: max_len - 3] + "..."
return cleaned
def _get_ai_chat_memory(username: str) -> dict:
user_key = (username or "guest").strip() or "guest"
store = _load_ai_chat_memory_store()
block = store.get(user_key, {}) if isinstance(store, dict) else {}
notes = block.get("notes", []) if isinstance(block, dict) else []
turns = block.get("turns", []) if isinstance(block, dict) else []
if not isinstance(notes, list):
notes = []
if not isinstance(turns, list):
turns = []
return {
"user": user_key,
"notes": [str(item) for item in notes[:30]],
"turns": [item for item in turns[-20:] if isinstance(item, dict)],
}
def _memory_context_text(memory: dict) -> str:
notes = memory.get("notes", []) if isinstance(memory, dict) else []
turns = memory.get("turns", []) if isinstance(memory, dict) else []
lines = []
if notes:
lines.append("用户长期记忆:")
for note in notes[-12:]:
lines.append(f"- {_normalize_memory_text(note, max_len=120)}")
if turns:
lines.append("最近本地记忆对话:")
for turn in turns[-6:]:
q = _normalize_memory_text(turn.get("q", ""), max_len=90)
a = _normalize_memory_text(turn.get("a", ""), max_len=120)
lines.append(f"Q: {q}")
lines.append(f"A: {a}")
return "\n".join(lines)
def _append_ai_chat_memory(username: str, question: str, answer: str) -> None:
user_key = (username or "guest").strip() or "guest"
store = _load_ai_chat_memory_store()
block = store.get(user_key)
if not isinstance(block, dict):
block = {"notes": [], "turns": []}
notes = block.get("notes")
turns = block.get("turns")
if not isinstance(notes, list):
notes = []
if not isinstance(turns, list):
turns = []
q = _normalize_memory_text(question)
a = _normalize_memory_text(answer)
turns.append({"q": q, "a": a, "ts": datetime.utcnow().isoformat(timespec="seconds")})
turns = turns[-30:]
match = re.match(r"^(记住|记一下|请记住)[:\s]*(.+)$", q)
if match:
note = _normalize_memory_text(match.group(2), max_len=140)
if note and note not in notes:
notes.append(note)
notes = notes[-30:]
block["notes"] = notes
block["turns"] = turns
store[user_key] = block
_save_ai_chat_memory_store(store)
def _clear_ai_chat_memory(username: str) -> None:
user_key = (username or "guest").strip() or "guest"
store = _load_ai_chat_memory_store()
if user_key in store:
store.pop(user_key, None)
_save_ai_chat_memory_store(store)
_setup_app_logger()
# 这里集中放全局常量,避免后面函数里散落硬编码。
@@ -2875,6 +2975,675 @@ def _call_siliconflow_chat(
raise RuntimeError("AI 返回格式无法解析") from exc
DB_CHAT_ALLOWED_TABLES = {"boxes", "components", "inventory_events"}
DB_CHAT_FORBIDDEN_KEYWORDS = {
"insert",
"update",
"delete",
"drop",
"alter",
"create",
"replace",
"truncate",
"attach",
"detach",
"pragma",
"vacuum",
"reindex",
"begin",
"commit",
"rollback",
}
def _extract_json_object_text(raw_text: str) -> str:
text = (raw_text or "").strip()
if not text:
return ""
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
first = text.find("{")
last = text.rfind("}")
if first >= 0 and last > first:
return text[first : last + 1]
return text
def _build_db_chat_schema_hint() -> str:
return (
"可用表和字段:\n"
"1) boxes(id, name, description, box_type, slot_capacity, slot_prefix, start_number)\n"
"2) components(id, box_id, slot_index, part_no, name, specification, quantity, location, note, is_enabled)\n"
"3) inventory_events(id, box_id, box_type, component_id, part_no, event_type, delta, created_at)\n"
"禁止访问 users 表。"
)
def _is_safe_readonly_sql(raw_sql: str) -> tuple[bool, str, str]:
sql = (raw_sql or "").strip()
if not sql:
return False, "SQL 为空", ""
if sql.endswith(";"):
sql = sql[:-1].strip()
lowered = re.sub(r"\s+", " ", sql.lower())
if ";" in sql:
return False, "仅允许单条 SQL 查询", ""
if not (lowered.startswith("select ") or lowered.startswith("with ")):
return False, "仅允许 SELECT 只读查询", ""
for keyword in DB_CHAT_FORBIDDEN_KEYWORDS:
if re.search(rf"\b{re.escape(keyword)}\b", lowered):
return False, f"SQL 包含危险关键字: {keyword}", ""
if re.search(r"\busers\b", lowered):
return False, "禁止访问 users 表", ""
referenced_tables = re.findall(r"\b(?:from|join)\s+([a-zA-Z_][\w]*)", lowered)
if not referenced_tables:
return False, "SQL 未包含可识别的数据表", ""
for table_name in referenced_tables:
if table_name not in DB_CHAT_ALLOWED_TABLES:
return False, f"不允许访问表: {table_name}", ""
return True, "", sql
def _ensure_query_limit(sql: str, row_limit: int = 80) -> str:
if re.search(r"\blimit\s+\d+", sql, flags=re.IGNORECASE):
return sql
return f"{sql} LIMIT {row_limit}"
def _serialize_sql_rows(rows, columns: list[str], *, max_cell_len: int = 120) -> list[dict]:
serialized = []
for row in rows:
item = {}
for idx, column in enumerate(columns):
value = row[idx] if idx < len(row) else None
if isinstance(value, datetime):
text = value.strftime("%Y-%m-%d %H:%M:%S")
else:
text = "" if value is None else str(value)
if len(text) > max_cell_len:
text = text[: max_cell_len - 3] + "..."
item[column] = text
serialized.append(item)
return serialized
def _build_db_chat_sql_plan(
question: str,
history: list[dict],
memory_text: str,
settings: dict,
) -> tuple[str, str, str]:
history_lines = []
for row in history[-6:]:
role = str(row.get("role", "user") or "user")
content = str(row.get("content", "") or "").strip()
if not content:
continue
history_lines.append(f"[{role}] {content[:200]}")
system_prompt = (
"你是库存数据库查询规划助手。"
"请根据用户问题生成一条 SQLite 只读 SQL。"
"只允许 SELECT/CTE不要写入操作不要访问 users 表。"
"输出必须是 JSON: {\"sql\":string,\"reason\":string},不要输出其他文字。"
)
user_prompt = (
_build_db_chat_schema_hint()
+ "\n\n对话历史:\n"
+ "\n".join(history_lines or ["(无)"])
+ "\n\n当前问题:\n"
+ question
)
if memory_text:
user_prompt += "\n\n本地记忆摘要:\n" + memory_text
raw_text = _call_siliconflow_chat(
system_prompt,
user_prompt,
api_url=settings["api_url"],
model=settings["model"],
api_key=settings["api_key"],
timeout=settings["timeout"],
)
parsed = json.loads(_extract_json_object_text(raw_text))
sql = str(parsed.get("sql", "") or "").strip()
reason = str(parsed.get("reason", "") or "").strip() or "AI 未提供 SQL 说明"
if not sql:
raise RuntimeError("AI 未生成有效 SQL")
return sql, reason, raw_text
def _build_db_chat_web_context(question: str, rows: list[dict], timeout: int, max_queries: int = 2) -> list[dict]:
"""为数据库聊天构建可选联网参考上下文。"""
query_candidates = []
base_query = (question or "").strip()
if base_query:
query_candidates.append(base_query)
for row in rows[:3]:
part_no = str(row.get("part_no", "") or "").strip()
name = str(row.get("name", "") or "").strip()
if part_no and name:
query_candidates.append(f"{part_no} {name}")
elif part_no:
query_candidates.append(part_no)
elif name:
query_candidates.append(name)
contexts = []
used = set()
for query in query_candidates:
q = query.strip()
if len(q) < 2 or q in used:
continue
used.add(q)
result = _fetch_open_search_context(q, timeout=timeout)
sources = result.get("sources") or []
if not sources:
continue
contexts.append(
{
"query": result.get("query", q),
"sources": sources[:4],
}
)
if len(contexts) >= max_queries:
break
return contexts
def _extract_weather_location(question: str) -> str:
text = (question or "").strip()
if not text:
return ""
patterns = [
r"查询(.{1,20}?)的?天气",
r"(.{1,20}?)的?天气",
]
for pattern in patterns:
match = re.search(pattern, text)
if not match:
continue
candidate = re.sub(r"[\s?、]+", "", match.group(1) or "").strip()
if candidate:
return candidate
return ""
def _weather_code_to_text(code: int) -> str:
mapping = {
0: "",
1: "大部晴",
2: "多云",
3: "",
45: "",
48: "雾凇",
51: "小毛毛雨",
53: "毛毛雨",
55: "浓毛毛雨",
61: "小雨",
63: "中雨",
65: "大雨",
71: "小雪",
73: "中雪",
75: "大雪",
80: "阵雨",
81: "较强阵雨",
82: "强阵雨",
95: "雷阵雨",
}
return mapping.get(int(code), f"天气代码 {code}")
def _fetch_weather_context(question: str, timeout: int) -> dict | None:
"""天气问题联网兜底:使用 Open-Meteo 地理编码与实时天气接口。"""
location = _extract_weather_location(question)
if not location:
location = "广州"
try:
geo_params = urllib.parse.urlencode(
{
"name": location,
"count": "1",
"language": "zh",
"format": "json",
}
)
geo_url = f"https://geocoding-api.open-meteo.com/v1/search?{geo_params}"
geo_req = urllib.request.Request(
geo_url,
method="GET",
headers={"User-Agent": "inventory-ai-chat/1.0"},
)
with urllib.request.urlopen(geo_req, timeout=timeout) as resp:
geo_raw = resp.read().decode("utf-8", errors="ignore")
geo_data = json.loads(geo_raw)
results = geo_data.get("results") or []
if not results:
return None
first = results[0]
latitude = first.get("latitude")
longitude = first.get("longitude")
if latitude is None or longitude is None:
return None
display_name = str(first.get("name") or location).strip() or location
country = str(first.get("country") or "").strip()
admin1 = str(first.get("admin1") or "").strip()
place_name = " / ".join([v for v in [display_name, admin1, country] if v])
weather_params = urllib.parse.urlencode(
{
"latitude": str(latitude),
"longitude": str(longitude),
"current": "temperature_2m,weather_code,wind_speed_10m",
"timezone": "Asia/Shanghai",
}
)
weather_url = f"https://api.open-meteo.com/v1/forecast?{weather_params}"
weather_req = urllib.request.Request(
weather_url,
method="GET",
headers={"User-Agent": "inventory-ai-chat/1.0"},
)
with urllib.request.urlopen(weather_req, timeout=timeout) as resp:
weather_raw = resp.read().decode("utf-8", errors="ignore")
weather_data = json.loads(weather_raw)
current = weather_data.get("current") or {}
if not current:
return None
temp = current.get("temperature_2m")
code = int(current.get("weather_code", 0) or 0)
wind = current.get("wind_speed_10m")
time_text = str(current.get("time") or "").strip()
weather_text = _weather_code_to_text(code)
snippet = f"{place_name} 当前{weather_text},气温约 {temp}°C风速约 {wind} km/h时间 {time_text}"
source = {
"title": f"Open-Meteo 实时天气({place_name}",
"snippet": snippet,
"url": "https://open-meteo.com/",
"reliability_level": "high",
"reliability_label": "高可信",
"reliability_reason": "公开气象 API 实时数据",
"domain": "open-meteo.com",
}
return {
"query": f"{location} 天气",
"sources": [source],
}
except Exception:
return None
def _build_db_chat_answer(
question: str,
sql: str,
sql_reason: str,
rows: list[dict],
web_context: list[dict],
settings: dict,
) -> str:
if not rows:
return "查询已执行,但没有匹配数据。你可以换个条件,例如指定料号、时间范围或盒型。"
system_prompt = (
"你是库存分析助手。"
"请仅根据提供的 SQL 结果回答,禁止虚构不存在的数据。"
"回答用简明中文,优先给结论,再给关键证据。"
"若提供了联网参考,请明确标注为参考信息,不得当作数据库事实。"
)
user_prompt = (
"用户问题:\n"
+ question
+ "\n\n执行SQL:\n"
+ sql
+ "\n\nSQL说明:\n"
+ sql_reason
+ "\n\n查询结果(JSON):\n"
+ json.dumps(rows, ensure_ascii=False)
)
if web_context:
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
return _call_siliconflow_chat(
system_prompt,
user_prompt,
api_url=settings["api_url"],
model=settings["model"],
api_key=settings["api_key"],
timeout=settings["timeout"],
)
def _is_general_chat_question(question: str) -> bool:
text = (question or "").strip().lower()
if not text:
return False
general_hints = [
"你是什么",
"你是谁",
"你能做什么",
"聊天",
"模型",
"介绍",
"hello",
"hi",
"你好",
"谢谢",
]
data_hints = [
"库存",
"器件",
"料号",
"",
"",
"出库",
"入库",
"数量",
"part_no",
"component",
"box",
]
has_general_hint = any(hint in text for hint in general_hints)
has_data_hint = any(hint in text for hint in data_hints)
return has_general_hint and not has_data_hint
def _build_general_chat_answer(
question: str,
history: list[dict],
web_context: list[dict],
memory_text: str,
settings: dict,
) -> str:
q = (question or "").strip()
q_lower = q.lower()
if "模型" in q or "model" in q_lower:
model_name = str(settings.get("model", "") or "未配置")
return f"当前系统配置的 AI 模型是:{model_name}"
history_lines = []
for row in history[-8:]:
role = str(row.get("role", "user") or "user")
content = str(row.get("content", "") or "").strip()
if not content:
continue
history_lines.append(f"[{role}] {content[:220]}")
system_prompt = (
"你是系统内置通用助手。"
"请直接回答用户问题,不要把自己限制为仅库存问题。"
"对于时效性问题(如天气、新闻、实时价格),若提供了联网参考就优先依据参考;"
"若没有联网参考,需明确说明可能不够实时。"
"回答保持简洁中文。"
)
user_prompt = (
"对话历史:\n"
+ "\n".join(history_lines or ["(无)"])
+ "\n\n用户当前问题:\n"
+ question
)
if memory_text:
user_prompt += "\n\n本地记忆摘要:\n" + memory_text
if web_context:
user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False)
return _call_siliconflow_chat(
system_prompt,
user_prompt,
api_url=settings["api_url"],
model=settings["model"],
api_key=settings["api_key"],
timeout=settings["timeout"],
)
@app.route("/ai/chat")
def ai_chat_page():
return render_template("ai_chat.html")
@app.route("/ai/chat/memory/clear", methods=["POST"])
def ai_chat_memory_clear():
username = (session.get("username") or "guest").strip() or "guest"
_clear_ai_chat_memory(username)
_log_event(logging.INFO, "ai_chat_memory_cleared", user=username)
return {"ok": True, "message": "本地记忆已清空"}
@app.route("/ai/chat/query", methods=["POST"])
def ai_chat_query():
payload = request.get_json(silent=True) or {}
question = str(payload.get("question", "") or "").strip()
history = payload.get("history", [])
allow_web_search = _is_truthy_form_value(str(payload.get("allow_web_search", "")))
allow_db_query = _is_truthy_form_value(str(payload.get("allow_db_query", "")))
if not isinstance(history, list):
history = []
if not question:
return {"ok": False, "message": "请输入问题"}, 400
username = (session.get("username") or "guest").strip() or "guest"
memory_payload = _get_ai_chat_memory(username)
memory_text = _memory_context_text(memory_payload)
settings = _get_ai_settings()
if not settings.get("api_key") or not settings.get("api_url") or not settings.get("model"):
return {"ok": False, "message": "AI 参数不完整,请先到参数页配置"}, 400
if not allow_db_query:
web_context = []
if allow_web_search:
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
if not web_context and ("天气" in question or "weather" in question.lower()):
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
if weather_ctx:
web_context = [weather_ctx]
answer = _build_general_chat_answer(question, history, web_context, memory_text, settings)
_append_ai_chat_memory(username, question, answer)
_log_event(
logging.INFO,
"ai_chat_general_success",
question=question,
allow_web_search=allow_web_search,
allow_db_query=allow_db_query,
web_sources=sum(len(item.get("sources", [])) for item in web_context),
memory_notes=len(memory_payload.get("notes", [])),
)
return {
"ok": True,
"answer": answer,
"sql": "",
"planner_reason": "未启用数据库查询,按通用问答模式处理",
"planner_raw": "",
"row_count": 0,
"rows_preview": [],
"web_context": web_context,
"allow_web_search": allow_web_search,
"allow_db_query": allow_db_query,
"chat_mode": "general",
}
try:
sql_raw, planner_reason, planner_text = _build_db_chat_sql_plan(
question,
history,
memory_text,
settings,
)
except Exception as exc:
_log_event(
logging.WARNING,
"ai_chat_plan_error",
error=str(exc),
question=question,
model=settings.get("model", ""),
)
if not _is_general_chat_question(question):
return {"ok": False, "message": f"SQL 规划失败: {exc}"}, 400
web_context = []
if allow_web_search:
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
if not web_context and ("天气" in question or "weather" in question.lower()):
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
if weather_ctx:
web_context = [weather_ctx]
answer = _build_general_chat_answer(question, history, web_context, memory_text, settings)
_append_ai_chat_memory(username, question, answer)
return {
"ok": True,
"answer": answer,
"sql": "",
"planner_reason": "已切换到通用对话模式",
"planner_raw": "",
"row_count": 0,
"rows_preview": [],
"web_context": web_context,
"allow_web_search": allow_web_search,
"allow_db_query": allow_db_query,
"chat_mode": "general",
}
safe, reject_reason, safe_sql = _is_safe_readonly_sql(sql_raw)
if not safe:
_log_event(
logging.WARNING,
"ai_chat_sql_rejected",
reason=reject_reason,
sql=sql_raw,
question=question,
)
can_fallback_general = (
"未包含可识别的数据表" in reject_reason
or _is_general_chat_question(question)
)
if not can_fallback_general:
return {"ok": False, "message": f"SQL 被安全策略拒绝: {reject_reason}"}, 400
web_context = []
if allow_web_search:
web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1)
if not web_context and ("天气" in question or "weather" in question.lower()):
weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30))
if weather_ctx:
web_context = [weather_ctx]
answer = _build_general_chat_answer(question, history, web_context, memory_text, settings)
_append_ai_chat_memory(username, question, answer)
return {
"ok": True,
"answer": answer,
"sql": "",
"planner_reason": "已切换到通用对话模式",
"planner_raw": planner_text,
"row_count": 0,
"rows_preview": [],
"web_context": web_context,
"allow_web_search": allow_web_search,
"allow_db_query": allow_db_query,
"chat_mode": "general",
}
final_sql = _ensure_query_limit(safe_sql, row_limit=80)
try:
query_result = db.session.execute(db.text(final_sql))
columns = list(query_result.keys())
rows = query_result.fetchall()
except Exception as exc:
_log_event(
logging.ERROR,
"ai_chat_query_execute_error",
error=str(exc),
sql=final_sql,
question=question,
traceback=traceback.format_exc(),
)
return {"ok": False, "message": "SQL 执行失败,请调整问题后重试"}, 400
serialized_rows = _serialize_sql_rows(rows, columns)
web_context = []
if allow_web_search:
try:
web_context = _build_db_chat_web_context(
question,
serialized_rows,
timeout=settings.get("timeout", 30),
)
except Exception as exc:
_log_event(
logging.WARNING,
"ai_chat_web_context_error",
error=str(exc),
question=question,
)
web_context = []
try:
answer = _build_db_chat_answer(
question,
final_sql,
planner_reason,
serialized_rows,
web_context,
settings,
)
except Exception as exc:
_log_event(
logging.WARNING,
"ai_chat_answer_error",
error=str(exc),
sql=final_sql,
row_count=len(serialized_rows),
)
answer = (
"查询已完成,但 AI 总结超时/失败。"
f"共返回 {len(serialized_rows)} 行数据,你可以查看本次 SQL 并缩小问题范围后重试。"
)
_append_ai_chat_memory(username, question, answer)
_log_event(
logging.INFO,
"ai_chat_query_success",
question=question,
sql=final_sql,
row_count=len(serialized_rows),
web_sources=sum(len(item.get("sources", [])) for item in web_context),
memory_notes=len(memory_payload.get("notes", [])),
)
return {
"ok": True,
"answer": answer,
"sql": final_sql,
"planner_reason": planner_reason,
"planner_raw": planner_text,
"row_count": len(serialized_rows),
"rows_preview": serialized_rows[:8],
"web_context": web_context,
"allow_web_search": allow_web_search,
"allow_db_query": allow_db_query,
"chat_mode": "db",
}
def _is_safe_next_path(path: str) -> bool:
candidate = (path or "").strip()
if not candidate: