news 2026/1/10 16:52:16

04_残差网络

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
04_残差网络

描述

残差网络是现代卷积神经网络的一种,有效的抑制了深层神经网络的梯度弥散和梯度爆炸现象,使得深度网络训练不那么困难。

下面以cifar-10-batches-py数据集,实现一个ResNet18的残差网络,通过继承nn.Module实现残差块(Residual Block),网络模型类。

定义Block

ResNetBlock派生至nn.Module,需要自己实现forward函数。

torch.nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法,可以从这个类派生自己的模型类。

nn.Module重要的函数:

  • forward(self,*input):forward函数为前向传播函数,需要自己重写,它用来实现模型的功能,并实现各个层的连接关系;
  • __call__(self, *input, **kwargs): __call__()的作用是使class实例能够像函数一样被调用,以“对象名()”的形式使用;
  • __repr__(self):__repr__函数为Python的一个内置函数,它能把一个对象用字符串的形式表达出来;
  • __init__(self):构造函数,自定义模型的网络层对象一般在这个函数中定义。
classResNetBlock(nn.Module):def__init__(self,input_channels,num_channels,stride=1):''' 构造函数:定义网络层 '''super().__init__()self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=stride)self.btn1=nn.BatchNorm2d(num_channels)self.conv2=nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1,stride=1)self.btn2=nn.BatchNorm2d(num_channels)ifstride!=1:self.downsample=nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=stride)else:self.downsample=lambdax:xdefforward(self,X):''' 实现反向传播 '''Y=self.btn1(self.conv1(X))Y=nn.functional.relu(Y)Y=self.btn2(self.conv2(Y))Y+=self.downsample(X)returnnn.functional.relu(Y)

定义模型

ResNet同样派生于nn.Module,与ResNetBlock类似,需要实现forward。

torch.nn.Sequential是PyTorch 中一个用于构建顺序神经网络模型的容器类,它将多个神经网络层或模块按顺序组合在一起,简化模型搭建过程。‌Sequential器会严格按照添加的顺序执行内部的子模块,前向传播时自动传递数据,适用于简单神经网络的构建。

classResNet(nn.Module):def__init__(self,layer_dism,num_class=10):''' 构造函数:定义预处理model;构建block层 '''super(ResNet,self).__init__()# 预处理self.stem=nn.Sequential(nn.Conv2d(3,64,3,1),# 3x30x30nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2,2)# 64x15x15)self.layer1=self.build_resblock(64,64,layer_dism[0])self.layer2=self.build_resblock(64,128,layer_dism[1],2)self.layer3=self.build_resblock(128,256,layer_dism[2],2)self.layer4=self.build_resblock(256,512,layer_dism[3],2)self.avgpool=nn.AvgPool2d(1,1)self.btn=nn.Flatten()self.fc=nn.Linear(512,num_class)defbuild_resblock(self,input_channels,num_channels,block,stride=1):res_block=nn.Sequential()res_block.append(ResNetBlock(input_channels,num_channels,stride))for_inrange(1,block):res_block.append(ResNetBlock(num_channels,num_channels,stride))returnres_blockdefforward(self,X):out=self.stem(X)out=self.layer1(out)out=self.layer2(out)out=self.layer3(out)out=self.layer4(out)out=self.avgpool(out)returnself.fc(self.btn(out))

模型训练

加载数据

使用torchvision.datasets加载本地数据,如果本地没有数据,可以设置download=True自动下载。

# 定义数据转换transform=transforms.Compose([transforms.ToTensor(),# 将PIL图像转换为Tensortransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))# 归一化])# 加载CIFAR-10训练集trainset=torchvision.datasets.CIFAR10(root=r'D:\dwload',train=True,download=False,transform=transform)trainloader=th.utils.data.DataLoader(trainset,batch_size=16,shuffle=False,num_workers=2)# 加载CIFAR-10测试集testset=torchvision.datasets.CIFAR10(root=r'D:\dwload',train=False,download=False,transform=transform)testloader=th.utils.data.DataLoader(testset,batch_size=16,shuffle=False,num_workers=2)

模型初始化

模型初始化是确保网络能够有效学习的关键步骤,一个好的初始值,会使模型收敛速度提高,使模型准确率更精确。

torch.nn.init模块提供了一系列的权重初始化函数:

  • torch.nn.init.uniform_ :均匀分布
  • torch.nn.init.normal_ :正态分布
  • torch.nn.init.constant_:初始化为指定常数
  • torch.nn.init.kaiming_uniform_:凯明均匀分布
  • torch.nn.init.kaiming_normal_:凯明正态分布
  • torch.nn.init.xavier_uniform_:Xavier均匀分布
  • torch.nn.init.xavier_normal_:Xavier正态分布

在初始化时,最好不要将模型的参数初始化为0,因为这样会导致梯度消失,进而影响训练效果。可以将模型初始化为一个很小的值,如0.01,0.001等。

definitialize_weight(m):ifisinstance(m,nn.Conv2d)orisinstance(m,nn.Linear):nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')# mode:权重方差计算方式,可选 'fan_in' 或 'fan_out'(输入、输出神经元数量)# nonlinearity:激活函数类型,用于调整计算公式 ,一般是relu、leaky_reluifm.biasisnotNone:nn.init.constant_(m.bias,0)

[2,2,2,2] 参数分别代表四个block的中的残差块数量(可以仔细看一下build_resblock函数)

resnet_18=ResNet([2,2,2,2])resnet_18.apply(initialize_weight)# 初始化模型loss_cross=nn.CrossEntropyLoss()trainer=th.optim.SGD(resnet_18.parameters())

训练

训练过程比较漫长,这里训练只有20轮,测试精度0.51。如果有N卡加持的话,可以适当调高epoch,精度能进一步提高。

forepochinrange(0,20):running_loss=0.0forinputs,labelsintrainloader:trainer.zero_grad()outputs=resnet_18(inputs)loss=loss_cross(outputs,labels)loss.backward()trainer.step()running_loss+=loss.item()print(f'[{epoch+1}] ev loss:{running_loss/3125}')running_loss=0.0
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2025/12/18 14:56:56

【气象灾害Agent预警阈值优化】:掌握精准预测的5大核心参数配置

第一章:气象灾害Agent预警阈值的核心意义在智能气象监测系统中,Agent技术被广泛应用于实时数据采集与灾害预警。预警阈值作为核心参数,直接决定了系统对异常气象事件的响应灵敏度与准确性。设定合理的阈值,能够在极端天气发生前及…

作者头像 李华
网站建设 2025/12/18 14:56:24

Luckysheet数据验证:告别数据录入烦恼的完整指南

还在为员工录入错误数据而头疼吗?财务报表中出现不合规的数值?客户信息表中的手机号格式五花八门?Luckysheet的数据验证功能正是你需要的解决方案。这个强大的功能可以确保表格数据的准确性和一致性,让你从繁琐的数据校对工作中解…

作者头像 李华
网站建设 2026/1/9 22:50:31

为什么显示器分辨率越高越清晰?——从像素到 4K/8K 的视觉革命

🖥️ 为什么显示器分辨率越高越清晰?——从像素到 4K/8K 的视觉革命 👁️大家好,我是无限大,欢迎收看十万个为什么系列文章今天咱们来聊聊显示器这个"电脑的脸"!从模糊的老式显示器到如今的4K/8K…

作者头像 李华
网站建设 2025/12/18 14:55:48

为什么顶尖实验室都在布局量子-经典Agent协同?真相曝光

第一章:量子 - 经典 Agent 的协同在混合计算架构日益普及的背景下,量子计算资源与经典计算系统的协同工作成为实现实际应用的关键路径。通过构建量子 - 经典 Agent 协同框架,开发者能够将传统算法逻辑与量子加速能力有机结合,充分…

作者头像 李华
网站建设 2025/12/18 14:55:40

iOS降级神器:macOS平台A6/A7设备终极降级攻略

iOS降级神器:macOS平台A6/A7设备终极降级攻略 【免费下载链接】LeetDown a GUI macOS Downgrade Tool for A6 and A7 iDevices 项目地址: https://gitcode.com/gh_mirrors/le/LeetDown 对于拥有iPhone 5s、iPad 4等A6/A7芯片设备的用户来说,系统降…

作者头像 李华