文章目录
- 一、前期工作
- 1. 设置GPU
- 2. 导入数据
- 3. 查看数据
- 二、数据预处理
- 1. 加载数据
- 2. 可视化数据
- 3. 再次检查数据
- 4. 配置数据集
- 5. 归一化
- 三、构建VGG-19网络
- 1. 官方模型(已打包好)
- 2. 自建模型
- 3. 网络结构图
- 四、编译
- 五、训练模型
- 六、模型评估
- 七、保存and加载模型
- 八、预测
一、前期工作
本文将实现灵笼中人物角色的识别。较上一篇文章,这次我采用了VGG-19结构,并增加了预测与保存and加载模型两个部分。
我的环境:
- 语言环境:Python3.6.5
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2
往期精彩内容:
- 深度学习100例-卷积神经网络(CNN)实现mnist手写数字识别 | 第1天
- 深度学习100例-卷积神经网络(CNN)彩色图片分类 | 第2天
- 深度学习100例-卷积神经网络(CNN)服装图像分类 | 第3天
- 深度学习100例-卷积神经网络(CNN)花朵识别 | 第4天
- 深度学习100例-卷积神经网络(CNN)天气识别 | 第5天
- 深度学习100例-卷积神经网络(VGG-16)识别海贼王草帽一伙 | 第6天
来自专栏:【深度学习100例】
1. 设置GPU
如果使用的是CPU可以忽略这步
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gpus[0]],"GPU")
2. 导入数据
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
import os,PIL
# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)
# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)
from tensorflow import keras
from tensorflow.keras import layers,models
import pathlib
data_dir = "D:/jupyter notebook/DL-100-days/datasets/linglong_photos"
data_dir = pathlib.Path(data_dir)
3. 查看数据
数据集中一共有白月魁、查尔斯、红蔻、马克、摩根、冉冰等6个人物角色。
文件夹 | 含义 | 数量 |
---|---|---|
baiyuekui | 白月魁 | 40 张 |
chaersi | 查尔斯 | 76 张 |
hongkou | 红蔻 | 36 张 |
make | 马克 | 38张 |
mogen | 摩根 | 30 张 |
ranbing | 冉冰 | 60张 |
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 280
二、数据预处理
1. 加载数据
使用image_dataset_from_directory
方法将磁盘中的数据加载到tf.data.Dataset
中
batch_size = 16
img_height = 224
img_width = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.1,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
Found 280 files belonging to 6 classes.
Using 252 files for training.
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.1,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
Found 280 files belonging to 6 classes.
Using 28 files for validation.
我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。
class_names = train_ds.class_names
print(class_names)
['baiyuekui', 'chaersi', 'hongkou', 'make', 'mogen', 'ranbing']
2. 可视化数据
plt.figure(figsize=(10, 5)) # 图形的宽为10高为5
plt.suptitle("
本文链接:http://m.zhangshiyu.com/post/20773.html
- 已完结小说《我的怀表连通八零》全章节在线阅读
- 排行通历史:盘点史上十大猛将最新章节,排行通历史:盘点史上十大猛将章节目录阅读
- StableDiffusionWebUI 让我找到了宫崎骏动漫里的夏天
- 初始Python篇(11)—— 面向对象三大特征