news 2026/4/27 11:54:22

《动手学深度学习》-48全连接卷积神经网络FCN实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
《动手学深度学习》-48全连接卷积神经网络FCN实现

全连接神经网络通过卷积神经网络CNN实现特征提取,然后通过1x1的卷积将通道数转换为类别个数,最后通过转置卷积层将图像的高宽变换为原输入图的尺寸大小

一、代码

1.构建net

(1)框架

pretrained_net=torchvision.models.resnet18(pretrained=True) # print(list(pretrained_net.children())[-3:])#最后两层为AdaptiveAvgPool2d、Linear去掉 net=nn.Sequential(*list(pretrained_net.children())[:-2])
num_classes=21 net.add_module('final_conv',nn.Conv2d(in_features=512, out_features=num_classes,kernel_size=1)) net.add_module('Transposed_conv',nn.ConvTranspose2d(num_classes,num_classes,kernel_size=64,padding=16,stride=32))

(2)初始化

def bilinear_kernel(in_channel,out_channel,kernel_size): factor=(kernel_size+1)//2 #上采样放大倍数 if kernel_size %2==1: center=factor-1 else: center=factor-0.5 og=(torch.arange(kernel_size).reshape(-1,1),torch.arange(kernel_size).reshape(1,-1))#og[0]是行向量kx1,ogp[1]列向量1xk,广播之后变成kxk, filt=(1-torch.abs(og[0]-center)/factor)*(1-torch.abs(og[1]-center)/factor)#kxk的矩阵,中心大,周围小 weight=torch.zeros((in_channel,out_channel,kernel_size,kernel_size)) weight[range(in_channel),range(out_channel),:,:]=filt#让输入通道c只影响同编号C’输出,不进行混合,只改变对角线上的K初始化 return weight
W=bilinear_kernel(num_classes,num_classes,64) net.Transposed_conv.weight.data.copy_(W)

(3)测试

conv_transopsed=nn.ConvTranspose2d(3,3,kernel_size=4,padding=1,stride=2,bias=False) conv_transopsed.weight.data.copy_(bilinear_kernel(3,3,4)) img=torchvision.transforms.ToTensor()(Image.open('D:/PycharmDocument/limu/data/dogcat.png').convert('RGB')) X=img.unsqueeze(0) Y=conv_transopsed(X) out_img=Y[0].permute(1,2,0).detach() print('input image shape',img.permute(1,2,0).shape) print('output image shape',out_img.shape) d2l.set_figsize() fig,axes=plt.subplots(1,2) axes[0].imshow(img.permute(1,2,0)) axes[0].set_title('input image') axes[1].imshow(out_img) axes[1].set_title('output image') d2l.plt.show()

输入一张图,采用conv_transopsed操作,看一下大小,可以看出经过转置卷积,输出图片尺寸大一倍,

2.读取数据

batch_size,crop_size=36,(320,480)
train_iter,test_iter=test46SemanticSegmentation.load_data_voc(batch_size=batch_size,crop_size=crop_size)
voc_dir = 'D:/VOCtrainval_11-May-2012/VOCdevkit/VOC2012'
def read_voc_images(voc_dir, is_train=True):
"""读取所有VOC图像并标注"""
# 这里代码会自动拼路径:voc_dir + ImageSets + Segmentation + train.txt
txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',
'train.txt' if is_train else 'val.txt')
mode = torchvision.io.image.ImageReadMode.RGB
with open(txt_fname, 'r') as f:
images = f.read().split()
features, labels = [], []
for i, fname in enumerate(images):
# 读取原始图片
features.append(torchvision.io.read_image(os.path.join(
voc_dir, 'JPEGImages', f'{fname}.jpg')))
# 读取语义分割标签图
labels.append(torchvision.io.read_image(os.path.join(
voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))
return features, labels

3.训练

def loss(inputs,targets):
return F.cross_entropy(inputs,targets,reduction='none').mean(1).mean(1)
num_epochs,lr,wd,device=5,0.01,1e-3,d2l.try_gpu()
trainer=torch.optim.SGD(net.parameters(),lr=lr,weight_decay=wd)
d2l.train_ch3(net,trainer,num_epochs,batch_size,device)
4.预测
def predect(img):
X=test_iter.dataset.normalize_image(img).unsqueeze(0)#(1,3,h,w,)
pred=net(X.to(device)).argmax(dim=1)#(1,h,w)
return pred.reshape(pred.shape[1],pred.shape[2])#(h,w)
#根据类别反向找对应的rgb,将像素点涂对应的颜色
def label2image(pred):
colormap=torch.tensor(test46SemanticSegmentation.VOC_COLORMAP,device=device)
X=pred.long()
return colormap[X,:]
test_images,test_labels=read_voc_images(voc_dir,is_train=False)
n,imags=4,[]
for i in range(n):
crop_rect=(0,0,320,480)
X=torchvision.transforms.functional.crop(test_images[i],*crop_rect)
pred=label2image(predect(X))
imags+=[X.permute(1,2,0),pred.cpu(),torchvision.transforms.functional.crop(test_labels[i],*crop_rect).permute(1,2,0)]
d2l.show_images(imags[::3]+imags[1::3]+imags[2::3],3,n,scale=2)

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/27 9:30:53

社交网络影响力分析:大数据方法与实践

社交网络影响力分析:从大数据方法到实践落地的全指南 摘要/引言:为什么你需要重新理解“影响力”? 去年双11,某美妆品牌的市场部犯了愁: 他们花50万找了一位“百万粉小红书KOL”推广新品,结果笔记点赞破1…

作者头像 李华
网站建设 2026/4/24 23:18:26

11-3 register integration

文章目录 原始代码 详细解读:Direct vs Layered UVM验证框架 一、第一段代码(Direct框架)详细解读 1. APB从设备模块(slave) 2. 寄存器模型(dut_regmodel) 3. 测试平台(tb_top) 4. 测试环境(tb_env)关键代码 二、第二段代码(Layered框架)详细解读 1. APB从设备模块…

作者头像 李华
网站建设 2026/4/26 5:50:42

12款智能论文生成工具分析:数学建模论文快速复现与专业格式调整方法

还在为数学建模论文的复现与排版问题困扰?面对时间紧迫、任务繁重的挑战,AI工具或许能成为你的得力助手。本次评测将针对10款热门AI论文写作工具进行深度分析,帮助你快速找到最适合提升写作效率与排版质量的解决方案,让学术创作事…

作者头像 李华
网站建设 2026/4/25 4:06:06

9D VR体验馆设备多少钱的投资分析与运营策略探讨

9D VR体验馆设备投资成本详解与市场分析 在考虑9D VR体验馆设备的投资成本时,首先需要评估几个关键因素,包括设施建设、设备采购及日常运营等方面。通常,初期投资大约在10万至15万元之间,这包括了VR双人蛋椅、VR魔力互动设备和VR3…

作者头像 李华
网站建设 2026/4/25 8:08:25

基于51/STM32单片机智能水杯保温杯恒温温度控制防干烧水质设计(设计源文件+万字报告+讲解)(支持资料、图片参考_相关定制)_文章底部可以扫码

STM32-S264-水量监测保温常温温度灯光指示定时提醒定时开关加热防干烧参数可设OLED屏声光提醒(无线方式选择)STM32-S264N无无线-无APP板(硬件操作详细): STM32-S264B蓝牙无线-APP版: STM32-S264W-WIFI无线-APP版: STM32-S264CAN-视频监控WIFI无线-APP版: STM32-S264I-云平台-AP…

作者头像 李华