import time
import torch
from torch import nn,optim
import numpy as np
import torch.nn.functional as F
from torch.optim import lr_scheduler
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict    
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool2d,self).__init__()
    def forward(self,x):
        return F.avg_pool2d(x,kernel_size=x.size()[2:])
class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)#[0]代表的是批,选择批然后展开。
class Residual(nn.Module):
    def __init__(self,in_channels,out_channels,use_1x1conv=False,stride=1):
        super(Residual,self).__init__()
        self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,stride=stride)
        self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1)
        if use_1x1conv:
            self.conv3=nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride)
        else:
            self.conv3=None
        self.bn1=nn.BatchNorm2d(out_channels)
        self.bn2=nn.BatchNorm2d(out_channels)
    def forward(self,X):
        Y=F.relu(self.bn1(self.conv1(X)))
        Y=self.bn2(self.conv2(Y))
        if self.conv3:
            X=self.conv3(X)
        return F.relu(Y+X)
def resnet_block(in_channels,out_channels,num_residuals,first_block=False):
    if first_block:
        assert in_channels==out_channels
    blk=[]
    for i in range(num_residuals):
        if i==0 and not first_block:
            blk.append(Residual(in_channels,out_channels,use_1x1conv=True,stride=2))
        else:
            blk.append(Residual(out_channels,out_channels))
    return nn.Sequential(*blk)
net=nn.Sequential(nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),nn.BatchNorm2d(64),
nn.ReLU())#,nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
net.add_module("resnet_block1",resnet_block(64,64,2,first_block=True))
net.add_module("resnet_block3",resnet_block(64,128,2))
net.add_module("resnet_block4",resnet_block(128,256,2))
net.add_module("resnet_block5",resnet_block(256,512,2))
net.add_module("global_avg_pool", GlobalAvgPool2d())
net.add_module("fc",nn.Sequential(FlattenLayer(),nn.Linear(512,10)))

test1=unpickle('test_batch')
test_1=np.reshape(test1[b'data'],(10000,3,32,32))
test_lable=test1[b'labels']
train1=unpickle('data_batch_1')
train_1=np.reshape(train1[b'data'],(10000,3,32,32))
lable1=train1[b'labels']
train2=unpickle('data_batch_2')
train_2=np.reshape(train2[b'data'],(10000,3,32,32))
lable2=train2[b'labels']
train3=unpickle('data_batch_3')
train_3=np.reshape(train1[b'data'],(10000,3,32,32))
lable3=train3[b'labels']
train4=unpickle('data_batch_4')
train_4=np.reshape(train4[b'data'],(10000,3,32,32))
lable4=train4[b'labels']
train5=unpickle('data_batch_5')
train_5=np.reshape(train5[b'data'],(10000,3,32,32))
lable5=train5[b'labels']
train1=torch.Tensor(train_1)
train2=torch.Tensor(train_2)
train3=torch.Tensor(train_3)
train4=torch.Tensor(train_4)
train5=torch.Tensor(train_5)

lable1=torch.Tensor(lable1)
lable2=torch.Tensor(lable2)
lable3=torch.Tensor(lable3)
lable4=torch.Tensor(lable4)
lable5=torch.Tensor(lable5)
train_1=torch.cat((train1,train2,train3,train4,train5),0)
lable1=torch.cat((lable1,lable2,lable3,lable4,lable5),0)
train_1=torch.FloatTensor(train_1)
test_1=torch.FloatTensor(test_1)
test_lable=torch.tensor(test_lable,dtype=torch.long)
lable1=torch.tensor(lable1,dtype=torch.long)
batch_size=100
lr,num_epochs=0.0001,2500
optimizer=torch.optim.Adam(net.parameters(),lr=lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
device = torch.device('cuda' if torch.cuda.is_available() else'cpu')
net=net.to(device)
train_1=train_1.to(device)
lable1=lable1.to(device)
test_1=test_1.to(device)
test_lable=test_lable.to(device)
loss=torch.nn.CrossEntropyLoss()

for i in range(num_epochs):
    bach_count=0
    scheduler.step()
    n,train_acc_sum,test_acc,train_l_sum,batch_count=0,0,0,0,0
    for j in range(500):
        train=train_1[j*batch_size:j*batch_size+batch_size]
        lable=lable1[j*batch_size:j*batch_size+batch_size]
        y_hat=net(train)
        l=loss(y_hat,lable)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        train_acc_sum += (y_hat.argmax(dim=1) ==lable).sum().cpu().item()
        n+=lable.shape[0]
        batch_count=batch_count+1
        train_l_sum += l.cpu().item()  
    for j in range(100):
        train=test_1[j*batch_size:j*batch_size+batch_size]
        lable=test_lable[j*batch_size:j*batch_size+batch_size]
        y_hat=net(train)
        test_acc += (y_hat.argmax(dim=1) ==lable).sum().cpu().item()
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'%(i+1, train_l_sum / batch_count,train_acc_sum / n, test_acc))