/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3.Model;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import net.yacy.ai.llama3.Model.GGMLTensorEntry;
import net.yacy.ai.llama3.Model.GGMLType;
import net.yacy.ai.llama3.Model.Pair;
import net.yacy.ai.llama3.Tensor.AbstractFloatTensor;

public final class GGUF {
    private static final int GGUF_MAGIC = 1179993927;
    private static final int DEFAULT_ALIGNMENT = 32;
    private static final List<Integer> SUPPORTED_GGUF_VERSIONS = List.of(Integer.valueOf(2), Integer.valueOf(3));
    private int magic;
    private int version;
    private int tensorCount;
    private int alignment;
    private int metadata_kv_count;
    private Map<String, Object> metadata;
    private Map<String, GGUFTensorInfo> tensorInfos;
    private long tensorDataOffset;
    private Map<String, GGMLTensorEntry> tensorEntries;
    private final ByteBuffer BB_1 = ByteBuffer.allocate(1).order(ByteOrder.LITTLE_ENDIAN);
    private final ByteBuffer BB_2 = ByteBuffer.allocate(2).order(ByteOrder.LITTLE_ENDIAN);
    private final ByteBuffer BB_4 = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
    private final ByteBuffer BB_8 = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN);

    public Map<String, GGUFTensorInfo> getTensorInfos() {
        return this.tensorInfos;
    }

    public long getTensorDataOffset() {
        return this.tensorDataOffset;
    }

    public Map<String, Object> getMetadata() {
        return this.metadata;
    }

    public Map<String, GGMLTensorEntry> getTensorEntries() {
        return this.tensorEntries;
    }

    public static GGUF loadModel(Path modelPath) throws IOException {
        try (FileChannel fileChannel = FileChannel.open(modelPath, new OpenOption[0]);){
            GGUF gguf = new GGUF();
            gguf.loadModelImpl(fileChannel);
            GGUF gGUF = gguf;
            return gGUF;
        }
    }

    private void loadModelImpl(FileChannel fileChannel) throws IOException {
        this.readHeader(fileChannel);
        this.tensorInfos = new HashMap<String, GGUFTensorInfo>(this.tensorCount);
        for (int i = 0; i < this.tensorCount; ++i) {
            GGUFTensorInfo ti = this.readTensorInfo(fileChannel);
            assert (!this.tensorInfos.containsKey(ti.name));
            this.tensorInfos.put(ti.name, ti);
        }
        long _padding = (long)this.getAlignment() - fileChannel.position() % (long)this.getAlignment();
        fileChannel.position(fileChannel.position() + _padding);
        this.tensorDataOffset = fileChannel.position();
    }

    static Map<String, GGMLTensorEntry> loadTensors(FileChannel fileChannel, long tensorDataOffset, Map<String, GGUFTensorInfo> tensorInfos) throws IOException {
        long totalDataSize = fileChannel.size() - tensorDataOffset;
        HashMap<String, GGMLTensorEntry> tensorEntries = new HashMap<String, GGMLTensorEntry>(tensorInfos.size());
        if (totalDataSize <= Integer.MAX_VALUE) {
            MappedByteBuffer fullBuffer = (MappedByteBuffer)fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, totalDataSize).order(ByteOrder.nativeOrder());
            for (Map.Entry<String, GGUFTensorInfo> entry2 : tensorInfos.entrySet()) {
                GGUFTensorInfo ti = entry2.getValue();
                long offset = ti.offset();
                int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(AbstractFloatTensor.numberOfElements(ti.dimensions())));
                MappedByteBuffer tensorBuffer = fullBuffer.duplicate();
                tensorBuffer.position((int)offset);
                tensorBuffer.limit((int)offset + sizeInBytes);
                tensorBuffer = (MappedByteBuffer)tensorBuffer.slice().order(ByteOrder.nativeOrder());
                tensorEntries.put(ti.name(), new GGMLTensorEntry(ti.name(), ti.ggmlType(), ti.dimensions(), tensorBuffer));
            }
            return tensorEntries;
        }
        List<Object> boundaries = new ArrayList();
        for (GGUFTensorInfo ti : tensorInfos.values()) {
            long start = ti.offset();
            long end = start + ti.ggmlType().byteSizeFor(AbstractFloatTensor.numberOfElements(ti.dimensions()));
            boundaries.add(start);
            boundaries.add(end);
        }
        boundaries = boundaries.stream().distinct().sorted().collect(Collectors.toList());
        ArrayList<MappedSegment> mappings = new ArrayList<MappedSegment>();
        long currentPos = tensorDataOffset;
        Iterator<Object> offset = boundaries.iterator();
        while (offset.hasNext()) {
            long chunkSize;
            long boundary = (Long)offset.next();
            if (boundary <= currentPos) continue;
            for (long mappingSize = boundary - currentPos; mappingSize > 0L; mappingSize -= chunkSize) {
                chunkSize = Math.min(mappingSize, Integer.MAX_VALUE);
                MappedByteBuffer buffer = (MappedByteBuffer)fileChannel.map(FileChannel.MapMode.READ_ONLY, currentPos, chunkSize).order(ByteOrder.nativeOrder());
                mappings.add(new MappedSegment(currentPos, buffer));
                currentPos += chunkSize;
            }
        }
        if (currentPos < fileChannel.size()) {
            long chunkSize;
            for (long remaining = fileChannel.size() - currentPos; remaining > 0L; remaining -= chunkSize) {
                chunkSize = Math.min(remaining, Integer.MAX_VALUE);
                MappedByteBuffer buffer = (MappedByteBuffer)fileChannel.map(FileChannel.MapMode.READ_ONLY, currentPos, chunkSize).order(ByteOrder.nativeOrder());
                mappings.add(new MappedSegment(currentPos, buffer));
                currentPos += chunkSize;
            }
        }
        for (Map.Entry<String, GGUFTensorInfo> entry3 : tensorInfos.entrySet()) {
            GGUFTensorInfo ti = entry3.getValue();
            String name = ti.name();
            long tensorOffset = ti.offset() + tensorDataOffset;
            long tensorSize = ti.ggmlType().byteSizeFor(AbstractFloatTensor.numberOfElements(ti.dimensions()));
            long tensorEnd = tensorOffset + tensorSize;
            List overlappingSegments = mappings.stream().filter(seg -> seg.startOffset <= tensorEnd && seg.startOffset + (long)seg.buffer.capacity() > tensorOffset).sorted(Comparator.comparingLong(seg -> seg.startOffset)).collect(Collectors.toList());
            if (overlappingSegments.isEmpty()) {
                throw new IOException("Tensor " + name + " not found in any mapped segment");
            }
            if (overlappingSegments.size() == 1) {
                MappedSegment segment = (MappedSegment)overlappingSegments.get(0);
                int bufferOffset = (int)(tensorOffset - segment.startOffset);
                int sizeInBytes = (int)tensorSize;
                MappedByteBuffer tensorBuffer = segment.buffer.duplicate();
                tensorBuffer.position(bufferOffset);
                tensorBuffer.limit(bufferOffset + sizeInBytes);
                tensorBuffer = (MappedByteBuffer)tensorBuffer.slice().order(ByteOrder.nativeOrder());
                tensorEntries.put(name, new GGMLTensorEntry(name, ti.ggmlType(), ti.dimensions(), tensorBuffer));
                continue;
            }
            ByteBuffer combinedBuffer = ByteBuffer.allocateDirect((int)tensorSize).order(ByteOrder.nativeOrder());
            long remainingBytes = tensorSize;
            long currentTensorPos = tensorOffset;
            for (MappedSegment segment : overlappingSegments) {
                long segmentEnd = segment.startOffset + (long)segment.buffer.capacity();
                long copyStart = Math.max(currentTensorPos, segment.startOffset);
                long copyEnd = Math.min(tensorEnd, segmentEnd);
                int bytesToCopy = (int)(copyEnd - copyStart);
                int srcPos = (int)(copyStart - segment.startOffset);
                segment.buffer.position(srcPos);
                byte[] temp = new byte[bytesToCopy];
                segment.buffer.get(temp);
                combinedBuffer.put(temp);
                currentTensorPos += (long)bytesToCopy;
                if ((remainingBytes -= (long)bytesToCopy) > 0L) continue;
                break;
            }
            combinedBuffer.flip();
            tensorEntries.put(name, new GGMLTensorEntry(name, ti.ggmlType(), ti.dimensions(), combinedBuffer));
        }
        return tensorEntries;
    }

    private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
        int ggmlTypeId = this.readInt(fileChannel);
        return GGMLType.fromId(ggmlTypeId);
    }

    private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
        String name = this.readString(fileChannel);
        assert (name.length() <= 64);
        int n_dimensions = this.readInt(fileChannel);
        assert (n_dimensions <= 4);
        int[] dimensions = new int[n_dimensions];
        for (int i = 0; i < n_dimensions; ++i) {
            dimensions[i] = Math.toIntExact(this.readLong(fileChannel));
        }
        GGMLType ggmlType = this.readGGMLType(fileChannel);
        long offset = this.readLong(fileChannel);
        assert (offset % (long)this.getAlignment() == 0L);
        return new GGUFTensorInfo(name, dimensions, ggmlType, offset);
    }

    private String readString(FileChannel fileChannel) throws IOException {
        int len = Math.toIntExact(this.readLong(fileChannel));
        byte[] bytes = new byte[len];
        int bytesRead = fileChannel.read(ByteBuffer.wrap(bytes));
        assert (len == bytesRead);
        return new String(bytes, StandardCharsets.UTF_8);
    }

    private Pair<String, Object> readKeyValuePair(FileChannel fileChannel) throws IOException {
        String key = this.readString(fileChannel);
        assert (key.length() < 65536);
        assert (key.codePoints().allMatch(cp -> 97 <= cp && cp <= 122 || 48 <= cp && cp <= 57 || cp == 95 || cp == 46));
        Object value = this.readMetadataValue(fileChannel);
        return new Pair<String, Object>(key, value);
    }

    private Object readMetadataValue(FileChannel fileChannel) throws IOException {
        MetadataValueType value_type = this.readMetadataValueType(fileChannel);
        return this.readMetadataValueOfType(value_type, fileChannel);
    }

    void readHeader(FileChannel fileChannel) throws IOException {
        this.magic = this.readInt(fileChannel);
        if (this.magic != 1179993927) {
            throw new IllegalArgumentException("unsupported header.magic " + this.magic);
        }
        this.version = this.readInt(fileChannel);
        if (!SUPPORTED_GGUF_VERSIONS.contains(this.version)) {
            throw new IllegalArgumentException("unsupported header.version " + this.version);
        }
        this.tensorCount = Math.toIntExact(this.readLong(fileChannel));
        this.metadata_kv_count = Math.toIntExact(this.readLong(fileChannel));
        this.metadata = new HashMap<String, Object>(this.metadata_kv_count);
        for (int i = 0; i < this.metadata_kv_count; ++i) {
            Pair<String, Object> keyValue = this.readKeyValuePair(fileChannel);
            assert (!this.metadata.containsKey(keyValue.first()));
            this.metadata.put(keyValue.first(), keyValue.second());
        }
    }

    private Object readArray(FileChannel fileChannel) throws IOException {
        MetadataValueType value_type = this.readMetadataValueType(fileChannel);
        int len = Math.toIntExact(this.readLong(fileChannel));
        switch (value_type.ordinal()) {
            case 0: 
            case 1: {
                byte[] bytes = new byte[len];
                for (int i = 0; i < len; ++i) {
                    bytes[i] = this.readByte(fileChannel);
                }
                return bytes;
            }
            case 2: 
            case 3: {
                short[] shorts = new short[len];
                for (int i = 0; i < len; ++i) {
                    shorts[i] = this.readShort(fileChannel);
                }
                return shorts;
            }
            case 4: 
            case 5: {
                int[] ints = new int[len];
                for (int i = 0; i < len; ++i) {
                    ints[i] = this.readInt(fileChannel);
                }
                return ints;
            }
            case 6: {
                float[] floats = new float[len];
                for (int i = 0; i < len; ++i) {
                    floats[i] = this.readFloat(fileChannel);
                }
                return floats;
            }
            case 7: {
                boolean[] booleans = new boolean[len];
                for (int i = 0; i < len; ++i) {
                    booleans[i] = this.readBoolean(fileChannel);
                }
                return booleans;
            }
            case 8: {
                String[] strings = new String[len];
                for (int i = 0; i < len; ++i) {
                    strings[i] = this.readString(fileChannel);
                }
                return strings;
            }
            case 9: {
                Object[] arrays = new Object[len];
                for (int i = 0; i < len; ++i) {
                    arrays[i] = this.readArray(fileChannel);
                }
                return arrays;
            }
        }
        throw new UnsupportedOperationException("read array of " + String.valueOf((Object)value_type));
    }

    private Object readMetadataValueOfType(MetadataValueType valueType, FileChannel fileChannel) throws IOException {
        switch (valueType.ordinal()) {
            case 0: 
            case 1: {
                return this.readByte(fileChannel);
            }
            case 2: 
            case 3: {
                return this.readShort(fileChannel);
            }
            case 4: 
            case 5: {
                return this.readInt(fileChannel);
            }
            case 6: {
                return Float.valueOf(this.readFloat(fileChannel));
            }
            case 10: 
            case 11: {
                return this.readLong(fileChannel);
            }
            case 12: {
                return this.readDouble(fileChannel);
            }
            case 7: {
                return this.readBoolean(fileChannel);
            }
            case 8: {
                return this.readString(fileChannel);
            }
            case 9: {
                return this.readArray(fileChannel);
            }
        }
        throw new AssertionError();
    }

    private byte readByte(FileChannel fileChannel) throws IOException {
        int bytesRead = fileChannel.read(this.BB_1);
        assert (bytesRead == 1);
        return this.BB_1.clear().get(0);
    }

    private boolean readBoolean(FileChannel fileChannel) throws IOException {
        return this.readByte(fileChannel) != 0;
    }

    private short readShort(FileChannel fileChannel) throws IOException {
        int bytesRead = fileChannel.read(this.BB_2);
        assert (bytesRead == 2);
        return this.BB_2.clear().getShort(0);
    }

    private int readInt(FileChannel fileChannel) throws IOException {
        int bytesRead = fileChannel.read(this.BB_4);
        assert (bytesRead == 4);
        return this.BB_4.clear().getInt(0);
    }

    private long readLong(FileChannel fileChannel) throws IOException {
        int bytesRead = fileChannel.read(this.BB_8);
        assert (bytesRead == 8);
        return this.BB_8.clear().getLong(0);
    }

    private float readFloat(FileChannel fileChannel) throws IOException {
        return Float.intBitsToFloat(this.readInt(fileChannel));
    }

    private double readDouble(FileChannel fileChannel) throws IOException {
        return Double.longBitsToDouble(this.readLong(fileChannel));
    }

    private MetadataValueType readMetadataValueType(FileChannel fileChannel) throws IOException {
        int index2 = this.readInt(fileChannel);
        return MetadataValueType.fromIndex(index2);
    }

    public int getAlignment() {
        if (this.alignment != 0) {
            return this.alignment;
        }
        this.alignment = (Integer)this.metadata.getOrDefault("general.alignment", 32);
        assert (Integer.bitCount(this.alignment) == 1) : "alignment must be a power of two";
        return this.alignment;
    }

    public static final class GGUFTensorInfo {
        private final String name;
        private final int[] dimensions;
        private final GGMLType ggmlType;
        private final long offset;

        public GGUFTensorInfo(String name, int[] dimensions, GGMLType ggmlType, long offset) {
            this.name = name;
            this.dimensions = dimensions != null ? (int[])dimensions.clone() : null;
            this.ggmlType = ggmlType;
            this.offset = offset;
        }

        public String name() {
            return this.name;
        }

        public int[] dimensions() {
            return this.dimensions != null ? (int[])this.dimensions.clone() : null;
        }

        public GGMLType ggmlType() {
            return this.ggmlType;
        }

        public long offset() {
            return this.offset;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            GGUFTensorInfo that = (GGUFTensorInfo)o;
            return this.offset == that.offset && Objects.equals(this.name, that.name) && Arrays.equals(this.dimensions, that.dimensions) && Objects.equals((Object)this.ggmlType, (Object)that.ggmlType);
        }

        public int hashCode() {
            int result = Objects.hash(new Object[]{this.name, this.ggmlType, this.offset});
            result = 31 * result + Arrays.hashCode(this.dimensions);
            return result;
        }

        public String toString() {
            return "GGUFTensorInfo[name=" + this.name + ", dimensions=" + Arrays.toString(this.dimensions) + ", ggmlType=" + String.valueOf((Object)this.ggmlType) + ", offset=" + this.offset + "]";
        }
    }

    private static class MappedSegment {
        final long startOffset;
        final MappedByteBuffer buffer;

        MappedSegment(long startOffset, MappedByteBuffer buffer) {
            this.startOffset = startOffset;
            this.buffer = buffer;
        }
    }

    static enum MetadataValueType {
        UINT8(1),
        INT8(1),
        UINT16(2),
        INT16(2),
        UINT32(4),
        INT32(4),
        FLOAT32(4),
        BOOL(1),
        STRING(-8),
        ARRAY(-8),
        UINT64(8),
        INT64(8),
        FLOAT64(8);

        private final int byteSize;
        private static final MetadataValueType[] VALUES;

        private MetadataValueType(int byteSize) {
            this.byteSize = byteSize;
        }

        public static MetadataValueType fromIndex(int index2) {
            return VALUES[index2];
        }

        public int byteSize() {
            return this.byteSize;
        }

        static {
            VALUES = MetadataValueType.values();
        }
    }
}

