当前位置: 移动技术网 > IT编程>脚本编程>Python > 利用高斯核卷积对MINIST数据集进行去噪

利用高斯核卷积对MINIST数据集进行去噪

2020年07月20日  | 移动技术网IT编程  | 我要评论

import torch

import torchvision

from torch.autograd import Variable

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

import cv2

from torch import nn

import numpy as np

import torch.nn.functional as F

import advertorch.defenses as defenses 

from numpy import *

seed = 2014

 

torch.manual_seed(seed)

np.random.seed(seed)  # Numpy module.

random.seed(seed)  # Python random module.

torch.manual_seed(seed)

 

train_dataset =   datasets.FashionMNIST('./fashionmnist_data/', train=True, download=True,

                       transform=transforms.Compose([

                           transforms.ToTensor(),

                       ]))

 

train_loader = DataLoader(dataset = train_dataset, batch_size = 500, shuffle = True)

 

test_loader = torch.utils.data.DataLoader(

        datasets.FashionMNIST('./fashionmnist_data/', train=False, transform=transforms.Compose([

        transforms.ToTensor(),

        ])),batch_size=1, shuffle=True)

 

epoch = 12


 

class Linear_cliassifer(torch.nn.Module):

    def __init__(self) :

        super(Linear_cliassifer, self).__init__()


 

        self.Gs = defenses.GaussianSmoothing2D(3, 1, 3)

        self.Line1 = torch.nn.Linear(28 * 28, 10)

 

    def forward(self, x):


 

        x = self.Gs(x)

        x = self.Line1(x.view(-1, 28 * 28))

 

        return x


 

net = Linear_cliassifer()

cost = torch.nn.CrossEntropyLoss()



 

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

 

for k in range(epoch):

    sum_loss = 0.0

    train_correct = 0

    for i, data in enumerate(train_loader, 0):

        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)

 

        loss = cost(outputs, labels)

        loss.backward()

        optimizer.step()

 

        print(loss)

        _, id = torch.max(outputs.data, 1) 

        sum_loss += loss.data

        train_correct += torch.sum(id == labels.data)

        #print('[%d,%d] loss:%.03f' % (k + 1, k, sum_loss / len(train_loader)))

    print('        correct:%.03f%%' % (100 * train_correct / len(train_dataset)))

    torch.save(net.state_dict(), 'model/fasion_BL.pt')

本文地址:https://blog.csdn.net/qq_23144435/article/details/107430946

如对本文有疑问, 点击进行留言回复!!

相关文章:

验证码:
移动技术网