
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=3)
        self.fc1 = nn.Linear(20 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 20 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def main():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)

    batch_size = 64
    learning_rate = 0.01
    num_epochs = 20

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # 只有主进程（rank=0）需要下载，其他进程需要等待它完成
    # 让 rank!=0 的进程先在屏障处等待
    if rank != 0:
        dist.barrier()

    # 所有进程都执行数据集的创建
    # 但只有 rank=0 的进程会实际执行下载
    train_dataset = datasets.MNIST(root='/mnt/data/dataSet', train=True, download=(rank == 0), transform=transform)

    # rank=0 进程下载完成后，也到达屏障处，从而释放所有进程
    if rank == 0:
        dist.barrier()

    # 此处，所有进程都已同步，可以继续执行后续代码
    val_dataset = datasets.MNIST(root='/mnt/data/dataSet', train=False, download=False, transform=transform)

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    model = SimpleCNN().to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    if rank == 0:
        writer = SummaryWriter('/mnt/data/output_distributed/runs/mnist_experiment')
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                # 每个 rank 和 local_rank 都打印自己的 loss
                print(f"Rank: {rank}, Local_Rank: {local_rank} -- Train Epoch: {epoch} "
                      f"[{batch_idx * len(data) * world_size}/{len(train_loader.dataset)} "
                      f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

                if rank == 0:
                    writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx)

        # 验证
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
                output = model(data)
                val_loss += criterion(output, target).item() * data.size(0)
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)
        val_loss_tensor = torch.tensor([val_loss], dtype=torch.float32, device=device)
        correct_tensor = torch.tensor([correct], dtype=torch.float32, device=device)
        total_tensor = torch.tensor([total], dtype=torch.float32, device=device)
        dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
        dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)

        val_loss = val_loss_tensor.item() / total_tensor.item()
        val_accuracy = 100. * correct_tensor.item() / total_tensor.item()

        if rank == 0:
            print(f'Validation Loss: {val_loss:.4f}, Accuracy: {int(correct_tensor.item())}/{int(total_tensor.item())} ({val_accuracy:.0f}%)')
            writer.add_scalar('Loss/validation', val_loss, epoch)
            writer.add_scalar('Accuracy/validation', val_accuracy, epoch)
            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                torch.save(model.module.state_dict(), '/mnt/data/output_distributed/best_model.pth')
                print(f'Model saved with accuracy: {best_val_accuracy:.2f}%')
    if rank == 0:
        writer.close()
    dist.destroy_process_group()
    if rank == 0:
        print('Training complete. writer.close()')


if __name__ == "__main__":
    main()
