# upload_dataset.py

from torchvision import transforms
from PIL import Image
import oss2
import os
from oss2.credentials import EnvironmentVariableCredentialsProvider

# In this example, the internal endpoint for the China (Hangzhou) region is used.
OSS_ENDPOINT = "oss-cn-hangzhou-internal.aliyuncs.com"    # The internal OSS endpoint. 
OSS_BUCKET_NAME = "<YourBucketName>"     # The name of the bucket. 
BUCKET_REGION = "cn-hangzhou"     # The ID of the region in which the bucket is located. 

# Specify a custom prefix in the names of the datasets in the bucket.
OSS_URI_BASE = "dataset/imagenet/ILSVRC/Data"

def to_tensor(img_path):
    IMG_DIM_224 = 224
    compose = transforms.Compose([
            transforms.RandomResizedCrop(IMG_DIM_224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    img = Image.open(img_path).convert('RGB')
    img_tensor = compose(img)
    numpy_data = img_tensor.numpy()
    binary_data = numpy_data.tobytes()
    return binary_data

def list_dir(directory):
    for root, _, files in os.walk(directory):
        rel_root = os.path.relpath(root, start=directory)
        for file in files:
            rel_filepath = os.path.join(rel_root, file) if rel_root != '.' else file
            yield rel_filepath
IMG_DIR_BASE = "./dataset" 
"""
    IMG_DIR_BASE stores the local path of the images. You can specify the local path by using an absolute or relative path.
    The structure of the local path must be consistent with that of the datasets:
    {IMG_DIR_BASE}/
        train/
            n10148035/
                n10148035_10034.JPEG
                n10148035_10217.JPEG
                ... 
            n11879895/
                n11879895_10016.JPEG
                n11879895_10019.JPEG
                ...
            ...
        val/
            ILSVRC2012_val_00000001.JPEG
            ILSVRC2012_val_00000002.JPEG
            ...
"""

bucket_api = oss2.Bucket(oss2.ProviderAuthV4(EnvironmentVariableCredentialsProvider()), OSS_ENDPOINT, OSS_BUCKET_NAME, region=BUCKET_REGION)
        
for phase in [ "val", "train"]:
    IMG_DIR = "%s/%s" % (IMG_DIR_BASE, phase)
    for _, img_relative_path in enumerate(list_dir(IMG_DIR)):
        img_bin_name = img_relative_path.replace(".JPEG", ".pt")
        object_key = "%s/%s/%s" % (OSS_URI_BASE, phase, img_bin_name)
        bucket_api.put_object(object_key, to_tensor("%s/%s" % (IMG_DIR,img_relative_path)))
