/*
 * Decompiled with CFR 0.152.
 */
package net.yacy.ai.llama3.Tensor;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import net.yacy.ai.llama3.Model.GGMLType;
import net.yacy.ai.llama3.Tensor.AbstractFloatTensor;
import net.yacy.ai.llama3.Tensor.FloatTensor;

public class DirectBufferFloatTensor
extends AbstractFloatTensor
implements FloatTensor {
    final ByteBuffer byteBuffer;

    public DirectBufferFloatTensor(ByteBuffer bb) {
        if (bb.isDirect()) {
            this.byteBuffer = bb.slice().order(bb.order());
        } else {
            int capacityBytes = bb.remaining();
            ByteBuffer direct = ByteBuffer.allocateDirect(capacityBytes).order(bb.order());
            direct.put(bb.slice());
            direct.flip();
            this.byteBuffer = direct.slice().order(bb.order());
        }
    }

    public DirectBufferFloatTensor(float[] values) {
        int capacityBytes = values.length * 4;
        this.byteBuffer = ByteBuffer.allocateDirect(capacityBytes).order(ByteOrder.nativeOrder());
        for (int i = 0; i < values.length; ++i) {
            this.byteBuffer.putFloat(i << 2, values[i]);
        }
    }

    public static FloatTensor allocate(int ... dims) {
        int numberOfElements = AbstractFloatTensor.numberOfElements(dims);
        int bytesNeeded = numberOfElements * 4;
        ByteBuffer buffer = ByteBuffer.allocateDirect(bytesNeeded).order(ByteOrder.nativeOrder());
        return new DirectBufferFloatTensor(buffer);
    }

    @Override
    public final int size() {
        return this.byteBuffer.capacity() / 4;
    }

    @Override
    public final float getFloat(int index2) {
        return this.byteBuffer.getFloat(index2 << 2);
    }

    @Override
    public final void setFloat(int index2, float value) {
        this.byteBuffer.putFloat(index2 << 2, value);
    }

    @Override
    public final GGMLType type() {
        return GGMLType.F32;
    }

    @Override
    public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) {
        float sum0 = 0.0f;
        float sum1 = 0.0f;
        float sum2 = 0.0f;
        float sum3 = 0.0f;
        int limit = size & 0xFFFFFFFC;
        if (that instanceof DirectBufferFloatTensor) {
            DirectBufferFloatTensor thatb = (DirectBufferFloatTensor)that;
            int i = thisOffset << 2;
            int k = thatOffset << 2;
            for (int j = 0; j < limit; j += 4) {
                sum0 += this.byteBuffer.getFloat(i) * thatb.byteBuffer.getFloat(k);
                sum1 += this.byteBuffer.getFloat(i + 4) * thatb.byteBuffer.getFloat(k + 4);
                sum2 += this.byteBuffer.getFloat(i + 8) * thatb.byteBuffer.getFloat(k + 8);
                sum3 += this.byteBuffer.getFloat(i + 12) * thatb.byteBuffer.getFloat(k + 12);
                i += 16;
                k += 16;
            }
        } else {
            int i = thisOffset << 2;
            int k = thatOffset;
            for (int j = 0; j < limit; j += 4) {
                sum0 += this.byteBuffer.getFloat(i) * that.getFloat(k);
                sum1 += this.byteBuffer.getFloat(i + 4) * that.getFloat(k + 1);
                sum2 += this.byteBuffer.getFloat(i + 8) * that.getFloat(k + 2);
                sum3 += this.byteBuffer.getFloat(i + 12) * that.getFloat(k + 3);
                i += 16;
                k += 4;
            }
        }
        float result = sum0 + sum1 + sum2 + sum3;
        for (int j = limit; j < size; ++j) {
            result += this.byteBuffer.getFloat(j << 2) * that.getFloat(j);
        }
        return result;
    }
}

