news 2026/6/7 21:35:47

深度学习实验——PyTorch实现CIFAR10彩色图片识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
深度学习实验——PyTorch实现CIFAR10彩色图片识别
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

文章目录

  • 1. 简介
  • 2. 环境
  • 3. 数据集介绍
  • 4. 代码实现
    • 4.1 前期准备
      • 4.1.1 导入库 & GPU设置
      • 4.1.2 数据下载和数据集划分
      • 4.1.3 数据可视化
    • 4.2 模型构建
    • 4.3 模型训练
      • 4.3.1 设置超参数 & 编写训练和测试函数
      • 4.3.2 正式训练
  • 5. 结果可视化

1. 简介

利用Pytorch构建CNN模型以用于识别彩色图片

2. 环境

  • 语言环境:Python 3.12.7
  • 编译器:Jupyter Notebook
  • 深度学习环境:torch—2.8.0 + cu126 / torchvision—0.23.1+cu126

3. 数据集介绍

CIFAR-10数据集,又称加拿大高等研究院数据集是一个常用于训练机器学习和计算机视觉算法的图像集合。它是最广泛使用的机器学习研究数据集之一。CIFAR-10数据集包含60,000张32×32像素的彩色图像,分为10个不同的类别。

4. 代码实现

4.1 前期准备

4.1.1 导入库 & GPU设置

importtorchimporttorch.nnasnnimportmatplotlib.pyplotaspltimporttorchvisionimportnumpyasnpimporttorch.nn.functionalasFfromtorchinfoimportsummaryimportwarningsfromdatetimeimportdatetime warnings.filterwarnings("ignore")plt.rcParams['font.sans-serif']=['SimHei']plt.rcParams['axes.unicode_minus']=Falseplt.rcParams['figure.dpi']=100device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")device

4.1.2 数据下载和数据集划分

先使用torchvision的datasets下载CIFAR10数据集,并划分好训练集与测试集。

train_ds=torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)test_ds=torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)


然后使用DataLoader()加载数据,并设置好基本的batch_size。

batch_size=32train_dl=torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True)test_dl=torch.utils.data.DataLoader(test_ds,batch_size=batch_size)imgs,labels=next(iter(train_dl))imgs.shape

4.1.3 数据可视化

使用transpose()对NumPy数组进行轴变换,将轴的顺序从PyTorch存储图像的(C, H, W)格式转换为(H, W, C)格式,使得数据格式更适合Matplotlib imshow() 函数可视化和处理。

plt.figure(figsize=(20,5))fori,imgsinenumerate(imgs[:20]):npimg=imgs.numpy().transpose((1,2,0))plt.subplot(2,10,i+1)plt.imshow(npimg,cmap=plt.cm.binary)plt.axis('off')

4.2 模型构建

这个模型专门为32×32像素的CIFAR-10图像设计(10个类别),包含3个卷积层和2个全连接层。
首先通过三个卷积层逐级提取图像特征:第一层将RGB三通道转换为64个特征图,第二层保持64个特征图进行深度特征提取,第三层进一步扩展到128个特征图以捕获更复杂的模式,每个卷积层后都使用2×2最大池化层逐步降低空间分辨率。然后网络将三维特征图展平为一维向量,通过两个全连接层进行分类决策:第一层将512维特征压缩到256维并应用ReLU激活函数,第二层输出最终的10个类别分数。

num_classes=10classModel(nn.Module):def__init__(self):super().__init__()self.conv1=nn.Conv2d(3,64,kernel_size=3)self.pool1=nn.MaxPool2d(kernel_size=2)self.conv2=nn.Conv2d(64,64,kernel_size=3)self.pool2=nn.MaxPool2d(kernel_size=2)self.conv3=nn.Conv2d(64,128,kernel_size=3)self.pool3=nn.MaxPool2d(kernel_size=2)self.fc1=nn.Linear(512,256)self.fc2=nn.Linear(256,num_classes)defforward(self,x):x=self.pool1(F.relu(self.conv1(x)))x=self.pool2(F.relu(self.conv2(x)))x=self.pool3(F.relu(self.conv3(x)))x=torch.flatten(x,start_dim=1)x=F.relu(self.fc1(x))x=self.fc2(x)returnx model=Model().to(device)summary(model)

4.3 模型训练

4.3.1 设置超参数 & 编写训练和测试函数

训练函数train在每个批次中执行前向传播计算预测值,使用交叉熵损失评估误差,通过反向传播计算梯度并利用SGD优化器更新模型参数,同时统计训练准确率和损失;测试函数test则在禁用梯度计算的模式下进行前向传播,评估模型在验证集上的表现而不更新权重,最终返回模型在测试数据上的平均准确率和损失,两个函数共同构成了一个典型的有监督深度学习训练评估循环。

loss_fn=nn.CrossEntropyLoss()learn_rate=1e-2opt=torch.optim.SGD(model.parameters(),lr=learn_rate)deftrain(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset)num_batches=len(dataloader)train_loss,train_acc=0,0forX,yindataloader:X,y=X.to(device),y.to(device)pred=model(X)loss=loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()train_loss+=loss.item()train_acc/=size train_loss/=num_batchesreturntrain_acc,train_lossdeftest(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)test_loss,test_acc=0,0withtorch.no_grad():forimgs,targetindataloader:imgs,target=imgs.to(device),target.to(device)target_pred=model(imgs)loss=loss_fn(target_pred,target)test_loss+=loss.item()test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()test_acc/=size test_loss/=num_batchesreturntest_acc,test_loss

4.3.2 正式训练

epochs=10train_loss=[]train_acc=[]test_loss=[]test_acc=[]forepochinrange(epochs):model.train()epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template=('Epoch:{:2d}, train_acc:{:.1f}%, train_loss:{:.3f}, test_acc:{:.1f}%, test_loss:{:.3f}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))print('Done')

5. 结果可视化

current_time=datetime.now()epochs_range=range(epochs)plt.figure(figsize=(12,3))plt.subplot(1,2,1)plt.plot(epochs_range,train_acc,label='Training Accuracy')plt.plot(epochs_range,test_acc,label='Test Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.xlabel(current_time)plt.subplot(1,2,2)plt.plot(epochs_range,train_loss,label='Training Loss')plt.plot(epochs_range,test_loss,label='Test Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/2 20:13:42

【纤维协程资源释放全攻略】:掌握高效内存管理的5大核心技巧

第一章:纤维协程资源释放的核心意义在现代高并发系统中,纤维(Fiber)作为一种轻量级的用户态线程,被广泛应用于提升程序的执行效率与资源利用率。然而,若未能妥善管理其生命周期,尤其是未及时释放…

作者头像 李华
网站建设 2026/5/22 12:28:06

掌握这3种R语言方法,轻松实现气象数据中百年一遇极值识别

第一章:气象数据的 R 语言极端值检测在气象数据分析中,识别极端天气事件(如极端高温、强降雨等)是风险评估与气候建模的关键步骤。R 语言提供了丰富的统计工具和可视化函数,能够高效实现极端值检测。常用方法包括基于广…

作者头像 李华
网站建设 2026/6/7 9:43:03

为什么你的甲基化分析结果不显著?这4个R语言常见错误你可能正在犯

第一章:为什么你的甲基化分析结果不显著?在进行DNA甲基化数据分析时,许多研究者常遇到统计结果不显著的问题。这并非总是因为生物学效应不存在,而更可能是实验设计或数据处理中的关键环节被忽视。样本量不足导致统计效能低下 甲基…

作者头像 李华
网站建设 2026/6/4 7:22:44

RTL8852BE Linux驱动终极指南:轻松解决无线网卡兼容性问题

RTL8852BE Linux驱动终极指南:轻松解决无线网卡兼容性问题 【免费下载链接】rtl8852be Realtek Linux WLAN Driver for RTL8852BE 项目地址: https://gitcode.com/gh_mirrors/rt/rtl8852be 还在为Linux系统下Realtek RTL8852BE无线网卡无法正常工作而烦恼吗&…

作者头像 李华