灵活的数据读取

from torchvision.datasets import ImageFolder
# 三个文件夹,每个文件夹一共有 3 张图片作为例子
folder_set = ImageFolder('./example_data/image/')
# 查看名称和类别下标的对应
folder_set.class_to_idx
{'class_1': 0, 'class_2': 1, 'class_3': 2}
# 得到所有的图片名字和标签
folder_set.imgs
[('./example_data/image/class_1/1.png', 0),
 ('./example_data/image/class_1/2.png', 0),
 ('./example_data/image/class_1/3.png', 0),
 ('./example_data/image/class_2/10.png', 1),
 ('./example_data/image/class_2/11.png', 1),
 ('./example_data/image/class_2/12.png', 1),
 ('./example_data/image/class_3/16.png', 2),
 ('./example_data/image/class_3/17.png', 2),
 ('./example_data/image/class_3/18.png', 2)]
# 取出其中一个数据
im, label = folder_set[0]
im

png

label
0
from torchvision import transforms as tfs
# 传入数据预处理方式
data_tf = tfs.ToTensor()

folder_set = ImageFolder('./example_data/image/', transform=data_tf)

im, label = folder_set[0]
im
(0 ,.,.) = 
  0.2314  0.1686  0.1961  ...   0.6196  0.5961  0.5804
  0.0627  0.0000  0.0706  ...   0.4824  0.4667  0.4784
  0.0980  0.0627  0.1922  ...   0.4627  0.4706  0.4275
           ...             ⋱             ...          
  0.8157  0.7882  0.7765  ...   0.6275  0.2196  0.2078
  0.7059  0.6784  0.7294  ...   0.7216  0.3804  0.3255
  0.6941  0.6588  0.7020  ...   0.8471  0.5922  0.4824

(1 ,.,.) = 
  0.2431  0.1804  0.1882  ...   0.5176  0.4902  0.4863
  0.0784  0.0000  0.0314  ...   0.3451  0.3255  0.3412
  0.0941  0.0275  0.1059  ...   0.3294  0.3294  0.2863
           ...             ⋱             ...          
  0.6667  0.6000  0.6314  ...   0.5216  0.1216  0.1333
  0.5451  0.4824  0.5647  ...   0.5804  0.2431  0.2078
  0.5647  0.5059  0.5569  ...   0.7216  0.4627  0.3608

(2 ,.,.) = 
  0.2471  0.1765  0.1686  ...   0.4235  0.4000  0.4039
  0.0784  0.0000  0.0000  ...   0.2157  0.1961  0.2235
  0.0824  0.0000  0.0314  ...   0.1961  0.1961  0.1647
           ...             ⋱             ...          
  0.3765  0.1333  0.1020  ...   0.2745  0.0275  0.0784
  0.3765  0.1647  0.1176  ...   0.3686  0.1333  0.1333
  0.4549  0.3686  0.3412  ...   0.5490  0.3294  0.2824
[torch.FloatTensor of size 3x32x32]
label
0

可以看到通过这种方式能够非常方便的访问每个数据点

Dataset

from torch.utils.data import Dataset
# 定义一个子类叫 custom_dataset,继承与 Dataset
class custom_dataset(Dataset):
    def __init__(self, txt_path, transform=None):
        self.transform = transform # 传入数据预处理
        with open(txt_path, 'r') as f:
            lines = f.readlines()

        self.img_list = [i.split()[0] for i in lines] # 得到所有的图像名字
        self.label_list = [i.split()[1] for i in lines] # 得到所有的 label 

    def __getitem__(self, idx): # 根据 idx 取出其中一个
        img = self.img_list[idx]
        label = self.label_list[idx]
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self): # 总数据的多少
        return len(self.label_list)
txt_dataset = custom_dataset('./example_data/train.txt') # 读入 txt 文件
# 取得其中一个数据
data, label = txt_dataset[0]
print(data)
print(label)
1009_2.png
YOU
# 再取一个
data2, label2 = txt_dataset[34]
print(data2)
print(label2)
1046_7.png
LIFE

所以通过这种方式我们也能够非常方便的定义一个数据读入,同时也能够方便的定义数据预处理

DataLoader

from torch.utils.data import DataLoader
train_data1 = DataLoader(folder_set, batch_size=2, shuffle=True) # 将 2 个数据作为一个 batch
for im, label in train_data1: # 访问迭代器
    print(label)
 1
 2
[torch.LongTensor of size 2]


 0
 1
[torch.LongTensor of size 2]


 0
 2
[torch.LongTensor of size 2]


 0
 2
[torch.LongTensor of size 2]


 1
[torch.LongTensor of size 1]

可以看到,通过训练我们可以访问到所有的数据,这些数据被分为了 5 个 batch,前面 4 个都有两个数据,最后一个 batch 只有一个数据,因为一共有 9 个数据,同时顺序也被打乱了

下面我们用自定义的数据读入举例子

train_data2 = DataLoader(txt_dataset, 8, True) # batch size 设置为 8
im, label = next(iter(train_data2)) # 使用这种方式访问迭代器中第一个 batch 的数据
im
('377_10.png',
 '178_1.png',
 '5008_4.png',
 '5050_5.png',
 '716_3.png',
 '415_8.png',
 '858_6.png',
 '5086_10.png')
label
('AUGUST',
 'OTKRIJTE',
 'ASTAIRE',
 'BOONMEE',
 'OF',
 'CAUTION',
 'PROPANE',
 'PECC')

现在有一个需求,希望能够将上面一个 batch 输出的 label 补成相同的长度,短的 label 用 0 填充,我们就需要使用 collate_fn 来自定义我们 batch 的处理方式,下面直接举例子

def collate_fn(batch):
    batch.sort(key=lambda x: len(x[1]), reverse=True) # 将数据集按照 label 的长度从大到小排序
    img, label = zip(*batch) # 将数据和 label 配对取出
    # 填充
    pad_label = []
    lens = []
    max_len = len(label[0])
    for i in range(len(label)):
        temp_label = label[i]
        temp_label += '0' * (max_len - len(label[i]))
        pad_label.append(temp_label)
        lens.append(len(label[i]))
    pad_label 
    return img, pad_label, lens # 输出 label 的真实长度

使用我们自己定义 collate_fn 看看效果

train_data3 = DataLoader(txt_dataset, 8, True, collate_fn=collate_fn) # batch size 设置为 8
im, label, lens = next(iter(train_data3))
im
('5016_1.png',
 '2314_3.png',
 '731_9.png',
 '5019_4.png',
 '208_4.png',
 '5017_12.png',
 '5190_1.png',
 '855_12.png')
label
['LINDSAY',
 'ADDRESS',
 'MAIDEN0',
 'EINER00',
 'INDIA00',
 'GERE000',
 'JAWS000',
 'TD00000']
lens
[7, 7, 6, 5, 5, 4, 4, 2]

可以看到一个 batch 中所有的 label 都从长到短进行排列,同时短的 label 都被补长了,所以使用 collate_fn 能够非常方便的处理一个 batch 中的数据,一般情况下,没有特别的要求,使用 pytorch 中内置的 collate_fn 就可以满足要求了

results matching ""

    No results matching ""