news 2026/6/18 17:19:13

食物图像分类代码实战

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
食物图像分类代码实战

前言

延续之前所讲,基本上项目代码都是数据集的读入和处理,模型定义、训练之前的各种准备设置以及训练流程,那么接下来也是按照这个顺序进行。

数据集读入和处理

train_transform = transforms.Compose( [ transforms.ToPILImage(), # Convert image to PIL format (224,224,3) -> (3,224,224) transforms.RandomResizedCrop(224), # Random crop and resize to 224x224 transforms.RandomRotation(50), # Apply random rotation up to 50 degrees transforms.ToTensor() # Convert to tensor ] ) val_transform = transforms.Compose( [ transforms.ToPILImage(), # Convert image to PIL format transforms.ToTensor() # Convert to tensor ] )

图像数据集有点区别于前面的回归模型特征数据集,通常,在读入数据集阶段可以选用数据增强技术对图像处理,这种技术对图片进行随机放大裁剪旋转等操作,通过内置强化学习算法自动选择最优方式,这个过程类似在分类任务中让模型见识不同角度各种各样的某类物体,可以提高模型识别能力。另一方面也可以拓宽训练集,抑制模型过拟合。但是,在测试阶段不使用该技术,测试集上数据模型都没见过,可以检验模型泛化能力。

class FoodDataset(Dataset): def __init__(self, path, mode="train"): self.mode = mode self.transform = train_transform if mode == "train" else val_transform self.X, self.Y = self._load_data(path) def _load_data(self, path): X, Y = None, None for class_idx in range(11): class_dir = os.path.join(path, f"{class_idx:02d}") img_files = os.listdir(class_dir) class_images = np.zeros((len(img_files), HW, HW, 3), dtype=np.uint8) class_labels = np.full(len(img_files), class_idx, dtype=np.uint8) for idx, filename in enumerate(img_files): img_path = os.path.join(class_dir, filename) img = Image.open(img_path).resize((HW, HW)) class_images[idx] = img if class_idx == 0: X, Y = class_images, class_labels else: X = np.concatenate((X, class_images), axis=0) Y = np.concatenate((Y, class_labels), axis=0) print(f"Loaded {len(Y)} samples") return X, Y def __getitem__(self, index): return self.transform(self.X[index]), self.Y[index] def __len__(self): return len(self.X)

这次项目训练集和测试集在不同文件,因此,不需要像之前一样拆分,直接根据mode用读取不同的文件即可。

不同项目的实现方式各有差异:文件读取功能会根据数据集在本地的存储路径进行配置。Dataset类负责数据读取,将原始数据转换为三维数值表示;而Dataloader则用于处理这些数据集,如批量加载数据,支持数据打乱和多批次处理功能。

模型定义

class MyModel(nn.Module): def __init__(self, num_class): super(MyModel, self).__init__() # Input: 3x224x224 -> Output: 512x7x7 -> Flatten -> Fully connected layers # Initial convolution block self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) # Output: 64x224x224 self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.pool1 = nn.MaxPool2d(2) # Output: 64x112x112 # Feature extraction layers self.layer1 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), # 128x112x112 nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2) # 128x56x56 ) self.layer2 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), # 256x56x56 nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2) # 256x28x28 ) self.layer3 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), # 512x28x28 nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d(2) # 512x14x14 ) # Final pooling and classifier self.pool2 = nn.MaxPool2d(2) # 512x7x7 self.fc1 = nn.Linear(512*7*7, 1000) # 25088 -> 1000 self.relu2 = nn.ReLU() self.fc2 = nn.Linear(1000, num_class) # 1000 -> num_class def forward(self, x): x = self.pool1(self.relu(self.bn1(self.conv1(x)))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.pool2(x) x = x.view(x.size(0), -1) # Flatten x = self.relu2(self.fc1(x)) x = self.fc2(x) return x

模型定义相对简单且相似。值得一提的是,数据作为新时代的石油资源,我们训练的模型通常难以与投入数百万美元训练的大模型相媲美,因此可以采用迁移学习策略。迁移学习主要分为微调线性探测两种方式,二者的核心区别在于是否冻结主干网络的参数。

训练流程

def train_val(model, train_loader, val_loader, no_label_loader, device, epochs, optimizer, loss, thres, save_path): model = model.to(device) plt_train_loss = [] plt_val_loss = [] plt_train_acc = [] plt_val_acc = [] max_acc = 0.0 for epoch in range(epochs): train_loss = 0.0 val_loss = 0.0 train_acc = 0.0 val_acc = 0.0 start_time = time.time() # Training phase model.train() for batch_x, batch_y in train_loader: x, target = batch_x.to(device), batch_y.to(device) pred = model(x) train_bat_loss = loss(pred, target) train_bat_loss.backward() optimizer.step() optimizer.zero_grad() train_loss += train_bat_loss.item() train_acc += (pred.argmax(dim=1) == target).sum().item() avg_train_loss = train_loss / len(train_loader) avg_train_acc = train_acc / len(train_loader.dataset) plt_train_loss.append(avg_train_loss) plt_train_acc.append(avg_train_acc) # Validation phase model.eval() with torch.no_grad(): for batch_x, batch_y in val_loader: x, target = batch_x.to(device), batch_y.to(device) pred = model(x) val_bat_loss = loss(pred, target) val_loss += val_bat_loss.item() val_acc += (pred.argmax(dim=1) == target).sum().item() avg_val_loss = val_loss / len(val_loader) avg_val_acc = val_acc / len(val_loader.dataset) plt_val_loss.append(avg_val_loss) plt_val_acc.append(avg_val_acc) # Semi-supervised learning if epoch % 3 == 0 and avg_val_acc > 0.6: semi_loader = get_semi_loader(no_label_loader, model, device, thres) # Save best model if avg_val_acc > max_acc: torch.save(model, save_path) max_acc = avg_val_acc # Print progress elapsed = time.time() - start_time print(f'[{epoch:03d}/{epochs:03d}] {elapsed:.2f}s | ' f'TrainLoss: {avg_train_loss:.6f} | ValLoss: {avg_val_loss:.6f} | ' f'TrainAcc: {avg_train_acc:.6f} | ValAcc: {avg_val_acc:.6f}') # Plot training curves plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(plt_train_loss, label='Train') plt.plot(plt_val_loss, label='Val') plt.title('Loss Curve') plt.legend() plt.subplot(1, 2, 2) plt.plot(plt_train_acc, label='Train') plt.plot(plt_val_acc, label='Val') plt.title('Accuracy Curve') plt.legend() plt.show()

主干训练流程也是大体相似,这里和回归模型主要区别在于多了计算准确率(识别正确图片占比),因为分类任务输出 label 可以直接知道整体预测正确率。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/18 17:16:06

Gemini多模态原生架构解析:统一token空间与硬件感知推理

1. 项目概述:这不是一次普通模型发布,而是一场多模态能力的系统性重构“谷歌发布最新大模型Gemini,包含多模态、三大版本,还有哪些特点?能力是否超越 GPT-4了?”——这句话在2023年12月6日刷屏科技圈时&…

作者头像 李华
网站建设 2026/6/18 17:12:13

深入解析SCF5250微控制器:从ColdFire V2内核到音频处理实战

1. SCF5250微控制器:一款被低估的嵌入式音频处理利器在嵌入式音频处理、工业控制和消费电子领域,选对一颗微控制器(MCU)往往意味着项目成功了一半。今天我想和大家深入聊聊飞思卡尔(Freescale,现为NXP的一部…

作者头像 李华
网站建设 2026/6/18 17:07:50

【相机内参标定】张氏标定法

一、算法概述 张正友标定法是计算机视觉领域的里程碑成果。它巧妙地介于“传统标定法(需要高精度三维标定物)”与“自标定法(鲁棒性差)”之间,只需一个打印出来的二维平面棋盘格,在不同角度下拍摄几张照片,即可精确解算出相机的内参、外参及畸变系数。 二、单目相机成像…

作者头像 李华
网站建设 2026/6/18 17:04:39

AWVS专业Web漏洞扫描器部署与实战指南:从安装到深度扫描

1. 项目概述:为什么需要AWVS这样的专业扫描器?在Web应用安全领域,漏洞扫描器就像一位不知疲倦的“安全审计员”。想象一下,你开发了一个电商网站,上线前信心满满,但没过多久,用户数据泄露、支付…

作者头像 李华
网站建设 2026/6/18 17:02:55

三步极速部署方案:OpenCore Legacy Patcher让旧Mac重获新生

三步极速部署方案:OpenCore Legacy Patcher让旧Mac重获新生 【免费下载链接】OpenCore-Legacy-Patcher Experience macOS just like before 项目地址: https://gitcode.com/GitHub_Trending/op/OpenCore-Legacy-Patcher 你是否还在为苹果官方放弃的老旧Mac无…

作者头像 李华