使用 **迭代器** 获取 Cifar 等常用数据集
Cifar
、MNIST
等常用数据集的坑:
- 每次在一台新的机器上使用它们去训练模型都需要重新下载(国内网络往往都不给力,需要花费大量的时间,有时还下载不了);
- 即使下载到本地,然而不同的模型对它们的处理方式各不相同,我们又需要花费一些时间去了解如何读取数据。
为了解决上述的坑,我在Bunch 转换为 HDF5 文件:高效存储 Cifar 等数据集中将一些常用的数据集封装为 HDF5
文件。
下面的 X.h5c
可以参考Bunch 转换为 HDF5 文件:高效存储 Cifar 等数据集自己制作,也可以直接下载使用(链接:https://pan.baidu.com/s/1hsbMhv3MDlOES3UDDmOQiw 密码:qlb7)。
使用方法很简单:
访问数据集
# 载入所需要的包
import tables as tb
import numpy as np
xpath = 'E:/xdata/X.h5' # 文件所在路径
h5 = tb.open_file(xpath)
下面我们来看看此文件中有那些数据集:
h5.root
/ (RootGroup) "Xinet's dataset"
children := ['cifar10' (Group), 'cifar100' (Group), 'fashion_mnist' (Group), 'mnist' (Group)]
下面我们以 Cifar
为例,来详细说明该文件的使用:
cifar = h5.root.cifar100 # 获取 cifar100
为了高效使用数据集,我们使用迭代器的方式来获取它:
class Loader:
"""
方法
========
L 为该类的实例
len(L)::返回 batch 的批数
iter(L)::即为数据迭代器
Return
========
可迭代对象(numpy 对象)
"""
def __init__(self, X, Y, batch_size, shuffle):
'''
X, Y 均为类 numpy
'''
self.X = X
self.Y = Y
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
n = len(self.X)
idx = np.arange(n)
if self.shuffle:
np.random.shuffle(idx)
for k in range(0, n, self.batch_size):
K = idx[k:min(k + self.batch_size, n)].tolist()
yield np.take(self.X, K, 0), np.take(self.Y, K, 0)
def __len__(self):
return round(len(self.X) / self.batch_size)
下面我们可以使用 Loader
来实例化我们的数据集:
batch_size = 512
train_cifar = Loader(cifar.trainX, cifar.train_fine_labels, batch_size, True)
test_cifar = Loader(cifar.testX, cifar.test_fine_labels, batch_size, False)
读取一个 Batch 的数据:
for imgs, labels in iter(train_cifar):
break
names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U')
names[:7]
array(['orchid', 'spider', 'rabbit', 'shark', 'shrew', 'clock', 'bed'],
dtype='<U13')
可视化
需要注意,这里的 Cifar
是 first channel
的,即:
imgs.shape
(512, 3, 32, 32)
names.shape
(512,)
from pylab import plt, mpl
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体
mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号 '-' 显示为方块的问题
def show_imgs(imgs, labels):
'''
展示 多张图片
'''
imgs = np.transpose(imgs, (0, 2, 3, 1))
n = imgs.shape[0]
h, w = 5, int(n / 5)
fig, ax = plt.subplots(h, w, figsize=(7, 7))
K = np.arange(n).reshape((h, w))
names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U')
names = names.reshape((h, w))
for i in range(h):
for j in range(w):
img = imgs[K[i, j]]
ax[i][j].imshow(img)
ax[i][j].axes.get_yaxis().set_visible(False)
ax[i][j].axes.set_xlabel(names[i][j])
ax[i][j].set_xticks([])
plt.show()
show_imgs(imgs[:25], labels[:25])
$2$ 个深度学习框架 & 数据集
因为,上面的数据集是 NumPy
的 array
形式,故而:
TensorFlow
import tensorflow as tf
for imgs, labels in iter(train_cifar):
imgs = tf.constant(imgs)
labels = tf.constant(labels)
break
imgs
<tf.Tensor 'Const:0' shape=(512, 3, 32, 32) dtype=uint8>
labels
<tf.Tensor 'Const_1:0' shape=(512,) dtype=int32>
MXNet
from mxnet import nd, cpu, gpu
for imgs, labels in iter(train_cifar):
imgs = nd.array(imgs, ctx = gpu(0))
labels = nd.array(labels, ctx = cpu(0))
break
imgs.context
gpu(0)
labels.context
cpu(0)
Matlab 读取 HDF
参考:h5read

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
-
上一篇
数据预处理:自定义PDF格式批量转换TXT系统
数据预处理:自定义文件格式转换系统 ( 白宁超 2018年8月29日15:36:24 ) 导读:随着大数据的快速发展,自然语言处理、数据挖掘、机器学习技术应用愈加广泛。针对大数据的预处理工作是一项庞杂、棘手的工作。首先数据采集和存储,尤其高质量数据采集往往不是那么简单。采集后的信息文件格式不一,诸如pdf,doc,docx,Excel,ppt等多种形式。然而最常见便是txt、pdf和word类型的文档。本文主要对pdf和word文档进行文本格式转换成txt。格式一致化以后再进行后续预处理工作。笔者采用一些工具转换效果都不理想,于是才出现本系统的研究与实现。(本文原创,转载必须注明出处: 数据预处理:自定义文件格式转换系统 ) 1 本文概述 1.1 背景介绍 为什么要文件格式转换? 无论读者现在是做数据挖掘、数据分析、自然语言处理、智能对话系统、商品推荐系统等等,都不可避免的涉及语料的问题即大数据。数据来源无非分为结构化数据、半结构化数据和非结构化数据。其中结构化数据以规范的文档、数据库文件等等为代表;半结构化数据以网页、json文件等为代表;非结构化数据以自由文本为主,诸如随想录、中医...
-
下一篇
Java底层学习
最近在看几本Java的书,也做了很多笔记,主要是关于Java虚拟机、Java GC、Java 并发编程等方面,参考的主要几本书籍有: 《深入理解Java虚拟机》——周志明 《深入理解Java虚拟机 第二版》——美 Bill Venners 《Java性能调优指南》——也是老美的 《Java高并发程序设计》——葛一鸣 本来想自己把这些书的pdf传上来的,可惜已经有人上传了,大家自己去找资源吧 当然在写作过程中也参考了很多大神的文章,下面给几个链接,大家也可以看看: 【Java成神之路】—-死磕Java系列博客 《成神之路系列文章》 JVM调优总结 等全部写完,我就写个目录方便大家查看
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- CentOS关闭SELinux安全模块
- 设置Eclipse缩进为4个空格,增强代码规范
- Eclipse初始化配置,告别卡顿、闪退、编译时间过长
- Jdk安装(Linux,MacOS,Windows),包含三大操作系统的最全安装
- CentOS8编译安装MySQL8.0.19
- Springboot2将连接池hikari替换为druid,体验最强大的数据库连接池
- SpringBoot2整合MyBatis,连接MySql数据库做增删改查操作
- SpringBoot2整合Thymeleaf,官方推荐html解决方案
- MySQL数据库在高并发下的优化方案
- SpringBoot2更换Tomcat为Jetty,小型站点的福音