/*
 * Decompiled with CFR 0.152.
 */
package xyz.felh.openai.jtokkit;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import xyz.felh.openai.jtokkit.ImmutableByteArray;
import xyz.felh.openai.jtokkit.TokenEncoder;
import xyz.felh.openai.jtokkit.api.Encoding;
import xyz.felh.openai.jtokkit.api.EncodingResult;
import xyz.felh.openai.jtokkit.api.GptBytePairEncodingParams;

final class GptBytePairEncoding
implements Encoding {
    private final String name;
    private final Pattern pattern;
    private final TokenEncoder<ImmutableByteArray, Integer> encoder;
    private final TokenEncoder<String, Integer> specialTokensEncoder;

    GptBytePairEncoding(GptBytePairEncodingParams params) {
        this.name = params.getName();
        this.pattern = params.getPattern();
        this.encoder = new TokenEncoder<ImmutableByteArray, Integer>(params.getEncoder(), ImmutableByteArray::from);
        this.specialTokensEncoder = new TokenEncoder<String, Integer>(params.getSpecialTokensEncoder());
    }

    @Override
    public List<Integer> encode(String text) {
        return this.encodeInternal(text, null).getTokens();
    }

    @Override
    public EncodingResult encode(String text, int maxTokens) {
        return this.encodeInternal(text, maxTokens);
    }

    private EncodingResult encodeInternal(String text, Integer maxTokens) {
        if (text == null) {
            return new EncodingResult(Collections.emptyList(), false);
        }
        for (String specialToken : this.specialTokensEncoder.getDecodedTokens()) {
            if (!text.contains(specialToken)) continue;
            throw new UnsupportedOperationException("Encoding special tokens is not supported yet.");
        }
        return this.encodeOrdinaryInternal(text, maxTokens);
    }

    @Override
    public List<Integer> encodeOrdinary(String text) {
        return this.encodeOrdinaryInternal(text, null).getTokens();
    }

    @Override
    public EncodingResult encodeOrdinary(String text, int maxTokens) {
        return this.encodeOrdinaryInternal(text, maxTokens);
    }

    private EncodingResult encodeOrdinaryInternal(String text, Integer maxTokens) {
        if (text == null) {
            return new EncodingResult(Collections.emptyList(), false);
        }
        ArrayList<Integer> out = new ArrayList<Integer>();
        Matcher matcher = this.pattern.matcher(text);
        int tokenCount = 0;
        while (matcher.find() && this.maxTokenCountNotReached(maxTokens, tokenCount)) {
            ImmutableByteArray match = ImmutableByteArray.from(matcher.group());
            if (this.encoder.containsDecodedToken(match)) {
                out.add(this.encoder.encode(match));
                ++tokenCount;
                continue;
            }
            List<Integer> tokensToAdd = this.bytePairMerge(match);
            tokenCount += this.addTokens(out, tokensToAdd, maxTokens);
        }
        if (maxTokens != null) {
            for (int tokensToRemove = 0; tokensToRemove <= out.size(); ++tokensToRemove) {
                List<Integer> tokens = out.subList(0, out.size() - tokensToRemove);
                String decoded = this.decode(tokens);
                if (!text.startsWith(decoded)) continue;
                return new EncodingResult(tokens, text.length() > decoded.length());
            }
        }
        return new EncodingResult(out, false);
    }

    private int addTokens(List<Integer> out, List<Integer> tokensToAdd, Integer maxTokens) {
        if (maxTokens != null) {
            List<Integer> sublist = tokensToAdd.subList(0, Math.min(maxTokens - out.size(), tokensToAdd.size()));
            out.addAll(sublist);
            return sublist.size();
        }
        out.addAll(tokensToAdd);
        return tokensToAdd.size();
    }

    @Override
    public int countTokens(String text) {
        return this.encode(text).size();
    }

    @Override
    public int countTokensOrdinary(String text) {
        return this.encodeOrdinary(text).size();
    }

    @Override
    public String decode(List<Integer> tokens) {
        return new String(this.decodeBytes(tokens), StandardCharsets.UTF_8);
    }

    @Override
    public byte[] decodeBytes(List<Integer> tokens) {
        ArrayList<Byte> out = new ArrayList<Byte>();
        for (int token : tokens) {
            byte[] decodedToken;
            for (byte b : decodedToken = this.decodeToken(token)) {
                out.add(b);
            }
        }
        byte[] outArray = new byte[out.size()];
        for (int i = 0; i < out.size(); ++i) {
            outArray[i] = (Byte)out.get(i);
        }
        return outArray;
    }

    @Override
    public String getName() {
        return this.name;
    }

    private List<Integer> bytePairMerge(ImmutableByteArray piece) {
        int i;
        ArrayList<PieceIndexToRank> parts = new ArrayList<PieceIndexToRank>();
        for (i = 0; i < piece.length() + 1; ++i) {
            parts.add(new PieceIndexToRank(i, Integer.MAX_VALUE));
        }
        for (i = 0; i < parts.size() - 2; ++i) {
            Optional<Integer> rank = this.getRank(piece, parts, i, 0);
            if (!rank.isPresent()) continue;
            ((PieceIndexToRank)parts.get((int)i)).rank = rank.get();
        }
        while (parts.size() > 1) {
            int minRankIndex = 0;
            int minRank = Integer.MAX_VALUE;
            for (int i2 = 0; i2 < parts.size() - 1; ++i2) {
                int rank = ((PieceIndexToRank)parts.get((int)i2)).rank;
                if (rank >= minRank) continue;
                minRank = rank;
                minRankIndex = i2;
            }
            if (minRank == Integer.MAX_VALUE) break;
            ((PieceIndexToRank)parts.get((int)minRankIndex)).rank = this.getRank(piece, parts, minRankIndex, 1).orElse(Integer.MAX_VALUE);
            if (minRankIndex > 0) {
                ((PieceIndexToRank)parts.get((int)(minRankIndex - 1))).rank = this.getRank(piece, parts, minRankIndex - 1, 1).orElse(Integer.MAX_VALUE);
            }
            parts.remove(minRankIndex + 1);
        }
        ArrayList<Integer> out = new ArrayList<Integer>();
        for (int i3 = 0; i3 < parts.size() - 1; ++i3) {
            out.add(this.encoder.encode(piece.getBytesBetween(((PieceIndexToRank)parts.get((int)i3)).index, ((PieceIndexToRank)parts.get((int)(i3 + 1))).index)));
        }
        return out;
    }

    private boolean maxTokenCountReached(Integer maxTokenCount, int tokenCount) {
        return maxTokenCount != null && maxTokenCount.compareTo(tokenCount) <= 0;
    }

    private boolean maxTokenCountNotReached(Integer maxTokenCount, int tokenCount) {
        return !this.maxTokenCountReached(maxTokenCount, tokenCount);
    }

    private Optional<Integer> getRank(ImmutableByteArray piece, List<PieceIndexToRank> parts, int startIndex, int skip) {
        if (startIndex + skip + 2 >= parts.size()) {
            return Optional.empty();
        }
        int pieceStartIndex = parts.get((int)startIndex).index;
        int pieceEndIndex = parts.get((int)(startIndex + skip + 2)).index;
        ImmutableByteArray encoderIndex = piece.getBytesBetween(pieceStartIndex, pieceEndIndex);
        return this.encoder.encodeIfPresent(encoderIndex);
    }

    private byte[] decodeToken(int token) {
        Optional<ImmutableByteArray> decodedToken = this.encoder.decodeIfPresent(token);
        if (decodedToken.isPresent()) {
            return decodedToken.get().getRawArray();
        }
        Optional<String> decodedSpecialToken = this.specialTokensEncoder.decodeIfPresent(token);
        if (decodedSpecialToken.isPresent()) {
            return decodedSpecialToken.get().getBytes(StandardCharsets.UTF_8);
        }
        throw new IllegalArgumentException("Unknown token for decoding: " + token);
    }

    private static class PieceIndexToRank {
        private final int index;
        private int rank;

        public PieceIndexToRank(int index, int rank) {
            this.index = index;
            this.rank = rank;
        }
    }
}

