基于OpenAI Clip模型的自动图像分类
一 环境安装
pip install git+https://github.com/openai/CLIP.git
pip install torch torchvision pip install git+https://github.com/openai/CLIP.git pip install pillow matplotlib
二 基本使用
import clip import torchfromPIL import Image import numpyasnpclassCLIPImageClassifier:def__init__(self,model_name="ViT-B/32",device=None):""" 初始化CLIP分类器 Args:model_name:CLIP模型名称,可选:"ViT-B/32","ViT-B/16","ViT-L/14"device:运行设备,"cuda"或"cpu""""ifdeviceisNone:self.device="cuda"iftorch.cuda.is_available()else"cpu"else:self.device=deviceprint(f"使用设备: {self.device}")print(f"加载模型: {model_name}")# 加载CLIP模型和预处理函数self.model,self.preprocess=clip.load(model_name,device=self.device)self.model.eval()defclassify_image(self,image_path,class_names,top_k=5):""" 对单张图像进行分类 Args:image_path:图像路径 class_names:类别名称列表 top_k:返回top-k预测结果 Returns:包含预测结果的字典列表"""# 加载并预处理图像try:image=Image.open(image_path).convert('RGB')except Exceptionase:print(f"无法加载图像: {e}")returnNoneimage_input=self.preprocess(image).unsqueeze(0).to(self.device)# 处理文本类别text_inputs=torch.cat([clip.tokenize(f"a photo of {c}")forcinclass_names]).to(self.device)# 推理withtorch.no_grad():image_features=self.model.encode_image(image_input)text_features=self.model.encode_text(text_inputs)# 计算相似度image_features=image_features/image_features.norm(dim=-1,keepdim=True)text_features=text_features/text_features.norm(dim=-1,keepdim=True)similarity=(100.0*image_features @ text_features.T).softmax(dim=-1)# 获取top-k结果probs=similarity.cpu().numpy()[0]top_indices=np.argsort(probs)[::-1][:top_k]results=[]foridxintop_indices:results.append({'class':class_names[idx],'probability':float(probs[idx])})returnresultsdefclassify_batch(self,image_paths,class_names,top_k=5):""" 批量分类多张图像""" results=[]forimage_pathinimage_paths:result=self.classify_image(image_path,class_names,top_k)ifresult:results.append({'image_path':image_path,'predictions':result})returnresults# 使用示例if__name__=="__main__":# 初始化分类器classifier=CLIPImageClassifier(model_name="ViT-B/32")# 定义类别(可以是任何你想要的类别)class_names=["cat","dog","bird","car","airplane","beach","mountain","forest","city","ocean","apple","banana","orange","person","bicycle"]# 单张图像分类image_path="test_image.jpg"# 替换为你的图像路径results=classifier.classify_image(image_path,class_names,top_k=3)ifresults:print("\n分类结果:")fori,resultinenumerate(results,1):print(f"{i}. {result['class']}: {result['probability']:.2%}")
![]()