/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import net.yacy.ai.OllamaClient;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

public class OpenAIClient {
    private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "<EOS_TOKEN>", "</s>", "<|end|>"};
    private final String hoststub;

    public OpenAIClient(String hoststub) {
        this.hoststub = hoststub;
    }

    public static String sendPostRequest(String endpoint, JSONObject data) throws IOException, URISyntaxException {
        URL url = new URI(endpoint).toURL();
        HttpURLConnection conn = (HttpURLConnection)url.openConnection();
        conn.setRequestMethod("POST");
        conn.setRequestProperty("Content-Type", "application/json");
        conn.setDoOutput(true);
        try (OutputStream os = conn.getOutputStream();){
            byte[] input = data.toString().getBytes("utf-8");
            os.write(input, 0, input.length);
        }
        int responseCode = conn.getResponseCode();
        if (responseCode == 200) {
            try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"));){
                String responseLine;
                StringBuilder response = new StringBuilder();
                while ((responseLine = br.readLine()) != null) {
                    response.append(responseLine.trim());
                }
                String string = response.toString();
                return string;
            }
        }
        throw new IOException("Request failed with response code " + responseCode);
    }

    public static String sendGetRequest(String endpoint) throws IOException, URISyntaxException {
        URL url = new URI(endpoint).toURL();
        HttpURLConnection conn = (HttpURLConnection)url.openConnection();
        conn.setRequestMethod("GET");
        int responseCode = conn.getResponseCode();
        if (responseCode == 200) {
            try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"));){
                String responseLine;
                StringBuilder response = new StringBuilder();
                while ((responseLine = br.readLine()) != null) {
                    response.append(responseLine.trim());
                }
                String string = response.toString();
                return string;
            }
        }
        throw new IOException("Request failed with response code " + responseCode);
    }

    public String chat(String model, String prompt, int max_tokens) throws IOException {
        JSONObject data = new JSONObject();
        JSONArray messages = new JSONArray();
        JSONObject systemPrompt = new JSONObject(true);
        JSONObject userPrompt = new JSONObject(true);
        messages.put(systemPrompt);
        messages.put(userPrompt);
        try {
            systemPrompt.put("role", "system");
            systemPrompt.put("content", "Make short answers.");
            userPrompt.put("role", "user");
            userPrompt.put("content", prompt);
            data.put("model", model);
            data.put("temperature", 0.1);
            data.put("max_tokens", max_tokens);
            data.put("messages", messages);
            data.put("stop", new JSONArray(STOPTOKENS));
            data.put("stream", false);
            String response = OpenAIClient.sendPostRequest(this.hoststub + "/v1/chat/completions", data);
            JSONObject responseObject = new JSONObject(response);
            JSONArray choices = responseObject.getJSONArray("choices");
            JSONObject choice = choices.getJSONObject(0);
            JSONObject message2 = choice.getJSONObject("message");
            String content = message2.optString("content", "");
            return content;
        }
        catch (URISyntaxException | JSONException e) {
            throw new IOException(e.getMessage());
        }
    }

    public static String[] stringsFromChat(String answer) {
        int p = answer.indexOf(91);
        int q = answer.indexOf(93);
        if (p < 0 || q < 0 || q < p) {
            return new String[0];
        }
        try {
            JSONArray a = new JSONArray(answer.substring(p, q + 1));
            String[] arr = new String[a.length()];
            for (int i = 0; i < a.length(); ++i) {
                arr[i] = a.getString(i);
            }
            return arr;
        }
        catch (JSONException e) {
            return new String[0];
        }
    }

    public static void main(String[] args) {
        String model = "phi3:3.8b";
        OpenAIClient oaic = new OpenAIClient(OllamaClient.OLLAMA_API_HOST);
        String question = "Who invented the wheel?";
        try {
            String answer = oaic.chat("phi3:3.8b", question, 80);
            System.out.println(answer);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        question = "Make a list of four names from Star Wars movies. Use a JSON Array.";
        try {
            String[] a;
            for (String s : a = OpenAIClient.stringsFromChat(oaic.chat("phi3:3.8b", question, 80))) {
                System.out.println(s);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }
}

