/*
 * Decompiled with CFR 0.152.
 */
package unbbayes.evaluation;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import unbbayes.evaluation.AEvaluation;
import unbbayes.evaluation.EvidenceEvaluation;
import unbbayes.evaluation.exception.EvaluationException;
import unbbayes.prs.bn.TreeVariable;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ExactEvaluation
extends AEvaluation {
    protected TreeVariable targetNode;

    @Override
    public float getError() {
        return 0.0f;
    }

    @Override
    protected float[][] computeCM(List<String> targetNodeNameList, List<String> evidenceNodeNameList) throws EvaluationException {
        this.init(targetNodeNameList, evidenceNodeNameList);
        this.targetNode = this.targetNodeList[0];
        if (this.targetNodeList.length != 1) {
            throw new EvaluationException("For now, just one target node is accepted!");
        }
        float[][] postTGivenE = new float[this.targetNode.getStatesSize()][this.evidenceStatesProduct];
        float[][] postEGivenT = new float[this.evidenceStatesProduct][this.targetNode.getStatesSize()];
        int row = 0;
        while (row < this.statesProduct) {
            float probEGivenT;
            float probTGivenE;
            int[] states = this.getMultidimensionalCoord(row);
            int indexTarget = states[0];
            int indexEvidence = this.getEvidenceLinearCoord(states);
            postTGivenE[indexTarget][indexEvidence] = probTGivenE = this.getProbTargetGivenEvidence(states);
            int[] evidencesStates = new int[states.length - 1];
            int i = 0;
            while (i < evidencesStates.length) {
                evidencesStates[i] = states[i + 1];
                ++i;
            }
            float probE = this.getEvidencesJointProbability(evidencesStates);
            float probT = this.getTargetPriorProbability(states[0]);
            postEGivenT[indexEvidence][indexTarget] = probEGivenT = probTGivenE * probE / probT;
            ++row;
        }
        int N = this.targetNode.getStatesSize();
        float[][] CM = new float[N][N];
        int i = 0;
        while (i < N) {
            float[] arowi = postTGivenE[i];
            float[] crowi = CM[i];
            int k = 0;
            while (k < this.evidenceStatesProduct) {
                float[] browk = postEGivenT[k];
                float aik = arowi[k];
                int j = 0;
                while (j < N) {
                    int n = j;
                    crowi[n] = crowi[n] + aik * browk[j];
                    ++j;
                }
                ++k;
            }
            ++i;
        }
        return CM;
    }

    protected float[] getExatProbTargetGivenEvidence() throws EvaluationException {
        TreeVariable targetNode = this.targetNodeList[0];
        try {
            this.net.compile();
        }
        catch (Exception e) {
            throw new EvaluationException(e.getMessage());
        }
        float[] postProbList = new float[this.statesProduct];
        int sProd = targetNode.getStatesSize();
        byte[][] stateCombinationMatrix = new byte[this.statesProduct][1 + this.evidenceNodeList.length];
        int state = 0;
        int row = 0;
        while (row < this.statesProduct) {
            stateCombinationMatrix[row][0] = (byte)(row / (this.statesProduct / sProd));
            int j = 0;
            while (j < this.evidenceNodeList.length) {
                state = row / (this.statesProduct / (sProd *= this.evidenceNodeList[j].getStatesSize())) % this.evidenceNodeList[j].getStatesSize();
                this.evidenceNodeList[j].addFinding(state);
                stateCombinationMatrix[row][j + 1] = (byte)state;
                ++j;
            }
            sProd = targetNode.getStatesSize();
            try {
                this.net.updateEvidences();
                postProbList[row] = targetNode.getMarginalAt(stateCombinationMatrix[row][0]);
            }
            catch (Exception e) {
                postProbList[row] = 0.0f;
            }
            try {
                this.net.compile();
            }
            catch (Exception e) {
                throw new EvaluationException(e.getMessage());
            }
            ++row;
        }
        return postProbList;
    }

    protected float getProbTargetGivenEvidence(int[] states) throws EvaluationException {
        try {
            this.net.compile();
        }
        catch (Exception e) {
            throw new EvaluationException(e.getMessage());
        }
        int i = 0;
        while (i < this.evidenceNodeList.length) {
            this.evidenceNodeList[i].addFinding(states[1 + i]);
            ++i;
        }
        try {
            this.net.updateEvidences();
            return this.targetNode.getMarginalAt(states[0]);
        }
        catch (Exception e) {
            return 0.0f;
        }
    }

    protected float getEvidencesJointProbability(int[] states) throws EvaluationException {
        try {
            this.net.compile();
        }
        catch (Exception e) {
            throw new EvaluationException(e.getMessage());
        }
        float prob = 1.0f;
        int j = 0;
        while (j < this.evidenceNodeList.length) {
            if (j > 0) {
                this.evidenceNodeList[j - 1].addFinding(states[j - 1]);
            }
            try {
                this.net.updateEvidences();
                prob *= this.evidenceNodeList[j].getMarginalAt(states[j]);
            }
            catch (Exception e) {
                return 0.0f;
            }
            ++j;
        }
        return prob;
    }

    protected float[] getEvidencesJointProbability() throws EvaluationException {
        try {
            this.net.compile();
        }
        catch (Exception e) {
            throw new EvaluationException(e.getMessage());
        }
        float[] jointProbability = new float[this.evidenceStatesProduct];
        int sProd = 1;
        int stateCurrentNode = 0;
        int statePreviousNode = 0;
        int row = 0;
        while (row < this.evidenceStatesProduct) {
            jointProbability[row] = 1.0f;
            int j = 0;
            while (j < this.evidenceNodeList.length) {
                stateCurrentNode = row / (this.evidenceStatesProduct / (sProd *= this.evidenceNodeList[j].getStatesSize())) % this.evidenceNodeList[j].getStatesSize();
                if (j > 0) {
                    this.evidenceNodeList[j - 1].addFinding(statePreviousNode);
                }
                try {
                    this.net.updateEvidences();
                    int n = row;
                    jointProbability[n] = jointProbability[n] * this.evidenceNodeList[j].getMarginalAt(stateCurrentNode);
                }
                catch (Exception e) {
                    jointProbability[row] = 0.0f;
                }
                statePreviousNode = stateCurrentNode;
                ++j;
            }
            sProd = 1;
            try {
                this.net.compile();
            }
            catch (Exception e) {
                throw new EvaluationException(e.getMessage());
            }
            ++row;
        }
        return jointProbability;
    }

    protected float getTargetPriorProbability(int state) throws EvaluationException {
        try {
            this.net.compile();
        }
        catch (Exception e) {
            throw new EvaluationException(e.getMessage());
        }
        return this.targetNode.getMarginalAt(state);
    }

    protected float[] getTargetPriorProbability() throws EvaluationException {
        try {
            this.net.compile();
        }
        catch (Exception e) {
            throw new EvaluationException(e.getMessage());
        }
        float[] priorProb = new float[this.targetNode.getStatesSize()];
        int i = 0;
        while (i < this.targetNode.getStatesSize()) {
            priorProb[i] = this.targetNode.getMarginalAt(i);
            ++i;
        }
        return priorProb;
    }

    public static void main(String[] args) throws Exception {
        boolean runSmallTest = false;
        boolean onlyGCM = true;
        ArrayList<String> targetNodeNameList = new ArrayList<String>();
        ArrayList<String> evidenceNodeNameList = new ArrayList<String>();
        String netFileName = "";
        if (runSmallTest) {
            targetNodeNameList = new ArrayList();
            targetNodeNameList.add("Springler");
            evidenceNodeNameList = new ArrayList();
            evidenceNodeNameList.add("Cloudy");
            evidenceNodeNameList.add("Rain");
            evidenceNodeNameList.add("Wet");
            netFileName = "src/test/resources/testCases/evaluation/WetGrass.xml";
        } else {
            targetNodeNameList = new ArrayList();
            targetNodeNameList.add("TargetType");
            evidenceNodeNameList = new ArrayList();
            evidenceNodeNameList.add("UHRR_Confusion");
            evidenceNodeNameList.add("ModulationFrequency");
            evidenceNodeNameList.add("CenterFrequency");
            evidenceNodeNameList.add("PRI");
            evidenceNodeNameList.add("PRF");
            netFileName = "src/test/resources/testCases/evaluation/AirID.xml";
        }
        ExactEvaluation evaluationExact = new ExactEvaluation();
        evaluationExact.evaluate(netFileName, targetNodeNameList, evidenceNodeNameList, onlyGCM);
        System.out.println("----TOTAL------");
        System.out.println("LCM:\n");
        ExactEvaluation.show(evaluationExact.getEvidenceSetCM());
        System.out.println("\n");
        System.out.println("PCC: ");
        System.out.printf("%2.2f\n", Float.valueOf(evaluationExact.getEvidenceSetPCC() * 100.0f));
        if (!onlyGCM) {
            System.out.println("\n\n\n");
            System.out.println("----MARGINAL------");
            System.out.println("\n\n");
            List<EvidenceEvaluation> list = evaluationExact.getBestMarginalImprovement();
            for (EvidenceEvaluation evidenceEvaluation : list) {
                System.out.println("-" + evidenceEvaluation.getName() + "-");
                System.out.println("\n\n");
                System.out.println("LCM:\n");
                ExactEvaluation.show(evidenceEvaluation.getMarginalCM());
                System.out.println("\n");
                System.out.println("PCC: ");
                System.out.printf("%2.2f\n", Float.valueOf(evidenceEvaluation.getMarginalPCC() * 100.0f));
                System.out.println("\n");
                System.out.println("Marginal Improvement: ");
                System.out.printf("%2.2f\n", Float.valueOf(evidenceEvaluation.getMarginalImprovement() * 100.0f));
                System.out.println("\n\n");
            }
            System.out.println("\n");
            System.out.println("----INDIVIDUAL PCC------");
            System.out.println("\n\n");
            list = evaluationExact.getBestIndividualPCC();
            for (EvidenceEvaluation evidenceEvaluation : list) {
                System.out.println("-" + evidenceEvaluation.getName() + "-");
                System.out.println("\n\n");
                System.out.println("LCM:\n");
                ExactEvaluation.show(evidenceEvaluation.getIndividualLCM());
                System.out.println("\n");
                System.out.println("PCC: ");
                System.out.printf("%2.2f\n", Float.valueOf(evidenceEvaluation.getIndividualPCC() * 100.0f));
                System.out.println("\n\n");
                evidenceEvaluation.setCost(new Random().nextFloat() * 1000.0f);
            }
            System.out.println("\n");
            System.out.println("----INDIVIDUAL PCC------");
            System.out.println("\n\n");
            list = evaluationExact.getBestIndividualCostRate();
            for (EvidenceEvaluation evidenceEvaluation : list) {
                System.out.println("-" + evidenceEvaluation.getName() + "-");
                System.out.println("\n\n");
                System.out.println("PCC: ");
                System.out.printf("%2.2f\n", Float.valueOf(evidenceEvaluation.getIndividualPCC() * 100.0f));
                System.out.println("\n");
                System.out.println("Cost: ");
                System.out.printf("%2.2f\n", Float.valueOf(evidenceEvaluation.getCost()));
                System.out.println("\n");
                System.out.println("Cost Rate: ");
                System.out.printf("%2.2f\n", Float.valueOf(evidenceEvaluation.getMarginalCost() * 100.0f));
                System.out.println("\n\n");
            }
        }
    }
}

