DASD-4B-Thinking模型API开发:基于FastAPI的高效服务封装
最近在折腾DASD-4B-Thinking这个模型,发现它确实有点意思。作为一个40亿参数的思考型大语言模型,它在多步推理和长链思维方面表现不错。但问题来了,总不能每次都手动跑脚本调用吧?得有个像样的API服务才行。
正好最近在搞FastAPI,就想着用这个框架给DASD-4B-Thinking封装一套RESTful API。用下来感觉FastAPI确实挺适合这种场景,开发速度快,性能也不错,还能自动生成API文档。今天就跟大家分享一下我是怎么做的,从环境搭建到接口设计,再到性能优化,一步步带你搞定。
1. 环境准备与项目搭建
开始之前,咱们先把环境准备好。我建议用Python 3.9或更高版本,这样兼容性会好一些。
1.1 创建虚拟环境
先创建一个独立的Python环境,避免包冲突:
# 创建项目目录 mkdir dasd-api-service cd dasd-api-service # 创建虚拟环境 python -m venv venv # 激活虚拟环境 # Linux/Mac source venv/bin/activate # Windows venv\Scripts\activate1.2 安装核心依赖
接下来安装必要的包。这里我列了一个比较全的依赖列表:
# 基础框架和模型相关 pip install fastapi uvicorn pydantic httpx # 模型推理相关(根据你的部署方式选择) # 如果你用vLLM部署 pip install vllm # 如果你用Transformers pip install transformers torch # 监控和日志 pip install prometheus-client loguru # 可选:异步任务处理 pip install celery redis # 可选:API限流 pip install slowapi如果你用的是vLLM部署的DASD-4B-Thinking,那安装vLLM就够了。我用的是vLLM,因为它的推理速度确实快,而且内存管理做得不错。
1.3 项目结构设计
一个好的项目结构能让后续开发轻松很多。我是这样组织的:
dasd-api-service/ ├── app/ │ ├── __init__.py │ ├── main.py # FastAPI应用入口 │ ├── config.py # 配置文件 │ ├── models/ # 数据模型 │ │ ├── __init__.py │ │ ├── request.py # 请求模型 │ │ └── response.py # 响应模型 │ ├── routers/ # 路由模块 │ │ ├── __init__.py │ │ ├── chat.py # 聊天接口 │ │ ├── completion.py # 补全接口 │ │ └── health.py # 健康检查 │ ├── services/ # 业务逻辑 │ │ ├── __init__.py │ │ ├── inference.py # 模型推理服务 │ │ └── cache.py # 缓存服务 │ ├── utils/ # 工具函数 │ │ ├── __init__.py │ │ ├── logger.py # 日志配置 │ │ └── metrics.py # 监控指标 │ └── dependencies.py # 依赖注入 ├── tests/ # 测试文件 ├── requirements.txt # 依赖列表 ├── .env.example # 环境变量示例 └── README.md # 项目说明这个结构比较清晰,各个模块职责分明。当然你可以根据自己的习惯调整,关键是保持一致性。
2. 核心配置与模型加载
配置管理是API服务的基础。我习惯用Pydantic的BaseSettings来管理配置,这样既安全又方便。
2.1 配置文件设计
在app/config.py里,我这样定义配置:
from pydantic_settings import BaseSettings from typing import Optional class Settings(BaseSettings): # API配置 api_host: str = "0.0.0.0" api_port: int = 8000 api_workers: int = 1 api_debug: bool = False # 模型配置 model_name: str = "DASD-4B-Thinking" model_path: str = "/path/to/your/model" # 模型本地路径 model_max_length: int = 8192 model_temperature: float = 0.7 model_top_p: float = 0.9 # vLLM配置 vllm_max_model_len: int = 16384 vllm_gpu_memory_utilization: float = 0.9 vllm_enforce_eager: bool = False # 服务配置 max_concurrent_requests: int = 10 request_timeout: int = 300 # 5分钟 enable_streaming: bool = True # 缓存配置 cache_enabled: bool = True cache_ttl: int = 3600 # 1小时 # 监控配置 enable_metrics: bool = True metrics_port: int = 9090 class Config: env_file = ".env" env_file_encoding = "utf-8" settings = Settings()然后在项目根目录创建.env文件,覆盖默认配置:
# API配置 API_HOST=0.0.0.0 API_PORT=8000 API_DEBUG=false # 模型配置 MODEL_PATH=/data/models/dasd-4b-thinking MODEL_MAX_LENGTH=8192 # vLLM配置 VLLM_GPU_MEMORY_UTILIZATION=0.852.2 模型加载服务
模型加载是API服务的核心。我创建了一个单例模式的服务类来管理模型实例:
# app/services/inference.py import asyncio from typing import Optional, Dict, Any from loguru import logger from vllm import AsyncLLMEngine, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from app.config import settings class InferenceService: _instance = None _engine: Optional[AsyncLLMEngine] = None def __new__(cls): if cls._instance is None: cls._instance = super(InferenceService, cls).__new__(cls) return cls._instance async def initialize(self): """初始化模型引擎""" if self._engine is not None: return logger.info(f"正在加载模型: {settings.model_name}") # 配置引擎参数 engine_args = AsyncEngineArgs( model=settings.model_path, tensor_parallel_size=1, # 根据你的GPU数量调整 max_model_len=settings.vllm_max_model_len, gpu_memory_utilization=settings.vllm_gpu_memory_utilization, enforce_eager=settings.vllm_enforce_eager, disable_log_stats=True, disable_log_requests=True, ) # 创建异步引擎 self._engine = AsyncLLMEngine.from_engine_args(engine_args) # 预热模型 await self._warmup() logger.info("模型加载完成") async def _warmup(self): """模型预热,避免第一次请求延迟""" try: warmup_prompt = "Hello, this is a warmup request." sampling_params = SamplingParams( temperature=0.1, top_p=0.9, max_tokens=10 ) # 生成预热请求 results_generator = self._engine.generate( prompt=warmup_prompt, sampling_params=sampling_params, request_id="warmup" ) # 消费生成器但不使用结果 async for _ in results_generator: pass logger.debug("模型预热完成") except Exception as e: logger.warning(f"模型预热失败: {e}") async def generate( self, prompt: str, max_tokens: int = 512, temperature: float = None, top_p: float = None, stream: bool = False ): """生成文本""" if self._engine is None: raise RuntimeError("模型引擎未初始化") # 使用配置的默认值或传入的参数 sampling_params = SamplingParams( temperature=temperature or settings.model_temperature, top_p=top_p or settings.model_top_p, max_tokens=min(max_tokens, settings.model_max_length), stop=None, # DASD-4B-Thinking可能有特定的停止词 ) # 生成请求ID import uuid request_id = str(uuid.uuid4()) # 调用模型 results_generator = self._engine.generate( prompt=prompt, sampling_params=sampling_params, request_id=request_id ) if stream: # 流式响应 async def stream_generator(): async for output in results_generator: if output.finished: break for choice in output.outputs: yield choice.text return stream_generator() else: # 非流式响应 final_output = None async for output in results_generator: final_output = output if final_output and final_output.outputs: return final_output.outputs[0].text return "" async def chat( self, messages: list, max_tokens: int = 512, temperature: float = None, stream: bool = False ): """聊天接口""" # 将消息列表转换为提示词 # DASD-4B-Thinking可能有特定的聊天格式 prompt = self._format_chat_prompt(messages) return await self.generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, stream=stream ) def _format_chat_prompt(self, messages: list) -> str: """格式化聊天提示词""" # 这里需要根据DASD-4B-Thinking的聊天格式来调整 # 假设是类似ChatML的格式 formatted = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") formatted.append(f"<|{role}|>\n{content}") formatted.append("<|assistant|>\n") return "\n".join(formatted) async def shutdown(self): """关闭模型引擎""" if self._engine: await self._engine.shutdown() self._engine = None logger.info("模型引擎已关闭") # 创建全局实例 inference_service = InferenceService()这个服务类做了几件事:一是用单例模式确保模型只加载一次;二是提供了异步的初始化和推理方法;三是支持流式和非流式两种响应方式。
3. API接口设计与实现
有了模型服务,接下来就是设计API接口了。FastAPI的路由设计很灵活,我习惯按功能模块来组织。
3.1 健康检查接口
先来个简单的健康检查接口,方便监控服务状态:
# app/routers/health.py from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse from typing import Dict from app.services.inference import inference_service router = APIRouter(prefix="/health", tags=["health"]) @router.get("/") async def health_check() -> Dict[str, str]: """基础健康检查""" return {"status": "healthy", "service": "dasd-api"} @router.get("/ready") async def readiness_check() -> Dict[str, str]: """就绪检查,确认模型已加载""" if inference_service._engine is None: return JSONResponse( status_code=503, content={"status": "not ready", "message": "Model not loaded"} ) return {"status": "ready", "message": "Model is loaded"} @router.get("/metrics") async def metrics_endpoint(): """Prometheus指标端点""" # 这里可以集成prometheus-client # 为了简化,先返回基础信息 from app.utils.metrics import get_metrics return get_metrics()3.2 聊天接口实现
聊天接口是最常用的,我设计得比较全面:
# app/routers/chat.py from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from typing import List, Optional import asyncio from app.services.inference import inference_service from app.config import settings router = APIRouter(prefix="/chat", tags=["chat"]) # 请求响应模型 class Message(BaseModel): role: str = Field(..., description="消息角色,如user、assistant") content: str = Field(..., description="消息内容") class ChatRequest(BaseModel): messages: List[Message] = Field(..., description="消息历史") max_tokens: Optional[int] = Field(512, ge=1, le=settings.model_max_length, description="最大生成token数") temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="温度参数") stream: Optional[bool] = Field(False, description="是否流式响应") thinking_enabled: Optional[bool] = Field(True, description="是否启用思考模式") class ChatResponse(BaseModel): message: Message = Field(..., description="助手回复") usage: dict = Field(..., description="token使用情况") finish_reason: str = Field(..., description="结束原因") @router.post("/completions", response_model=ChatResponse) async def chat_completion(request: ChatRequest): """聊天补全接口""" try: # 检查消息是否为空 if not request.messages: raise HTTPException(status_code=400, detail="Messages cannot be empty") # 检查最后一条消息是否是用户消息 last_message = request.messages[-1] if last_message.role != "user": raise HTTPException( status_code=400, detail="Last message must be from user" ) # 调用模型服务 if request.stream: # 流式响应 async def event_generator(): full_response = "" async for chunk in inference_service.chat( messages=[m.dict() for m in request.messages], max_tokens=request.max_tokens, temperature=request.temperature, stream=True ): full_response += chunk # 这里可以按SSE格式返回 yield f"data: {chunk}\n\n" # 发送结束标记 yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream" ) else: # 非流式响应 response_text = await inference_service.chat( messages=[m.dict() for m in request.messages], max_tokens=request.max_tokens, temperature=request.temperature, stream=False ) # 构造响应 return ChatResponse( message=Message(role="assistant", content=response_text), usage={ "prompt_tokens": len(str(request.messages)), # 简化计算 "completion_tokens": len(response_text.split()), "total_tokens": len(str(request.messages)) + len(response_text.split()) }, finish_reason="stop" ) except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @router.post("/stream") async def chat_stream(request: ChatRequest): """专门的流式聊天接口""" request.stream = True return await chat_completion(request)3.3 文本补全接口
除了聊天接口,我还加了一个通用的文本补全接口:
# app/routers/completion.py from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field from typing import Optional from fastapi.responses import StreamingResponse from app.services.inference import inference_service from app.config import settings router = APIRouter(prefix="/completions", tags=["completions"]) class CompletionRequest(BaseModel): prompt: str = Field(..., description="输入提示词") max_tokens: Optional[int] = Field(512, ge=1, le=settings.model_max_length) temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0) top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0) stream: Optional[bool] = Field(False) stop: Optional[list] = Field(None, description="停止词列表") class CompletionResponse(BaseModel): text: str = Field(..., description="生成的文本") usage: dict = Field(..., description="token使用情况") @router.post("/") async def create_completion(request: CompletionRequest): """文本补全接口""" try: if request.stream: async def event_generator(): async for chunk in inference_service.generate( prompt=request.prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, stream=True ): yield f"data: {chunk}\n\n" yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), media_type="text/event-stream" ) else: text = await inference_service.generate( prompt=request.prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, stream=False ) return CompletionResponse( text=text, usage={ "prompt_tokens": len(request.prompt.split()), "completion_tokens": len(text.split()), "total_tokens": len(request.prompt.split()) + len(text.split()) } ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))4. 并发处理与性能优化
API服务上线后,并发访问是必须考虑的问题。DASD-4B-Thinking虽然只有40亿参数,但推理仍然需要一定的计算资源。
4.1 请求限流与队列管理
为了防止服务被压垮,我加了请求限流:
# app/dependencies.py from fastapi import HTTPException, Request import asyncio from collections import defaultdict import time from loguru import logger class RateLimiter: def __init__(self, requests_per_minute: int = 60): self.requests_per_minute = requests_per_minute self.requests = defaultdict(list) async def __call__(self, request: Request): # 获取客户端IP client_ip = request.client.host # 清理过期记录 current_time = time.time() self.requests[client_ip] = [ req_time for req_time in self.requests[client_ip] if current_time - req_time < 60 ] # 检查是否超限 if len(self.requests[client_ip]) >= self.requests_per_minute: logger.warning(f"Rate limit exceeded for {client_ip}") raise HTTPException( status_code=429, detail="Too many requests. Please try again later." ) # 记录本次请求 self.requests[client_ip].append(current_time) # 创建限流器实例 rate_limiter = RateLimiter(requests_per_minute=30)然后在路由中使用:
# 在聊天接口中添加限流 @router.post("/completions", response_model=ChatResponse) async def chat_completion( request: ChatRequest, rate_limit: bool = Depends(rate_limiter) ): # ... 原有代码4.2 异步任务处理
对于耗时的请求,可以考虑用异步任务队列:
# app/services/task_queue.py import asyncio from typing import Callable, Any from concurrent.futures import ThreadPoolExecutor from loguru import logger class TaskQueue: def __init__(self, max_workers: int = 4): self.executor = ThreadPoolExecutor(max_workers=max_workers) self.tasks = {} async def submit( self, func: Callable, *args, task_id: str = None, **kwargs ) -> str: """提交任务到队列""" import uuid task_id = task_id or str(uuid.uuid4()) # 将同步函数包装为异步 loop = asyncio.get_event_loop() # 记录任务开始时间 start_time = asyncio.get_event_loop().time() # 提交到线程池 future = loop.run_in_executor( self.executor, lambda: func(*args, **kwargs) ) self.tasks[task_id] = { "future": future, "start_time": start_time, "status": "pending" } # 添加完成回调 future.add_done_callback( lambda f: self._task_done_callback(task_id, f) ) return task_id def _task_done_callback(self, task_id: str, future): """任务完成回调""" if task_id in self.tasks: self.tasks[task_id]["status"] = "completed" self.tasks[task_id]["end_time"] = asyncio.get_event_loop().time() try: result = future.result() self.tasks[task_id]["result"] = result self.tasks[task_id]["success"] = True except Exception as e: self.tasks[task_id]["error"] = str(e) self.tasks[task_id]["success"] = False async def get_result(self, task_id: str, timeout: float = None): """获取任务结果""" if task_id not in self.tasks: raise ValueError(f"Task {task_id} not found") task = self.tasks[task_id] future = task["future"] try: if timeout: result = await asyncio.wait_for(future, timeout) else: result = await future return { "task_id": task_id, "status": "completed", "result": result, "success": True } except asyncio.TimeoutError: return { "task_id": task_id, "status": "timeout", "success": False } except Exception as e: return { "task_id": task_id, "status": "error", "error": str(e), "success": False } def get_status(self, task_id: str): """获取任务状态""" if task_id not in self.tasks: return None task = self.tasks[task_id] status = task["status"] if status == "pending": elapsed = asyncio.get_event_loop().time() - task["start_time"] return { "task_id": task_id, "status": "running", "elapsed": elapsed } return { "task_id": task_id, "status": task["status"], "success": task.get("success", False) } # 创建全局任务队列 task_queue = TaskQueue(max_workers=settings.max_concurrent_requests)4.3 响应缓存
对于相似的请求,可以加一层缓存来提高响应速度:
# app/services/cache.py import hashlib import json from typing import Any, Optional import asyncio from datetime import datetime, timedelta from loguru import logger class ResponseCache: def __init__(self, ttl: int = 3600): self.ttl = ttl self.cache = {} self.hits = 0 self.misses = 0 def _generate_key(self, data: Any) -> str: """生成缓存键""" data_str = json.dumps(data, sort_keys=True) return hashlib.md5(data_str.encode()).hexdigest() def get(self, key_data: Any) -> Optional[Any]: """获取缓存""" key = self._generate_key(key_data) if key in self.cache: entry = self.cache[key] # 检查是否过期 if datetime.now() < entry["expires_at"]: self.hits += 1 logger.debug(f"Cache hit for key: {key[:8]}...") return entry["data"] else: # 清理过期缓存 del self.cache[key] self.misses += 1 return None def set(self, key_data: Any, value: Any, ttl: Optional[int] = None): """设置缓存""" key = self._generate_key(key_data) ttl = ttl or self.ttl self.cache[key] = { "data": value, "expires_at": datetime.now() + timedelta(seconds=ttl), "created_at": datetime.now() } # 限制缓存大小 if len(self.cache) > 10000: # 最大缓存条目数 # 清理最早过期的缓存 sorted_keys = sorted( self.cache.keys(), key=lambda k: self.cache[k]["expires_at"] ) for old_key in sorted_keys[:1000]: # 清理1000个 del self.cache[old_key] def get_stats(self) -> dict: """获取缓存统计""" total = self.hits + self.misses hit_rate = self.hits / total if total > 0 else 0 return { "hits": self.hits, "misses": self.misses, "hit_rate": f"{hit_rate:.2%}", "size": len(self.cache) } # 创建全局缓存实例 response_cache = ResponseCache(ttl=settings.cache_ttl)5. 监控与日志系统
服务上线后,监控和日志是必不可少的。我集成了Prometheus指标和结构化日志。
5.1 监控指标收集
# app/utils/metrics.py from prometheus_client import Counter, Histogram, Gauge import time from functools import wraps # 定义指标 REQUEST_COUNT = Counter( 'dasd_api_requests_total', 'Total number of API requests', ['method', 'endpoint', 'status'] ) REQUEST_LATENCY = Histogram( 'dasd_api_request_duration_seconds', 'API request latency', ['method', 'endpoint'] ) ACTIVE_REQUESTS = Gauge( 'dasd_api_active_requests', 'Number of active requests' ) MODEL_INFERENCE_TIME = Histogram( 'dasd_model_inference_seconds', 'Model inference time' ) TOKENS_GENERATED = Counter( 'dasd_tokens_generated_total', 'Total tokens generated' ) def track_request(func): """请求跟踪装饰器""" @wraps(func) async def wrapper(*args, **kwargs): start_time = time.time() ACTIVE_REQUESTS.inc() try: response = await func(*args, **kwargs) status = "success" return response except Exception as e: status = "error" raise finally: latency = time.time() - start_time ACTIVE_REQUESTS.dec() # 这里需要从请求中提取信息 # 简化处理,实际使用时需要根据具体路由调整 REQUEST_LATENCY.labels( method="POST", endpoint=func.__name__ ).observe(latency) return wrapper def record_inference_time(duration: float): """记录推理时间""" MODEL_INFERENCE_TIME.observe(duration) def record_tokens(count: int): """记录生成的token数""" TOKENS_GENERATED.inc(count) def get_metrics(): """获取当前指标""" import prometheus_client from prometheus_client import generate_latest return generate_latest(prometheus_client.REGISTRY)5.2 结构化日志配置
# app/utils/logger.py import sys import json from loguru import logger from datetime import datetime def setup_logging(): """配置结构化日志""" # 移除默认配置 logger.remove() # 控制台输出(开发环境) logger.add( sys.stdout, format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>", level="INFO", colorize=True ) # 文件输出(生产环境) logger.add( "logs/dasd_api_{time:YYYY-MM-DD}.log", rotation="00:00", # 每天轮转 retention="30 days", # 保留30天 compression="zip", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}", level="INFO", serialize=True # 输出JSON格式 ) # 错误日志单独文件 logger.add( "logs/error_{time:YYYY-MM-DD}.log", rotation="00:00", retention="60 days", compression="zip", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}", level="ERROR", serialize=True ) return logger # 初始化日志 log = setup_logging()6. 主应用入口
最后,把所有模块整合起来:
# app/main.py from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import asyncio from contextlib import asynccontextmanager from app.routers import chat, completion, health from app.services.inference import inference_service from app.utils.logger import log from app.config import settings @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" # 启动时 log.info("正在启动DASD-4B-Thinking API服务...") # 初始化模型 try: await inference_service.initialize() log.info("模型初始化成功") except Exception as e: log.error(f"模型初始化失败: {e}") raise yield # 关闭时 log.info("正在关闭服务...") await inference_service.shutdown() log.info("服务已关闭") # 创建FastAPI应用 app = FastAPI( title="DASD-4B-Thinking API", description="基于FastAPI封装的DASD-4B-Thinking模型API服务", version="1.0.0", lifespan=lifespan ) # 添加CORS中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 生产环境应该限制来源 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 全局异常处理 @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): log.error(f"未处理的异常: {exc}", exc_info=True) return JSONResponse( status_code=500, content={"error": "Internal server error", "detail": str(exc)} ) # 注册路由 app.include_router(health.router) app.include_router(chat.router) app.include_router(completion.router) # 根路径 @app.get("/") async def root(): return { "service": "DASD-4B-Thinking API", "version": "1.0.0", "docs": "/docs", "health": "/health" } if __name__ == "__main__": import uvicorn uvicorn.run( "app.main:app", host=settings.api_host, port=settings.api_port, workers=settings.api_workers, log_level="info" if settings.api_debug else "warning" )7. 部署与测试
7.1 启动服务
创建启动脚本run.sh:
#!/bin/bash # 激活虚拟环境 source venv/bin/activate # 设置环境变量 export PYTHONPATH=$PYTHONPATH:$(pwd) # 启动服务 python -m app.main或者用Docker部署:
# Dockerfile FROM python:3.9-slim WORKDIR /app # 安装系统依赖 RUN apt-get update && apt-get install -y \ gcc \ g++ \ && rm -rf /var/lib/apt/lists/* # 复制依赖文件 COPY requirements.txt . # 安装Python依赖 RUN pip install --no-cache-dir -r requirements.txt # 复制应用代码 COPY app/ ./app/ # 创建日志目录 RUN mkdir -p logs # 暴露端口 EXPOSE 8000 # 启动命令 CMD ["python", "-m", "app.main"]7.2 API测试
服务启动后,可以用curl测试:
# 健康检查 curl http://localhost:8000/health # 聊天接口测试 curl -X POST http://localhost:8000/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [ {"role": "user", "content": "你好,请介绍一下你自己"} ], "max_tokens": 200 }' # 流式响应测试 curl -X POST http://localhost:8000/chat/stream \ -H "Content-Type: application/json" \ -d '{ "messages": [ {"role": "user", "content": "写一个关于AI的短故事"} ], "max_tokens": 300, "stream": true }'也可以用Python客户端测试:
# test_client.py import httpx import asyncio async def test_chat(): async with httpx.AsyncClient() as client: response = await client.post( "http://localhost:8000/chat/completions", json={ "messages": [ {"role": "user", "content": "什么是机器学习?"} ], "max_tokens": 150 } ) if response.status_code == 200: result = response.json() print(f"回复: {result['message']['content']}") print(f"使用情况: {result['usage']}") else: print(f"请求失败: {response.status_code}") print(response.text) async def test_stream(): async with httpx.AsyncClient() as client: async with client.stream( "POST", "http://localhost:8000/chat/stream", json={ "messages": [ {"role": "user", "content": "解释一下神经网络的工作原理"} ], "max_tokens": 200, "stream": true } ) as response: async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:] if data == "[DONE]": break print(data, end="", flush=True) print() if __name__ == "__main__": asyncio.run(test_chat()) # asyncio.run(test_stream())8. 总结与建议
整体用下来,FastAPI封装DASD-4B-Thinking模型还是挺顺畅的。这套方案有几个明显的优点:一是开发速度快,FastAPI的异步特性很适合AI推理这种IO密集型场景;二是性能不错,配合vLLM的推理引擎,响应速度有保障;三是扩展性好,后续加监控、加缓存、加限流都很方便。
实际部署时,有几点建议:首先,根据你的GPU显存调整vLLM的内存利用率参数,别设太高导致OOM;其次,生产环境一定要加API密钥认证,我为了演示简化了这部分;第三,考虑用Nginx做反向代理,处理SSL和负载均衡;第四,监控指标要接入到Prometheus+Grafana,方便观察服务状态。
这套代码只是个起点,你可以根据自己的需求调整。比如加上对话历史管理、支持多轮对话上下文、集成向量数据库做知识库增强等等。DASD-4B-Thinking的思考能力不错,好好利用能做出挺有意思的应用。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。