/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning.neuralnet;

import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.neuralnet.Layer;
import ciir.umass.edu.learning.neuralnet.PropParameter;
import ciir.umass.edu.learning.neuralnet.RankNet;
import ciir.umass.edu.metric.MetricScorer;
import java.util.List;

public class LambdaRank
extends RankNet {
    protected float[][] targetValue = null;

    public LambdaRank() {
    }

    public LambdaRank(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    @Override
    protected int[][] batchFeedForward(RankList rl) {
        int[][] pairMap = new int[rl.size()][];
        this.targetValue = new float[rl.size()][];
        for (int i = 0; i < rl.size(); ++i) {
            this.addInput(rl.get(i));
            this.propagate(i);
            int count = 0;
            for (int j = 0; j < rl.size(); ++j) {
                if (!(rl.get(i).getLabel() > rl.get(j).getLabel()) && !(rl.get(i).getLabel() < rl.get(j).getLabel())) continue;
                ++count;
            }
            pairMap[i] = new int[count];
            this.targetValue[i] = new float[count];
            int k = 0;
            for (int j = 0; j < rl.size(); ++j) {
                if (!(rl.get(i).getLabel() > rl.get(j).getLabel()) && !(rl.get(i).getLabel() < rl.get(j).getLabel())) continue;
                pairMap[i][k] = j;
                this.targetValue[i][k] = rl.get(i).getLabel() > rl.get(j).getLabel() ? 1.0f : 0.0f;
                ++k;
            }
        }
        return pairMap;
    }

    @Override
    protected void batchBackPropagate(int[][] pairMap, float[][] pairWeight) {
        for (int i = 0; i < pairMap.length; ++i) {
            int j;
            PropParameter p = new PropParameter(i, pairMap, pairWeight, this.targetValue);
            this.outputLayer.computeDelta(p);
            for (j = this.layers.size() - 2; j >= 1; --j) {
                ((Layer)this.layers.get(j)).updateDelta(p);
            }
            this.outputLayer.updateWeight(p);
            for (j = this.layers.size() - 2; j >= 1; --j) {
                ((Layer)this.layers.get(j)).updateWeight(p);
            }
        }
    }

    @Override
    protected RankList internalReorder(RankList rl) {
        return this.rank(rl);
    }

    @Override
    protected float[][] computePairWeight(int[][] pairMap, RankList rl) {
        double[][] changes = this.scorer.swapChange(rl);
        float[][] weight = new float[pairMap.length][];
        for (int i = 0; i < weight.length; ++i) {
            weight[i] = new float[pairMap[i].length];
            for (int j = 0; j < pairMap[i].length; ++j) {
                int sign = rl.get(i).getLabel() > rl.get(pairMap[i][j]).getLabel() ? 1 : -1;
                weight[i][j] = (float)Math.abs(changes[i][pairMap[i][j]]) * (float)sign;
            }
        }
        return weight;
    }

    @Override
    protected void estimateLoss() {
        this.misorderedPairs = 0;
        for (int j = 0; j < this.samples.size(); ++j) {
            RankList rl = (RankList)this.samples.get(j);
            for (int k = 0; k < rl.size() - 1; ++k) {
                double o1 = this.eval(rl.get(k));
                for (int l = k + 1; l < rl.size(); ++l) {
                    double o2;
                    if (!(rl.get(k).getLabel() > rl.get(l).getLabel()) || !(o1 < (o2 = this.eval(rl.get(l))))) continue;
                    ++this.misorderedPairs;
                }
            }
        }
        this.error = 1.0 - this.scoreOnTrainingData;
        this.straightLoss = this.error > this.lastError ? ++this.straightLoss : 0;
        this.lastError = this.error;
    }

    @Override
    public Ranker createNew() {
        return new LambdaRank();
    }

    @Override
    public String name() {
        return "LambdaRank";
    }
}

