package xyz.felh.openai.jtokkit;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;

/* loaded from: input_file:xyz/felh/openai/jtokkit/TokenEncoder.class */
public final class TokenEncoder {
    public static final String VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY = "VERY_LARGE_TOKENIZER_BYTE_THRESHOLD";
    public static final int DUMMY_RANK = Integer.MAX_VALUE;
    public static final int MAX_RANK = 2147483646;
    private final Map<ByteArrayWrapper, Integer>[] encoders;
    private final Map<Integer, byte[]> decoder;
    private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TokenEncoder(Map<byte[], Integer> map) {
        if (map.isEmpty()) {
            this.encoders = new Map[0];
            this.decoder = Collections.emptyMap();
            return;
        }
        this.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = Integer.parseInt(System.getProperty(VERY_LARGE_TOKENIZER_BYTE_THRESHOLD_KEY, "500"));
        TreeMap treeMap = new TreeMap();
        map.forEach((bArr, num) -> {
            ((Map) treeMap.computeIfAbsent(Integer.valueOf(bArr.length), num -> {
                return new HashMap();
            })).put(new ByteArrayWrapper(bArr), num);
        });
        this.encoders = new Map[((Integer) treeMap.lastKey()).intValue() + 1];
        treeMap.forEach((num2, map2) -> {
            this.encoders[num2.intValue()] = map2;
        });
        this.decoder = new HashMap(map.size());
        map.forEach((bArr2, num3) -> {
            this.decoder.put(num3, bArr2);
        });
    }

    private static int getMinRankIndex(List<Integer> list) {
        int i = -1;
        int i2 = 2147483646;
        int i3 = 0;
        int size = list.size() - 3;
        while (i3 < size - 2) {
            int intValue = list.get(i3).intValue();
            if (intValue < i2) {
                i = i3;
                i2 = intValue;
            }
            int intValue2 = list.get(i3 + 1).intValue();
            if (intValue2 < i2) {
                i = i3 + 1;
                i2 = intValue2;
            }
            int intValue3 = list.get(i3 + 2).intValue();
            if (intValue3 < i2) {
                i = i3 + 2;
                i2 = intValue3;
            }
            int intValue4 = list.get(i3 + 3).intValue();
            if (intValue4 < i2) {
                i = i3 + 3;
                i2 = intValue4;
            }
            i3 += 4;
        }
        while (i3 <= size) {
            int intValue5 = list.get(i3).intValue();
            if (intValue5 < i2) {
                i = i3;
                i2 = intValue5;
            }
            i3++;
        }
        return i;
    }

    private static int getNextIndex(List<Integer> list, int i) {
        while (i < list.size() && list.get(i).intValue() == Integer.MAX_VALUE) {
            i++;
        }
        return i;
    }

    private static int getPreviousIndex(List<Integer> list, int i) {
        while (i >= 0 && list.get(i).intValue() == Integer.MAX_VALUE) {
            i--;
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int addTokensAndGetCount(int i, boolean z, byte[] bArr, List<Integer> list, ArrayList<Integer> arrayList) {
        ByteArrayWrapper byteArrayWrapper = new ByteArrayWrapper(bArr);
        int encode = encode(byteArrayWrapper);
        if (encode == 2147483646) {
            int length = byteArrayWrapper.length();
            return length < this.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD ? calculateTokensSmall(i, z, list, arrayList, byteArrayWrapper, length) : TokenEncoderLarge.calculateTokensLarge(this, i, z, list, byteArrayWrapper, length);
        }
        if (!z) {
            return 1;
        }
        list.add(Integer.valueOf(encode));
        return 1;
    }

    private int calculateTokensSmall(int i, boolean z, List<Integer> list, ArrayList<Integer> arrayList, ByteArrayWrapper byteArrayWrapper, int i2) {
        if (!$assertionsDisabled && i2 <= 1) {
            throw new AssertionError("Already filtered out");
        }
        arrayList.clear();
        arrayList.ensureCapacity(i2 + 1);
        int i3 = 0;
        int i4 = -1;
        int i5 = 2147483646;
        for (int i6 = 0; i6 < i2 + 1; i6++) {
            int encode = encode(byteArrayWrapper, i6, i6 + 2);
            if (encode != 2147483646) {
                i3++;
                if (encode < i5) {
                    i4 = i6;
                    i5 = encode;
                }
            }
            arrayList.add(Integer.valueOf(encode));
        }
        int mergeBytesAndGetTokenCount = mergeBytesAndGetTokenCount(byteArrayWrapper, i2, arrayList, i3, i4);
        if (z) {
            int i7 = 0;
            for (int i8 = 1; i8 < arrayList.size() && list.size() < i; i8++) {
                if (arrayList.get(i8).intValue() != Integer.MAX_VALUE) {
                    int encode2 = encode(byteArrayWrapper, i7, i8);
                    if (!$assertionsDisabled && encode2 == 2147483646) {
                        throw new AssertionError("Token should not be MAX_RANK");
                    }
                    list.add(Integer.valueOf(encode2));
                    i7 = i8;
                }
            }
        }
        return mergeBytesAndGetTokenCount;
    }

    int mergeBytesAndGetTokenCount(ByteArrayWrapper byteArrayWrapper, int i, List<Integer> list, int i2, int i3) {
        if (!$assertionsDisabled && getMinRankIndex(list) != i3) {
            throw new AssertionError();
        }
        while (i2 > 0) {
            if (!$assertionsDisabled && i3 < 0) {
                throw new AssertionError();
            }
            int previousIndex = getPreviousIndex(list, i3 - 1);
            int nextIndex = getNextIndex(list, i3 + 1);
            int nextIndex2 = getNextIndex(list, nextIndex + 1);
            int nextIndex3 = getNextIndex(list, nextIndex2 + 1);
            if (previousIndex >= 0) {
                if (!$assertionsDisabled && list.get(previousIndex).intValue() == Integer.MAX_VALUE) {
                    throw new AssertionError();
                }
                int encode = encode(byteArrayWrapper, previousIndex, nextIndex2);
                if ((encode == 2147483646) != (list.set(previousIndex, Integer.valueOf(encode)).intValue() == 2147483646)) {
                    i2 -= encode == 2147483646 ? 1 : -1;
                }
            }
            if (!$assertionsDisabled && list.get(i3).intValue() == Integer.MAX_VALUE) {
                throw new AssertionError();
            }
            int encode2 = encode(byteArrayWrapper, i3, nextIndex3);
            if ((encode2 == 2147483646) != (list.set(i3, Integer.valueOf(encode2)).intValue() == 2147483646)) {
                i2--;
            }
            if (list.set(nextIndex, Integer.valueOf(DUMMY_RANK)).intValue() != 2147483646) {
                i2--;
            }
            i--;
            i3 = getMinRankIndex(list);
        }
        if ($assertionsDisabled || getMinRankIndex(list) < 0) {
            return i;
        }
        throw new AssertionError();
    }

    private int encode(ByteArrayWrapper byteArrayWrapper) {
        Map<ByteArrayWrapper, Integer> map;
        Integer num;
        return (byteArrayWrapper.length() >= this.encoders.length || (map = this.encoders[byteArrayWrapper.length()]) == null || (num = map.get(byteArrayWrapper)) == null) ? MAX_RANK : num.intValue();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int encode(ByteArrayWrapper byteArrayWrapper, int i, int i2) {
        return i2 > byteArrayWrapper.length() ? MAX_RANK : i2 - i == byteArrayWrapper.length() ? encode(byteArrayWrapper) : encode(byteArrayWrapper.getBytesBetween(i, i2));
    }

    public byte[] decodeToken(int i, SpecialEncoder specialEncoder) {
        Map<Integer, byte[]> map = this.decoder;
        Integer valueOf = Integer.valueOf(i);
        Objects.requireNonNull(specialEncoder);
        return map.computeIfAbsent(valueOf, specialEncoder::decodeIfPresent);
    }

    static {
        $assertionsDisabled = !TokenEncoder.class.desiredAssertionStatus();
    }
}
