AI:拿来主义——预训练网络(一)
我们已经训练过几个神经网络了,识别手写数字,房价预测或者是区分猫和狗,那随之而来就有一个问题,这些训练出的网络怎么用,每个问题我都需要重新去训练网络吗?因为程序员都不太喜欢做重复的事情,因此答案肯定是已经有轮子了。
我们先来介绍一个数据集,ImageNet。这就不得不提一个大名鼎鼎的华裔 AI 科学家李飞飞。
2005 年左右,李飞飞结束了他的博士生涯,开始了他的学术研究不就她就意识到了一个问题,在此之前,人们都尽可能优化算法,认为无论数据如何,只要算法够好,就能做出更好的决策,李飞飞意识到了这个问题的局限性,恰巧她还是一个行动派,她要做出一个无比庞大的数据集,尽可能描述世界上一切物体的数据集,下载图片,给没一张图片做标注,简单而无聊,当然后来这项工作放到了亚马逊的众包平台上,全世界无数的人参与了这个伟大的项目,到此刻为止,已经有 14,197,122 张图片(一千四百万张),21841 个分类。在这个发展的过程中,人们也发现了这个数据集带来的成功远比预想的要多,甚至现在被认为最有前景的深度卷积神经网络的提出也与 ImageNet 不无关系。我忘记了谁这么说过:“就单单这一个数据集,就可以让李飞飞数据科学这个领域拥有一席之地”。暂且不说这么说是否准确,但这个数据集仍然在创造新的突破。(我曾经在台下听过李飞飞一次演讲,现在想想还觉得甚是激动,她真的充满热情)。
基于这个数据集,我们是不是可以训练出一些网络,一般情况下,大家就不用耗时再去训练网络了呢?答案是肯定的,并且在 Keras 就有个一些这样的模型,还是内置的,Keras 就是这么懂你,那就不用客气了,我们拿来用就好了,谢谢啦!
特征提取
我们之前用到的卷积神经网络都是分成了两部分,第一部分是由池化层和卷积层组成的卷积积,第二部分是由分类器,特征提取的含义就是第一部分不变,改变第二部分。
为什么可以这么做?我们之前解释过神经网络的运行原理,跟人脑的认识过程非常类似,还记得吗?我们还是看一看原来的图吧。
我们可以看出来,网络识别图像是有层次结构的,比如一开始的网络层是用来识别图像或者拼装线条的,这是通用且类似的,因此我们可以复用。而后面的分类器往往是根据具体的问题所决定的,比如识别猫或狗的眼睛就与识别桌子腿是不一样的,因此有越靠前越具有通用性的特点。Keras 中很多的内置模型都可以直接下载,如果你没有下载在使用的时候会自动下载:
https://github.com/fchollet/deep-learning-models/releases
我们举一个例子,用 VGG16 去识别猫或狗,这次的解释都比较简单且都是以前说明过的,因此放在代码注释中:
#!/usr/bin/env python3 import os import time import matplotlib.pyplot as plt import numpy as np from keras import layers from keras import models from keras import optimizers from keras.applications import VGG16 from keras.preprocessing.image import ImageDataGenerator def extract_features(directory, sample_count): # 图片转换区间 datagen = ImageDataGenerator(rescale=1. / 255) batch_size = 20 conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3)) conv_base.summary() features = np.zeros(shape=(sample_count, 4, 4, 512)) labels = np.zeros(shape=(sample_count)) # 读出图片,处理成神经网络需要的数据格式,上一篇文章中有介绍 generator = datagen.flow_from_directory( directory, target_size=(150, 150), batch_size=batch_size, class_mode='binary') i = 0 for inputs_batch, labels_batch in generator: print(i, '/', len(generator)) # 提取特征 features_batch = conv_base.predict(inputs_batch) features[i * batch_size: (i + 1) * batch_size] = features_batch labels[i * batch_size: (i + 1) * batch_size] = labels_batch i += 1 if i * batch_size >= sample_count: break # 特征和标签 return features, labels def cat(): base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small' train_dir = os.path.join(base_dir, 'train') validation_dir = os.path.join(base_dir, 'validation') # 提取出的特征 train_features, train_labels = extract_features(train_dir, 2000) validation_features, validation_labels = extract_features(validation_dir, 1000) # 对特征进行变形展平 train_features = np.reshape(train_features, (2000, 4 * 4 * 512)) validation_features = np.reshape(validation_features, (1000, 4 * 4 * 512)) # 定义密集连接分类器 model = models.Sequential() model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512)) model.add(layers.Dropout(0.5)) model.add(layers.Dense(1, activation='sigmoid')) # 对模型进行配置 model.compile(optimizer=optimizers.RMSprop(lr=2e-5), loss='binary_crossentropy', metrics=['acc']) # 对模型进行训练 history = model.fit(train_features, train_labels, epochs=30, batch_size=20, validation_data=(validation_features, validation_labels)) # 画图 acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(1, len(acc) + 1) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.show() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show() if __name__ == "__main__": time_start = time.time() cat() time_end = time.time() print('Time Used: ', time_end - time_start)
有点巧合的是这里居然看不到太多的过拟合的痕迹,其实也是有可能会有过拟合的隐患的,那样就需要进行数据增强,与以前是一样的,只不过这里的区别就是用到了内置模型,模型的参数需要冻结,我们是不希望对已经训练好的模型进行更改的,具体关键代码写法如下:
conv_base.trainable = False model = models.Sequential() model.add(conv_base) model.add(layers.Flatten()) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dense(1, activation='sigmoid'))
以上就是模型复用的一种方法,我们对模型都是原封不动的拿来用,我们下一篇文章将介绍另外一种方法,对模型进行微调。
> 首发自公众号:RAIS

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
探究 Go 语言 defer 语句的三种机制
Golang 的 1.13 版本 与 1.14 版本对 defer 进行了两次优化,使得 defer 的性能开销在大部分场景下都得到大幅降低,其中到底经历了什么原理? 这是因为这两个版本对 defer 各加入了一项新的机制,使得 defer 语句在编译时,编译器会根据不同版本与情况,对每个 defer 选择不同的机制,以更轻量的方式运行调用。 堆上分配 在 Golang 1.13 之前的版本中,所有 defer 都是在堆上分配,该机制在编译时会进行两个步骤: 在 defer 语句的位置插入 runtime.deferproc,当被执行时,延迟调用会被保存为一个 _defer 记录,并将被延迟调用的入口地址及其参数复制保存,存入 Goroutine 的调用链表中。 在函数返回之前的位置插入 runtime.deferreturn,当被执行时,会将延迟调用从 Goroutine 链表中取出并执行,多个延迟调用则以 jmpdefer 尾递归调用方式连续执行。 这种机制的主要性能问题存在于每个 defer 语句产生记录时的内存分配,以及记录参数和完成调用时参数移动的系统调用开销。 栈上分配 G...
- 下一篇
Spring Security基于Oauth2的SSO单点登录怎样做?一个注解搞定
一、说明 单点登录顾名思义就是在多个应用系统中,只需要登录一次,就可以访问其他相互信任的应用系统,免除多次登录的烦恼。本文主要介绍 同域 和 跨域 两种不同场景单点登录的实现原理,并使用 Spring Security 来实现一个最简单的跨域 SSO客户端 。 二、原理说明 单点登录主流都是基于共享 cookie 来实现的,下面分别介绍 同域 和 跨域 下的两种场景具体怎样实现共享 cookie 的 2.1. 同域单点登录 适用场景:都是企业自己的系统,所有系统都使用同一个一级域名通过不同的二级域名来区分。 举个例子:公司有一个一级域名为 zlt.com ,我们有三个系统分别是:门户系统(sso.zlt.com)、应用1(app1.zlt.com)和应用2(app2.zlt.com),需要实现系统之间的单点登录,实现架构如下: 核心原理: 门户系统设置 Cookie 的 domain 为一级域名也就是 zlt.com,这样就可以共享门户的 Cookie 给所有的使用该域名(xxx.zlt.com)的系统 使用 Spring Session 等技术让所有系统共享 Session 这样只...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- CentOS7,8上快速安装Gitea,搭建Git服务器
- SpringBoot2配置默认Tomcat设置,开启更多高级功能
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- Windows10,CentOS7,CentOS8安装MongoDB4.0.16
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- Eclipse初始化配置,告别卡顿、闪退、编译时间过长
- 设置Eclipse缩进为4个空格,增强代码规范
- SpringBoot2整合Redis,开启缓存,提高访问速度
- Windows10,CentOS7,CentOS8安装Nodejs环境