DataLoader 与 DataSet
DataLoader 与 DataSet
PyTorch 数据读入是通过 Dataset+DataLoader 的方式完成的,Dataset 定义好数据的格式和数据变换形式,DataLoader 用 iterative 的方式不断读入批次数据。
torch.utils.data (opens new window)
PyTorch 的五大模块:数据、模型、损失函数、优化器和迭代训练。
数据模块又可以细分为 4 个部分:
- 数据收集:样本和标签。
- 数据划分:训练集、验证集和测试集
- 数据读取:对应于 PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。
- 数据预处理:对应于 PyTorch 的 transforms
Dataloader
数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
参数
- dataset (Dataset) – 加载数据的数据集,Dataset 类,决定数据从哪里读取以及如何读取。
- batch_size (int, optional) – 每个 batch 加载多少个样本(默认: 1)。
- shuffle (bool, optional) – 设置为
True
时会在每个 epoch 重新打乱数据(默认: False). - sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略
shuffle
参数。 - num_workers (int, optional) – 用多少个子进程加载数据。0 表示数据将在主进程中加载(默认: 0)
- collate_fn (callable, optional) –
- pin_memory (bool, optional) –
- drop_last (bool, optional) – 如果数据集大小不能被 batch size 整除,则设置为 True 后可删除最后一个不完整的 batch。如果设为 False 并且数据集的大小不能被 batch size 整除,则最后一个 batch 将更小。(默认: False)
DataLoader
的使用方法示例:
from torch.utils.data import DataLoader
dataset = CifarDataset()
data_loader = DataLoader(dataset=dataset,batch_size=10,shuffle=True,num_workers=2)
#遍历,获取其中的每个batch的结果
for index, (label, context) in enumerate(data_loader):
print(index,label,context)
print("*"*100)
2
3
4
5
6
7
8
9
数据迭代器的返回结果如下:
555 ('spam', 'ham', 'spam', 'ham', 'ham', 'ham', 'ham', 'spam', 'ham', 'ham') ('URGENT! We are trying to contact U. Todays draw shows that you have won a £800 prize GUARANTEED. Call 09050003091 from....", 'swhrt how u dey,hope ur ok, tot about u 2day.love n miss.take care.')
***********************************************************************************
556 ('ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'ham', 'spam') ('He telling not to tell any one. If so treat for me hi hi hi', 'Did u got that persons story', "Don kn....1000 cash prize or a prize worth £5000')
2
3
注意:
len(dataset) = 数据集的样本数
len(data_loader) = math.ceil(样本数/batch_size) 即向上取整
Dataset
torch.utils.data.Dataset
是 PyTorch 中的一个抽象类,用于表示一个数据集。当你想要创建自己的数据集时,你需要继承这个类并实现至少两个方法:__len__
和__getitem__
。
参数说明
torch.utils.data.Dataset
本身并没有直接的参数,因为它是一个抽象基类,需要子类实现具体的方法。
需要实现的方法
__len__(self)
:- 返回数据集的大小(即数据项的总数)。
- 当你使用
len(dataset)
时,这个方法会被调用。
__getitem__(self, index)
:- 根据提供的索引返回单个数据项。
- 当你使用
dataset[index]
时,这个方法会被调用。 - 返回的数据项通常是一个元组,包含输入数据和标签(如果有的话)。
示例
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 使用示例
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 获取数据集的大小
print(len(dataset)) # 输出: 5
# 获取索引为2的数据项
print(dataset[2]) # 输出: 3
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
其他常见方法(可选实现)
虽然 __len__
和 __getitem__
是必需的,但你可能还想实现其他方法以提供更多的功能,例如:
__init__
: 用于向类中传入外部参数,同时定义样本集__add__(self, other)
: 实现数据集的加法操作,使得你可以合并两个数据集。transform(self, fn)
: 定义一个转换方法,该方法接受一个函数fn
并返回一个新的数据集,其中每个数据项都经过fn
的处理。
注意事项
- 当创建自定义数据集时,确保你的
__getitem__
方法返回的数据类型与你的模型期望的输入类型相匹配。 - 如果你想要进行批量处理,可以考虑使用
torch.utils.data.DataLoader
,它可以与你的Dataset
子类一起使用。
PyTorch 数据读取流程图
首先在 for 循环中遍历DataLoader
,然后根据是否采用多进程,决定使用单进程或者多进程的DataLoaderIter
。在DataLoaderIter
里调用Sampler
生成Index
的 list,再调DatasetFetcher
根据index
获取数据。在DatasetFetcher
里会调用Dataset
的__getitem__()
方法获取真正的数据。这里获取的数据是一个 list,其中每个元素是 (img, label) 的元组,再使用 collate_fn()
函数整理成一个 list,里面包含两个元素,分别是 img 和 label 的tenser
。
通过 DataLoader 获取数据
for i, (_, inputs, targets) in enumerate(sample_loader):
其中 (_, inputs, targets)
与DummyDataset
中的__getitem__
返回值对应
class DummyDataset(Dataset):
def __init__(self, images, labels, trsf, use_path=False):
……
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
if self.use_path:
image = self.trsf(pil_loader(self.images[idx]))
else:
image = self.trsf(Image.fromarray(self.images[idx]))
label = self.labels[idx]
return idx, image, label
2
3
4
5
6
7
8
9
10
11
12
13
14
15
torchvision
torchvision 是独立于 pytorch 的关于图像操作的一些方便工具库。
torchvision 的详细介绍在:https://pypi.org/project/torchvision/
torchvision 主要包括以下几个包:
- vision.datasets (opens new window) : 几个常用视觉数据集,可以下载和加载,这里主要的高级用法就是可以看源码如何自己写自己的 Dataset 的子类
- vision.models (opens new window) : 流行的模型,例如 AlexNet, VGG, ResNet 和 Densenet 以及 与训练好的参数。
- vision.transforms (opens new window) : 常用的图像操作,例如:随机切割,旋转,数据类型转换,图像到 tensor ,numpy 数组到 tensor , tensor 到 图像等。
- vision.utils (opens new window) : 用于把形似 (3 x H x W) 的张量保存到硬盘中,给一个 mini-batch 的图像可以产生一个图像格网。
pytorch 自带的数据集
pytorch 中自带的数据集由两个上层 api 提供,分别是torchvision
和torchtext
其中:
torchvision
提供了对图片数据处理相关的 api 和数据- 数据位置:
torchvision.datasets
,例如:torchvision.datasets.MNIST
(手写数字图片数据)
- 数据位置:
torchtext
提供了对文本数据处理相关的 API 和数据- 数据位置:
torchtext.datasets
,例如:torchtext.datasets.IMDB(电影
评论文本数据)
- 数据位置:
下面我们以 Mnist 手写数字为例,来看看 pytorch 如何加载其中自带的数据集
使用方法和之前一样:
- 准备好 Dataset 实例
- 把 dataset 交给 dataloder 打乱顺序,组成 batch
torchversion.datasets
torchversoin.datasets
中的数据集类(比如torchvision.datasets.MNIST
),都是继承自Dataset
意味着:直接对torchvision.datasets.MNIST
进行实例化就可以得到Dataset
的实例
但是 MNIST API 中的参数需要注意一下:
torchvision.datasets.MNIST(root='/files/', train=True, download=True, transform=)
root
参数表示数据存放的位置train:
bool 类型,表示是使用训练集的数据还是测试集的数据download:
bool 类型,表示是否需要下载数据到 root 目录transform:
实现的对图片的处理函数
MNIST 数据集的介绍
数据集的原始地址:http://yann.lecun.com/exdb/mnist/
MNIST 是由Yann LeCun
等人提供的免费的图像识别的数据集,其中包括 60000 个训练样本和 10000 个测试样本,其中图拍了的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28X28
执行代码,下载数据,观察数据类型:
import torchvision
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=None)
print(dataset[0])
2
3
下载的数据如下:
代码输出结果如下:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
(<PIL.Image.Image image mode=L size=28x28 at 0x18D303B9C18>, tensor(5))
2
3
4
5
6
7
可以其中数据集返回了两条数据,可以猜测为图片的数据和目标值
返回值的第 0 个为 Image 类型,可以调用 show() 方法打开,发现为手写数字 5
import torchvision
dataset = torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=None)
print(dataset[0])
img = dataset[0][0]
img.show() #打开图片
2
3
4
5
图片如下:
由上可知:返回值为(图片,目标值)
,这个结果也可以通过观察源码得到。