
上QQ阅读APP看书,第一时间看更新
1.1.2 导入数据集
在PyTorch中,有一个非常重要且好用的包是torchvision,该包主要由3个子包组成,分别是models、datasets和transforms。models定义了许多用来完成图像方面深度学习的任务模型。datasets中包含MNIST、Fake Data、COCO、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR等一些常用的数据集,并且提供了数据集设置的一些重要参数,可以通过简单数据集设置来进行数据集的调用。transforms用来对数据进行预处理,预处理会加快神经网络的训练,常见的预处理包括从数组转成张量(tensor)、归一化等常见的变化。本章导入主要涉及datasets和transforms,下面通过例子来讲解。

上述代码最外层调用了DataLoader对数据进行封装,而里面涉及了datasets和transforms。对于root目录,PyTorch会检测数据是否存在,当数据不存在时,系统会自动将数据下载到data文件夹中。其中的transforms对原数据进行了两个操作,一个是ToTensor,用来把PIL.Image(RGB)或者numpy.ndarray(H×W×C)0~255的值映射到0~1的范围内,并转换成Tensor格式;另一个是Normalize(mean,std),用来实现归一化,不同数据集中图像通道的均值(mean)和标准差(std)这两个数值是不一样的,MNIST数据集的均值是0.1307,标准差是0.3081,这些系数是数据集提供方计算好的,有利于加快神经网络的训练。我们随机取一个batch下的数据进行观察,并将其可视化画出来,结果如图1-2所示。



图1-2 数据加载后的部分示例