確保安裝
一個(gè)例子:
# 導(dǎo)入需要的包 import torch import torch.utils.data.dataset as Dataset import numpy as np # 編造數(shù)據(jù) Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]]) Label = np.asarray([[0], [1], [0], [2]]) # 數(shù)據(jù)[1,2],對(duì)應(yīng)的標(biāo)簽是[0],數(shù)據(jù)[3,4],對(duì)應(yīng)的標(biāo)簽是[1] #創(chuàng)建子類(lèi) class subDataset(Dataset.Dataset): #初始化,定義數(shù)據(jù)內(nèi)容和標(biāo)簽 def __init__(self, Data, Label): self.Data = Data self.Label = Label #返回?cái)?shù)據(jù)集大小 def __len__(self): return len(self.Data) #得到數(shù)據(jù)內(nèi)容和標(biāo)簽 def __getitem__(self, index): data = torch.Tensor(self.Data[index]) label = torch.IntTensor(self.Label[index]) return data, label # 主函數(shù) if __name__ == '__main__': dataset = subDataset(Data, Label) print(dataset) print('dataset大小為:', dataset.__len__()) print(dataset.__getitem__(0)) print(dataset[0])
輸出的結(jié)果
我們有了對(duì)Dataset的一個(gè)整體的把握,再來(lái)分析里面的細(xì)節(jié):
#創(chuàng)建子類(lèi) class subDataset(Dataset.Dataset):
創(chuàng)建子類(lèi)時(shí),繼承的時(shí)Dataset.Dataset,不是一個(gè)Dataset。因?yàn)镈ataset是module模塊,不是class類(lèi),所以需要調(diào)用module里的class才行,因此是Dataset.Dataset!
len和getitem這兩個(gè)函數(shù),前者給出數(shù)據(jù)集的大小**,后者是用于查找數(shù)據(jù)和標(biāo)簽。是最重要的兩個(gè)函數(shù),我們后續(xù)如果要對(duì)數(shù)據(jù)做一些操作基本上都是再這兩個(gè)函數(shù)的基礎(chǔ)上進(jìn)行。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_works=0, clollate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
功能:構(gòu)建可迭代的數(shù)據(jù)裝載器;
dataset:Dataset類(lèi),決定數(shù)據(jù)從哪里讀取及如何讀??;數(shù)據(jù)集的路徑
batchsize:批大?。?br />
num_works:是否多進(jìn)程讀取數(shù)據(jù);只對(duì)于CPU
shuffle:每個(gè)epoch是否打亂;
drop_last:當(dāng)樣本數(shù)不能被batchsize整除時(shí),是否舍棄最后一批數(shù)據(jù);
Epoch:所有訓(xùn)練樣本都已輸入到模型中,稱為一個(gè)Epoch;
Iteration:一批樣本輸入到模型中,稱之為一個(gè)Iteration;
Batchsize:批大小,決定一個(gè)Epoch中有多少個(gè)Iteration;
還是舉一個(gè)實(shí)例:
import torch import torch.utils.data.dataset as Dataset import torch.utils.data.dataloader as DataLoader import numpy as np Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]]) Label = np.asarray([[0], [1], [0], [2]]) #創(chuàng)建子類(lèi) class subDataset(Dataset.Dataset): #初始化,定義數(shù)據(jù)內(nèi)容和標(biāo)簽 def __init__(self, Data, Label): self.Data = Data self.Label = Label #返回?cái)?shù)據(jù)集大小 def __len__(self): return len(self.Data) #得到數(shù)據(jù)內(nèi)容和標(biāo)簽 def __getitem__(self, index): data = torch.Tensor(self.Data[index]) label = torch.IntTensor(self.Label[index]) return data, label if __name__ == '__main__': dataset = subDataset(Data, Label) print(dataset) print('dataset大小為:', dataset.__len__()) print(dataset.__getitem__(0)) print(dataset[0]) #創(chuàng)建DataLoader迭代器,相當(dāng)于我們要先定義好前面說(shuō)的Dataset,然后再用Dataloader來(lái)對(duì)數(shù)據(jù)進(jìn)行一些操作,比如是否需要打亂,則shuffle=True,是否需要多個(gè)進(jìn)程讀取數(shù)據(jù)num_workers=4,就是四個(gè)進(jìn)程 dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4) for i, item in enumerate(dataloader): #可以用enumerate來(lái)提取出里面的數(shù)據(jù) print('i:', i) data, label = item #數(shù)據(jù)是一個(gè)元組 print('data:', data) print('label:', label)
這部分可以直接去看博客:Dataset和DataLoader
總結(jié)下來(lái)時(shí)有兩種方法解決
1.如果在創(chuàng)建Dataset的類(lèi)時(shí),定義__getitem__方法的時(shí)候,將數(shù)據(jù)轉(zhuǎn)變?yōu)镚PU類(lèi)型。則需要將Dataloader里面的參數(shù)num_workers設(shè)置為0,因?yàn)檫@個(gè)參數(shù)是對(duì)于CPU而言的。如果數(shù)據(jù)改成了GPU,則只能單進(jìn)程。如果是在Dataloader的部分,先多個(gè)子進(jìn)程讀取,再轉(zhuǎn)變?yōu)镚PU,則num_wokers不用修改。就是上述__getitem__部分的代碼,移到Dataloader部分。
2.不過(guò)一般來(lái)講,數(shù)據(jù)集和標(biāo)簽不會(huì)像我們上述編輯的那么簡(jiǎn)單。一般再kaggle上的標(biāo)簽都是存在CSV這種文件中。需要pandas的配合。
這個(gè)進(jìn)階可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人臉圖片作為數(shù)據(jù)和人臉特征點(diǎn)作為標(biāo)簽。
到此這篇關(guān)于Pytorch數(shù)據(jù)讀取之Dataset和DataLoader知識(shí)總結(jié)的文章就介紹到這了,更多相關(guān)詳解Dataset和DataLoader內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
標(biāo)簽:六盤(pán)水 宿遷 常州 江蘇 山東 蘭州 駐馬店 成都
巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《Pytorch數(shù)據(jù)讀取之Dataset和DataLoader知識(shí)總結(jié)》,本文關(guān)鍵詞 Pytorch,數(shù)據(jù),讀,取之,Dataset,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。