您现在的位置是:首页 > 文章详情

使用 **迭代器** 获取 Cifar 等常用数据集

日期:2018-07-17点击:291

CifarMNIST 等常用数据集的坑:

  • 每次在一台新的机器上使用它们去训练模型都需要重新下载(国内网络往往都不给力,需要花费大量的时间,有时还下载不了);
  • 即使下载到本地,然而不同的模型对它们的处理方式各不相同,我们又需要花费一些时间去了解如何读取数据。

为了解决上述的坑,我在Bunch 转换为 HDF5 文件:高效存储 Cifar 等数据集中将一些常用的数据集封装为 HDF5 文件。

下面的 X.h5c 可以参考Bunch 转换为 HDF5 文件:高效存储 Cifar 等数据集自己制作,也可以直接下载使用(链接:https://pan.baidu.com/s/1hsbMhv3MDlOES3UDDmOQiw 密码:qlb7)。

使用方法很简单:

访问数据集

# 载入所需要的包 import tables as tb import numpy as np
xpath = 'E:/xdata/X.h5' # 文件所在路径 h5 = tb.open_file(xpath)

下面我们来看看此文件中有那些数据集:

h5.root
/ (RootGroup) "Xinet's dataset" children := ['cifar10' (Group), 'cifar100' (Group), 'fashion_mnist' (Group), 'mnist' (Group)] 

下面我们以 Cifar 为例,来详细说明该文件的使用:

cifar = h5.root.cifar100 # 获取 cifar100

为了高效使用数据集,我们使用迭代器的方式来获取它:

class Loader: """ 方法 ======== L 为该类的实例 len(L)::返回 batch 的批数 iter(L)::即为数据迭代器 Return ======== 可迭代对象(numpy 对象) """ def __init__(self, X, Y, batch_size, shuffle): ''' X, Y 均为类 numpy ''' self.X = X self.Y = Y self.batch_size = batch_size self.shuffle = shuffle def __iter__(self): n = len(self.X) idx = np.arange(n) if self.shuffle: np.random.shuffle(idx) for k in range(0, n, self.batch_size): K = idx[k:min(k + self.batch_size, n)].tolist() yield np.take(self.X, K, 0), np.take(self.Y, K, 0) def __len__(self): return round(len(self.X) / self.batch_size)

下面我们可以使用 Loader 来实例化我们的数据集:

batch_size = 512 train_cifar = Loader(cifar.trainX, cifar.train_fine_labels, batch_size, True) test_cifar = Loader(cifar.testX, cifar.test_fine_labels, batch_size, False)

读取一个 Batch 的数据:

for imgs, labels in iter(train_cifar): break
names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U') names[:7]
array(['orchid', 'spider', 'rabbit', 'shark', 'shrew', 'clock', 'bed'], dtype='<U13') 

可视化

需要注意,这里的 Cifarfirst channel 的,即:

imgs.shape
(512, 3, 32, 32) 
names.shape
(512,) 
from pylab import plt, mpl mpl.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体 mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号 '-' 显示为方块的问题 def show_imgs(imgs, labels): ''' 展示 多张图片 ''' imgs = np.transpose(imgs, (0, 2, 3, 1)) n = imgs.shape[0] h, w = 5, int(n / 5) fig, ax = plt.subplots(h, w, figsize=(7, 7)) K = np.arange(n).reshape((h, w)) names = np.asanyarray([cifar.fine_label_names[label] for label in labels], dtype='U') names = names.reshape((h, w)) for i in range(h): for j in range(w): img = imgs[K[i, j]] ax[i][j].imshow(img) ax[i][j].axes.get_yaxis().set_visible(False) ax[i][j].axes.set_xlabel(names[i][j]) ax[i][j].set_xticks([]) plt.show()
show_imgs(imgs[:25], labels[:25])

output_19_0.png-89.9kB

$2$ 个深度学习框架 & 数据集

因为,上面的数据集是 NumPyarray 形式,故而:

TensorFlow

import tensorflow as tf
for imgs, labels in iter(train_cifar): imgs = tf.constant(imgs) labels = tf.constant(labels) break
imgs
<tf.Tensor 'Const:0' shape=(512, 3, 32, 32) dtype=uint8> 
labels
<tf.Tensor 'Const_1:0' shape=(512,) dtype=int32> 

MXNet

from mxnet import nd, cpu, gpu
for imgs, labels in iter(train_cifar): imgs = nd.array(imgs, ctx = gpu(0)) labels = nd.array(labels, ctx = cpu(0)) break
imgs.context
gpu(0) 
labels.context
cpu(0) 

Matlab 读取 HDF

参考:h5read
捕获.PNG-65.5kB

原文链接:https://yq.aliyun.com/articles/614332
关注公众号

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。

持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。

转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。

文章评论

共有0条评论来说两句吧...

文章二维码

扫描即可查看该文章

点击排行

推荐阅读

最新文章