news 2026/5/22 20:20:29

Deeplabv3+实现双输出任务(分割+分类)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Deeplabv3+实现双输出任务(分割+分类)

1. 引言

DeepLabv3+ 是经典的语义分割模型。根据实际项目需求,我对其网络结构进行了修改,使其支持双输出任务:同时输出像素级分割结果与图像级分类结果。

2. 代码修改

2.1网络结构修改

nets/deeplabv3_plus.py修改DeepLab类,增加分类头

import torch import torch.nn as nn import torch.nn.functional as F from nets.xception import xception from nets.mobilenetv2 import mobilenetv2 class MobileNetV2(nn.Module): def __init__(self, downsample_factor=8, pretrained=True): super(MobileNetV2, self).__init__() from functools import partial model = mobilenetv2(pretrained) self.features = model.features[:-1] self.total_idx = len(self.features) self.down_idx = [2, 4, 7, 14] if downsample_factor == 8: for i in range(self.down_idx[-2], self.down_idx[-1]): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=4) ) elif downsample_factor == 16: for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate // 2, dilate // 2) m.padding = (dilate // 2, dilate // 2) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x): low_level_features = self.features[:4](x) x = self.features[4:](low_level_features) return low_level_features, x # -----------------------------------------# # ASPP特征提取模块 # 利用不同膨胀率的膨胀卷积进行特征提取 # -----------------------------------------# class ASPP(nn.Module): def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): super(ASPP, self).__init__() self.branch1 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch3 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch4 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) self.branch5_relu = nn.ReLU(inplace=True) self.conv_cat = nn.Sequential( nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) def forward(self, x): [b, c, row, col] = x.size() # -----------------------------------------# # 一共五个分支 # -----------------------------------------# conv1x1 = self.branch1(x) conv3x3_1 = self.branch2(x) conv3x3_2 = self.branch3(x) conv3x3_3 = self.branch4(x) # -----------------------------------------# # 第五个分支,全局平均池化+卷积 # -----------------------------------------# global_feature = torch.mean(x, 2, True) global_feature = torch.mean(global_feature, 3, True) global_feature = self.branch5_conv(global_feature) global_feature = self.branch5_bn(global_feature) global_feature = self.branch5_relu(global_feature) global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) # -----------------------------------------# # 将五个分支的内容堆叠起来 # 然后1x1卷积整合特征。 # -----------------------------------------# feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) result = self.conv_cat(feature_cat) return result # ==============================新增双任务模型:分割 + 分类===================================== class DeepLab(nn.Module): def __init__( self, num_classes, # 分割类别数 num_classes_classify=2, # 分类类别数 backbone="mobilenet", pretrained=True, downsample_factor=16 ): super(DeepLab, self).__init__() if backbone == "xception": self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained) in_channels = 2048 low_level_channels = 256 elif backbone == "mobilenet": self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained) in_channels = 320 low_level_channels = 24 else: raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone)) self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16 // downsample_factor) self.shortcut_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) self.cat_conv = nn.Sequential( nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), ) # 原来的分割头 self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1) # ==================== 新增:图像分类头 ==================== self.num_classes_classify = num_classes_classify self.classification_head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Dropout(0.5), nn.Linear(256, 512), nn.ReLU(inplace=True), nn.Linear(512, num_classes_classify) ) # ========================================================== def forward(self, x): H, W = x.size(2), x.size(3) # -----------------------------------------# #
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/22 20:11:29

独立开发者如何利用Taotoken管理多个副业项目的AI支出

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 独立开发者如何利用Taotoken管理多个副业项目的AI支出 对于独立开发者而言,同时维护多个小型AI应用或机器人是常见的状…

作者头像 李华
网站建设 2026/5/22 20:10:21

上班族开例会懒得记要点?2026年这3款AI总结工具,会后自动整理纪要

做互联网运营四年,开会已经成了每天的常态。部门周例会、项目复盘会、线上培训课、远程沟通会,大大小小的视频会议一场接一场。以前最让我头疼的不是参会,而是会后整理纪要。开会时既要认真听讨论、跟进工作进度,又要低头飞速记笔…

作者头像 李华
网站建设 2026/5/22 20:03:26

权威榜单2026年250克以下微型无人机推荐

在2026年的无人机市场,轻量化航拍需求愈发显著,特别是250克以下微型无人机。这类无人机以其便携性和简单操作,逐渐成为新手用户的首选。推荐关注的产品包括备受认可的博坦ATOM系列,提供全面的AI智能功能和出色画质,以及…

作者头像 李华