Few-shot learning 基于任務(wù)對(duì)模型進(jìn)行訓(xùn)練,在N-way-K-shot中,一個(gè)任務(wù)中的meta-training中含有N類,每一類抽取K個(gè)樣本構(gòu)成support set, query set則是在剛才抽取的N類剩余的樣本中sample一定數(shù)量的樣本(可以是均勻采樣,也可以是不均勻采樣)。
針對(duì)上述情況,我們需要使用不同類別放置在不同文件夾的數(shù)據(jù)集。但有時(shí),數(shù)據(jù)并沒(méi)有按類放置,這時(shí)就需要對(duì)數(shù)據(jù)進(jìn)行處理。
下面以CIFAR100為列(不含N-way-k-shot的采樣):
import os from skimage import io import torchvision as tv import numpy as np import torch def Cifar100(root): character = [[] for i in range(100)] train_set = tv.datasets.CIFAR100(root, train=True, download=True) test_set = tv.datasets.CIFAR100(root, train=False, download=True) dataset = [] for (X, Y) in zip(train_set.train_data, train_set.train_labels): # 將train_set的數(shù)據(jù)和label讀入列表 dataset.append(list((X, Y))) for (X, Y) in zip(test_set.test_data, test_set.test_labels): # 將test_set的數(shù)據(jù)和label讀入列表 dataset.append(list((X, Y))) for X, Y in dataset: character[Y].append(X) # 32*32*3 character = np.array(character) character = torch.from_numpy(character) # 按類打亂 np.random.seed(6) shuffle_class = np.arange(len(character)) np.random.shuffle(shuffle_class) character = character[shuffle_class] # shape = self.character.shape # self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3]) # 將數(shù)據(jù)轉(zhuǎn)成channel在前 meta_training, meta_validation, meta_testing = \ character[:64], character[64:80], character[80:] # meta_training : meta_validation : Meta_testing = 64類:16類:20類 dataset = [] # 釋放內(nèi)存 character = [] os.mkdir(os.path.join(root, 'meta_training')) for i, per_class in enumerate(meta_training): character_path = os.path.join(root, 'meta_training', 'character_' + str(i)) os.mkdir(character_path) for j, img in enumerate(per_class): img_path = character_path + '/' + str(j) + ".jpg" io.imsave(img_path, img) os.mkdir(os.path.join(root, 'meta_validation')) for i, per_class in enumerate(meta_validation): character_path = os.path.join(root, 'meta_validation', 'character_' + str(i)) os.mkdir(character_path) for j, img in enumerate(per_class): img_path = character_path + '/' + str(j) + ".jpg" io.imsave(img_path, img) os.mkdir(os.path.join(root, 'meta_testing')) for i, per_class in enumerate(meta_testing): character_path = os.path.join(root, 'meta_testing', 'character_' + str(i)) os.mkdir(character_path) for j, img in enumerate(per_class): img_path = character_path + '/' + str(j) + ".jpg" io.imsave(img_path, img) if __name__ == '__main__': root = '/home/xie/文檔/datasets/cifar_100' Cifar100(root) print("-----------------")
補(bǔ)充:使用Pytorch對(duì)數(shù)據(jù)集CIFAR-10進(jìn)行分類
1、下載并預(yù)處理數(shù)據(jù)集
2、定義網(wǎng)絡(luò)結(jié)構(gòu)
3、定義損失函數(shù)和優(yōu)化器
4、訓(xùn)練網(wǎng)絡(luò)并更新參數(shù)
5、測(cè)試網(wǎng)絡(luò)效果
#數(shù)據(jù)加載和預(yù)處理 #使用CIFAR-10數(shù)據(jù)進(jìn)行分類實(shí)驗(yàn) import torch as t import torchvision as tv import torchvision.transforms as transforms from torchvision.transforms import ToPILImage show = ToPILImage() # 可以把Tensor轉(zhuǎn)成Image,方便可視化 #定義對(duì)數(shù)據(jù)的預(yù)處理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), #歸一化 ]) #訓(xùn)練集 trainset = tv.datasets.CIFAR10( root = './data/', train = True, download = True, transform = transform ) trainloader = t.utils.data.DataLoader( trainset, batch_size = 4, shuffle = True, num_workers = 2, ) #測(cè)試集 testset = tv.datasets.CIFAR10( root = './data/', train = False, download = True, transform = transform, ) testloader = t.utils.data.DataLoader( testset, batch_size = 4, shuffle = False, num_workers = 2, ) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
初次下載需要一些時(shí)間,運(yùn)行結(jié)束后,顯示如下:
import torch.nn as nn import torch.nn.functional as F import time start = time.time()#計(jì)時(shí) #定義網(wǎng)絡(luò)結(jié)構(gòu) class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = F.max_pool2d(F.relu(self.conv1(x)),2) x = F.max_pool2d(F.relu(self.conv2(x)),2) x = x.view(x.size()[0],-1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() print(net)
顯示net結(jié)構(gòu)如下:
#定義優(yōu)化和損失 loss_func = nn.CrossEntropyLoss() #交叉熵?fù)p失函數(shù) optimizer = t.optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9) #訓(xùn)練網(wǎng)絡(luò) for epoch in range(2): running_loss = 0 for i,data in enumerate(trainloader,0): inputs,labels = data outputs = net(inputs) loss = loss_func(outputs,labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss +=loss.item() if i%2000 ==1999: print('epoch:',epoch+1,'|i:',i+1,'|loss:%.3f'%(running_loss/2000)) running_loss = 0.0 end = time.time() time_using = end - start print('finish training') print('time:',time_using)
結(jié)果如下:
下一步進(jìn)行使用測(cè)試集進(jìn)行網(wǎng)絡(luò)測(cè)試:
#測(cè)試網(wǎng)絡(luò) correct = 0 #定義的預(yù)測(cè)正確的圖片數(shù) total = 0#總共圖片個(gè)數(shù) with t.no_grad(): for data in testloader: images,labels = data outputs = net(images) _,predict = t.max(outputs,1) total += labels.size(0) correct += (predict == labels).sum() print('測(cè)試集中的準(zhǔn)確率為:%d%%'%(100*correct/total))
結(jié)果如下:
簡(jiǎn)單的網(wǎng)絡(luò)訓(xùn)練確實(shí)要比10%的比例高一點(diǎn):)
在GPU中訓(xùn)練:
#在GPU中訓(xùn)練 device = t.device('cuda:0' if t.cuda.is_available() else 'cpu') net.to(device) images = images.to(device) labels = labels.to(device) output = net(images) loss = loss_func(output,labels) loss
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。
標(biāo)簽:益陽(yáng) 鷹潭 黑龍江 常德 四川 上海 惠州 黔西
巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《PyTorch 如何將CIFAR100數(shù)據(jù)按類標(biāo)歸類保存》,本文關(guān)鍵詞 PyTorch,如何,將,CIFAR100,數(shù)據(jù),;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。