Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

      • DataLoader 与 DataSet
      • torchvision.transforms
      • torch.utils.data
      • 模型

      • 训练

      • 并行计算

      • 可视化

      • 实战

      • timm

      • Pytorch Lightning

      • 数据增强

      • 面经与bug解决

      • 常用代码片段

      • Reference
    • CL

    • CIL

    • 小样本类增量学习FSCIL

    • UCIL

    • 多模态增量学习MMCL

    • LTCIL

    • DIL

    • 论文阅读与写作

    • 分布外检测

    • GPU

    • 深度学习调参指南

    • AINotes
    • PyTorch
    • 数据处理
    Geeks_Z
    2024-02-01
    目录

    torch.utils.data

    TensorDataset

    一文搞懂PyTorch中的TensorDataset (opens new window)

    简介

    顾名思义,torch.utils.data 中的 TensorDataset 基于一系列张量构建数据集。这些张量的形状可以不尽相同,但第一个维度必须具有相同大小,这是为了保证在使用 DataLoader 时可以正常地返回一个批量的数据。

    源码解读

    以下是 TensorDataset 的源码:

    class TensorDataset(Dataset[Tuple[Tensor, ...]]):
        r"""Dataset wrapping tensors.
    
        Each sample will be retrieved by indexing tensors along the first dimension.
    
        Args:
            *tensors (Tensor): tensors that have the same size of the first dimension.
        """
        tensors: Tuple[Tensor, ...]
    
        def __init__(self, *tensors: Tensor) -> None:
            assert all(tensors[0].size(0) == tensor.size(0)
                       for tensor in tensors), "Size mismatch between tensors"
            self.tensors = tensors
    
        def __getitem__(self, index):
            return tuple(tensor[index] for tensor in self.tensors)
    
        def __len__(self):
            return self.tensors[0].size(0)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20

    *tensors 告诉我们实例化 TensorDataset 时传入的是一系列张量,即:

    dataset = TensorDataset(tensor_1, tensor_2, ..., tensor_n)
    
    1

    随后的 assert 是用来确保传入的这些张量中,每个张量在第一个维度的大小都等于第一个张量在第一个维度的大小,即要求所有张量在第一个维度的大小都相同。

    __getitem__ 方法返回的结果等价于

    return tensor_1[index], tensor_2[index], ..., tensor_n[index]
    
    1

    从这行代码可以看出,如果n 张量在第一个维度的大小不完全相同,则必然会有一个张量出现 IndexError。确保第一个维度大小相同也是为了之后传入DataLoader 中能够正常地以一个批量的形式加载。

    __len__就不用多说了,因为所有张量的第一个维度大小都相同,所以直接返回传入的第一个张量在第一个维度的大小即可。

    📌 TensorDataset 将张量的第一个维度视为数据集大小的维度,数据集在传入 DataLoader 后,该维度也是 batch_size 所在的维度

    通过例子进一步理解

    假设当前目录下存放一个 data.csv 文件,其中的每一行的后六个数字代表样本对应的特征向量,第一个数字代表该样本对应的标签。

    1.0000, 0.9449, -0.8295, -0.7112, -0.7005, -0.2167, -0.7059
    1.0000, -2.1290, 0.3062, -0.2188, -1.3525, 1.6726, -0.8547
    -1.0000, -1.5803, 0.6320, -1.9216, -0.0722, 1.4919, -0.3219
    1.0000, -0.2993, 0.3256, 0.3015, 0.4959, -0.1034, -1.0536
    -1.0000, -0.0025, 0.8698, 0.9149, 1.4535, 1.1784, 0.1983
    -1.0000, -0.5881, -0.5728, 2.5740, 0.9449, 1.9096, 0.3761
    1.0000, -0.9585, -1.3368, -1.1004, 0.6487, 1.7098, 1.5862
    -1.0000, 1.4861, 1.3814, 0.7968, 0.5741, 1.0919, -0.1592
    
    1
    2
    3
    4
    5
    6
    7
    8

    接下来我们分别用普通方法和 TensorDataset 方法来构建数据集。

    普通方法:

    import torch
    from torch.utils.data import Dataset
    import pandas as pd
    
    class MyDataset(Dataset):
    
        def __init__(self):
            self.data = pd.read_csv('data.csv', header=None).values
    
        def __getitem__(self, idx):
            feature = torch.from_numpy(self.data[idx, 1:])
            label = torch.tensor(self.data[idx, 0])
            return feature, label
    
        def __len__(self):
            return len(self.data)
    
    
    mydataset = MyDataset()
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19

    TensorDataset 方法

    import torch
    from torch.utils.data import TensorDataset
    import pandas as pd
    
    data = pd.read_csv('data.csv', header=None).values
    features = torch.from_numpy(data[:, 1:])
    labels = torch.from_numpy(data[:, 0])
    
    mydataset = TensorDataset(features, labels)
    
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    #PyTorch
    上次更新: 2025/06/25, 11:25:50
    torchvision.transforms
    模型构建

    ← torchvision.transforms 模型构建→

    最近更新
    01
    帮助信息查看
    06-08
    02
    常用命令
    06-08
    03
    学习资源
    06-07
    更多文章>
    Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
    京公网安备 11010802040735号 | 京ICP备2022029989号-1
    • 跟随系统
    • 浅色模式
    • 深色模式
    • 阅读模式