如何用Deeplearning4j实现GAN
一、Gan的思想
Gan的核心所做的事情是在解决一个argminmax的问题,公式:
1、求解一个Discriminator,可以最大尺度的丈量Generator 产生的数据和真实数据之间的分布距离
2、求解一个Generator,可以最大程度减小产生数据和真实数据之间的距离
gan的原始公式如下:
实际上,我们不可能真求期望,只能sample出data来近似求解,于是,公式变成如下:
于是,求解V的最大值,变成了一个二分类问题,变成了求交叉熵的最小值。
二、代码
public class Gan {
static double lr = 0.01;
public static void main(String[] args) throws Exception {
final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
.weightInit(WeightInit.XAVIER);
final GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard)
.addInputs("input1", "input2")
.addLayer("g1",
new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build(),
"input1")
.addLayer("g2",
new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build(),
"g1")
.addLayer("g3",
new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build(),
"g2")
.addVertex("stack", new StackVertex(), "input2", "g3")
.addLayer("d1",
new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build(),
"stack")
.addLayer("d2",
new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build(),
"d1")
.addLayer("d3",
new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
.weightInit(WeightInit.XAVIER).build(),
"d2")
.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
.activation(Activation.SIGMOID).build(), "d3")
.setOutputs("out");
ComputationGraph net = new ComputationGraph(graphBuilder.build());
net.init();
System.out.println(net.summary());
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
net.setListeners(new ScoreIterationListener(100));
net.getLayers();
DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
INDArray labelG = Nd4j.ones(60, 1);
for (int i = 1; i <= 100000; i++) {
if (!train.hasNext()) {
train.reset();
}
INDArray trueExp = train.next().getFeatures();
INDArray z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution());
MultiDataSet dataSetD = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp },
new INDArray[] { labelD });
for(int m=0;m<10;m++){
trainD(net, dataSetD);
}
z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution());
MultiDataSet dataSetG = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp },
new INDArray[] { labelG });
trainG(net, dataSetG);
if (i % 10000 == 0) {
net.save(new File("E:/gan.zip"), true);
}
}
}
public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
net.setLearningRate("g1", 0);
net.setLearningRate("g2", 0);
net.setLearningRate("g3", 0);
net.setLearningRate("d1", lr);
net.setLearningRate("d2", lr);
net.setLearningRate("d3", lr);
net.setLearningRate("out", lr);
net.fit(dataSet);
}
public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
net.setLearningRate("g1", lr);
net.setLearningRate("g2", lr);
net.setLearningRate("g3", lr);
net.setLearningRate("d1", 0);
net.setLearningRate("d2", 0);
net.setLearningRate("d3", 0);
net.setLearningRate("out", 0);
net.fit(dataSet);
}
}
说明:
1、dl4j并没有提供像keras那样冻结某些层参数的方法,这里采用设置learningrate为0的方法,来冻结某些层的参数
2、这个的更新器,用的是sgd,不能用其他的(比方说Adam、Rmsprop),因为这些自适应更新器会考虑前面batch的梯度作为本次更新的梯度,达不到不更新参数的目的
3、这里用了StackVertex,沿着第一维合并张量,也就是合并真实数据样本和Generator产生的数据样本,共同训练Discriminator
4、训练过程中多次update Discriminator的参数,以便量出最大距离,让后更新Generator一次
5、进行10w次迭代
三、Generator生成手写数字
加载训练好的模型,随机从NormalDistribution取出一些噪音数据,丢给模型,经过feedForward,取出最后一层Generator的激活值,便是我们想要的结果,代码如下:
public class LoadGan {
public static void main(String[] args) throws Exception {
ComputationGraph restored = ComputationGraph.load(new File("E:/gan.zip"), true);
DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
INDArray trueExp = train.next().getFeatures();
Map<String, INDArray> map = restored.feedForward(
new INDArray[] { Nd4j.rand(new long[] { 50, 10 }, new NormalDistribution()), trueExp }, false);
INDArray indArray = map.get("g3");// .reshape(20,28,28);
List<INDArray> list = new ArrayList<>();
for (int j = 0; j < indArray.size(0); j++) {
list.add(indArray.getRow(j));
}
MNISTVisualizer bestVisualizer = new MNISTVisualizer(1, list, "Gan");
bestVisualizer.visualize();
}
public static class MNISTVisualizer {
private double imageScale;
private List<INDArray> digits; // Digits (as row vectors), one per
// INDArray
private String title;
private int gridWidth;
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title) {
this(imageScale, digits, title, 5);
}
public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth) {
this.imageScale = imageScale;
this.digits = digits;
this.title = title;
this.gridWidth = gridWidth;
}
public void visualize() {
JFrame frame = new JFrame();
frame.setTitle(title);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
JPanel panel = new JPanel();
panel.setLayout(new GridLayout(0, gridWidth));
List<JLabel> list = getComponents();
for (JLabel image : list) {
panel.add(image);
}
frame.add(panel);
frame.setVisible(true);
frame.pack();
}
public List<JLabel> getComponents() {
List<JLabel> images = new ArrayList<>();
for (INDArray arr : digits) {
BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
for (int i = 0; i < 784; i++) {
bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * arr.getDouble(i)));
}
ImageIcon orig = new ImageIcon(bi);
Image imageScaled = orig.getImage().getScaledInstance((int) (imageScale * 28), (int) (imageScale * 28),
Image.SCALE_DEFAULT);
ImageIcon scaled = new ImageIcon(imageScaled);
images.add(new JLabel(scaled));
}
return images;
}
}
}
实际效果,还算比较清晰
快乐源于分享。
此博客乃作者原创, 转载请注明出处

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
-
上一篇
啤酒节上尿意浓-SVG低级艺术展示
领导:阿仁,有个很重要的任务需要你去执行。 阿仁:什么任务这么重要? 领导:最近我们有个新园区开张,还请了一个写手写了篇软文来炒作,《云破月来花弄影-SVG多种技术组合实现》,可惜这个写手除了长得帅,文章一点都不好,根本没有阅读量,所以决定派你卧底去这个园区,打探一下。 阿仁:这个园区有啥问题吗,是有黑帮在那里做交易,还是有境外势力在那边有基地? 领导:比这些重要。主要是园区人气太差,希望你假扮一个游客,帮忙了解下其他游客的想法,另外多说说我们园区的好处,鼓动游客参与,活跃气氛。 阿仁:领导,这个工作其实不叫卧底,有个更熟悉的名字 领导:是吗,啥名字? 阿仁:线上这个叫水军,线下叫托儿... 领导:你不愿意去? 阿仁:领导,我堂堂计算机专业毕业,干这个... 领导:请你用十条理由来说服所有的程序员,PHP是世界上最好的语言。 阿仁:领导,你这个园区叫啥名字? 领导:叫SVG园区。 1、啤酒节霓虹灯 阿仁经过一番乔装打扮(主要是换掉了格子衫,背了斜肩的挎包),来到了SVG园区。巧的很,一眼就看到了园区的霓虹灯牌子,正在举办啤酒节。 啤酒节霓虹灯的代码是这样的 <style> ...
-
下一篇
superset nginx 反向代理配置遇到的一个小问题
在用 nginx 配置 superset 反向代理,并且使用 map 通过 cookie 分流的时候,遇到十分诡异的问题,访问主页的时候总是被重定向到 upstream 同名的域名 upstream release { server 127.0.0.1:8088 weight=1 max_fails=1 fail_timeout=30s; } upstream development { server 127.0.0.1:8089 weight=1 max_fails=1 fail_timeout=30s; } map $COOKIE_version $env { default release; release release; development development; } server { listen 10001; server_name localhost; location / { proxy_pass http://$env; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "u...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- MySQL数据库在高并发下的优化方案
- SpringBoot2初体验,简单认识spring boot2并且搭建基础工程
- Docker容器配置,解决镜像无法拉取问题
- Docker安装Oracle12C,快速搭建Oracle学习环境
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- 2048小游戏-低调大师作品
- SpringBoot2编写第一个Controller,响应你的http请求并返回结果
- Docker使用Oracle官方镜像安装(12C,18C,19C)
- SpringBoot2整合MyBatis,连接MySql数据库做增删改查操作