import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

# 超参数
batch_size = 64  # 每次训练的数据量
learning_rate = 0.01  # 学习率
num_epochs = 20  # 训练轮次

# 检查是否有可用的 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='/mnt/data/dataSet', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='/mnt/data/dataSet', train=False, download=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


# 定义简单的神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 第一层卷积：输入通道1（灰度图像），输出通道10，卷积核5x5
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        # 第二层卷积：输入通道10，输出通道20，卷积核3x3
        self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
        # 全连接层：输入为20*5*5（卷积+池化后的特征图尺寸），输出128
        self.fc1 = nn.Linear(20 * 5 * 5, 128)
        # 输出层：128 -> 10（对应10个数字类别）
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # 输入x形状: [batch, 1, 28, 28]
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)  # [batch, 10, 12, 12]
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)  # [batch, 20, 5, 5]
        x = x.view(-1, 20 * 5 * 5)  # 展平为[batch, 500]
        x = F.relu(self.fc1(x))      # [batch, 128]
        x = self.fc2(x)              # [batch, 10]
        return x


# 实例化模型，并将其移动到 GPU 上（如果可用）
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 创建 TensorBoard 的 SummaryWriter，可用于可视化的查看模型训练过程
writer = SummaryWriter('/mnt/data/output/runs/mnist_experiment')

# 用于保存最高准确率的模型的变量
best_val_accuracy = 0.0

# 训练模型并记录损失和准确率
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)  # 将数据和目标移动到 GPU

        # 清零梯度
        optimizer.zero_grad()
        # 前向传播
        output = model(data)
        # 计算损失
        loss = criterion(output, target)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()

        # 记录训练损失到 TensorBoard
        if batch_idx % 100 == 0:  # 每 100 个批次记录一次
            writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    # 验证模型并记录验证损失和准确率
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():  # 不计算梯度
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)  # 将数据和目标移动到 GPU
            output = model(data)
            val_loss += criterion(output, target).item()  # 累加验证损失
            pred = output.argmax(dim=1, keepdim=True)  # 获取预测标签
            correct += pred.eq(target.view_as(pred)).sum().item()  # 累加正确预测的数量

    val_loss /= len(val_loader)  # 计算平均验证损失
    val_accuracy = 100. * correct / len(val_loader.dataset)  # 计算验证准确率
    print(f'Validation Loss: {val_loss:.4f}, Accuracy: {correct}/{len(val_loader.dataset)} ({val_accuracy:.0f}%)')

    # 记录验证损失和准确率到 TensorBoard
    writer.add_scalar('Loss/validation', val_loss, epoch)
    writer.add_scalar('Accuracy/validation', val_accuracy, epoch)

    # 保存验证准确率最高的模型
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), '/mnt/data/output/best_model.pth')
        print(f'Model saved with accuracy: {best_val_accuracy:.2f}%')

# 关闭 SummaryWriter
writer.close()
print('Training complete. writer.close()')