/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3;

import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.IntConsumer;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import net.yacy.ai.llama3.Model.Arch;
import net.yacy.ai.llama3.Model.Tokenizer;
import net.yacy.ai.llama3.Sampler;
import net.yacy.ai.llama3.Tensor.AbstractFloatTensor;
import net.yacy.ai.llama3.Tensor.DirectBufferFloatTensor;
import net.yacy.ai.llama3.Tensor.FloatTensor;

public final class Llama {
    private final Configuration configuration;
    private final Tokenizer tokenizer;
    private final Weights weights;

    public Llama(Configuration configuration, Tokenizer tokenizer, Weights weights) {
        this.configuration = configuration;
        this.tokenizer = tokenizer;
        this.weights = weights;
    }

    public Configuration configuration() {
        return this.configuration;
    }

    public Tokenizer tokenizer() {
        return this.tokenizer;
    }

    public Weights weights() {
        return this.weights;
    }

    public State createNewState(int batchsize) {
        State state = new State(this.configuration(), batchsize);
        state.latestToken = this.tokenizer().getSpecialTokens().get("<|begin_of_text|>");
        return state;
    }

    static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
        float ss = x.reduce(0, size, 0.0f, (acc, xi) -> acc + xi * xi);
        ss /= (float)size;
        ss += rmsNormEps;
        ss = (float)(1.0 / Math.sqrt(ss));
        for (int i = 0; i < size; ++i) {
            out.setFloat(i, weight.get(i) * (ss * x.getFloat(i)));
        }
    }

    static FloatTensor forward(Llama model, State state, int[] tokens, int position, boolean computeLogits) {
        Configuration config = model.configuration();
        Weights weights = model.weights();
        int dim = config.dim;
        int headSize = config.headSize;
        int kvDim = config.dim * config.numberOfKeyValueHeads / config.numberOfHeads;
        int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads;
        float sqrtHeadSize = (float)Math.sqrt(headSize);
        int nTokens = tokens.length;
        AbstractFloatTensor.parallelFor(0, nTokens, t -> weights.token_embedding_table.copyTo(tokens[t] * dim, state.x[t], 0, dim));
        for (int l = 0; l < config.numberOfLayers; ++l) {
            int curLayer = l;
            AbstractFloatTensor.parallelFor(0, nTokens, t -> Llama.rmsnorm(state.xb[t], state.x[t], weights.rms_att_weight[curLayer], dim, config.rmsNormEps));
            weights.wq[l].matmul(nTokens, state.xb, state.q, dim, dim);
            weights.wk[l].matmul(nTokens, state.xb, state.k, kvDim, dim);
            weights.wv[l].matmul(nTokens, state.xb, state.v, kvDim, dim);
            AbstractFloatTensor.parallelFor(0, nTokens, t -> {
                for (int i = 0; i < dim; i += 2) {
                    int head_dim = i % headSize;
                    float fcr = weights.freq_cis_real.get((position + t) * (headSize / 2) + head_dim / 2);
                    float fci = weights.freq_cis_imag.get((position + t) * (headSize / 2) + head_dim / 2);
                    int rotn = i < kvDim ? 2 : 1;
                    for (int vi = 0; vi < rotn; ++vi) {
                        FloatTensor vec = vi == 0 ? state.q[t] : state.k[t];
                        float v0 = vec.getFloat(i);
                        float v1 = vec.getFloat(i + 1);
                        vec.setFloat(i, v0 * fcr - v1 * fci);
                        vec.setFloat(i + 1, v0 * fci + v1 * fcr);
                    }
                }
            });
            AbstractFloatTensor.parallelFor(0, nTokens, t -> {
                state.k[t].copyTo(0, state.keyCache[curLayer], (position + t) * kvDim, kvDim);
                state.v[t].copyTo(0, state.valueCache[curLayer], (position + t) * kvDim, kvDim);
            });
            if (!computeLogits && curLayer == config.numberOfLayers - 1) {
                state.idxPrevBlock = nTokens - 1;
                return null;
            }
            AbstractFloatTensor.parallelForLong(0L, (long)nTokens * (long)config.numberOfHeads, ht -> {
                int token = (int)(ht / (long)config.numberOfHeads);
                int h = (int)(ht % (long)config.numberOfHeads);
                int qOffset = h * headSize;
                int attOffset = h * config.contextLength;
                for (int t = 0; t <= position + token; ++t) {
                    int keyCacheOffset = t * kvDim + h / kvMul * headSize;
                    float score = state.q[token].dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
                    state.att[token].setFloat(attOffset + t, score /= sqrtHeadSize);
                }
                state.att[token].softmaxInPlace(attOffset, position + token + 1);
                int xbOffset = h * headSize;
                state.xb[token].fillInPlace(xbOffset, headSize, 0.0f);
                for (int t = 0; t <= position + token; ++t) {
                    int vOffset = t * kvDim + h / kvMul * headSize;
                    float a = state.att[token].getFloat(attOffset + t);
                    state.xb[token].saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
                }
            });
            weights.wo[l].matmul(nTokens, state.xb, state.xb2, dim, dim);
            AbstractFloatTensor.parallelFor(0, nTokens, t -> state.x[t].addInPlace(state.xb2[t]));
            AbstractFloatTensor.parallelFor(0, nTokens, t -> Llama.rmsnorm(state.xb[t], state.x[t], weights.rms_ffn_weight[curLayer], dim, config.rmsNormEps));
            weights.w1[l].matmul(nTokens, state.xb, state.hb, config.hiddenDim, dim);
            weights.w3[l].matmul(nTokens, state.xb, state.hb2, config.hiddenDim, dim);
            AbstractFloatTensor.parallelFor(0, nTokens, t -> state.hb[t].mapInPlace(value -> value / (float)(1.0 + Math.exp(-value))));
            AbstractFloatTensor.parallelFor(0, nTokens, t -> state.hb[t].multiplyInPlace(state.hb2[t]));
            weights.w2[l].matmul(nTokens, state.hb, state.xb, dim, config.hiddenDim);
            AbstractFloatTensor.parallelFor(0, nTokens, t -> state.x[t].addInPlace(state.xb[t]));
        }
        AbstractFloatTensor.parallelFor(0, nTokens, t -> Llama.rmsnorm(state.x[t], state.x[t], weights.rms_final_weight, dim, config.rmsNormEps));
        weights.wcls.matmul(state.x[nTokens - 1], state.logits, config.vocabularySize, dim);
        state.idxPrevBlock = nTokens - 1;
        return state.logits;
    }

    public static List<Integer> generateTokens(Llama model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, IntConsumer onTokenGenerated) {
        if (maxTokens < 0 || model.configuration().contextLength < maxTokens) {
            maxTokens = model.configuration().contextLength;
        }
        ArrayList<Integer> generatedTokens = new ArrayList<Integer>(maxTokens);
        int token = state.latestToken;
        int promptIndex = 0;
        for (int position = startPosition; position < maxTokens; ++position) {
            if (promptIndex < promptTokens.size()) {
                int nTokens = Math.min(maxTokens - position, Math.min(promptTokens.size() - promptIndex, state.batchsize));
                int[] tokens = new int[nTokens];
                for (int i = 0; i < nTokens; ++i) {
                    tokens[i] = promptTokens.get(promptIndex + i);
                }
                boolean computeLogits = promptIndex + nTokens >= promptTokens.size();
                Llama.forward(model, state, tokens, position, computeLogits);
                position += nTokens - 1;
                if ((promptIndex += nTokens) < promptTokens.size()) {
                    continue;
                }
            } else {
                Llama.forward(model, state, new int[]{token}, position, true);
            }
            int nextToken = sampler.sampleToken(state.logits);
            generatedTokens.add(nextToken);
            if (onTokenGenerated != null) {
                onTokenGenerated.accept(nextToken);
            }
            if (stopTokens.contains(nextToken)) break;
            state.latestToken = token = nextToken;
        }
        return generatedTokens;
    }

    public static final class Configuration {
        public final Arch arch;
        public final int dim;
        public final int hiddenDim;
        public final int numberOfLayers;
        public final int numberOfHeads;
        public final int numberOfKeyValueHeads;
        public final int vocabularySize;
        public final int contextLength;
        public final boolean sharedWeights;
        public final float rmsNormEps;
        public final float ropeTheta;
        public final int headSize;

        public Configuration(Arch arch, int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, boolean sharedWeights, float rmsNormEps, float ropeTheta) {
            this.arch = arch;
            this.dim = dim;
            this.hiddenDim = hiddenDim;
            this.numberOfLayers = numberOfLayers;
            this.numberOfHeads = numberOfHeads;
            this.numberOfKeyValueHeads = numberOfKeyValueHeads;
            this.vocabularySize = vocabularySize;
            this.contextLength = contextLength;
            this.sharedWeights = sharedWeights;
            this.rmsNormEps = rmsNormEps;
            this.ropeTheta = ropeTheta;
            this.headSize = dim / numberOfHeads;
        }

        public Configuration withContextLength(int newContextLength) {
            if (newContextLength < 0) {
                return this;
            }
            return new Configuration(this.arch, this.dim, this.hiddenDim, this.numberOfLayers, this.numberOfHeads, this.numberOfKeyValueHeads, this.vocabularySize, newContextLength, this.sharedWeights, this.rmsNormEps, this.ropeTheta);
        }
    }

    public static final class Weights {
        public final FloatTensor token_embedding_table;
        public final FloatBuffer[] rms_att_weight;
        public final FloatTensor[] wq;
        public final FloatTensor[] wk;
        public final FloatTensor[] wv;
        public final FloatTensor[] wo;
        public final FloatTensor[] q_bias;
        public final FloatTensor[] k_bias;
        public final FloatTensor[] v_bias;
        public final FloatBuffer[] rms_ffn_weight;
        public final FloatTensor[] w1;
        public final FloatTensor[] w2;
        public final FloatTensor[] w3;
        public final FloatBuffer rms_final_weight;
        public final FloatBuffer freq_cis_real;
        public final FloatBuffer freq_cis_imag;
        public final FloatTensor wcls;

        public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] q_bias, FloatTensor[] k_bias, FloatTensor[] v_bias, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight, FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls) {
            this.token_embedding_table = token_embedding_table;
            this.rms_att_weight = rms_att_weight;
            this.wq = wq;
            this.wk = wk;
            this.wv = wv;
            this.q_bias = q_bias;
            this.k_bias = k_bias;
            this.v_bias = v_bias;
            this.wo = wo;
            this.rms_ffn_weight = rms_ffn_weight;
            this.w1 = w1;
            this.w2 = w2;
            this.w3 = w3;
            this.rms_final_weight = rms_final_weight;
            this.freq_cis_real = freq_cis_real;
            this.freq_cis_imag = freq_cis_imag;
            this.wcls = wcls;
        }
    }

    public static final class State {
        public final int batchsize;
        public final FloatTensor[] x;
        public final FloatTensor[] xb;
        public final FloatTensor[] xb2;
        public final FloatTensor[] hb;
        public final FloatTensor[] hb2;
        public final FloatTensor[] q;
        public final FloatTensor[] k;
        public final FloatTensor[] v;
        public final FloatTensor[] att;
        public final FloatTensor logits;
        public final FloatTensor[] keyCache;
        public final FloatTensor[] valueCache;
        int idxPrevBlock;
        public int latestToken;

        State(Configuration config, int batchsize) {
            this.batchsize = batchsize;
            this.x = State.allocate(batchsize, config.dim);
            this.xb = State.allocate(batchsize, config.dim);
            this.xb2 = State.allocate(batchsize, config.dim);
            this.hb = State.allocate(batchsize, config.hiddenDim);
            this.hb2 = State.allocate(batchsize, config.hiddenDim);
            this.q = State.allocate(batchsize, config.dim);
            this.k = State.allocate(batchsize, config.dim);
            this.v = State.allocate(batchsize, config.dim);
            this.att = State.allocate(batchsize, config.numberOfHeads, config.contextLength);
            this.idxPrevBlock = -1;
            this.logits = DirectBufferFloatTensor.allocate(config.vocabularySize);
            int kvDim = config.dim * config.numberOfKeyValueHeads / config.numberOfHeads;
            this.keyCache = (FloatTensor[])Stream.generate(() -> DirectBufferFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new);
            this.valueCache = (FloatTensor[])Stream.generate(() -> DirectBufferFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new);
        }

        private static FloatTensor[] allocate(int numTokens, int ... dims) {
            return (FloatTensor[])IntStream.range(0, numTokens).mapToObj(i -> DirectBufferFloatTensor.allocate(dims)).toArray(FloatTensor[]::new);
        }
    }
}

