feat: 添加用户登录认证功能,确保系统安全性,并提供修改密码和退出登录选项
This commit is contained in:
209
app.py
209
app.py
@@ -24,8 +24,9 @@ from copy import deepcopy
|
||||
from io import StringIO
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from flask import Flask, Response, redirect, render_template, request, url_for
|
||||
from flask import Flask, Response, redirect, render_template, request, session, url_for
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -37,12 +38,18 @@ DB_PATH = os.path.join(DB_DIR, "inventory.db")
|
||||
app = Flask(__name__)
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{DB_PATH}"
|
||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
||||
app.config["SECRET_KEY"] = os.environ.get(
|
||||
"INVENTORY_SECRET_KEY",
|
||||
hashlib.sha256(f"inventory-local::{BASE_DIR}".encode("utf-8")).hexdigest(),
|
||||
)
|
||||
db = SQLAlchemy(app)
|
||||
|
||||
# 这里集中放全局常量,避免后面函数里散落硬编码。
|
||||
LOW_STOCK_THRESHOLD = 5
|
||||
BOX_TYPES_OVERRIDE_PATH = os.path.join(DB_DIR, "box_types.json")
|
||||
AI_SETTINGS_PATH = os.path.join(DB_DIR, "ai_settings.json")
|
||||
DEFAULT_ADMIN_USERNAME = os.environ.get("INVENTORY_ADMIN_USERNAME", "admin").strip() or "admin"
|
||||
DEFAULT_ADMIN_PASSWORD = os.environ.get("INVENTORY_ADMIN_PASSWORD", "admin123456")
|
||||
LCSC_BASE_URL = "https://open-api.jlc.com"
|
||||
LCSC_BASIC_PATH = "/lcsc/openapi/sku/product/basic"
|
||||
AI_SETTINGS_DEFAULT = {
|
||||
@@ -520,6 +527,15 @@ class InventoryEvent(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = db.Column(db.Integer, primary_key=True)
|
||||
username = db.Column(db.String(64), nullable=False, unique=True)
|
||||
password_hash = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.now())
|
||||
|
||||
|
||||
def _add_column_if_missing(table_name: str, column_name: str, ddl: str) -> None:
|
||||
columns = {
|
||||
row[1]
|
||||
@@ -563,6 +579,116 @@ def ensure_schema() -> None:
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def _get_session_user() -> User | None:
|
||||
user_id = session.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
return User.query.get(int(user_id))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _is_authenticated() -> bool:
|
||||
return _get_session_user() is not None
|
||||
|
||||
|
||||
def _login_user(user: User) -> None:
|
||||
now_iso = datetime.utcnow().isoformat(timespec="seconds")
|
||||
session["user_id"] = int(user.id)
|
||||
session["username"] = user.username
|
||||
session["login_at"] = now_iso
|
||||
session["last_active_at"] = now_iso
|
||||
|
||||
|
||||
def _logout_user() -> None:
|
||||
session.pop("user_id", None)
|
||||
session.pop("username", None)
|
||||
session.pop("login_at", None)
|
||||
session.pop("last_active_at", None)
|
||||
|
||||
|
||||
def _parse_iso_datetime(raw_value: str) -> datetime | None:
|
||||
value = (raw_value or "").strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _build_session_status() -> tuple[str, str, str, str]:
|
||||
raw_login_at = (session.get("login_at") or "").strip()
|
||||
raw_last_active_at = (session.get("last_active_at") or "").strip()
|
||||
login_at = _parse_iso_datetime(raw_login_at)
|
||||
last_active_at = _parse_iso_datetime(raw_last_active_at)
|
||||
|
||||
if login_at is None:
|
||||
return "-", "-", "-", "-"
|
||||
if last_active_at is None:
|
||||
last_active_at = login_at
|
||||
|
||||
online_minutes = max(0, int((datetime.utcnow() - login_at).total_seconds() // 60))
|
||||
idle_minutes = max(0, int((datetime.utcnow() - last_active_at).total_seconds() // 60))
|
||||
login_label = login_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
last_active_label = last_active_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
online_label = "刚刚" if online_minutes <= 0 else f"{online_minutes} 分钟"
|
||||
idle_label = "刚刚" if idle_minutes <= 0 else f"{idle_minutes} 分钟"
|
||||
return login_label, online_label, last_active_label, idle_label
|
||||
|
||||
|
||||
def _ensure_default_admin_user() -> None:
|
||||
"""确保系统至少存在一个可登录用户。
|
||||
|
||||
中文说明:首次初始化时自动创建管理员账号,避免系统开启登录保护后无人可进。
|
||||
用户名和密码可通过环境变量 INVENTORY_ADMIN_USERNAME / INVENTORY_ADMIN_PASSWORD 覆盖。
|
||||
"""
|
||||
existing_user = User.query.order_by(User.id.asc()).first()
|
||||
if existing_user:
|
||||
return
|
||||
|
||||
admin = User(
|
||||
username=DEFAULT_ADMIN_USERNAME,
|
||||
password_hash=generate_password_hash(DEFAULT_ADMIN_PASSWORD),
|
||||
)
|
||||
db.session.add(admin)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@app.context_processor
|
||||
def inject_auth_context():
|
||||
current_user = _get_session_user()
|
||||
login_label, online_label, last_active_label, idle_label = _build_session_status()
|
||||
return {
|
||||
"auth_username": current_user.username if current_user else "",
|
||||
"auth_logged_in": current_user is not None,
|
||||
"auth_login_at": login_label,
|
||||
"auth_online_for": online_label,
|
||||
"auth_last_active_at": last_active_label,
|
||||
"auth_idle_for": idle_label,
|
||||
}
|
||||
|
||||
|
||||
@app.before_request
|
||||
def require_login_for_app_routes():
|
||||
open_endpoints = {
|
||||
"login_page",
|
||||
"logout_page",
|
||||
"static",
|
||||
}
|
||||
endpoint = request.endpoint or ""
|
||||
if endpoint in open_endpoints or endpoint.startswith("static"):
|
||||
return None
|
||||
|
||||
if _is_authenticated():
|
||||
session["last_active_at"] = datetime.utcnow().isoformat(timespec="seconds")
|
||||
return None
|
||||
|
||||
next_path = request.full_path if request.query_string else request.path
|
||||
return redirect(url_for("login_page", next=next_path.rstrip("?")))
|
||||
|
||||
|
||||
def slot_code_for_box(box: Box, slot_index: int) -> str:
|
||||
serial = box.start_number + slot_index - 1
|
||||
return f"{box.slot_prefix}{serial}"
|
||||
@@ -2689,6 +2815,86 @@ def _call_siliconflow_chat(
|
||||
raise RuntimeError("AI 返回格式无法解析") from exc
|
||||
|
||||
|
||||
def _is_safe_next_path(path: str) -> bool:
|
||||
candidate = (path or "").strip()
|
||||
if not candidate:
|
||||
return False
|
||||
return candidate.startswith("/") and not candidate.startswith("//")
|
||||
|
||||
|
||||
@app.route("/login", methods=["GET", "POST"])
|
||||
def login_page():
|
||||
if _is_authenticated():
|
||||
return redirect(url_for("types_page"))
|
||||
|
||||
error = ""
|
||||
notice = request.args.get("notice", "").strip()
|
||||
next_path = request.args.get("next", "").strip()
|
||||
if request.method == "POST":
|
||||
username = request.form.get("username", "").strip()
|
||||
password = request.form.get("password", "")
|
||||
next_path = request.form.get("next", "").strip()
|
||||
|
||||
user = User.query.filter_by(username=username).first()
|
||||
if not user or not check_password_hash(user.password_hash, password):
|
||||
error = "用户名或密码错误"
|
||||
else:
|
||||
_login_user(user)
|
||||
if _is_safe_next_path(next_path):
|
||||
return redirect(next_path)
|
||||
return redirect(url_for("types_page"))
|
||||
|
||||
if not _is_safe_next_path(next_path):
|
||||
next_path = ""
|
||||
|
||||
return render_template("login.html", error=error, notice=notice, next_path=next_path)
|
||||
|
||||
|
||||
@app.route("/logout")
|
||||
def logout_page():
|
||||
_logout_user()
|
||||
return redirect(url_for("login_page"))
|
||||
|
||||
|
||||
@app.route("/account/password", methods=["GET", "POST"])
|
||||
def change_password_page():
|
||||
"""登录后修改当前账号密码。
|
||||
|
||||
中文说明:为了避免长期使用默认密码,提供页面自助改密。
|
||||
改密成功后会强制重新登录,确保会话状态干净。
|
||||
"""
|
||||
current_user = _get_session_user()
|
||||
if current_user is None:
|
||||
return redirect(url_for("login_page"))
|
||||
|
||||
error = ""
|
||||
notice = ""
|
||||
|
||||
if request.method == "POST":
|
||||
current_password = request.form.get("current_password", "")
|
||||
new_password = request.form.get("new_password", "")
|
||||
confirm_password = request.form.get("confirm_password", "")
|
||||
|
||||
if not check_password_hash(current_user.password_hash, current_password):
|
||||
error = "当前密码不正确"
|
||||
elif len(new_password) < 8:
|
||||
error = "新密码至少需要 8 位"
|
||||
elif new_password != confirm_password:
|
||||
error = "两次输入的新密码不一致"
|
||||
elif check_password_hash(current_user.password_hash, new_password):
|
||||
error = "新密码不能与当前密码相同"
|
||||
else:
|
||||
current_user.password_hash = generate_password_hash(new_password)
|
||||
db.session.commit()
|
||||
_logout_user()
|
||||
return redirect(url_for("login_page", notice="密码修改成功,请使用新密码重新登录"))
|
||||
|
||||
if not error:
|
||||
notice = "建议使用强密码(字母+数字+符号),并定期更换。"
|
||||
|
||||
return render_template("change_password.html", error=error, notice=notice)
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return redirect(url_for("types_page"))
|
||||
@@ -4602,6 +4808,7 @@ def bootstrap() -> None:
|
||||
db.create_all()
|
||||
ensure_schema()
|
||||
normalize_legacy_data()
|
||||
_ensure_default_admin_user()
|
||||
|
||||
|
||||
bootstrap()
|
||||
|
||||
Reference in New Issue
Block a user