/*
 * Decompiled with CFR 0.152.
 */
package io.antmedia.enterprise.tensorflow.detection;

import io.antmedia.enterprise.tensorflow.detection.RunStats;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.types.UInt8;

public class TensorFlowInferenceInterface {
    private final String modelName;
    private final Graph g;
    private final Session sess;
    private Session.Runner runner;
    private List<String> feedNames = new ArrayList<String>();
    private List<Tensor<?>> feedTensors = new ArrayList();
    private List<String> fetchNames = new ArrayList<String>();
    private List<Tensor<?>> fetchTensors = new ArrayList();
    private RunStats runStats;

    public TensorFlowInferenceInterface(byte[] graphDef) {
        this.modelName = "";
        this.g = new Graph();
        this.g.importGraphDef(graphDef);
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();
    }

    public TensorFlowInferenceInterface(Graph g) {
        this.modelName = "";
        this.g = g;
        this.sess = new Session(g);
        this.runner = this.sess.runner();
    }

    public void run(String[] outputNames) {
        this.run(outputNames, false);
    }

    public void run(String[] outputNames, boolean enableStats) {
        this.closeFetches();
        for (String o : outputNames) {
            this.fetchNames.add(o);
            TensorId tid = TensorId.parse(o);
            this.runner.fetch(tid.name, tid.outputIndex);
        }
        try {
            if (enableStats) {
                Session.Run r = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
                this.fetchTensors = r.outputs;
                if (this.runStats == null) {
                    this.runStats = new RunStats();
                }
                this.runStats.add(r.metadata);
            } else {
                this.fetchTensors = this.runner.run();
            }
        }
        catch (RuntimeException e) {
            throw e;
        }
        finally {
            this.closeFeeds();
            this.runner = this.sess.runner();
        }
    }

    public Graph graph() {
        return this.g;
    }

    public Operation graphOperation(String operationName) {
        GraphOperation operation = this.g.operation(operationName);
        if (operation == null) {
            throw new RuntimeException("Node '" + operationName + "' does not exist in model '" + this.modelName + "'");
        }
        return operation;
    }

    public String getStatString() {
        return this.runStats == null ? "" : this.runStats.summary();
    }

    public void close() {
        this.closeFeeds();
        this.closeFetches();
        this.sess.close();
        this.g.close();
        if (this.runStats != null) {
            this.runStats.close();
        }
        this.runStats = null;
    }

    protected void finalize() throws Throwable {
        try {
            this.close();
        }
        finally {
            super.finalize();
        }
    }

    public void feed(String inputName, boolean[] src, long ... dims) {
        byte[] b = new byte[src.length];
        for (int i = 0; i < src.length; ++i) {
            b[i] = src[i] ? (byte)1 : 0;
        }
        this.addFeed(inputName, Tensor.create(Boolean.class, (long[])dims, (ByteBuffer)ByteBuffer.wrap(b)));
    }

    public void feed(String inputName, float[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create((long[])dims, (FloatBuffer)FloatBuffer.wrap(src)));
    }

    public void feed(String inputName, int[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create((long[])dims, (IntBuffer)IntBuffer.wrap(src)));
    }

    public void feed(String inputName, long[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create((long[])dims, (LongBuffer)LongBuffer.wrap(src)));
    }

    public void feed(String inputName, double[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create((long[])dims, (DoubleBuffer)DoubleBuffer.wrap(src)));
    }

    public void feed(String inputName, byte[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create(UInt8.class, (long[])dims, (ByteBuffer)ByteBuffer.wrap(src)));
    }

    public void feedString(String inputName, byte[] src) {
        this.addFeed(inputName, Tensors.create((byte[])src));
    }

    public void feedString(String inputName, byte[][] src) {
        this.addFeed(inputName, Tensors.create((byte[][])src));
    }

    public void feedImage(String inputName, byte[][][][] src) {
        this.addFeed(inputName, Tensor.create((Object)src, UInt8.class));
    }

    public void feed(String inputName, FloatBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create((long[])dims, (FloatBuffer)src));
    }

    public void feed(String inputName, IntBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create((long[])dims, (IntBuffer)src));
    }

    public void feed(String inputName, LongBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create((long[])dims, (LongBuffer)src));
    }

    public void feed(String inputName, DoubleBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create((long[])dims, (DoubleBuffer)src));
    }

    public void feed(String inputName, ByteBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create(UInt8.class, (long[])dims, (ByteBuffer)src));
    }

    public void fetch(String outputName, float[] dst) {
        this.fetch(outputName, FloatBuffer.wrap(dst));
    }

    public void fetch(String outputName, int[] dst) {
        this.fetch(outputName, IntBuffer.wrap(dst));
    }

    public void fetch(String outputName, long[] dst) {
        this.fetch(outputName, LongBuffer.wrap(dst));
    }

    public void fetch(String outputName, double[] dst) {
        this.fetch(outputName, DoubleBuffer.wrap(dst));
    }

    public void fetch(String outputName, byte[] dst) {
        this.fetch(outputName, ByteBuffer.wrap(dst));
    }

    public void fetch(String outputName, FloatBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    public void fetch(String outputName, IntBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    public void fetch(String outputName, LongBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    public void fetch(String outputName, DoubleBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    public void fetch(String outputName, ByteBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    private void addFeed(String inputName, Tensor<?> t) {
        TensorId tid = TensorId.parse(inputName);
        this.runner.feed(tid.name, tid.outputIndex, t);
        this.feedNames.add(inputName);
        this.feedTensors.add(t);
    }

    private Tensor<?> getTensor(String outputName) {
        int i = 0;
        for (String n : this.fetchNames) {
            if (n.equals(outputName)) {
                return this.fetchTensors.get(i);
            }
            ++i;
        }
        throw new RuntimeException("Node '" + outputName + "' was not provided to run(), so it cannot be read");
    }

    private void closeFeeds() {
        for (Tensor<?> t : this.feedTensors) {
            t.close();
        }
        this.feedTensors.clear();
        this.feedNames.clear();
    }

    private void closeFetches() {
        for (Tensor<?> t : this.fetchTensors) {
            t.close();
        }
        this.fetchTensors.clear();
        this.fetchNames.clear();
    }

    private static class TensorId {
        String name;
        int outputIndex;

        private TensorId() {
        }

        public static TensorId parse(String name) {
            TensorId tid = new TensorId();
            int colonIndex = name.lastIndexOf(58);
            if (colonIndex < 0) {
                tid.outputIndex = 0;
                tid.name = name;
                return tid;
            }
            try {
                tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
                tid.name = name.substring(0, colonIndex);
            }
            catch (NumberFormatException e) {
                tid.outputIndex = 0;
                tid.name = name;
            }
            return tid;
        }
    }
}

