Java教程

七、Data Augmentation技巧

本文主要是介绍七、Data Augmentation技巧,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

目录
  • 前文
  • 数据生成器+数据部分展示
  • 数据增强模型
  • 数据增强模型的编译与拟合
  • GitHub下载地址:

前文

  • 一、Windows系统下安装Tensorflow2.x(2.6)
  • 二、深度学习-读取数据
  • 三、Tensorflow图像处理预算
  • 四、线性回归模型的tensorflow实现
  • 五、深度学习-逻辑回归模型
  • 六、AlexNet实现中文字体识别——隶书和行楷
  • 七、VGG16实现鸟类数据库分类
  • 七、VGG16+BN(Batch Normalization)实现鸟类数据库分类
  • 七、BatchNormalization使用技巧

数据生成器+数据部分展示

#猫狗分类。数据增强
#数据生成器生成测试集
from keras.preprocessing.image import ImageDataGenerator

IMSIZE = 128
validation_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory(
    '../../data/dogs-vs-cats/smallData/validation',
    target_size=(IMSIZE, IMSIZE),
    batch_size=10,
    class_mode='categorical'
)

在这里插入图片描述

# 利用数据增强技术生成的训练集
train_generator = ImageDataGenerator(rescale=1. / 255, shear_range=0.5, rotation_range=30,
                                     zoom_range=0.2, width_shift_range=0.2, height_shift_range=0.2
                                     ).flow_from_directory('../../data/dogs-vs-cats/smallData/train',
                                                           target_size=(IMSIZE, IMSIZE), batch_size=10,
                                                           class_mode='categorical')

在这里插入图片描述

数据来源kaggle的猫狗数据

#展示数据增强后的图像
from matplotlib import pyplot as plt

plt.figure()
fig, ax = plt.subplots(2, 5)
fig.set_figheight(6)
fig.set_figwidth(15)
ax = ax.flatten()
X, Y = next(validation_generator)
for i in range(10): ax[i].imshow(X[i, :, :, ])

在这里插入图片描述

数据增强模型

#数据增强模型
IMSIZE = 128
from keras.layers import BatchNormalization, Conv2D, Dense, Flatten, Input, MaxPooling2D
from keras import Model

n_channel = 100
input_layer = Input([IMSIZE, IMSIZE, 3])
x = input_layer
x =BatchNormalization()(x)
for _ in range(7):
    x =BatchNormalization()(x)
    x =Conv2D(n_channel,[2,2],padding='same',activation='relu')(x)
    x =MaxPooling2D([2,2])(x)

x =Flatten()(x)
x =Dense(2,activation='softmax')(x)
output_layer = x
model = Model(input_layer,output_layer)
model.summary()

在这里插入图片描述

数据增强模型的编译与拟合

#数据增强模型的编译与拟合
from keras.optimizers import Adam
model.compile(loss='categorical_crossentropy',
               optimizer=Adam(lr=0.0001),
               metrics=['accuracy'])
model.fit_generator(train_generator,
                     epochs=200,
                     validation_data=validation_generator)

在这里插入图片描述

GitHub下载地址:

Tensorflow1.15深度学习

这篇关于七、Data Augmentation技巧的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!