BGE-M3 API接口开发指南:FastAPI封装、请求限流、JWT鉴权集成
1. 引言:从模型服务到企业级API
如果你已经按照部署说明,成功在本地或服务器上跑起了BGE-M3的Gradio界面,看着那个简单的网页能返回文本向量,可能会想:这离真正的生产环境还差得远。
没错,原始的Gradio服务更像一个演示Demo。它缺少企业应用最核心的几个要素:标准的API接口、安全可控的访问权限、稳定的性能保障,以及方便集成的调用方式。直接把这个服务暴露给外部系统或客户端调用,会面临诸多挑战:如何验证调用者身份?如何防止接口被恶意刷爆?如何以更通用(如JSON)而非网页表单的形式交互?
本文将手把手带你进行二次开发,将一个基础的BGE-M3模型服务,封装成具备FastAPI框架、请求限流和JWT鉴权的成熟生产级API。完成后,你将获得一个类似下面这样的、可直接集成到业务系统中的接口:
POST /api/v1/embedding HTTP/1.1 Host: your-api-server.com Authorization: Bearer eyJhbGciOiJIUzI1NiIs... Content-Type: application/json { "texts": ["苹果是一种水果", "Apple is a technology company"], "mode": "dense" }2. 项目架构设计与核心组件
在动手写代码之前,我们先理清整个增强版API服务的架构。核心思想是分层和解耦,让每一层只负责一件事。
2.1 整体架构图(逻辑层面)
[ 客户端 ] | | (HTTPS + JWT Token) v [ FastAPI网关层 ] ——— 负责路由、请求验证、JWT鉴权、限流 | | (内部调用) v [ 模型服务层 ] ——— 封装原始的BGE-M3模型调用逻辑 | | (加载模型、计算) v [ BGE-M3模型 ] ——— 核心的FlagEmbedding模型2.2 核心组件介绍
- FastAPI: 现代、高性能的Python Web框架。它帮我们自动生成API文档(Swagger UI),用Python类型提示进行数据验证,并且异步支持好,性能出色。这是我们API的“大门”和“总控台”。
- JWT (JSON Web Token): 一种流行的鉴权方案。用户登录后获得一个Token,后续请求都在Header中带上这个Token来证明身份。服务端无需存储会话状态,只需验证Token的签名和有效性即可。它解决了“谁可以调用API”的问题。
- 请求限流 (Rate Limiting): 保护API免受滥用或DDoS攻击的关键机制。它可以限制每个用户(或每个IP)在特定时间窗口(如1分钟)内能发起的请求次数。这解决了“调用频率不可控”的问题。
- 原始模型服务: 我们将保留并改造之前部署的BGE-M3模型加载与推理代码,将其封装成一个内部可调用的Python类或函数,供FastAPI层调用。
3. 基础环境搭建与依赖安装
我们从最干净的环境开始。假设你已经在服务器上(比如之前部署的同一台),我们创建一个全新的项目目录。
# 1. 创建项目目录 mkdir bge-m3-api && cd bge-m3-api # 2. 创建虚拟环境(推荐,避免包冲突) python3 -m venv venv source venv/bin/activate # Linux/Mac # venv\Scripts\activate # Windows # 3. 安装核心依赖 pip install fastapi uvicorn[standard] python-jose[cryptography] passlib[bcrypt] python-multipart # fastapi: web框架 # uvicorn: ASGI服务器,用于运行FastAPI # python-jose: 用于生成和验证JWT令牌 # passlib: 密码哈希工具(用于用户认证,本例简化处理) # python-multipart: 处理表单数据 # 4. 安装限流依赖 pip install slowapi # 5. 安装模型推理依赖 (与原始部署一致) pip install FlagEmbedding sentence-transformers torch4. 核心代码实现:分步构建API
我们将代码分成几个文件,保持结构清晰。首先创建主应用文件。
4.1 第一步:创建FastAPI应用骨架与模型封装 (main.py)
我们先搭建FastAPI应用,并编写一个类来封装BGE-M3模型的加载和调用。这部分代码参考了原始app.py的核心逻辑。
# main.py import time from typing import List, Optional, Dict, Any from fastapi import FastAPI, HTTPException, Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel, Field from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from jose import JWTError, jwt from passlib.context import CryptContext # 导入模型相关库 from FlagEmbedding import BGEM3FlagModel import numpy as np import logging # --- 配置部分 --- # JWT配置 SECRET_KEY = "your-secret-key-change-this-in-production" # 务必在生产环境中更改! ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 # 模型配置 MODEL_PATH = "BAAI/bge-m3" # 会自动从HuggingFace下载,或指向本地路径如 `/root/.cache/.../bge-m3` MAX_LENGTH = 8192 DEVICE = "cuda" # 或 "cpu" # 限流器配置 limiter = Limiter(key_func=get_remote_address) # --- 初始化FastAPI应用 --- app = FastAPI( title="BGE-M3 Embedding API", description="提供BGE-M3文本嵌入模型的API服务,支持密集、稀疏、多向量模式。", version="1.0.0" ) # 将限流器挂载到app,并设置默认限流(示例:每分钟100次) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # 安全方案(用于提取Bearer Token) security = HTTPBearer() # 密码上下文(简化示例,实际生产需连接数据库) pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # 日志配置 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- 数据模型定义 (Pydantic) --- class EmbeddingRequest(BaseModel): """嵌入请求体""" texts: List[str] = Field(..., min_length=1, max_length=100, description="待编码的文本列表,最多100条") mode: str = Field(default="dense", description="编码模式:dense(密集), sparse(稀疏), colbert(多向量), 或 hybrid(混合)") normalize_embeddings: bool = Field(default=True, description="是否对密集向量进行L2归一化") return_dense: bool = Field(default=True, description="hybrid模式时是否返回密集向量") return_sparse: bool = Field(default=False, description="hybrid模式时是否返回稀疏向量") return_colbert_vecs: bool = Field(default=False, description="hybrid模式时是否返回Colbert多向量") class Config: schema_extra = { "example": { "texts": ["苹果是一种水果", "深度学习是人工智能的一个分支"], "mode": "dense", "normalize_embeddings": True } } class EmbeddingResponse(BaseModel): """嵌入响应体""" embeddings: Optional[List[List[float]]] = None sparse_embeddings: Optional[List[Dict[str, float]]] = None colbert_vecs: Optional[List[List[List[float]]]] = None model: str = MODEL_PATH mode: str request_id: str process_time_ms: float class TokenData(BaseModel): username: Optional[str] = None # --- BGE-M3模型封装类 --- class BGE_M3_Encoder: """封装BGE-M3模型的加载和推理""" _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(BGE_M3_Encoder, cls).__new__(cls) cls._instance._initialize_model() return cls._instance def _initialize_model(self): """初始化模型,单例模式确保只加载一次""" logger.info(f"正在加载模型: {MODEL_PATH},设备: {DEVICE}") start_time = time.time() try: # 设置环境变量,禁用TensorFlow(如果系统里有) import os os.environ['TRANSFORMERS_NO_TF'] = '1' self.model = BGEM3FlagModel( MODEL_PATH, use_fp16=True, # 使用FP16加速 device=DEVICE ) load_time = time.time() - start_time logger.info(f"模型加载完成,耗时: {load_time:.2f}秒") except Exception as e: logger.error(f"模型加载失败: {e}") raise RuntimeError(f"无法加载模型 {MODEL_PATH}: {e}") def encode(self, request: EmbeddingRequest) -> Dict[str, Any]: """根据请求参数进行编码""" start_time = time.time() try: # 调用FlagEmbedding库的encode方法 # 注意:原始库的返回格式可能根据参数不同而不同,这里做适配 if request.mode == 'dense': result = self.model.encode( request.texts, batch_size=32, max_length=MAX_LENGTH, return_dense=True, return_sparse=False, return_colbert_vecs=False, dense_dim=1024 # BGE-M3的密集向量维度是1024 ) embeddings = result['dense_vecs'].tolist() if hasattr(result['dense_vecs'], 'tolist') else result['dense_vecs'] output = { "embeddings": embeddings, "sparse_embeddings": None, "colbert_vecs": None } elif request.mode == 'sparse': result = self.model.encode( request.texts, return_dense=False, return_sparse=True, return_colbert_vecs=False ) # 稀疏向量通常表示为字典 {token_id: weight} output = { "embeddings": None, "sparse_embeddings": result.get('sparse_vecs', []), "colbert_vecs": None } elif request.mode == 'colbert': result = self.model.encode( request.texts, return_dense=False, return_sparse=False, return_colbert_vecs=True ) output = { "embeddings": None, "sparse_embeddings": None, "colbert_vecs": result.get('colbert_vecs', []) } elif request.mode == 'hybrid': result = self.model.encode( request.texts, return_dense=request.return_dense, return_sparse=request.return_sparse, return_colbert_vecs=request.return_colbert_vecs, max_length=MAX_LENGTH ) output = { "embeddings": result.get('dense_vecs', []), "sparse_embeddings": result.get('sparse_vecs', []), "colbert_vecs": result.get('colbert_vecs', []) } # 转换numpy数组为列表 if output['embeddings'] is not None and hasattr(output['embeddings'], 'tolist'): output['embeddings'] = output['embeddings'].tolist() if output['colbert_vecs'] is not None and hasattr(output['colbert_vecs'], 'tolist'): output['colbert_vecs'] = output['colbert_vecs'].tolist() else: raise ValueError(f"不支持的mode: {request.mode}") process_time = (time.time() - start_time) * 1000 # 毫秒 logger.info(f"编码完成,文本数: {len(request.texts)},模式: {request.mode},耗时: {process_time:.1f}ms") return output except Exception as e: logger.error(f"编码过程出错: {e}") raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") # 初始化模型编码器(单例) encoder = BGE_M3_Encoder()4.2 第二步:实现JWT鉴权与用户依赖 (auth.py)
我们将鉴权逻辑分离到一个单独的文件中。为了简化,我们使用一个内存中的“用户数据库”。生产环境请务必连接真实的数据库(如MySQL、PostgreSQL)。
# auth.py from datetime import datetime, timedelta from typing import Optional from jose import JWTError, jwt from passlib.context import CryptContext from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel # 复用main.py中的配置(实际项目可用配置模块) SECRET_KEY = "your-secret-key-change-this-in-production" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 # 模拟用户数据库 fake_users_db = { "test_user": { "username": "test_user", "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", # 密码是 "secret" "disabled": False, } } class Token(BaseModel): access_token: str token_type: str class TokenData(BaseModel): username: Optional[str] = None class User(BaseModel): username: str disabled: Optional[bool] = None class UserInDB(User): hashed_password: str pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") security = HTTPBearer() def verify_password(plain_password, hashed_password): """验证密码""" return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password): """生成密码哈希""" return pwd_context.hash(password) def get_user(db, username: str): """从数据库获取用户""" if username in db: user_dict = db[username] return UserInDB(**user_dict) return None def authenticate_user(fake_db, username: str, password: str): """用户认证""" user = get_user(fake_db, username) if not user: return False if not verify_password(password, user.hashed_password): return False return user def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): """创建JWT访问令牌""" to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): """依赖项:从JWT Token中获取当前用户""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的认证凭证", headers={"WWW-Authenticate": "Bearer"}, ) try: # 解码JWT Token payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise credentials_exception token_data = TokenData(username=username) except JWTError: raise credentials_exception # 从“数据库”获取用户 user = get_user(fake_users_db, username=token_data.username) if user is None: raise credentials_exception return user async def get_current_active_user(current_user: User = Depends(get_current_user)): """依赖项:检查用户是否活跃""" if current_user.disabled: raise HTTPException(status_code=400, detail="用户已被禁用") return current_user4.3 第三步:完善主应用,添加API端点 (main.py续)
现在回到main.py,添加具体的API路由,并集成鉴权和限流。
# main.py (续) # 在文件开头导入auth模块 from auth import get_current_active_user, create_access_token, authenticate_user, Token, User from datetime import timedelta # --- API端点定义 --- @app.post("/api/v1/token", response_model=Token) async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): """ 用户登录,获取JWT访问令牌。 使用表单格式提交 username 和 password。 """ # 验证用户凭证 user = authenticate_user(fake_users_db, form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误", headers={"WWW-Authenticate": "Bearer"}, ) # 设置Token过期时间 access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) # 创建Token,主题(sub)通常是用户名 access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} @app.get("/api/v1/health") @limiter.limit("100/minute") async def health_check(): """健康检查端点,无需鉴权""" return { "status": "healthy", "model_loaded": encoder.model is not None, "timestamp": time.time() } @app.post("/api/v1/embedding", response_model=EmbeddingResponse) @limiter.limit("30/minute") # 对嵌入端点进行更严格的限流 async def create_embedding( request: EmbeddingRequest, current_user: User = Depends(get_current_active_user), # 依赖鉴权 request_id: str = str(int(time.time() * 1000)) # 生成简单请求ID ): """ 核心端点:获取文本的嵌入向量。 需要Bearer Token认证。 """ logger.info(f"用户 [{current_user.username}] 请求嵌入,模式: {request.mode}, 文本数: {len(request.texts)}") # 调用模型编码器 start_time = time.time() result = encoder.encode(request) process_time_ms = (time.time() - start_time) * 1000 # 构建响应 response = EmbeddingResponse( embeddings=result.get("embeddings"), sparse_embeddings=result.get("sparse_embeddings"), colbert_vecs=result.get("colbert_vecs"), model=MODEL_PATH, mode=request.mode, request_id=request_id, process_time_ms=process_time_ms ) return response @app.get("/api/v1/info") async def get_model_info(current_user: User = Depends(get_current_active_user)): """获取模型信息端点""" return { "model_name": "BGE-M3", "vector_dimension": 1024, "max_sequence_length": MAX_LENGTH, "supported_modes": ["dense", "sparse", "colbert", "hybrid"], "device": DEVICE } # 注意:需要导入 OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm # --- 启动应用 --- if __name__ == "__main__": import uvicorn uvicorn.run( "main:app", host="0.0.0.0", # 监听所有网络接口 port=8000, # 使用8000端口,避免与原始Gradio服务(7860)冲突 reload=True # 开发模式,代码修改自动重启 )4.4 第四步:创建启动脚本与配置文件
为了方便部署,我们创建几个辅助文件。
启动脚本start_api.sh:
#!/bin/bash # start_api.sh cd /path/to/your/bge-m3-api source venv/bin/activate export TRANSFORMERS_NO_TF=1 uvicorn main:app --host 0.0.0.0 --port 8000 --workers 2后台运行:
nohup bash start_api.sh > /tmp/bge-m3-api.log 2>&1 &环境变量配置文件.env(可选,使用python-dotenv管理):
SECRET_KEY=your-very-strong-secret-key-here MODEL_PATH=BAAI/bge-m3 DEVICE=cuda5. 测试与使用你的生产级API
服务启动后,你可以通过多种方式测试API。
5.1 使用Swagger UI交互文档
FastAPI自动生成了交互式API文档。在浏览器中访问:
http://你的服务器IP:8000/docs你会看到一个漂亮的界面,可以在这里直接尝试/token端点登录,然后使用获得的Token去调用/embedding端点。这是最方便的测试方式。
5.2 使用cURL命令行测试
# 1. 获取Token curl -X POST "http://localhost:8000/api/v1/token" \ -H "Content-Type: application/x-www-form-urlencoded" \ -d "username=test_user&password=secret" # 响应示例: {"access_token":"eyJ...","token_type":"bearer"} # 2. 使用Token调用嵌入接口 curl -X POST "http://localhost:8000/api/v1/embedding" \ -H "Authorization: Bearer eyJ..." \ -H "Content-Type: application/json" \ -d '{ "texts": ["今天天气真好", "The weather is nice today"], "mode": "dense" }'5.3 使用Python客户端调用
# client_example.py import requests import json BASE_URL = "http://localhost:8000" USERNAME = "test_user" PASSWORD = "secret" # 1. 登录获取Token token_response = requests.post( f"{BASE_URL}/api/v1/token", data={"username": USERNAME, "password": PASSWORD} ) token_response.raise_for_status() access_token = token_response.json()["access_token"] print(f"获取到Token: {access_token[:20]}...") # 2. 调用嵌入API headers = { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json" } embedding_request = { "texts": ["苹果是一种水果", "Apple Inc. is a tech giant"], "mode": "dense", "normalize_embeddings": True } response = requests.post( f"{BASE_URL}/api/v1/embedding", headers=headers, json=embedding_request ) response.raise_for_status() result = response.json() print(f"请求ID: {result['request_id']}") print(f"处理时间: {result['process_time_ms']} ms") print(f"向量维度: {len(result['embeddings'][0]) if result['embeddings'] else 'N/A'}") print(f"第一条文本的向量前5维: {result['embeddings'][0][:5] if result['embeddings'] else 'N/A'}")6. 总结:从Demo到生产的关键步骤
回顾整个过程,我们通过几个关键步骤,将一个基础的模型演示服务升级成了适合集成的生产API:
- 框架升级:用FastAPI替代Gradio作为HTTP服务框架,获得了高性能、自动文档、数据验证等现代化特性。
- 安全加固:集成JWT鉴权,确保只有授权用户或系统可以调用API,并通过Token管理访问生命周期。
- 稳定性保障:引入请求限流,防止资源被过度消耗,保护后端模型服务稳定运行。
- 代码结构化:采用分层设计,将模型逻辑、认证逻辑、API路由分离,使代码更易维护和扩展。
- 接口标准化:定义清晰的请求/响应数据模型(Pydantic),提供符合RESTful风格的API端点。
这个增强版的API服务现在可以更安全、更稳定地接入到你的搜索系统、推荐系统、知识库问答等任何需要文本嵌入能力的业务场景中。你可以在此基础上进一步扩展,例如添加API密钥管理、更细粒度的权限控制、请求审计日志、Prometheus监控指标等,以满足更复杂的企业级需求。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。