/*
 * Decompiled with CFR 0.152.
 */
package com.aliyun.odps.udf.example.speech;

import com.aliyun.odps.Column;
import com.aliyun.odps.OdpsType;
import com.aliyun.odps.data.ArrayRecord;
import com.aliyun.odps.data.Record;
import com.aliyun.odps.io.InputStreamSet;
import com.aliyun.odps.io.SourceInputStream;
import com.aliyun.odps.udf.DataAttributes;
import com.aliyun.odps.udf.ExecutionContext;
import com.aliyun.odps.udf.Extractor;
import com.aliyun.odps.udf.example.speech.UtteranceLabel;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.HashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class SpeechSentenceSnrExtractor
extends Extractor {
    private static final Log logger = LogFactory.getLog(SpeechSentenceSnrExtractor.class);
    private static final String MLF_FILE_ATTRIBUTE_KEY = "mlfFileName";
    private static final String SPEECH_SAMPLE_RATE_KEY = "speechSampleRateInKHz";
    private String mlfFileName;
    private HashMap<String, UtteranceLabel> utteranceLabels = new HashMap();
    private InputStreamSet inputs;
    private DataAttributes attributes;
    private double sampleRateInKHz;

    public void setup(ExecutionContext ctx, InputStreamSet inputs, DataAttributes attributes) {
        this.inputs = inputs;
        this.attributes = attributes;
        this.mlfFileName = this.attributes.getValueByKey(MLF_FILE_ATTRIBUTE_KEY);
        if (this.mlfFileName == null) {
            throw new IllegalArgumentException("A mlf file must be specified in extractor attribute.");
        }
        String sampleRateInKHzStr = this.attributes.getValueByKey(SPEECH_SAMPLE_RATE_KEY);
        if (sampleRateInKHzStr == null) {
            throw new IllegalArgumentException("The speech sampling rate must be specified in extractor attribute.");
        }
        this.sampleRateInKHz = Double.parseDouble(sampleRateInKHzStr);
        try {
            BufferedInputStream inputStream = ctx.readResourceFileAsStream(this.mlfFileName);
            this.loadMlfLabelsFromResource(inputStream);
            inputStream.close();
        }
        catch (IOException e) {
            throw new RuntimeException("reading model from mlf failed with exception " + e.getMessage());
        }
    }

    public Record extract() throws IOException {
        SourceInputStream inputStream = this.inputs.next();
        if (inputStream == null) {
            return null;
        }
        String fileName = inputStream.getFileName();
        fileName = fileName.substring(fileName.lastIndexOf(47) + 1);
        logger.info((Object)("Processing wav file " + fileName));
        String id = fileName.substring(0, fileName.lastIndexOf(46));
        long fileSize = inputStream.getFileSize();
        if (fileSize > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Do not support speech file larger than 2G bytes");
        }
        byte[] buffer = new byte[(int)fileSize];
        Column[] outputColumns = this.attributes.getRecordColumns();
        ArrayRecord record = new ArrayRecord(outputColumns);
        if (outputColumns.length != 2 || outputColumns[0].getType() != OdpsType.DOUBLE || outputColumns[1].getType() != OdpsType.STRING) {
            throw new IllegalArgumentException("Expecting output to of schema double|string.");
        }
        int readSize = inputStream.readToEnd(buffer);
        inputStream.close();
        double snr = this.computeSnr(id, buffer, readSize);
        record.setDouble(0, Double.valueOf(snr));
        record.setString(1, id);
        logger.info((Object)String.format("file [%s] snr computed to be [%f]db", fileName, snr));
        return record;
    }

    public void close() {
    }

    private void loadMlfLabelsFromResource(BufferedInputStream fileInputStream) throws IOException {
        String line;
        BufferedReader br = new BufferedReader(new InputStreamReader(fileInputStream));
        String id = "";
        while ((line = br.readLine()) != null) {
            if (line.trim().isEmpty()) continue;
            if (line.startsWith("id:")) {
                id = line.split(":")[1].trim();
                continue;
            }
            this.utteranceLabels.put(id, new UtteranceLabel(id, line, " "));
        }
    }

    private double computeSnr(String id, byte[] buffer, int validBufferLen) {
        int headerLength = 44;
        if (validBufferLen < 44) {
            throw new IllegalArgumentException("A wav buffer must be at least larger than standard wav header size.");
        }
        int dataLen = (validBufferLen - 44) / 2;
        int sampleCountPerFrame = (int)this.sampleRateInKHz * 10;
        if (dataLen % sampleCountPerFrame != 0) {
            throw new IllegalArgumentException(String.format("Invalid wav file where dataLen %d does not divide sampleCountPerFrame %d", dataLen, sampleCountPerFrame));
        }
        int frameCount = dataLen / sampleCountPerFrame;
        UtteranceLabel utteranceLabel = this.utteranceLabels.get(id);
        if (utteranceLabel == null) {
            throw new IllegalArgumentException(String.format("Cannot find label of id %s from MLF.", id));
        }
        ArrayList<Long> labels = utteranceLabel.getLabels();
        if (labels.size() + 2 != frameCount) {
            throw new IllegalArgumentException(String.format("Mismatched frame labels size % d and frameCount %d.", labels.size() + 2, frameCount));
        }
        int offset = 44;
        short[] data = new short[sampleCountPerFrame];
        double[] energies = new double[frameCount];
        for (int i = 0; i < frameCount; ++i) {
            ByteBuffer.wrap(buffer, offset, sampleCountPerFrame * 2).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(data);
            double frameEnergy = 0.0;
            for (int j = 0; j < sampleCountPerFrame; ++j) {
                frameEnergy += (double)(data[j] * data[j]);
            }
            energies[i] = frameEnergy;
            offset += sampleCountPerFrame * 2;
        }
        double averageSpeechPower = 0.0;
        double averageNoisePower = 1.0E-8;
        int speechframeCount = 0;
        int noiseframeCount = 0;
        for (int i = 0; i < labels.size(); ++i) {
            if (labels.get(i) == 0L) {
                averageNoisePower += energies[i];
                ++noiseframeCount;
                continue;
            }
            averageSpeechPower += energies[i];
            ++speechframeCount;
        }
        if (noiseframeCount > 0) {
            averageNoisePower /= (double)noiseframeCount;
        } else {
            return 100.0;
        }
        if (speechframeCount <= 0) {
            return -100.0;
        }
        return 10.0 * Math.log10((averageSpeechPower /= (double)speechframeCount) / averageNoisePower);
    }
}

