news 2025/12/30 21:43:16

第P2周:CIFAR10彩色图片识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
第P2周:CIFAR10彩色图片识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖原作者:K同学啊

目录

一、 前期准备

1. 设置GPU

2. 导入数据

3. 数据可视化

二、构建简单的CNN网络

三、 训练模型

1. 设置超参数

2. 编写训练函数

3. 编写测试函数

4. 正式训练

四、 结果可视化

五、个人总结

一、 前期准备

1. 设置GPU

import torch import torch.nn as nn import matplotlib.pyplot as plt import torchvision device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device

2. 导入数据

train_ds = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), download=True) test_ds = torchvision.datasets.CIFAR10('data', train=False, transform=torchvision.transforms.ToTensor(), download=True) batch_size = 32 train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True) test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size) imgs, labels = next(iter(train_dl)) imgs.shape

3. 数据可视化

import numpy as np plt.figure(figsize=(20, 5)) for i, imgs in enumerate(imgs[:20]): npimg = imgs.numpy().transpose((1, 2, 0)) plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')

二、构建简单的CNN网络

import torch.nn.functional as F num_classes = 10 class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, kernel_size=3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3) self.pool3 = nn.MaxPool2d(kernel_size=2) self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, start_dim=1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary # 将模型转移到GPU中(我们模型运行均在GPU中进行) model = Model().to(device) summary(model)

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数 learn_rate = 1e-2 # 学习率 opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环 def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小,一共60000张图片 num_batches = len(dataloader) # 批次数目,1875(60000/32) train_loss, train_acc = 0, 0 # 初始化训练损失和正确率 for X, y in dataloader: # 获取图片及其标签 X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出 loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 每一步自动更新 # 记录acc与loss train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集的大小,一共10000张图片 num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整) test_loss, test_acc = 0, 0 # 当不进行训练时,停止梯度更新,节省计算内存消耗 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss

4. 正式训练

epochs = 10 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done')

四、 结果可视化

import matplotlib.pyplot as plt #隐藏警告 import warnings warnings.filterwarnings("ignore") #忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['figure.dpi'] = 100 #分辨率 from datetime import datetime current_time = datetime.now() # 获取当前时间 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs_range, train_acc, label='Training Accuracy') plt.plot(epochs_range, test_acc, label='Test Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_loss, label='Training Loss') plt.plot(epochs_range, test_loss, label='Test Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()

五、个人总结

逐渐熟悉CNN模型构建过程,并逐步理解其原理。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/26 20:19:43

jQuery EasyUI 布局 - 在面板中创建复杂布局

jQuery EasyUI 布局 - 在面板中创建复杂布局 jQuery EasyUI 支持布局的嵌套(nested layout),允许在 panel(面板)或其他区域内放置另一个 easyui-layout,从而构建非常复杂的界面布局。这种方式常用于创建自…

作者头像 李华
网站建设 2025/12/17 10:58:44

开源TTS模型选型指南:为何EmotiVoice脱颖而出?

开源TTS模型选型指南:为何EmotiVoice脱颖而出? 在智能语音技术飞速发展的今天,我们早已不满足于“能说话”的AI。从车载助手到虚拟偶像,用户期待的是有情绪、有个性、像真人一样的声音。然而,大多数开源文本转语音&…

作者头像 李华
网站建设 2025/12/29 5:24:15

React RSC 新漏洞可导致 DoS 和源代码泄露

聚焦源代码安全,网罗国内外最新资讯!编译:代码卫士React团队修复了React服务器组件(RSC)中的两个新漏洞,如遭成功利用,可能导致拒绝服务(DoS)或源代码泄露。React 团队表…

作者头像 李华
网站建设 2025/12/17 10:57:39

【大模型微调】11-Prefix Tuning技术:分析Prefix Tuning的工作机制

引言Prefix Tuning技术是近年来在自然语言处理(NLP)领域崭露头角的一种创新方法。作为一种高效的模型微调技术,Prefix Tuning旨在通过在输入序列前添加可学习的"前缀"(prefix)来调整预训练语言模型的性能&am…

作者头像 李华
网站建设 2025/12/24 6:55:53

低版本ant design vue 实现年度选择器

"ant-design-vue": "^1.6.5", 实现年度选择器, 绕了一大圈才找到解决办法,特在此记录 <a-form-item class"formItem_style"><span class"form-item-title">年度</span><a-date-picker v-model"searchForm.…

作者头像 李华
网站建设 2025/12/17 10:57:22

Cursor Rule:AI如何革新代码导航与智能提示

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容&#xff1a; 开发一个基于Cursor Rule的智能代码导航插件&#xff0c;要求&#xff1a;1. 支持通过自然语言描述跳转到指定代码段&#xff08;如跳转到用户登录验证逻辑&#xff09;2. 根据当前…

作者头像 李华