# oss_dataloader.py

import json
import numpy as np
from torch.utils.data import DataLoader
import torch

class ImageCls():
    def __init__(self):
        self.__syn_to_class = {}
        self.__syn_to_label = {}
        with open("imagenet_class_index.json", "rb") as f:
            cls_list = json.load(f)
            for cls, v in cls_list.items():
                syn = v[0]
                label = v[1]
                self.__syn_to_class[syn] = int(cls)
                self.__syn_to_label[int(cls)] = label

    def __len__(self):
        return len(self.__syn_to_label)
    
    def __getitem__(self, syn):
        cls = self.__syn_to_class[syn]
        return cls

class ImageValSet():
    def __init__(self):
        self.__val_to_syn = {}
        with open("ILSVRC2012_val_labels.json", "rb") as f:
            val_syn_list = json.load(f)
            for val, syn in val_syn_list.items():
                self.__val_to_syn[val] = syn
    
    def __getitem__(self, val):
        return self.__val_to_syn[val]

imageCls = ImageCls()
imageValSet = ImageValSet()


IMG_DIM_224 = 224
OSS_URI_BASE = "oss://<YourBucketName>/dataset/imagenet/ILSVRC/Data"

# Specify the accelerated OSS endpoint to download datasets. Replace the endpoint with your actual information.
ENDPOINT = "cn-hangzhou-internal.oss-data-acc.aliyuncs.com" 

def obj_to_tensor(object):
    data = object.read()
    numpy_array_from_binary = np.frombuffer(data, dtype=np.float32).reshape([3, IMG_DIM_224, IMG_DIM_224])
    return torch.from_numpy(numpy_array_from_binary)

def train_tensor_transform(object):
    tensor_from_binary = obj_to_tensor(object)
    key = object.key
    syn = key.split('/')[-2]
    
    return tensor_from_binary, imageCls[syn]

def val_tensor_transform(object):
    tensor_from_binary = obj_to_tensor(object)
    key = object.key
    image_name = key.split('/')[-1].split('.')[0] + ".JPEG"
    return tensor_from_binary, imageCls[imageValSet[image_name]]


def make_oss_dataloader(dataset, batch_size, num_worker, shuffle):
    image_datasets = {
        'train': dataset.from_prefix(OSS_URI_BASE + "/train/", endpoint=ENDPOINT, transform=train_tensor_transform),
        'val': dataset.from_prefix(OSS_URI_BASE + "/val/", endpoint=ENDPOINT, transform=val_tensor_transform),
    }
    dataloaders = {
        'train': DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=shuffle, num_workers=num_worker),
        'val': DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=shuffle, num_workers=num_worker)
    }
    
    return dataloaders
