feat: 添加用户登录认证功能,确保系统安全性,并提供修改密码和退出登录选项

This commit is contained in:
2026-03-14 00:11:16 +08:00
parent 847ec32144
commit d2d63d5e61
15 changed files with 437 additions and 5 deletions

209
app.py
View File

@@ -24,8 +24,9 @@ from copy import deepcopy
from io import StringIO
from datetime import datetime, timedelta
from flask import Flask, Response, redirect, render_template, request, url_for
from flask import Flask, Response, redirect, render_template, request, session, url_for
from flask_sqlalchemy import SQLAlchemy
from werkzeug.security import check_password_hash, generate_password_hash
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -37,12 +38,18 @@ DB_PATH = os.path.join(DB_DIR, "inventory.db")
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{DB_PATH}"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["SECRET_KEY"] = os.environ.get(
"INVENTORY_SECRET_KEY",
hashlib.sha256(f"inventory-local::{BASE_DIR}".encode("utf-8")).hexdigest(),
)
db = SQLAlchemy(app)
# 这里集中放全局常量,避免后面函数里散落硬编码。
LOW_STOCK_THRESHOLD = 5
BOX_TYPES_OVERRIDE_PATH = os.path.join(DB_DIR, "box_types.json")
AI_SETTINGS_PATH = os.path.join(DB_DIR, "ai_settings.json")
DEFAULT_ADMIN_USERNAME = os.environ.get("INVENTORY_ADMIN_USERNAME", "admin").strip() or "admin"
DEFAULT_ADMIN_PASSWORD = os.environ.get("INVENTORY_ADMIN_PASSWORD", "admin123456")
LCSC_BASE_URL = "https://open-api.jlc.com"
LCSC_BASIC_PATH = "/lcsc/openapi/sku/product/basic"
AI_SETTINGS_DEFAULT = {
@@ -520,6 +527,15 @@ class InventoryEvent(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
class User(db.Model):
__tablename__ = "users"
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(64), nullable=False, unique=True)
password_hash = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
def _add_column_if_missing(table_name: str, column_name: str, ddl: str) -> None:
columns = {
row[1]
@@ -563,6 +579,116 @@ def ensure_schema() -> None:
db.session.commit()
def _get_session_user() -> User | None:
user_id = session.get("user_id")
if not user_id:
return None
try:
return User.query.get(int(user_id))
except (TypeError, ValueError):
return None
def _is_authenticated() -> bool:
return _get_session_user() is not None
def _login_user(user: User) -> None:
now_iso = datetime.utcnow().isoformat(timespec="seconds")
session["user_id"] = int(user.id)
session["username"] = user.username
session["login_at"] = now_iso
session["last_active_at"] = now_iso
def _logout_user() -> None:
session.pop("user_id", None)
session.pop("username", None)
session.pop("login_at", None)
session.pop("last_active_at", None)
def _parse_iso_datetime(raw_value: str) -> datetime | None:
value = (raw_value or "").strip()
if not value:
return None
try:
return datetime.fromisoformat(value)
except ValueError:
return None
def _build_session_status() -> tuple[str, str, str, str]:
raw_login_at = (session.get("login_at") or "").strip()
raw_last_active_at = (session.get("last_active_at") or "").strip()
login_at = _parse_iso_datetime(raw_login_at)
last_active_at = _parse_iso_datetime(raw_last_active_at)
if login_at is None:
return "-", "-", "-", "-"
if last_active_at is None:
last_active_at = login_at
online_minutes = max(0, int((datetime.utcnow() - login_at).total_seconds() // 60))
idle_minutes = max(0, int((datetime.utcnow() - last_active_at).total_seconds() // 60))
login_label = login_at.strftime("%Y-%m-%d %H:%M:%S")
last_active_label = last_active_at.strftime("%Y-%m-%d %H:%M:%S")
online_label = "刚刚" if online_minutes <= 0 else f"{online_minutes} 分钟"
idle_label = "刚刚" if idle_minutes <= 0 else f"{idle_minutes} 分钟"
return login_label, online_label, last_active_label, idle_label
def _ensure_default_admin_user() -> None:
"""确保系统至少存在一个可登录用户。
中文说明:首次初始化时自动创建管理员账号,避免系统开启登录保护后无人可进。
用户名和密码可通过环境变量 INVENTORY_ADMIN_USERNAME / INVENTORY_ADMIN_PASSWORD 覆盖。
"""
existing_user = User.query.order_by(User.id.asc()).first()
if existing_user:
return
admin = User(
username=DEFAULT_ADMIN_USERNAME,
password_hash=generate_password_hash(DEFAULT_ADMIN_PASSWORD),
)
db.session.add(admin)
db.session.commit()
@app.context_processor
def inject_auth_context():
current_user = _get_session_user()
login_label, online_label, last_active_label, idle_label = _build_session_status()
return {
"auth_username": current_user.username if current_user else "",
"auth_logged_in": current_user is not None,
"auth_login_at": login_label,
"auth_online_for": online_label,
"auth_last_active_at": last_active_label,
"auth_idle_for": idle_label,
}
@app.before_request
def require_login_for_app_routes():
open_endpoints = {
"login_page",
"logout_page",
"static",
}
endpoint = request.endpoint or ""
if endpoint in open_endpoints or endpoint.startswith("static"):
return None
if _is_authenticated():
session["last_active_at"] = datetime.utcnow().isoformat(timespec="seconds")
return None
next_path = request.full_path if request.query_string else request.path
return redirect(url_for("login_page", next=next_path.rstrip("?")))
def slot_code_for_box(box: Box, slot_index: int) -> str:
serial = box.start_number + slot_index - 1
return f"{box.slot_prefix}{serial}"
@@ -2689,6 +2815,86 @@ def _call_siliconflow_chat(
raise RuntimeError("AI 返回格式无法解析") from exc
def _is_safe_next_path(path: str) -> bool:
candidate = (path or "").strip()
if not candidate:
return False
return candidate.startswith("/") and not candidate.startswith("//")
@app.route("/login", methods=["GET", "POST"])
def login_page():
if _is_authenticated():
return redirect(url_for("types_page"))
error = ""
notice = request.args.get("notice", "").strip()
next_path = request.args.get("next", "").strip()
if request.method == "POST":
username = request.form.get("username", "").strip()
password = request.form.get("password", "")
next_path = request.form.get("next", "").strip()
user = User.query.filter_by(username=username).first()
if not user or not check_password_hash(user.password_hash, password):
error = "用户名或密码错误"
else:
_login_user(user)
if _is_safe_next_path(next_path):
return redirect(next_path)
return redirect(url_for("types_page"))
if not _is_safe_next_path(next_path):
next_path = ""
return render_template("login.html", error=error, notice=notice, next_path=next_path)
@app.route("/logout")
def logout_page():
_logout_user()
return redirect(url_for("login_page"))
@app.route("/account/password", methods=["GET", "POST"])
def change_password_page():
"""登录后修改当前账号密码。
中文说明:为了避免长期使用默认密码,提供页面自助改密。
改密成功后会强制重新登录,确保会话状态干净。
"""
current_user = _get_session_user()
if current_user is None:
return redirect(url_for("login_page"))
error = ""
notice = ""
if request.method == "POST":
current_password = request.form.get("current_password", "")
new_password = request.form.get("new_password", "")
confirm_password = request.form.get("confirm_password", "")
if not check_password_hash(current_user.password_hash, current_password):
error = "当前密码不正确"
elif len(new_password) < 8:
error = "新密码至少需要 8 位"
elif new_password != confirm_password:
error = "两次输入的新密码不一致"
elif check_password_hash(current_user.password_hash, new_password):
error = "新密码不能与当前密码相同"
else:
current_user.password_hash = generate_password_hash(new_password)
db.session.commit()
_logout_user()
return redirect(url_for("login_page", notice="密码修改成功,请使用新密码重新登录"))
if not error:
notice = "建议使用强密码(字母+数字+符号),并定期更换。"
return render_template("change_password.html", error=error, notice=notice)
@app.route("/")
def index():
return redirect(url_for("types_page"))
@@ -4602,6 +4808,7 @@ def bootstrap() -> None:
db.create_all()
ensure_schema()
normalize_legacy_data()
_ensure_default_admin_user()
bootstrap()