神经网络基础及Keras入门
神经网络定义
人工神经网络,简称神经网络,在机器学习和认知科学领域,是一种模仿生物神经网络(动物的中枢神经系统,特别是大脑)的结构和功能的数学模型或计算模型,用于对函数进行估计或近似。
为了描述神经网络,我们先从最简单的神经网络讲起,这个神经网络仅由一个“神经元”构成,以下即是这个“神经元”的图示:
可以看出,这个单一“神经元”的输入-输出映射关系其实就是一个逻辑回归(logistic regression)。
神经网络模型
所谓神经网络就是将许多个单一“神经元”联结在一起,这样,一个“神经元”的输出就可以是另一个“神经元”的输入。例如,下图就是一个简单的神经网络:
Keras实战
使用keras实现如下网络结构, 并训练模型:
输入值(x1,x2,x3)代表人的身高体重和年龄, 输出值(y1,y2)
importnumpyasnp # 总人数是1000, 一半是男生 n =1000 # 所有的身体指标数据都是标准化数据, 平均值0, 标准差1 tizhong = np.random.normal(size = n) shengao = np.random.normal(size=n) nianling = np.random.normal(size=n) # 性别数据, 前500名学生是男生, 用数字1表示 gender = np.zeros(n) gender[:500] =1 # 男生的体重比较重,所以让男生的体重+1 tizhong[:500] +=1 # 男生的身高比较高, 所以让男生的升高 + 1 shengao[:500] +=1 # 男生的年龄偏小, 所以让男生年龄降低 1 nianling[:500] -=1
创建模型
fromkerasimportSequential fromkeras.layersimportDense, Activation model = Sequential() # 只有一个神经元, 三个输入数值 model.add(Dense(4, input_dim=3, kernel_initializer='random_normal', name="Dense1")) # 激活函数使用softmax model.add(Activation('relu', name="hidden")) # 添加输出层 model.add(Dense(2, input_dim=4, kernel_initializer='random_normal', name="Dense2")) # 激活函数使用softmax model.add(Activation('softmax', name="output"))
编译模型
需要指定优化器和损失函数:
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
训练模型
# 转换成one-hot格式 fromkerasimportutils gender_one_hot = utils.to_categorical(gender, num_classes=2) # 身体指标都放入一个矩阵data data = np.array([tizhong, shengao, nianling]).T # 训练模型 model.fit(data, gender_one_hot, epochs=10, batch_size=8)
输出(stream): Epoch1/10 1000/1000[==============================] -0s235us/step - loss:0.6743- acc:0.7180 Epoch2/10 1000/1000[==============================] -0s86us/step - loss:0.6162- acc:0.7310 Epoch3/10 1000/1000[==============================] -0s88us/step - loss:0.5592- acc:0.7570 Epoch4/10 1000/1000[==============================] -0s87us/step - loss:0.5162- acc:0.7680 Epoch5/10 1000/1000[==============================] -0s89us/step - loss:0.4867- acc:0.7770 Epoch6/10 1000/1000[==============================] -0s88us/step - loss:0.4663- acc:0.7830 Epoch7/10 1000/1000[==============================] -0s87us/step - loss:0.4539- acc:0.7890 Epoch8/10 1000/1000[==============================] -0s86us/step - loss:0.4469- acc:0.7920 Epoch9/10 1000/1000[==============================] -0s88us/step - loss:0.4431- acc:0.7940 Epoch10/10 1000/1000[==============================] -0s88us/step - loss:0.4407- acc:0.7900 输出(plain)://Python学习开发705673780
进行预测
test_data = np.array([[0,0,0]]) probability = model.predict(test_data) ifprobability[0,0]>0.5: print('女生') else: print('男生') ### 输出(stream): 女生
关键词解释
input_dim: 输入的维度数
kernel_initializer: 数值初始化方法, 通常是正太分布
batch_size: 一次训练中, 样本数据被分割成多个小份, 每一小份包含的样本数叫做batch_size
epochs: 如果说将所有数据训练一次叫做一轮的话。epochs决定了总共进行几轮训练。
optimizer: 优化器, 可以理解为求梯度的方法
loss: 损失函数, 可以理解为用于衡量估计值和观察值之间的差距, 差距越小, loss越小
metrics: 类似loss, 只是metrics不参与梯度计算, 只是一个衡量算法准确性的指标, 分类模型就用accuracy
看完觉得有所收获的朋友可以点赞加关注哦,谢谢支持!

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
2 张图,让你一秒理解 CountDownLatch、CyclicBarrier
CountDownLatch(倒数闩,Latch:门闩) 经常用于 监听某些初始化操作,等 初始化线程 全部执行完毕后,才通知 主线程 继续工作 a) 即 一个线程处于阻塞的状态下,他要收到 多少次通知,才能被 苏醒,并继续往下执行 b) 注意:只能阻塞 一个线程 c) "countDown.countDown() 到了 0,并使得 countDown.await() 苏醒" 之后,仍旧能进行 countDown.countDown(),并且不会报错;但是countDown.getCount() 始终为 0 CyclicBarrier(同步屏障,cyclic:周期的,循环的,barrier:屏障) 场景假设:每个线程代表一个 跑步运动员,当 所有运动员 都准备好,才能一起出发,只要有一个人没有准备好,那么大家都要等待他 a) 注意:阻塞的是 每个线程
- 下一篇
Dubbo分析之Cluster层
系列文章 Dubbo分析Serialize层 Dubbo分析之Transport层 Dubbo分析之Exchange 层 Dubbo分析之Protocol层 Dubbo分析之Cluster层 Dubbo分析之Registry层 前言 紧接上文Dubbo分析之Protocol层,本文继续分析dubbo的cluster层,此层封装多个提供者的路由及负载均衡,并桥接注册中心,以Invoker为中心,扩展接口为Cluster, Directory, Router, LoadBalance; Cluster接口 整个cluster层可以使用如下图片概括: 各节点关系: 这里的Invoker是Provider的一个可调用Service的抽象,Invoker封装了Provider地址及Service接口信息; Directory代表多个Invoker,可以把它看成List ,但与List不同的是,它的值可能是动态变化的,比如注册中心推送变更; Cluster将Directory中的多个Invoker伪装成一个 Invoker,对上层透明,伪装过程包含了容错逻辑,调用失败后,重试另一个; Route...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- CentOS7,8上快速安装Gitea,搭建Git服务器
- CentOS8编译安装MySQL8.0.19
- SpringBoot2编写第一个Controller,响应你的http请求并返回结果
- SpringBoot2整合Thymeleaf,官方推荐html解决方案
- CentOS8,CentOS7,CentOS6编译安装Redis5.0.7
- CentOS7安装Docker,走上虚拟化容器引擎之路
- SpringBoot2全家桶,快速入门学习开发网站教程
- Eclipse初始化配置,告别卡顿、闪退、编译时间过长
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装
- CentOS7编译安装Gcc9.2.0,解决mysql等软件编译问题