【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像
前言
深度学习作为人工智能的重要手段,迎来了爆发,在NLP、CV、物联网、无人机等多个领域都发挥了非常重要的作用。最近几年,各种深度学习算法层出不穷, Generative Adverarial Network(GAN)自2014年提出以来,引起广泛关注,身为深度学习三巨头之一的Yan Lecun对GAN的评价颇高,认为GAN是近年来在深度学习上最大的突破,是近十年来机器学习上最有意思的工作。围绕GAN的论文数量也迅速增多,各种版本的GAN出现,主要在CV领域带来了一些贡献,如下图所示。
我们可以利用GAN生成一些我们需要的图像或者文本,比如二次元头像。
GAN简介
GAN主要的应用是自动生成一些东西,包括图像和文本等,比如随机给一个向量作为输入,通过GAN的Generator生成一张图片,或者生成一串语句。Conditional GAN的应用更多一些,比如数据集是一段文字和图像的数据对,通过训练,GAN可以通过给定一段文字生成对应的图像。
GAN主要可以分为Generator(生成器)和Discriminator(判别器)两个部分,其中Generator其实就是一个神经网络,输入一个向量,可以输出一张图像(即一个高维的向量表示),如下图示。
Discriminator也是一个神经网络,输入为一张图像,输出为一个数值,输出的数值用于判断输入的图像是否是真的,数值越大,说明图像是真的,数值越小,说明图像为假的,如下图示。
Generator负责生成图像,Discriminator负责对Generator生成的图像和真实图像去进行对比,区别出真假,Generator需要不断优化来欺骗Discriminator,以假乱真;而Discriminator也不断优化,来提高识别能力,能够识别出Generator的把戏。二者的这种关系可以形象地通过下图展示。
Generator和Discriminator连接起来,形成一个比较大的深层网络,即为GAN网络。
场景描述
深度学习的各种算法在PAI上可以通过PAI-DSW进行实现,在PAI-DSW上进行训练数据,利用GAN自动生成二次元头像。
数据准备
首先需要准备真实的二次元头像作为数据集,这里从网上找到一些共享的资源,存储在了钉钉钉盘中,钉盘地址 ,提取密码: c2pz,数据集如下图示,约5万多张:
算法实践
利用PAI-DSW进行GAN算法实践,首先需要安装准备好环境。
首先进入到Notebook建模,创建新实例,之后打开实例,进入Terminal,在Terminal下用户可以像在自己本地一样安装相应的依赖包,进行操作。
准备好环境之后,我们可以通过如下图示方法,将基于Tensorflow的DCGAN代码和数据集上传上去。
用于训练的DCGAN代码地址:https://github.com/carpedm20/DCGAN-tensorflow,关于DCGAN的网络框架图如下,详细介绍可以参考论文:https://arxiv.org/abs/1511.06434,这里我们不做详述。
数据集和代码上传成功,如下图示。
其中,data目录下的faces即为数据集,该文件夹下为对应的5万多张真实二次元头像。DCGAN-tensorflow为整个代码路径,其中最主要的两个代码文件是main.py和model.py,其中最主要的核心代码如下。
def main(_): pp.pprint(flags.FLAGS.__flags) if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) run_config = tf.ConfigProto() run_config.gpu_options.allow_growth=True with tf.Session(config=run_config) as sess: if FLAGS.dataset == 'mnist': dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, y_dim=10, z_dim=FLAGS.generate_test_images, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, data_dir=FLAGS.data_dir) else: dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, z_dim=FLAGS.generate_test_images, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir, data_dir=FLAGS.data_dir) show_all_variables() if FLAGS.train: dcgan.train(FLAGS)
else: # Update D network _, summary_str = self.sess.run([d_optim, self.d_sum], feed_dict={ self.inputs: batch_images, self.z: batch_z }) self.writer.add_summary(summary_str, counter) # Update G network _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={ self.z: batch_z }) self.writer.add_summary(summary_str, counter) # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={ self.z: batch_z }) self.writer.add_summary(summary_str, counter) errD_fake = self.d_loss_fake.eval({ self.z: batch_z }) errD_real = self.d_loss_real.eval({ self.inputs: batch_images }) errG = self.g_loss.eval({self.z: batch_z})
一切就绪之后,我们执行命令进行训练,调用命令如下:
python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces --crop --train --epoch 300 --input_fname_pattern "*.jpg"
其中,参数dateset指定数据集的目录,epoch指定循环迭代的次数,input_height、input_width用于指定输入文件的大小,输出文件的大小同样也需要参数设定,代码执行过程如下图示:
我们来看下执行结果,分别看一下epoch为1,30,100的时候生成的二次元头像效果图。
epoch=1
epoch=30
epoch=100
我们发现,随着不断迭代,生成的二次元头像也越来越逼真。
总结
通过上面的实践,我们领略到了GAN的魅力,GAN的变种有很多,除此之外我们还可以利用GAN做非常多的有意思的事情,比如通过文字生成图像,通过简单文字生成宣传海报等。PAI-DSW像是一个练武场,为我们准备好了深度学习所需要的环境和条件,让我们可以尽情享受大数据和深度学习的乐趣,除了GAN,像比较火热的Bert等模型,我们也都可以试一试。
人人用得起的机器学习平台↓↓↓↓
低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
【机器学习PAI实战】—— 玩转人工智能之综述
模型训练与在线预测服务、推荐算法四部曲、机器学习PAI实战、更多精彩,尽在开发者分会场 【机器学习PAI实战】—— 玩转人工智能之商品价格预测 【机器学习PAI实战】—— 玩转人工智能之你最喜欢哪个男生? 【机器学习PAI实战】—— 玩转人工智能之美食推荐 【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像 绪论 人工智能并非新的术语,这个概念由来已久,大约从80年代初开始,计算机科学家们开始设计可以学习和模仿人类行为的算法。人工智能的发展曲折向前,伴随着数据量的上涨、计算力的提升,机器学习的火热,以及深度学习的爆发,人工智能迎来快速发展,迅速席卷全球。 人工智能的研究领域也在不断扩大,已经涵盖专家系统、机器学习、进化计算、模糊逻辑、计算机视觉、自然语言处理、推荐系统等多个领域。可以毫不夸张地说,人工智能技术正在像100多年前的电力一样,即将改变每个行业。每个企业都不希望在这次浪潮中掉队,如何才能利用AI帮助自己的企业进行转型呢?AI领域著名学者吴恩达在前不久针对该问题,发表了《AI转型指南》。 机器学习,作为实现人工智能的一种方法,对于人工智能的发展起着十分重要的...
- 下一篇
本田计划2025年回收废旧锂离子电池生产镍钴合金
据悉,本田一高管表示,计划到2025年,使用废旧锂离子电池作为原料,开始生产镍钴合金。 本田的混合乘用车车型都配备了锂离子电池。本田一高管在会上表示:“从2025年,本田将回收利用大量废旧锂离子电池。”来源:https://xincailiao.ofweek.com/
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- MySQL8.0.19开启GTID主从同步CentOS8
- SpringBoot2全家桶,快速入门学习开发网站教程
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- SpringBoot2更换Tomcat为Jetty,小型站点的福音
- Docker安装Oracle12C,快速搭建Oracle学习环境
- CentOS7编译安装Cmake3.16.3,解决mysql等软件编译问题
- Windows10,CentOS7,CentOS8安装Nodejs环境
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装
- Red5直播服务器,属于Java语言的直播服务器
- SpringBoot2初体验,简单认识spring boot2并且搭建基础工程