%matplotlib inline import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display()
3.5.1 Loading the dataset
1 2 3 4 5 6 7
#we download the dataset load it into memory through inner function #the image will be transformed from PIL to float32 by ToTensor instance #and divide 255 to give the pixel from 0 to 1 trans = transforms.ToTensor() mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST(root="../data", transform=trans, train=False, download=True) #测试集不会用于训练
timer = d2l.Timer() for X, y in train_iter:#直到加载完毕 continue f'{timer.stop():.2f} sec'
'2.27 sec'
3.5.3 integrate all components
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Now we define a load_data_fashion_mnist function to obtain and read Fashion-Mnist dataset, the function will return itertools of training set and validation set. In addition, the function will receive an optional #parameters resize, whichn can reshape the images. defload_data_fashion_mnist(batch_size, resize=None): #@save """下载fashion数据集,and load it into memory""" trans = [transforms.ToTensor()]#ToTensor进行从pil变成tensor并且把像素从0-255变成0-1 if resize: trans.insert(0, transforms.Resize(resize))#在列表开头插入一个操作,这是变成64边长的正方形 trans = transforms.Compose(trans)#Compose进行组合转换,所以trans写成一个列表,可以有多个变换,从前向后一个一个变 mnist_train = torchvision.datasets.FashionMNIST( root="../data", train=True, transform=trans, download=True ) mnist_test = torchvision.datasets.FashionMNIST( root="../data", train=False, transform=trans, download=True ) return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
1 2 3 4 5
#Then we assign param:resize to test load_data_fashion_mnist function's resizing function. train_iter, test_iter = load_data_fashion_mnist(32, resize=64) for X, y in train_iter: print(X.shape, X.dtype, y.shape, y.dtype)#For example we exhibit the shape of sample X in the training set break#显然,每个批量32个,单通道图片,64大小
#If we reduce batch_size to 1, whether performance will be impacted? train_iter = data.DataLoader(mnist_train, batch_size=256, shuffle=True, num_workers=get_dataloader_workers()) timer = d2l.Timer() for X, y in train_iter: continue f'{timer.stop(): .2f} sec' #可以看出严重影响了性能,如果增加batch_size到1024呢,到1024和2056性能几乎相同,