news 2026/5/13 14:29:15

【深度学习】 —— VGG-16 网络实现猫狗识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【深度学习】 —— VGG-16 网络实现猫狗识别
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

文章目录

  • 1. 简介 & 数据集介绍
  • 2. 环境
  • 3. 代码实现
    • 3.1 前期准备
      • 3.1.1 设置GPU & 导入库
      • 3.1.2 数据集统计与预览
    • 3.2 数据预处理
      • 3.2.1 数据集划分与预处理
      • 3.2.2 类别识别
      • 3.2.3 可视化
      • 3.2.4 整体数据检查
    • 3.3 模型建立与训练
      • 3.3.1 构建 VGG16 模型
      • 3.3.2 模型编译与训练
  • 4. 模型评估
    • 4.1 绘制训练集与验证集的 Accuracy 和 Loss 趋势图
    • 4.2 模型预测

1. 简介 & 数据集介绍

利用 TensorFlow,通过构建 VGG-16 网络实现猫狗识别。数据集中有 dog 和 cat 2 类图片,每类图片数量各有 1700 张图片。

2. 环境

  • 语言环境:Python 3.12.7
  • 编译器:Jupyter Notebook
  • 深度学习环境:TensorFlow 2.21.0

3. 代码实现

3.1 前期准备

3.1.1 设置GPU & 导入库

导入必要的库并配置 GPU 显存增长,以解决在 Windows 环境下可能出现的显存占用或驱动兼容性问题。

importtensorflowastfimportmatplotlib.pyplotaspltimportos,PIL,pathlibimportwarningsfromtqdmimporttqdmimporttensorflow.keras.backendasKfromtensorflow.kerasimportlayers,models,Inputfromtensorflow.keras.modelsimportModelfromtensorflow.keras.layersimportConv2D,MaxPooling2D,Dense,Flatten,Dropoutfromdatetimeimportdatetimeimportnumpyasnp plt.rcParams['font.sans-serif']=['SimHei']plt.rcParams['axes.unicode_minus']=Falsewarnings.filterwarnings('ignore')gpus=tf.config.list_physical_devices("GPU")ifgpus:tf.config.experimental.set_memory_growth(gpus[0],True)tf.config.set_visible_devices([gpus[0]],"GPU")print(gpus)

3.1.2 数据集统计与预览

通过 pathlib 扫描本地目录统计 3400 张图像。

data_dir="./Data/365-7-data"data_dir=pathlib.Path(data_dir)image_count=len(list(data_dir.glob('*/*')))print("图片总数为:",image_count)

3.2 数据预处理

3.2.1 数据集划分与预处理

使用 Keras 提供的便捷接口构建了训练数据集,将图像统一缩放至 VGG16 标准的 224x224 尺寸,设置批次大小为 8,并按 8:2 的比例划出了 80%(2720 张图片)的数据用于模型训练。

batch_size=8img_height=224img_width=224train_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height,img_width),batch_size=batch_size)


采用与构建训练集完全相同的参数和随机种子(seed=123),从同一个目录中划分出剩余的20%(680 张图片)作为验证数据集,以确保训练集和验证集互不重叠,用于评估模型性能。

val_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height,img_width),batch_size=batch_size)

3.2.2 类别识别

从创建好的训练数据集中提取并打印了分类的类别名称,输出的列表为[‘cat’, ‘dog’],明确了当前实验是一个简单的图像二分类任务。

class_names=train_ds.class_namesprint(class_names)

3.2.3 可视化

从训练集中抽取一个批次的数据(包含图像和标签),并打印它们的维度。输出结果 (8, 224, 224, 3) 验证了每批次包含 8 张长宽为 224 的 RGB 三通道彩色图片。

AUTOTUNE=tf.data.AUTOTUNEdefpreprocess_image(image,label):return(image/255.0,label)train_ds=train_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)val_ds=val_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)train_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)plt.figure(figsize=(15,10))forimages,labelsintrain_ds.take(1):foriinrange(8):ax=plt.subplot(5,8,i+1)plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")

3.2.4 整体数据检查

通过打印第一个批次中图像和标签的维度(shape)来验证数据结构的正确性,输出显示图像张量维度为 (8, 224, 224, 3),标签张量维度为(8,)。

forimage_batch,labels_batchintrain_ds:print(image_batch.shape)print(labels_batch.shape)break

3.3 模型建立与训练

3.3.1 构建 VGG16 模型

利用 Keras 函数式 API 从零开始手动搭建了一个经典的 VGG16 卷积神经网络架构,包含5个特征提取卷积块和末端的高维全连接层,并打印了包含约 1.38 亿个参数的模型结构摘要。

defVGG16(nb_classes,input_shape):input_tensor=Input(shape=input_shape)# 1st blockx=Conv2D(64,(3,3),activation='relu',padding='same',name='block1_conv1')(input_tensor)x=Conv2D(64,(3,3),activation='relu',padding='same',name='block1_conv2')(x)x=MaxPooling2D((2,2),strides=(2,2),name='block1_pool')(x)# 2nd blockx=Conv2D(128,(3,3),activation='relu',padding='same',name='block2_conv1')(x)x=Conv2D(128,(3,3),activation='relu',padding='same',name='block2_conv2')(x)x=MaxPooling2D((2,2),strides=(2,2),name='block2_pool')(x)# 3rd blockx=Conv2D(256,(3,3),activation='relu',padding='same',name='block3_conv1')(x)x=Conv2D(256,(3,3),activation='relu',padding='same',name='block3_conv2')(x)x=Conv2D(256,(3,3),activation='relu',padding='same',name='block3_conv3')(x)x=MaxPooling2D((2,2),strides=(2,2),name='block3_pool')(x)# 4th blockx=Conv2D(512,(3,3),activation='relu',padding='same',name='block4_conv1')(x)x=Conv2D(512,(3,3),activation='relu',padding='same',name='block4_conv2')(x)x=Conv2D(512,(3,3),activation='relu',padding='same',name='block4_conv3')(x)x=MaxPooling2D((2,2),strides=(2,2),name='block4_pool')(x)# 5th blockx=Conv2D(512,(3,3),activation='relu',padding='same',name='block5_conv1')(x)x=Conv2D(512,(3,3),activation='relu',padding='same',name='block5_conv2')(x)x=Conv2D(512,(3,3),activation='relu',padding='same',name='block5_conv3')(x)x=MaxPooling2D((2,2),strides=(2,2),name='block5_pool')(x)# full connectionx=Flatten()(x)x=Dense(4096,activation='relu',name='fc1')(x)x=Dense(4096,activation='relu',name='fc2')(x)output_tensor=Dense(nb_classes,activation='softmax',name='predictions')(x)model=Model(input_tensor,output_tensor)returnmodel model=VGG16(1000,(img_width,img_height,3))model.summary()

3.3.2 模型编译与训练

配置初始值为 0.0001 且呈指数衰减的学习率策略,使用 Adam 优化器对模型进行编译,随后在训练集上进行了 10 个周期的训练。

model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])epochs=10lr=1e-4history_train_loss=[]history_train_accuracy=[]history_val_loss=[]history_val_accuracy=[]forepochinrange(epochs):train_total=len(train_ds)val_total=len(val_ds)withtqdm(total=train_total,desc=f'Epoch{epoch+1}/{epochs}',mininterval=1,ncols=100)aspbar:lr=lr*0.92model.optimizer.learning_rate.assign(lr)forimage,labelintrain_ds:history=model.train_on_batch(image,label)train_loss=history[0]train_accuracy=history[1]pbar.set_postfix({"loss":"%.4f"%train_loss,"accuracy":"%.4f"%train_accuracy,"lr":model.optimizer.learning_rate.numpy()})pbar.update(1)history_train_loss.append(train_loss)history_train_accuracy.append(train_accuracy)print('开始验证!')withtqdm(total=val_total,desc=f'Epoch{epoch+1}/{epochs}',mininterval=0.3,ncols=100)aspbar:forimage,labelinval_ds:history=model.test_on_batch(image,label)val_loss=history[0]val_accuracy=history[1]pbar.set_postfix({"loss":"%.4f"%val_loss,"accuracy":"%.4f"%val_accuracy})pbar.update(1)history_val_loss.append(val_loss)history_val_accuracy.append(val_accuracy)print('结束验证!')print("验证loss为:%.4f"%val_loss)print("验证准确率为:%.4f"%val_accuracy)
Epoch1/10:100%|████████|340/340[19:30<00:00,3.44s/it,loss=0.8238,accuracy=0.5651,lr=9.2e-5]开始验证! Epoch1/10:100%|█████████████████████|85/85[13:52<00:00,9.79s/it,loss=0.7740,accuracy=0.5979]结束验证! 验证loss为:0.7740 验证准确率为:0.5979 Epoch2/10:100%|███████|340/340[14:52<00:00,2.62s/it,loss=0.5792,accuracy=0.7113,lr=8.46e-5]开始验证! Epoch2/10:100%|█████████████████████|85/85[00:37<00:00,2.25it/s,loss=0.5325,accuracy=0.7354]结束验证! 验证loss为:0.5325 验证准确率为:0.7354 Epoch3/10:100%|███████|340/340[14:15<00:00,2.52s/it,loss=0.4109,accuracy=0.8016,lr=7.79e-5]开始验证! Epoch3/10:100%|█████████████████████|85/85[00:38<00:00,2.19it/s,loss=0.3878,accuracy=0.8131]结束验证! 验证loss为:0.3878 验证准确率为:0.8131 Epoch4/10:100%|███████|340/340[15:28<00:00,2.73s/it,loss=0.3213,accuracy=0.8474,lr=7.16e-5]开始验证! Epoch4/10:100%|█████████████████████|85/85[00:51<00:00,1.65it/s,loss=0.3072,accuracy=0.8543]结束验证! 验证loss为:0.3072 验证准确率为:0.8543 Epoch5/10:100%|███████|340/340[16:07<00:00,2.85s/it,loss=0.2651,accuracy=0.8756,lr=6.59e-5]开始验证! Epoch5/10:100%|█████████████████████|85/85[00:48<00:00,1.75it/s,loss=0.2557,accuracy=0.8802]结束验证! 验证loss为:0.2557 验证准确率为:0.8802 Epoch6/10:100%|███████|340/340[16:27<00:00,2.91s/it,loss=0.2247,accuracy=0.8953,lr=6.06e-5]开始验证! Epoch6/10:100%|█████████████████████|85/85[00:44<00:00,1.92it/s,loss=0.2185,accuracy=0.8983]结束验证! 验证loss为:0.2185 验证准确率为:0.8983 Epoch7/10:100%|███████|340/340[15:44<00:00,2.78s/it,loss=0.1956,accuracy=0.9093,lr=5.58e-5]开始验证! Epoch7/10:100%|█████████████████████|85/85[00:42<00:00,2.00it/s,loss=0.1906,accuracy=0.9117]结束验证! 验证loss为:0.1906 验证准确率为:0.9117 Epoch8/10:100%|███████|340/340[15:45<00:00,2.78s/it,loss=0.1731,accuracy=0.9200,lr=5.13e-5]开始验证! Epoch8/10:100%|█████████████████████|85/85[00:42<00:00,1.98it/s,loss=0.1691,accuracy=0.9219]结束验证! 验证loss为:0.1691 验证准确率为:0.9219 Epoch9/10:100%|███████|340/340[15:19<00:00,2.70s/it,loss=0.1542,accuracy=0.9289,lr=4.72e-5]开始验证! Epoch9/10:100%|█████████████████████|85/85[00:41<00:00,2.05it/s,loss=0.1508,accuracy=0.9305]结束验证! 验证loss为:0.1508 验证准确率为:0.9305 Epoch10/10:100%|██████|340/340[15:14<00:00,2.69s/it,loss=0.1398,accuracy=0.9358,lr=4.34e-5]开始验证! Epoch10/10:100%|████████████████████|85/85[00:39<00:00,2.16it/s,loss=0.1375,accuracy=0.9369]结束验证! 验证loss为:0.1375 验证准确率为:0.9369

4. 模型评估

4.1 绘制训练集与验证集的 Accuracy 和 Loss 趋势图

提取了 model.fit() 返回的历史训练数据,并使用 Matplotlib 将训练集与验证集的准确率(Accuracy)和损失值(Loss)随时间变化的趋势绘制成了两幅直观的折线图。

current_time=datetime.now()epochs_range=range(epochs)plt.figure(figsize=(12,4))plt.subplot(1,2,1)plt.plot(epochs_range,history_train_accuracy,label='Training Accuracy')plt.plot(epochs_range,history_val_accuracy,label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.xlabel(current_time)plt.subplot(1,2,2)plt.plot(epochs_range,history_train_loss,label='Training Loss')plt.plot(epochs_range,history_val_loss,label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()

4.2 模型预测

从验证集中抽取第一个批次(8张图片),利用已经存在于上下文中的模型(model)对其逐一进行张量维度扩充和推理预测。最后使用 Matplotlib 绘制一行 8 列的画板,将图像打印出来,并将模型识别出的分类名称(猫或狗)写在各子图的标题上。

plt.figure(figsize=(18,3))plt.suptitle("预测结果展示")forimages,labelsinval_ds.take(1):foriinrange(8):ax=plt.subplot(1,8,i+1)plt.imshow(images[i].numpy())img_array=tf.expand_dims(images[i],0)predictions=model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")


版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/13 14:19:12

McpManager:统一AI模型与工具调用的MCP协议管理器实战

1. 项目概述与核心价值 最近在折腾AI应用开发&#xff0c;特别是想把不同的大模型能力整合到自己的项目里时&#xff0c;遇到了一个挺普遍的问题&#xff1a;每个模型、每个工具都有自己的API&#xff0c;调用方式千差万别。今天想和大家深入聊聊一个我最近在用的、感觉能极大提…

作者头像 李华
网站建设 2026/5/13 14:18:58

AI编码工具箱实战:从原理到应用,提升开发效率

1. 项目概述&#xff1a;一个面向开发者的AI编码工具箱最近在GitHub上看到一个挺有意思的项目&#xff0c;叫Lu7474/ai-coding-toolkit。光看名字&#xff0c;你大概能猜到这是个和AI编程相关的工具集。但具体是什么&#xff0c;能解决什么问题&#xff0c;值不值得花时间去研究…

作者头像 李华
网站建设 2026/5/13 14:13:32

半导体假冒芯片识别与全流程防御实战指南

1. 项目概述&#xff1a;一场看不见硝烟的战争在电子行业摸爬滚打了十几年&#xff0c;我见过太多因为一颗小小的芯片而引发的“血案”。生产线突然停摆&#xff0c;出货的产品在客户现场批量失效&#xff0c;甚至在某些极端案例中&#xff0c;导致了严重的安全事故。追根溯源&…

作者头像 李华
网站建设 2026/5/13 14:11:18

65.人工智能实战:模型幻觉怎么前置发现?从无答案问题集到拒答策略、证据校验与幻觉率监控

人工智能实战:模型幻觉怎么前置发现?从无答案问题集到拒答策略、证据校验与幻觉率监控 一、问题场景:资料里没有答案,模型却说得很像真的 大模型最危险的问题之一是: 不知道时也会说。在企业知识库场景中,这尤其严重。 用户问: 公司是否报销宠物托运费?资料里没有任…

作者头像 李华