/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3;

import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Scanner;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.IntConsumer;
import net.yacy.ai.llama3.ChatFormat;
import net.yacy.ai.llama3.Context;
import net.yacy.ai.llama3.Llama;
import net.yacy.ai.llama3.Model.ModelLoader;
import net.yacy.ai.llama3.Sampler;

public class Llama3 {
    private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);
    Llama model;

    public Llama3(Path modelPath, int contextLength) throws IOException {
        this.model = ModelLoader.loadModel(modelPath, contextLength, true);
    }

    public static String toString(Llama model, List<Integer> tokens) {
        return model.tokenizer().decode(tokens);
    }

    public static void main(String[] args) throws IOException {
        System.out.println("JVM version " + String.valueOf(Runtime.version()));
        Path modelPath = Path.of("/Users/admin/git/yacy_search_server", "DATA", "LLMS", "Llama-3.2-1B-Instruct-Q4_0.gguf");
        Context context = new Context("Write a Java program which computes the first 42 prime numbers.", "Be a very good programmer.", 0.0f, 0.95f, 0L, 1024);
        Llama3 llama3 = new Llama3(modelPath, 1024);
        long startTime = System.currentTimeMillis();
        Llama3 llama32 = llama3;
        Objects.requireNonNull(llama32);
        TokenSampler tokenSampler = llama32.new TokenSampler(context);
        List<String> resultToken = tokenSampler.runInstruct(token -> System.out.print((String)token));
        long endTime = System.currentTimeMillis();
        System.out.println("\nToken: " + resultToken.size() + ", " + (double)resultToken.size() * 1000.0 / (double)(endTime - startTime) + " Tokens per second");
    }

    public class TokenSampler {
        Sampler sampler;
        Context context;

        public TokenSampler(Context context) {
            this.context = context;
            this.sampler = Sampler.selectSampler(Llama3.this.model.configuration().vocabularySize, context.temp, context.topp, context.seed);
        }

        public void runInteractive() {
            Integer stopToken;
            Llama.State state = null;
            ArrayList<Integer> conversationTokens = new ArrayList<Integer>();
            ChatFormat chatFormat = new ChatFormat(Llama3.this.model.tokenizer());
            conversationTokens.add(chatFormat.beginOfText);
            if (this.context.systemPrompt != null) {
                conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, this.context.systemPrompt)));
            }
            int startPosition = 0;
            Scanner in = new Scanner(System.in);
            do {
                System.out.print("\n> ");
                System.out.flush();
                String userText = in.nextLine();
                if (state == null) {
                    state = Llama3.this.model.createNewState(BATCH_SIZE);
                }
                conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
                conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
                Set<Integer> stopTokens = chatFormat.getStopTokens();
                List<Integer> responseTokens = Llama.generateTokens(Llama3.this.model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, this.context.maxTokens, this.sampler, token -> {
                    if (!Llama3.this.model.tokenizer().isSpecialToken(token)) {
                        System.out.print(Llama3.this.model.tokenizer().decode(List.of(Integer.valueOf(token))));
                    }
                });
                conversationTokens.addAll(responseTokens);
                startPosition = conversationTokens.size();
                stopToken = null;
                if (responseTokens.isEmpty() || !stopTokens.contains(responseTokens.get(responseTokens.size() - 1))) continue;
                stopToken = responseTokens.get(responseTokens.size() - 1);
                responseTokens.remove(responseTokens.size() - 1);
            } while (stopToken != null);
            System.out.println("Ran out of context length...");
        }

        public List<String> runInstruct(Consumer<String> onTokenGenerated) {
            ArrayList<String> result = new ArrayList<String>();
            this.runInstructOnce(token -> {
                String t = Llama3.this.model.tokenizer().decode(List.of(Integer.valueOf(token)));
                onTokenGenerated.accept(t);
                result.add(t);
            });
            return result;
        }

        public List<Integer> runInstructOnce(IntConsumer onTokenGenerated) {
            Llama.State state = Llama3.this.model.createNewState(BATCH_SIZE);
            ChatFormat chatFormat = new ChatFormat(Llama3.this.model.tokenizer());
            ArrayList<Integer> promptTokens = new ArrayList<Integer>();
            promptTokens.add(chatFormat.beginOfText);
            if (this.context.systemPrompt != null) {
                promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, this.context.systemPrompt)));
            }
            promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, this.context.prompt)));
            promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
            Set<Integer> stopTokens = chatFormat.getStopTokens();
            List<Integer> responseTokens = Llama.generateTokens(Llama3.this.model, state, 0, promptTokens, stopTokens, this.context.maxTokens, this.sampler, onTokenGenerated);
            if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.get(responseTokens.size() - 1))) {
                responseTokens.remove(responseTokens.size() - 1);
            }
            return responseTokens;
        }
    }
}

