package xyz.felh.openai.jtokkit;

import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import xyz.felh.openai.jtokkit.api.Encoding;
import xyz.felh.openai.jtokkit.api.EncodingResult;
import xyz.felh.openai.jtokkit.api.GptBytePairEncodingParams;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:xyz/felh/openai/jtokkit/GptBytePairEncoding.class */
public class GptBytePairEncoding implements Encoding {
    private final String name;
    private final Pattern pattern;
    private final TokenEncoder encoder;
    private final SpecialEncoder specialEncoder;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GptBytePairEncoding(GptBytePairEncodingParams gptBytePairEncodingParams) {
        this.name = gptBytePairEncodingParams.getName();
        this.pattern = gptBytePairEncodingParams.getPattern();
        this.encoder = new TokenEncoder(gptBytePairEncodingParams.getEncoder());
        this.specialEncoder = new SpecialEncoder(gptBytePairEncodingParams.getSpecialTokensEncoder());
    }

    @Override // xyz.felh.openai.jtokkit.api.Encoding
    public List<Integer> encode(String str) {
        return encode(str, TokenEncoder.DUMMY_RANK).getTokens();
    }

    @Override // xyz.felh.openai.jtokkit.api.Encoding
    public EncodingResult encode(String str, int i) {
        return encodeInternal(str, i, true);
    }

    private EncodingResult encodeInternal(String str, int i, boolean z) {
        if (str == null) {
            return new EncodingResult(Collections.emptyList(), -1, false);
        }
        this.specialEncoder.checkForSpecialTokens(str);
        return encodeOrdinaryInternal(str, i, z);
    }

    @Override // xyz.felh.openai.jtokkit.api.Encoding
    public List<Integer> encodeOrdinary(String str) {
        return encodeOrdinary(str, TokenEncoder.DUMMY_RANK).getTokens();
    }

    @Override // xyz.felh.openai.jtokkit.api.Encoding
    public EncodingResult encodeOrdinary(String str, int i) {
        return encodeOrdinaryInternal(str, i, true);
    }

    private EncodingResult encodeOrdinaryInternal(String str, int i, boolean z) {
        if (str == null) {
            return new EncodingResult(Collections.emptyList(), -1, false);
        }
        ArrayList arrayList = new ArrayList();
        int encodeOrdinaryInternal = encodeOrdinaryInternal(str, i, z, arrayList);
        if (z && i != Integer.MAX_VALUE) {
            for (int i2 = 0; i2 <= arrayList.size(); i2++) {
                int size = arrayList.size() - i2;
                ArrayList arrayList2 = new ArrayList(size);
                for (int i3 = 0; i3 < size; i3++) {
                    arrayList2.add(arrayList.get(i3));
                }
                String decode = decode(arrayList2);
                if (str.startsWith(decode)) {
                    return new EncodingResult(arrayList2, -1, str.length() > decode.length());
                }
            }
        }
        return new EncodingResult(arrayList, encodeOrdinaryInternal, false);
    }

    int encodeOrdinaryInternal(String str, int i, boolean z, List<Integer> list) {
        int i2 = 0;
        ArrayList<Integer> arrayList = new ArrayList<>();
        Matcher matcher = this.pattern.matcher(str);
        while (i2 < i && matcher.find()) {
            i2 += this.encoder.addTokensAndGetCount(i, z, matcher.group().getBytes(StandardCharsets.UTF_8), list, arrayList);
        }
        return i2;
    }

    @Override // xyz.felh.openai.jtokkit.api.Encoding
    public int countTokens(String str) {
        return encodeInternal(str, TokenEncoder.DUMMY_RANK, false).getTokenCount();
    }

    @Override // xyz.felh.openai.jtokkit.api.Encoding
    public String decode(List<Integer> list) {
        return new String(decodeBytes(list), StandardCharsets.UTF_8);
    }

    @Override // xyz.felh.openai.jtokkit.api.Encoding
    public byte[] decodeBytes(List<Integer> list) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(10 * list.size());
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            for (byte b : decodeToken(it.next().intValue())) {
                byteArrayOutputStream.write(b);
            }
        }
        return byteArrayOutputStream.toByteArray();
    }

    @Override // xyz.felh.openai.jtokkit.api.Encoding
    public String getName() {
        return this.name;
    }

    private byte[] decodeToken(int i) {
        return (byte[]) Objects.requireNonNull(this.encoder.decodeToken(i, this.specialEncoder), "Unknown token for decoding: " + i);
    }
}
