package main

import (
    "os"
    "log"
    "errors"
    "syscall"
    "demo/pkg/crypto/sam"
)

var blockSize     int64 = 0x00400000
var blockSizeStep int64 = 0x00100000
var blockSizeMax  int64 = 0x00500000

type SamCryptoEngine = sam.SamCryptoEngine

func SaveDataToFile(name string, data []byte, off int64, targetSize int64) int {
    var file *os.File = nil
    var err error
    var dataSize int
    var wrSize int

    fileInfo, err := os.Stat(name)
    if err == nil {
        if fileInfo.Size() == targetSize {
            return 0
        }
    }

    if off == 0 {
        file, err = os.Create(name)
        if err != nil {
            log.Printf("Create %s error: %s\n", name, err)
            return -1
        }
    } else {
        file, err = os.OpenFile(name, os.O_WRONLY, 0666)
        if err != nil {
            log.Printf("Open %s fail, error = %s\n", name, err)
            return -1
        }
    }

    defer file.Close()

    dataSize = len(data)
    wrSize, err = file.WriteAt(data, off)
    if err != nil {
        log.Printf("Write data to file(off = %d) error: %s\n", off, err)
        return -1
    }

    if wrSize != dataSize {
        log.Printf("size is not equal: %d %d\n", wrSize, dataSize)
        return -1
    }

    return 0
}

func CompareBlockData(name string, data []byte, off int64) (int, error) {
    var file *os.File
    var err error
    var dataSize int
    var fileSize int64

    file, err = os.Open(name)
    if err != nil {
       log.Printf("Open %s error: %s\n", name, err)
       return -1, err
    }

    defer file.Close()

    fileInfo, err := file.Stat()
    if err != nil {
        log.Printf("File %s stat error: %s\n", name, err)
        return -1, err
    }

    dataSize = len(data)
    fileSize = fileInfo.Size()
    if int64(dataSize) + off > fileSize {
        log.Printf("File data overflow, %d %d\n", int64(dataSize) + off, fileSize)
        return -1, errors.New("data overflow")
    }

    mmaped, err := syscall.Mmap(int(file.Fd()), off,
                       dataSize, syscall.PROT_READ, syscall.MAP_SHARED)
    if err != nil {
       log.Printf("syscall mmap error: %s\n", err)
       return -1, err
    }

    defer syscall.Munmap(mmaped)

    for i := 0; i < dataSize; i = i + 1 {
        if (data[i] != mmaped[i]) {
            log.Printf("data[%d] is not equal, 0x%x 0x%x\n", int64(i) + off, data[i], mmaped[i]);
            return -1, errors.New("dats is not equal")
        }
    }

    outFileName := "model.dec"
    ret := SaveDataToFile(outFileName, data, off, fileSize)
    if ret < 0 {
        log.Printf("SaveDataToFile fail\n")
        return -1, errors.New("SaveDataToFile error")
    }

    return 0, nil
}

func main() {
    var outData []byte
    var err error
    var cipherSize uint64
    var plainSize uint64
    var off int64

    if len(os.Args) < 4 {
        log.Printf("Usage: %s [password] [modelFile] [plainFile]\n", os.Args[0])
        return
    }

    password := os.Args[1]
    modelFile := os.Args[2]
    plainFile := os.Args[3]

    log.Printf("------> password: %s\n", password)
    log.Printf("------> modelFile: %s\n", modelFile)
    log.Printf("------> palinFile: %s\n", plainFile)

    file, err := os.Open(modelFile)
    if err != nil {
        log.Printf("Open %s error: %s\n", modelFile, err)
        return
    }

    defer file.Close()

    fileInfo, err := file.Stat()
    if err != nil {
        log.Printf("Stat %s error: %s\n", modelFile, err)
        return
    }

    ce, wipe := sam.NewSamCryptoEngine([]byte(password), uint32(blockSize), "/tmp", "UID_12345678")
    if ce == nil {
        log.Printf("Failed to create SamCryptoEngine\n")
        return;
    }

    defer wipe()

    cipherSize = uint64(fileInfo.Size())
    plainSize, err = ce.CipherSizeToPlainSize(string(modelFile), cipherSize)
    if err != nil {
        log.Printf("CipherSize to PlainSize error: %s\n", err)
        return;
    }

    if (uint64(blockSizeMax) > plainSize) {
        blockSizeMax = int64(plainSize)
    }

    log.Printf("--------> plainSize: %d\n", plainSize)
    log.Printf("--------> blockSizeMax: %d\n", blockSizeMax)

    for count := 0; blockSize < blockSizeMax; blockSize += blockSizeStep {

        log.Printf("\n")
        log.Printf("--------------->Test Loop(%d): blockSize = 0x%x\n", count, blockSize)
        count += 1

        outSize := blockSize
        outData = make([]byte, outSize)

        for off = 0; uint64(off) < plainSize; off += blockSize {
            if off + blockSize > int64(plainSize) {
                outSize = int64(plainSize) - off
                outData = make([]byte, outSize)
            }

            readSize, err := ce.ReadAtCipherText(0, file, outData, off)
            if err != nil {
                log.Printf("ReadAtCipherText error: %s\n", err);
                return
            }
            if int64(readSize) != outSize {
                log.Printf("readSize is not correct: %d %d\n", readSize, outSize);
                return
            }

            _, err = CompareBlockData(plainFile, outData, off)
            if err != nil {
                log.Printf("CompareBlockData error: %s\n", err)
                return
            }
        }
    }
}

