/*
 * Decompiled with CFR 0.152.
 */
package io.antmedia.enterprise.tensorflow.detection;

import com.google.protobuf.Message;
import com.google.protobuf.TextFormat;
import io.antmedia.enterprise.tensorflow.detection.Classifier;
import io.antmedia.enterprise.tensorflow.detection.TensorFlowInferenceInterface;
import io.antmedia.enterprise.tensorflow.detection.Utils;
import io.antmedia.enterprise.tensorflow.detection.protos.StringIntLabelMapOuterClass;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;
import javafx.geometry.Rectangle2D;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;

public class TFObjectDetector
implements Classifier {
    private TensorFlowInferenceInterface inference;
    private static final int MAX_RESULTS = 1001;
    private String inputName;
    private Vector<String> labels = new Vector();
    private float[] outputLocations;
    private float[] outputScores;
    private float[] outputClasses;
    private float[] outputNumDetections;
    private String[] outputNames;

    public static Classifier create(String modelDir) throws IOException {
        TFObjectDetector d = new TFObjectDetector();
        try {
            d.labels.addAll(TFObjectDetector.loadLabels(modelDir));
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        byte[] graphDef = Utils.readAllBytesOrExit(Paths.get(modelDir, "model.pb"));
        try (Graph g = new Graph();){
            g.importGraphDef(graphDef);
            d.inputName = "image_tensor";
            GraphOperation inputOp = g.operation(d.inputName);
            if (inputOp == null) {
                throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");
            }
            GraphOperation outputOp1 = g.operation("detection_scores");
            if (outputOp1 == null) {
                throw new RuntimeException("Failed to find output Node 'detection_scores'");
            }
            GraphOperation outputOp2 = g.operation("detection_boxes");
            if (outputOp2 == null) {
                throw new RuntimeException("Failed to find output Node 'detection_boxes'");
            }
            GraphOperation outputOp3 = g.operation("detection_classes");
            if (outputOp3 == null) {
                throw new RuntimeException("Failed to find output Node 'detection_classes'");
            }
            d.outputNames = new String[]{"detection_boxes", "detection_scores", "detection_classes", "num_detections"};
            d.outputScores = new float[1001];
            d.outputLocations = new float[4004];
            d.outputClasses = new float[1001];
            d.outputNumDetections = new float[1];
            d.inference = new TensorFlowInferenceInterface(graphDef);
            TFObjectDetector tFObjectDetector = d;
            return tFObjectDetector;
        }
    }

    private static List<String> loadLabels(String modelDir) throws Exception {
        String text = new String(Files.readAllBytes(Paths.get(modelDir, "label.pbtxt")), StandardCharsets.UTF_8);
        StringIntLabelMapOuterClass.StringIntLabelMap.Builder builder = StringIntLabelMapOuterClass.StringIntLabelMap.newBuilder();
        TextFormat.merge((CharSequence)text, (Message.Builder)builder);
        StringIntLabelMapOuterClass.StringIntLabelMap proto = builder.build();
        int maxId = 0;
        for (StringIntLabelMapOuterClass.StringIntLabelMapItem item : proto.getItemList()) {
            if (item.getId() <= maxId) continue;
            maxId = item.getId();
        }
        ArrayList<String> ret = new ArrayList<String>(maxId + 1);
        while (ret.size() < maxId + 1) {
            ret.add("");
        }
        for (StringIntLabelMapOuterClass.StringIntLabelMapItem item : proto.getItemList()) {
            ret.set(item.getId(), item.getDisplayName());
        }
        return ret;
    }

    @Override
    public List<Classifier.Recognition> recognizeImage(BufferedImage image) {
        this.inference.feedImage(this.inputName, this.getPixelBytes(image));
        this.inference.run(this.outputNames, false);
        this.outputLocations = new float[4004];
        this.outputScores = new float[1001];
        this.outputClasses = new float[1001];
        this.outputNumDetections = new float[1];
        this.inference.fetch(this.outputNames[0], this.outputLocations);
        this.inference.fetch(this.outputNames[1], this.outputScores);
        this.inference.fetch(this.outputNames[2], this.outputClasses);
        this.inference.fetch(this.outputNames[3], this.outputNumDetections);
        PriorityQueue<Classifier.Recognition> pq = new PriorityQueue<Classifier.Recognition>(1, new Comparator<Classifier.Recognition>(){

            @Override
            public int compare(Classifier.Recognition lhs, Classifier.Recognition rhs) {
                return Float.compare(rhs.getConfidence().floatValue(), lhs.getConfidence().floatValue());
            }
        });
        for (int i = 0; i < this.outputScores.length; ++i) {
            float xmin = this.outputLocations[4 * i + 1] * (float)image.getWidth();
            float ymin = this.outputLocations[4 * i] * (float)image.getHeight();
            float xmax = this.outputLocations[4 * i + 3] * (float)image.getWidth();
            float ymax = this.outputLocations[4 * i + 2] * (float)image.getHeight();
            Rectangle2D detection = new Rectangle2D(xmin, ymin, xmax - xmin, ymax - ymin){};
            if (!((double)this.outputScores[i] > 0.5)) continue;
            pq.add(new Classifier.Recognition("" + i, this.labels.get((int)this.outputClasses[i]), Float.valueOf(this.outputScores[i]), detection));
        }
        ArrayList<Classifier.Recognition> recognitions = new ArrayList<Classifier.Recognition>();
        for (int i = 0; i < Math.min(pq.size(), 1001); ++i) {
            recognitions.add(pq.poll());
        }
        return recognitions;
    }

    private byte[][][][] getPixelBytes(BufferedImage image) {
        int imageWidth = image.getWidth();
        int imageHeight = image.getHeight();
        byte[][][][] featuresTensorData = new byte[1][imageHeight][imageWidth][3];
        int[][] imageArray = new int[imageHeight][imageWidth];
        for (int row = 0; row < imageHeight; ++row) {
            for (int column = 0; column < imageWidth; ++column) {
                imageArray[row][column] = image.getRGB(column, row);
                int pixel = image.getRGB(column, row);
                byte red = (byte)(pixel >> 16 & 0xFF);
                byte green = (byte)(pixel >> 8 & 0xFF);
                byte blue = (byte)(pixel & 0xFF);
                featuresTensorData[0][row][column][0] = red;
                featuresTensorData[0][row][column][1] = green;
                featuresTensorData[0][row][column][2] = blue;
            }
        }
        return featuresTensorData;
    }
}

