提示:默认是通用聊天 + 联网补充;勾选“启用数据库查询”后再做库存数据分析。
+ + + +diff --git a/.gitignore b/.gitignore index a0fec98..2fc3fbc 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ data/*.sqlite data/*.sqlite3 data/ai_settings.json data/box_types.json +data/ai_chat_memory.json data/*.log *.db *.sqlite diff --git a/README.md b/README.md index 95a331c..85a6749 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,34 @@ $env:SILICONFLOW_MODEL="Qwen/Qwen2.5-7B-Instruct" 当 AI 补货建议出现“请求失败,请稍后重试”时,优先打开系统日志查看最近的 `ERROR` 或 `WARNING` 记录。 +### 2.7 AI 数据库聊天 + +系统提供数据库聊天页,支持自然语言提问并自动生成只读 SQL 查询库存数据: + +- 页面:`/ai/chat` +- 入口:`仓库概览` 右侧 AI 卡片中的 `聊天` + +安全边界: + +- 只允许 `SELECT/CTE` 查询 +- 禁止写操作(INSERT/UPDATE/DELETE/DDL 等) +- 禁止访问 `users` 表 +- 默认自动补 `LIMIT`,避免一次性返回过多数据 + +若提问失败,可在系统日志中查看 `ai_chat_*` 相关记录。 + +联网补充(可选): + +- 聊天页可勾选 `允许联网补充` +- 开启后会在数据库查询结果基础上补充公开来源线索 +- 回答会区分“数据库结论”和“联网参考”,并展示来源可信度与链接 + +本地记忆: + +- AI 聊天会为当前登录用户保存本地记忆(最近对话与“记住 ...”条目) +- 记忆文件:`data/ai_chat_memory.json` +- 聊天页可点击 `清空本地记忆` 按钮重置 + ## 3. 页面说明 ### 3.1 首页 `/` diff --git a/app.py b/app.py index 740c62e..8d9f2c8 100644 --- a/app.py +++ b/app.py @@ -39,6 +39,7 @@ DB_DIR = os.path.join(BASE_DIR, "data") os.makedirs(DB_DIR, exist_ok=True) DB_PATH = os.path.join(DB_DIR, "inventory.db") APP_LOG_PATH = os.path.join(DB_DIR, "app.log") +AI_CHAT_MEMORY_PATH = os.path.join(DB_DIR, "ai_chat_memory.json") # Flask 和 SQLAlchemy 基础初始化。 app = Flask(__name__) @@ -100,6 +101,105 @@ def _read_log_lines(limit: int = 200) -> list[str]: return [line.rstrip("\n") for line in file_obj.readlines()[-limit:]] +def _load_ai_chat_memory_store() -> dict: + if not os.path.exists(AI_CHAT_MEMORY_PATH): + return {} + try: + with open(AI_CHAT_MEMORY_PATH, "r", encoding="utf-8") as file_obj: + payload = json.load(file_obj) + return payload if isinstance(payload, dict) else {} + except (OSError, json.JSONDecodeError): + return {} + + +def _save_ai_chat_memory_store(payload: dict) -> None: + safe_payload = payload if isinstance(payload, dict) else {} + with open(AI_CHAT_MEMORY_PATH, "w", encoding="utf-8") as file_obj: + json.dump(safe_payload, file_obj, ensure_ascii=False, indent=2) + + +def _normalize_memory_text(text: str, max_len: int = 220) -> str: + cleaned = re.sub(r"\s+", " ", str(text or "")).strip() + if len(cleaned) > max_len: + return cleaned[: max_len - 3] + "..." + return cleaned + + +def _get_ai_chat_memory(username: str) -> dict: + user_key = (username or "guest").strip() or "guest" + store = _load_ai_chat_memory_store() + block = store.get(user_key, {}) if isinstance(store, dict) else {} + notes = block.get("notes", []) if isinstance(block, dict) else [] + turns = block.get("turns", []) if isinstance(block, dict) else [] + if not isinstance(notes, list): + notes = [] + if not isinstance(turns, list): + turns = [] + return { + "user": user_key, + "notes": [str(item) for item in notes[:30]], + "turns": [item for item in turns[-20:] if isinstance(item, dict)], + } + + +def _memory_context_text(memory: dict) -> str: + notes = memory.get("notes", []) if isinstance(memory, dict) else [] + turns = memory.get("turns", []) if isinstance(memory, dict) else [] + lines = [] + if notes: + lines.append("用户长期记忆:") + for note in notes[-12:]: + lines.append(f"- {_normalize_memory_text(note, max_len=120)}") + if turns: + lines.append("最近本地记忆对话:") + for turn in turns[-6:]: + q = _normalize_memory_text(turn.get("q", ""), max_len=90) + a = _normalize_memory_text(turn.get("a", ""), max_len=120) + lines.append(f"Q: {q}") + lines.append(f"A: {a}") + return "\n".join(lines) + + +def _append_ai_chat_memory(username: str, question: str, answer: str) -> None: + user_key = (username or "guest").strip() or "guest" + store = _load_ai_chat_memory_store() + block = store.get(user_key) + if not isinstance(block, dict): + block = {"notes": [], "turns": []} + + notes = block.get("notes") + turns = block.get("turns") + if not isinstance(notes, list): + notes = [] + if not isinstance(turns, list): + turns = [] + + q = _normalize_memory_text(question) + a = _normalize_memory_text(answer) + turns.append({"q": q, "a": a, "ts": datetime.utcnow().isoformat(timespec="seconds")}) + turns = turns[-30:] + + match = re.match(r"^(记住|记一下|请记住)[::\s]*(.+)$", q) + if match: + note = _normalize_memory_text(match.group(2), max_len=140) + if note and note not in notes: + notes.append(note) + notes = notes[-30:] + + block["notes"] = notes + block["turns"] = turns + store[user_key] = block + _save_ai_chat_memory_store(store) + + +def _clear_ai_chat_memory(username: str) -> None: + user_key = (username or "guest").strip() or "guest" + store = _load_ai_chat_memory_store() + if user_key in store: + store.pop(user_key, None) + _save_ai_chat_memory_store(store) + + _setup_app_logger() # 这里集中放全局常量,避免后面函数里散落硬编码。 @@ -2875,6 +2975,675 @@ def _call_siliconflow_chat( raise RuntimeError("AI 返回格式无法解析") from exc +DB_CHAT_ALLOWED_TABLES = {"boxes", "components", "inventory_events"} +DB_CHAT_FORBIDDEN_KEYWORDS = { + "insert", + "update", + "delete", + "drop", + "alter", + "create", + "replace", + "truncate", + "attach", + "detach", + "pragma", + "vacuum", + "reindex", + "begin", + "commit", + "rollback", +} + + +def _extract_json_object_text(raw_text: str) -> str: + text = (raw_text or "").strip() + if not text: + return "" + if text.startswith("```"): + text = re.sub(r"^```(?:json)?\s*", "", text) + text = re.sub(r"\s*```$", "", text) + first = text.find("{") + last = text.rfind("}") + if first >= 0 and last > first: + return text[first : last + 1] + return text + + +def _build_db_chat_schema_hint() -> str: + return ( + "可用表和字段:\n" + "1) boxes(id, name, description, box_type, slot_capacity, slot_prefix, start_number)\n" + "2) components(id, box_id, slot_index, part_no, name, specification, quantity, location, note, is_enabled)\n" + "3) inventory_events(id, box_id, box_type, component_id, part_no, event_type, delta, created_at)\n" + "禁止访问 users 表。" + ) + + +def _is_safe_readonly_sql(raw_sql: str) -> tuple[bool, str, str]: + sql = (raw_sql or "").strip() + if not sql: + return False, "SQL 为空", "" + + if sql.endswith(";"): + sql = sql[:-1].strip() + + lowered = re.sub(r"\s+", " ", sql.lower()) + if ";" in sql: + return False, "仅允许单条 SQL 查询", "" + + if not (lowered.startswith("select ") or lowered.startswith("with ")): + return False, "仅允许 SELECT 只读查询", "" + + for keyword in DB_CHAT_FORBIDDEN_KEYWORDS: + if re.search(rf"\b{re.escape(keyword)}\b", lowered): + return False, f"SQL 包含危险关键字: {keyword}", "" + + if re.search(r"\busers\b", lowered): + return False, "禁止访问 users 表", "" + + referenced_tables = re.findall(r"\b(?:from|join)\s+([a-zA-Z_][\w]*)", lowered) + if not referenced_tables: + return False, "SQL 未包含可识别的数据表", "" + + for table_name in referenced_tables: + if table_name not in DB_CHAT_ALLOWED_TABLES: + return False, f"不允许访问表: {table_name}", "" + + return True, "", sql + + +def _ensure_query_limit(sql: str, row_limit: int = 80) -> str: + if re.search(r"\blimit\s+\d+", sql, flags=re.IGNORECASE): + return sql + return f"{sql} LIMIT {row_limit}" + + +def _serialize_sql_rows(rows, columns: list[str], *, max_cell_len: int = 120) -> list[dict]: + serialized = [] + for row in rows: + item = {} + for idx, column in enumerate(columns): + value = row[idx] if idx < len(row) else None + if isinstance(value, datetime): + text = value.strftime("%Y-%m-%d %H:%M:%S") + else: + text = "" if value is None else str(value) + if len(text) > max_cell_len: + text = text[: max_cell_len - 3] + "..." + item[column] = text + serialized.append(item) + return serialized + + +def _build_db_chat_sql_plan( + question: str, + history: list[dict], + memory_text: str, + settings: dict, +) -> tuple[str, str, str]: + history_lines = [] + for row in history[-6:]: + role = str(row.get("role", "user") or "user") + content = str(row.get("content", "") or "").strip() + if not content: + continue + history_lines.append(f"[{role}] {content[:200]}") + + system_prompt = ( + "你是库存数据库查询规划助手。" + "请根据用户问题生成一条 SQLite 只读 SQL。" + "只允许 SELECT/CTE,不要写入操作,不要访问 users 表。" + "输出必须是 JSON: {\"sql\":string,\"reason\":string},不要输出其他文字。" + ) + user_prompt = ( + _build_db_chat_schema_hint() + + "\n\n对话历史:\n" + + "\n".join(history_lines or ["(无)"]) + + "\n\n当前问题:\n" + + question + ) + if memory_text: + user_prompt += "\n\n本地记忆摘要:\n" + memory_text + + raw_text = _call_siliconflow_chat( + system_prompt, + user_prompt, + api_url=settings["api_url"], + model=settings["model"], + api_key=settings["api_key"], + timeout=settings["timeout"], + ) + parsed = json.loads(_extract_json_object_text(raw_text)) + sql = str(parsed.get("sql", "") or "").strip() + reason = str(parsed.get("reason", "") or "").strip() or "AI 未提供 SQL 说明" + if not sql: + raise RuntimeError("AI 未生成有效 SQL") + return sql, reason, raw_text + + +def _build_db_chat_web_context(question: str, rows: list[dict], timeout: int, max_queries: int = 2) -> list[dict]: + """为数据库聊天构建可选联网参考上下文。""" + query_candidates = [] + base_query = (question or "").strip() + if base_query: + query_candidates.append(base_query) + + for row in rows[:3]: + part_no = str(row.get("part_no", "") or "").strip() + name = str(row.get("name", "") or "").strip() + if part_no and name: + query_candidates.append(f"{part_no} {name}") + elif part_no: + query_candidates.append(part_no) + elif name: + query_candidates.append(name) + + contexts = [] + used = set() + for query in query_candidates: + q = query.strip() + if len(q) < 2 or q in used: + continue + used.add(q) + result = _fetch_open_search_context(q, timeout=timeout) + sources = result.get("sources") or [] + if not sources: + continue + contexts.append( + { + "query": result.get("query", q), + "sources": sources[:4], + } + ) + if len(contexts) >= max_queries: + break + + return contexts + + +def _extract_weather_location(question: str) -> str: + text = (question or "").strip() + if not text: + return "" + + patterns = [ + r"查询(.{1,20}?)的?天气", + r"(.{1,20}?)的?天气", + ] + for pattern in patterns: + match = re.search(pattern, text) + if not match: + continue + candidate = re.sub(r"[\s,。!??、]+", "", match.group(1) or "").strip() + if candidate: + return candidate + return "" + + +def _weather_code_to_text(code: int) -> str: + mapping = { + 0: "晴", + 1: "大部晴", + 2: "多云", + 3: "阴", + 45: "雾", + 48: "雾凇", + 51: "小毛毛雨", + 53: "毛毛雨", + 55: "浓毛毛雨", + 61: "小雨", + 63: "中雨", + 65: "大雨", + 71: "小雪", + 73: "中雪", + 75: "大雪", + 80: "阵雨", + 81: "较强阵雨", + 82: "强阵雨", + 95: "雷阵雨", + } + return mapping.get(int(code), f"天气代码 {code}") + + +def _fetch_weather_context(question: str, timeout: int) -> dict | None: + """天气问题联网兜底:使用 Open-Meteo 地理编码与实时天气接口。""" + location = _extract_weather_location(question) + if not location: + location = "广州" + + try: + geo_params = urllib.parse.urlencode( + { + "name": location, + "count": "1", + "language": "zh", + "format": "json", + } + ) + geo_url = f"https://geocoding-api.open-meteo.com/v1/search?{geo_params}" + geo_req = urllib.request.Request( + geo_url, + method="GET", + headers={"User-Agent": "inventory-ai-chat/1.0"}, + ) + with urllib.request.urlopen(geo_req, timeout=timeout) as resp: + geo_raw = resp.read().decode("utf-8", errors="ignore") + geo_data = json.loads(geo_raw) + results = geo_data.get("results") or [] + if not results: + return None + + first = results[0] + latitude = first.get("latitude") + longitude = first.get("longitude") + if latitude is None or longitude is None: + return None + + display_name = str(first.get("name") or location).strip() or location + country = str(first.get("country") or "").strip() + admin1 = str(first.get("admin1") or "").strip() + place_name = " / ".join([v for v in [display_name, admin1, country] if v]) + + weather_params = urllib.parse.urlencode( + { + "latitude": str(latitude), + "longitude": str(longitude), + "current": "temperature_2m,weather_code,wind_speed_10m", + "timezone": "Asia/Shanghai", + } + ) + weather_url = f"https://api.open-meteo.com/v1/forecast?{weather_params}" + weather_req = urllib.request.Request( + weather_url, + method="GET", + headers={"User-Agent": "inventory-ai-chat/1.0"}, + ) + with urllib.request.urlopen(weather_req, timeout=timeout) as resp: + weather_raw = resp.read().decode("utf-8", errors="ignore") + weather_data = json.loads(weather_raw) + current = weather_data.get("current") or {} + if not current: + return None + + temp = current.get("temperature_2m") + code = int(current.get("weather_code", 0) or 0) + wind = current.get("wind_speed_10m") + time_text = str(current.get("time") or "").strip() + weather_text = _weather_code_to_text(code) + + snippet = f"{place_name} 当前{weather_text},气温约 {temp}°C,风速约 {wind} km/h,时间 {time_text}。" + source = { + "title": f"Open-Meteo 实时天气({place_name})", + "snippet": snippet, + "url": "https://open-meteo.com/", + "reliability_level": "high", + "reliability_label": "高可信", + "reliability_reason": "公开气象 API 实时数据", + "domain": "open-meteo.com", + } + return { + "query": f"{location} 天气", + "sources": [source], + } + except Exception: + return None + + +def _build_db_chat_answer( + question: str, + sql: str, + sql_reason: str, + rows: list[dict], + web_context: list[dict], + settings: dict, +) -> str: + if not rows: + return "查询已执行,但没有匹配数据。你可以换个条件,例如指定料号、时间范围或盒型。" + + system_prompt = ( + "你是库存分析助手。" + "请仅根据提供的 SQL 结果回答,禁止虚构不存在的数据。" + "回答用简明中文,优先给结论,再给关键证据。" + "若提供了联网参考,请明确标注为参考信息,不得当作数据库事实。" + ) + user_prompt = ( + "用户问题:\n" + + question + + "\n\n执行SQL:\n" + + sql + + "\n\nSQL说明:\n" + + sql_reason + + "\n\n查询结果(JSON):\n" + + json.dumps(rows, ensure_ascii=False) + ) + if web_context: + user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False) + + return _call_siliconflow_chat( + system_prompt, + user_prompt, + api_url=settings["api_url"], + model=settings["model"], + api_key=settings["api_key"], + timeout=settings["timeout"], + ) + + +def _is_general_chat_question(question: str) -> bool: + text = (question or "").strip().lower() + if not text: + return False + + general_hints = [ + "你是什么", + "你是谁", + "你能做什么", + "聊天", + "模型", + "介绍", + "hello", + "hi", + "你好", + "谢谢", + ] + data_hints = [ + "库存", + "器件", + "料号", + "盒", + "袋", + "出库", + "入库", + "数量", + "part_no", + "component", + "box", + ] + + has_general_hint = any(hint in text for hint in general_hints) + has_data_hint = any(hint in text for hint in data_hints) + return has_general_hint and not has_data_hint + + +def _build_general_chat_answer( + question: str, + history: list[dict], + web_context: list[dict], + memory_text: str, + settings: dict, +) -> str: + q = (question or "").strip() + q_lower = q.lower() + if "模型" in q or "model" in q_lower: + model_name = str(settings.get("model", "") or "未配置") + return f"当前系统配置的 AI 模型是:{model_name}。" + + history_lines = [] + for row in history[-8:]: + role = str(row.get("role", "user") or "user") + content = str(row.get("content", "") or "").strip() + if not content: + continue + history_lines.append(f"[{role}] {content[:220]}") + + system_prompt = ( + "你是系统内置通用助手。" + "请直接回答用户问题,不要把自己限制为仅库存问题。" + "对于时效性问题(如天气、新闻、实时价格),若提供了联网参考就优先依据参考;" + "若没有联网参考,需明确说明可能不够实时。" + "回答保持简洁中文。" + ) + user_prompt = ( + "对话历史:\n" + + "\n".join(history_lines or ["(无)"]) + + "\n\n用户当前问题:\n" + + question + ) + if memory_text: + user_prompt += "\n\n本地记忆摘要:\n" + memory_text + if web_context: + user_prompt += "\n\n联网参考(JSON):\n" + json.dumps(web_context, ensure_ascii=False) + + return _call_siliconflow_chat( + system_prompt, + user_prompt, + api_url=settings["api_url"], + model=settings["model"], + api_key=settings["api_key"], + timeout=settings["timeout"], + ) + + +@app.route("/ai/chat") +def ai_chat_page(): + return render_template("ai_chat.html") + + +@app.route("/ai/chat/memory/clear", methods=["POST"]) +def ai_chat_memory_clear(): + username = (session.get("username") or "guest").strip() or "guest" + _clear_ai_chat_memory(username) + _log_event(logging.INFO, "ai_chat_memory_cleared", user=username) + return {"ok": True, "message": "本地记忆已清空"} + + +@app.route("/ai/chat/query", methods=["POST"]) +def ai_chat_query(): + payload = request.get_json(silent=True) or {} + question = str(payload.get("question", "") or "").strip() + history = payload.get("history", []) + allow_web_search = _is_truthy_form_value(str(payload.get("allow_web_search", ""))) + allow_db_query = _is_truthy_form_value(str(payload.get("allow_db_query", ""))) + if not isinstance(history, list): + history = [] + + if not question: + return {"ok": False, "message": "请输入问题"}, 400 + + username = (session.get("username") or "guest").strip() or "guest" + memory_payload = _get_ai_chat_memory(username) + memory_text = _memory_context_text(memory_payload) + + settings = _get_ai_settings() + if not settings.get("api_key") or not settings.get("api_url") or not settings.get("model"): + return {"ok": False, "message": "AI 参数不完整,请先到参数页配置"}, 400 + + if not allow_db_query: + web_context = [] + if allow_web_search: + web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1) + if not web_context and ("天气" in question or "weather" in question.lower()): + weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30)) + if weather_ctx: + web_context = [weather_ctx] + answer = _build_general_chat_answer(question, history, web_context, memory_text, settings) + _append_ai_chat_memory(username, question, answer) + _log_event( + logging.INFO, + "ai_chat_general_success", + question=question, + allow_web_search=allow_web_search, + allow_db_query=allow_db_query, + web_sources=sum(len(item.get("sources", [])) for item in web_context), + memory_notes=len(memory_payload.get("notes", [])), + ) + return { + "ok": True, + "answer": answer, + "sql": "", + "planner_reason": "未启用数据库查询,按通用问答模式处理", + "planner_raw": "", + "row_count": 0, + "rows_preview": [], + "web_context": web_context, + "allow_web_search": allow_web_search, + "allow_db_query": allow_db_query, + "chat_mode": "general", + } + + try: + sql_raw, planner_reason, planner_text = _build_db_chat_sql_plan( + question, + history, + memory_text, + settings, + ) + except Exception as exc: + _log_event( + logging.WARNING, + "ai_chat_plan_error", + error=str(exc), + question=question, + model=settings.get("model", ""), + ) + if not _is_general_chat_question(question): + return {"ok": False, "message": f"SQL 规划失败: {exc}"}, 400 + + web_context = [] + if allow_web_search: + web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1) + if not web_context and ("天气" in question or "weather" in question.lower()): + weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30)) + if weather_ctx: + web_context = [weather_ctx] + answer = _build_general_chat_answer(question, history, web_context, memory_text, settings) + _append_ai_chat_memory(username, question, answer) + return { + "ok": True, + "answer": answer, + "sql": "", + "planner_reason": "已切换到通用对话模式", + "planner_raw": "", + "row_count": 0, + "rows_preview": [], + "web_context": web_context, + "allow_web_search": allow_web_search, + "allow_db_query": allow_db_query, + "chat_mode": "general", + } + + safe, reject_reason, safe_sql = _is_safe_readonly_sql(sql_raw) + if not safe: + _log_event( + logging.WARNING, + "ai_chat_sql_rejected", + reason=reject_reason, + sql=sql_raw, + question=question, + ) + can_fallback_general = ( + "未包含可识别的数据表" in reject_reason + or _is_general_chat_question(question) + ) + if not can_fallback_general: + return {"ok": False, "message": f"SQL 被安全策略拒绝: {reject_reason}"}, 400 + + web_context = [] + if allow_web_search: + web_context = _build_db_chat_web_context(question, [], timeout=settings.get("timeout", 30), max_queries=1) + if not web_context and ("天气" in question or "weather" in question.lower()): + weather_ctx = _fetch_weather_context(question, timeout=settings.get("timeout", 30)) + if weather_ctx: + web_context = [weather_ctx] + answer = _build_general_chat_answer(question, history, web_context, memory_text, settings) + _append_ai_chat_memory(username, question, answer) + return { + "ok": True, + "answer": answer, + "sql": "", + "planner_reason": "已切换到通用对话模式", + "planner_raw": planner_text, + "row_count": 0, + "rows_preview": [], + "web_context": web_context, + "allow_web_search": allow_web_search, + "allow_db_query": allow_db_query, + "chat_mode": "general", + } + + final_sql = _ensure_query_limit(safe_sql, row_limit=80) + try: + query_result = db.session.execute(db.text(final_sql)) + columns = list(query_result.keys()) + rows = query_result.fetchall() + except Exception as exc: + _log_event( + logging.ERROR, + "ai_chat_query_execute_error", + error=str(exc), + sql=final_sql, + question=question, + traceback=traceback.format_exc(), + ) + return {"ok": False, "message": "SQL 执行失败,请调整问题后重试"}, 400 + + serialized_rows = _serialize_sql_rows(rows, columns) + web_context = [] + if allow_web_search: + try: + web_context = _build_db_chat_web_context( + question, + serialized_rows, + timeout=settings.get("timeout", 30), + ) + except Exception as exc: + _log_event( + logging.WARNING, + "ai_chat_web_context_error", + error=str(exc), + question=question, + ) + web_context = [] + + try: + answer = _build_db_chat_answer( + question, + final_sql, + planner_reason, + serialized_rows, + web_context, + settings, + ) + except Exception as exc: + _log_event( + logging.WARNING, + "ai_chat_answer_error", + error=str(exc), + sql=final_sql, + row_count=len(serialized_rows), + ) + answer = ( + "查询已完成,但 AI 总结超时/失败。" + f"共返回 {len(serialized_rows)} 行数据,你可以查看本次 SQL 并缩小问题范围后重试。" + ) + + _append_ai_chat_memory(username, question, answer) + + _log_event( + logging.INFO, + "ai_chat_query_success", + question=question, + sql=final_sql, + row_count=len(serialized_rows), + web_sources=sum(len(item.get("sources", [])) for item in web_context), + memory_notes=len(memory_payload.get("notes", [])), + ) + return { + "ok": True, + "answer": answer, + "sql": final_sql, + "planner_reason": planner_reason, + "planner_raw": planner_text, + "row_count": len(serialized_rows), + "rows_preview": serialized_rows[:8], + "web_context": web_context, + "allow_web_search": allow_web_search, + "allow_db_query": allow_db_query, + "chat_mode": "db", + } + + def _is_safe_next_path(path: str) -> bool: candidate = (path or "").strip() if not candidate: diff --git a/data/ai_settings.json b/data/ai_settings.json index 995b9a3..2c33b7f 100644 --- a/data/ai_settings.json +++ b/data/ai_settings.json @@ -1,6 +1,6 @@ { "api_url": "https://api.siliconflow.cn/v1/chat/completions", - "model": "Pro/zai-org/GLM-5", + "model": "Pro/moonshotai/Kimi-K2.5", "api_key": "sk-pekgnbdvwgydxzteabnykswjadkitoopwcekmksydfoslmlo", "timeout": 120, "restock_threshold": 2, diff --git a/static/css/style.css b/static/css/style.css index b147816..1ce36ea 100644 --- a/static/css/style.css +++ b/static/css/style.css @@ -618,6 +618,86 @@ body { overflow: auto; } +.ai-chat-shell { + display: grid; + gap: var(--space-2); +} + +.ai-chat-messages { + display: grid; + gap: var(--space-1); + max-height: 58vh; + overflow: auto; + padding-right: 4px; +} + +.ai-chat-item { + border: 1px solid var(--line); + border-radius: var(--radius); + background: color-mix(in srgb, var(--card) 88%, var(--card-alt)); + padding: var(--space-1) var(--space-2); +} + +.ai-chat-item.user { + border-color: color-mix(in srgb, var(--accent) 55%, var(--line)); +} + +.ai-chat-item h3 { + margin: 0 0 6px; + font-size: 14px; +} + +.ai-chat-item p { + margin: 0; +} + +.md-content h1, +.md-content h2, +.md-content h3 { + margin: 0 0 8px; + line-height: 1.35; +} + +.md-content p { + margin: 0 0 8px; +} + +.md-content ul { + margin: 0 0 8px; + padding-left: 20px; +} + +.md-content li { + margin: 4px 0; +} + +.md-content .md-code { + margin: 8px 0; + border: 1px solid var(--line); + border-radius: var(--radius); + background: color-mix(in srgb, var(--card) 90%, var(--card-alt)); + padding: 10px; + overflow: auto; +} + +.md-content code { + font: 12px/1.45 Consolas, "Cascadia Mono", monospace; +} + +.md-content a { + color: var(--accent-press); + text-decoration: none; +} + +.md-content a:hover { + text-decoration: underline; +} + +.ai-chat-form-wrap { + border-top: 1px dashed var(--line); + padding-top: var(--space-2); +} + .box-list { display: grid; grid-template-columns: repeat(auto-fill, minmax(300px, 1fr)); diff --git a/templates/ai_chat.html b/templates/ai_chat.html new file mode 100644 index 0000000..4275421 --- /dev/null +++ b/templates/ai_chat.html @@ -0,0 +1,326 @@ + + +
+ + +用自然语言提问,系统会生成只读 SQL 查询库存数据并给出结论
+提示:默认是通用聊天 + 联网补充;勾选“启用数据库查询”后再做库存数据分析。
+ + + +