AI:拿来主义——预训练网络(一)
我们已经训练过几个神经网络了,识别手写数字,房价预测或者是区分猫和狗,那随之而来就有一个问题,这些训练出的网络怎么用,每个问题我都需要重新去训练网络吗?因为程序员都不太喜欢做重复的事情,因此答案肯定是已经有轮子了。
我们先来介绍一个数据集,ImageNet。这就不得不提一个大名鼎鼎的华裔 AI 科学家李飞飞。
2005 年左右,李飞飞结束了他的博士生涯,开始了他的学术研究不就她就意识到了一个问题,在此之前,人们都尽可能优化算法,认为无论数据如何,只要算法够好,就能做出更好的决策,李飞飞意识到了这个问题的局限性,恰巧她还是一个行动派,她要做出一个无比庞大的数据集,尽可能描述世界上一切物体的数据集,下载图片,给没一张图片做标注,简单而无聊,当然后来这项工作放到了亚马逊的众包平台上,全世界无数的人参与了这个伟大的项目,到此刻为止,已经有 14,197,122 张图片(一千四百万张),21841 个分类。在这个发展的过程中,人们也发现了这个数据集带来的成功远比预想的要多,甚至现在被认为最有前景的深度卷积神经网络的提出也与 ImageNet 不无关系。我忘记了谁这么说过:“就单单这一个数据集,就可以让李飞飞数据科学这个领域拥有一席之地”。暂且不说这么说是否准确,但这个数据集仍然在创造新的突破。(我曾经在台下听过李飞飞一次演讲,现在想想还觉得甚是激动,她真的充满热情)。
基于这个数据集,我们是不是可以训练出一些网络,一般情况下,大家就不用耗时再去训练网络了呢?答案是肯定的,并且在 Keras 就有个一些这样的模型,还是内置的,Keras 就是这么懂你,那就不用客气了,我们拿来用就好了,谢谢啦!
特征提取
我们之前用到的卷积神经网络都是分成了两部分,第一部分是由池化层和卷积层组成的卷积积,第二部分是由分类器,特征提取的含义就是第一部分不变,改变第二部分。
为什么可以这么做?我们之前解释过神经网络的运行原理,跟人脑的认识过程非常类似,还记得吗?我们还是看一看原来的图吧。
我们可以看出来,网络识别图像是有层次结构的,比如一开始的网络层是用来识别图像或者拼装线条的,这是通用且类似的,因此我们可以复用。而后面的分类器往往是根据具体的问题所决定的,比如识别猫或狗的眼睛就与识别桌子腿是不一样的,因此有越靠前越具有通用性的特点。Keras 中很多的内置模型都可以直接下载,如果你没有下载在使用的时候会自动下载:
https://github.com/fchollet/deep-learning-models/releases
我们举一个例子,用 VGG16 去识别猫或狗,这次的解释都比较简单且都是以前说明过的,因此放在代码注释中:
#!/usr/bin/env python3
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from keras import layers
from keras import models
from keras import optimizers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
def extract_features(directory, sample_count):
# 图片转换区间
datagen = ImageDataGenerator(rescale=1. / 255)
batch_size = 20
conv_base = VGG16(weights='imagenet',
include_top=False,
input_shape=(150, 150, 3))
conv_base.summary()
features = np.zeros(shape=(sample_count, 4, 4, 512))
labels = np.zeros(shape=(sample_count))
# 读出图片,处理成神经网络需要的数据格式,上一篇文章中有介绍
generator = datagen.flow_from_directory(
directory,
target_size=(150, 150),
batch_size=batch_size,
class_mode='binary')
i = 0
for inputs_batch, labels_batch in generator:
print(i, '/', len(generator))
# 提取特征
features_batch = conv_base.predict(inputs_batch)
features[i * batch_size: (i + 1) * batch_size] = features_batch
labels[i * batch_size: (i + 1) * batch_size] = labels_batch
i += 1
if i * batch_size >= sample_count:
break
# 特征和标签
return features, labels
def cat():
base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
# 提取出的特征
train_features, train_labels = extract_features(train_dir, 2000)
validation_features, validation_labels = extract_features(validation_dir, 1000)
# 对特征进行变形展平
train_features = np.reshape(train_features, (2000, 4 * 4 * 512))
validation_features = np.reshape(validation_features, (1000, 4 * 4 * 512))
# 定义密集连接分类器
model = models.Sequential()
model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(1, activation='sigmoid'))
# 对模型进行配置
model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
loss='binary_crossentropy',
metrics=['acc'])
# 对模型进行训练
history = model.fit(train_features, train_labels,
epochs=30,
batch_size=20,
validation_data=(validation_features, validation_labels))
# 画图
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.show()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
if __name__ == "__main__":
time_start = time.time()
cat()
time_end = time.time()
print('Time Used: ', time_end - time_start)
有点巧合的是这里居然看不到太多的过拟合的痕迹,其实也是有可能会有过拟合的隐患的,那样就需要进行数据增强,与以前是一样的,只不过这里的区别就是用到了内置模型,模型的参数需要冻结,我们是不希望对已经训练好的模型进行更改的,具体关键代码写法如下:
conv_base.trainable = False
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
以上就是模型复用的一种方法,我们对模型都是原封不动的拿来用,我们下一篇文章将介绍另外一种方法,对模型进行微调。
> 首发自公众号:RAIS
关注公众号
低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
-
上一篇
两个桶兑出特定容积的水
面试的时候,可能会经常碰到这样一个问题:嘉定区有两个桶,一个容量为 3 升,一个容量为 5 升,我们怎么能够不通过其他度量工具的帮助兑出 1 升的水来。假定水是无限的。 !! 此处限制条件:会给定先倒入哪个桶,并且在倒的过程中,不能出现如下情况:给定先倒的桶空了,而另一个桶是满的。 例如题:https://exercism.io/my/solutions/3b849bba3ee840ccac12eb7ca734ba8e 问题分析 如果单纯针对这个问题来看,相信我们还是可以很容易的得到一个推导过程。既然我们有两个桶,一个是 3 升,一个是 5 升,那么我们可能需要将一个桶装满水,然后倒到另外一个桶里,通过将一个桶倒满而另外一个桶可能还有盈余来达到最终兑换出期望容量的水的目的。按照这个思路,我们可以开始第一步分析。 初步分析 上个例子问题中,我们整个兑水的过程可以描述如下(假设先倒入 3 升的桶): (3, 0)将 a 桶倒满; (0, 3) 将 a 桶倒入 b 桶;此情况可以出现,因为虽然 a 桶满了,但是 b 桶未满 (3, 3) 将 a 桶倒满; (1, 5) 将 a 桶倒入 b 桶...
-
下一篇
jSqlBox 4.0.3 发布,Java 全功能数据库持久层 ORM 工具
jSqlBox是一个Java全功能数据库持久层工具,主页 https://gitee.com/drinkjava2/jsqlbox jSqlBox主要特点是架构优、尺寸小、功能全,基本上所有与数据库操作相关的功能,jSqlBox都已提供。其它数据库工具的主要功能,jSqlBox通常都有对应的方案提供,不提供的往往是不重要甚至是反模式的功能,不信可以随便找一个其它持久层工具的功能,而jSqlBox却没有提供或无法轻易实现的功能来,我会给一个合理解释。反之,如下jSqlBox具备的功能其它持久层工具往往并不提供: 1.支持在Java里存放和定位多行(SQL)文本 2.参数内嵌式SQL写法 3.自带分库分表 4.自带声明式事务、分布式事务 5.ActiveRecord模式实体类可以只声明接口,不占用宝贵的单继承 6.支持多种数据库方言的DDL生成 jSqlBox的目标是做最好的通用持久层工具,正如我喜欢给别的项目找缺点一样,也欢迎大家来找出jSqlBox的缺点。开源项目就是在不断的对比和相互学习中进步的。另外如果有谁想听我唠叨其它持久层工具优缺点的,也可以在评论里提出,一起来讨论,jSqlB...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- Dcoker安装(在线仓库),最新的服务器搭配容器使用
- CentOS7,CentOS8安装Elasticsearch6.8.6
- CentOS6,CentOS7官方镜像安装Oracle11G
- CentOS8安装MyCat,轻松搞定数据库的读写分离、垂直分库、水平分库
- CentOS6,7,8上安装Nginx,支持https2.0的开启
- Linux系统CentOS6、CentOS7手动修改IP地址
- Windows10,CentOS7,CentOS8安装Nodejs环境
- SpringBoot2更换Tomcat为Jetty,小型站点的福音
- SpringBoot2整合Redis,开启缓存,提高访问速度
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装

微信收款码
支付宝收款码