/*
 * Decompiled with CFR 0.152.
 */
package xyz.felh.openai.jtokkit.utils;

import com.alibaba.fastjson2.JSONObject;
import com.alibaba.fastjson2.JSONWriter;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.Serializable;
import java.net.URL;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.imageio.ImageIO;
import org.apache.commons.lang3.SerializationUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import xyz.felh.openai.chat.ChatCompletion;
import xyz.felh.openai.chat.ChatMessage;
import xyz.felh.openai.chat.ChatMessageRole;
import xyz.felh.openai.chat.CreateChatCompletionRequest;
import xyz.felh.openai.chat.tool.Tool;
import xyz.felh.openai.chat.tool.ToolCall;
import xyz.felh.openai.chat.tool.ToolChoice;
import xyz.felh.openai.chat.tool.Type;
import xyz.felh.openai.jtokkit.Encodings;
import xyz.felh.openai.jtokkit.api.Encoding;
import xyz.felh.openai.jtokkit.api.EncodingRegistry;
import xyz.felh.openai.jtokkit.api.EncodingType;
import xyz.felh.openai.jtokkit.api.ModelType;
import xyz.felh.openai.jtokkit.utils.FunctionFormat;
import xyz.felh.openai.utils.ListUtils;
import xyz.felh.openai.utils.Preconditions;

public class TikTokenUtils {
    private static final Logger log = LoggerFactory.getLogger(TikTokenUtils.class);
    private static final Map<String, Encoding> modelMap = new HashMap<String, Encoding>();
    private static final EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();

    public static List<Integer> encode(Encoding enc, String text) {
        return TikTokenUtils.isBlank(text) ? new ArrayList() : enc.encode(text);
    }

    public static int tokens(Encoding enc, String text) {
        return TikTokenUtils.encode(enc, text).size();
    }

    public static String decode(Encoding enc, List<Integer> encoded) {
        return enc.decode(encoded);
    }

    public static Encoding getEncoding(EncodingType encodingType) {
        return registry.getEncoding(encodingType);
    }

    public static List<Integer> encode(EncodingType encodingType, String text) {
        if (TikTokenUtils.isBlank(text)) {
            return new ArrayList<Integer>();
        }
        Encoding enc = TikTokenUtils.getEncoding(encodingType);
        return enc.encode(text);
    }

    public static int tokens(EncodingType encodingType, String text) {
        return TikTokenUtils.encode(encodingType, text).size();
    }

    public static String decode(EncodingType encodingType, List<Integer> encoded) {
        Encoding enc = TikTokenUtils.getEncoding(encodingType);
        return enc.decode(encoded);
    }

    public static Encoding getEncoding(String modelName) {
        String baseModel;
        Encoding encoding = modelMap.get(modelName);
        if (Preconditions.isBlank((Object)encoding) && modelName.toLowerCase().startsWith("ft:") && Preconditions.isBlank((Object)(encoding = modelMap.get(baseModel = modelName.split(":")[1])))) {
            if (baseModel.toLowerCase().startsWith("gpt-3.5")) {
                encoding = modelMap.get(ModelType.GPT_3_5_TURBO.getName());
            }
            if (baseModel.toLowerCase().startsWith("gpt-4")) {
                encoding = modelMap.get(ModelType.GPT_4.getName());
            }
        }
        return encoding;
    }

    public static List<Integer> encode(String modelName, String text) {
        if (TikTokenUtils.isBlank(text)) {
            return new ArrayList<Integer>();
        }
        Encoding enc = TikTokenUtils.getEncoding(modelName);
        if (Objects.isNull(enc)) {
            log.warn("[{}]\u6a21\u578b\u4e0d\u5b58\u5728\u6216\u8005\u6682\u4e0d\u652f\u6301\u8ba1\u7b97tokens\uff0c\u76f4\u63a5\u8fd4\u56detokens==0", (Object)modelName);
            return new ArrayList<Integer>();
        }
        return enc.encode(text);
    }

    public static int tokens(String modelName, String text) {
        return TikTokenUtils.encode(modelName, text).size();
    }

    public static int estimateTokens(CreateChatCompletionRequest request) {
        List messages = request.getMessages();
        List tools = request.getTools();
        Object toolChoice = request.getToolChoice();
        String chatModel = request.getModel();
        int tokens = 0;
        tokens += TikTokenUtils.estimateTokensInMessages(chatModel, messages, tools);
        if (Preconditions.isNotBlank((Object)tools)) {
            tokens += TikTokenUtils.estimateTokensInTools(chatModel, tools);
        }
        if (Preconditions.isNotBlank((Object)tools) && messages.stream().anyMatch(it -> it.getRole() == ChatMessageRole.SYSTEM)) {
            tokens -= 4;
        }
        if (Preconditions.isNotBlank((Object)toolChoice) && !"auto".equals(toolChoice.toString())) {
            ToolChoice tc;
            if ("none".equals(toolChoice.toString())) {
                ++tokens;
            } else if (toolChoice instanceof ToolChoice && Preconditions.isNotBlank((Object)(tc = (ToolChoice)toolChoice).getFunction().getName())) {
                tokens += TikTokenUtils.tokens(chatModel, tc.getFunction().getName()) + 4;
            }
        }
        return tokens;
    }

    public static int estimateTokensInTools(String modelName, List<Tool> tools) {
        Encoding encoding = TikTokenUtils.getEncoding(modelName);
        int tokens = TikTokenUtils.tokens(encoding, FunctionFormat.formatFunctionDefinitions(tools));
        return tokens += 9;
    }

    public static int estimateTokensInMessages(String modelName, List<ChatMessage> messages) {
        return TikTokenUtils.estimateTokensInMessages(modelName, messages, null);
    }

    public static int estimateTokensInMessages(String modelName, List<ChatMessage> messages, List<Tool> tools) {
        int tokens = 0;
        boolean paddedSystem = false;
        for (ChatMessage message : messages) {
            ChatMessage msg = (ChatMessage)SerializationUtils.clone((Serializable)message);
            if (msg.getRole() == ChatMessageRole.SYSTEM && Preconditions.isNotBlank(tools) && !paddedSystem) {
                if (Preconditions.isNotBlank((Object)msg.getContent()) && msg.getContent() instanceof String) {
                    msg.setContent((Object)(String.valueOf(msg.getContent()) + "\n"));
                }
                paddedSystem = true;
            }
            tokens += TikTokenUtils.estimateTokensInMessage(modelName, msg);
        }
        return tokens += 3;
    }

    public static int estimateTokensInMessage(String modelName, ChatMessage message) {
        Encoding encoding = TikTokenUtils.getEncoding(modelName);
        int tokens = 0;
        tokens += TikTokenUtils.tokens(encoding, message.getRole().value());
        if (message.getContent() instanceof String) {
            tokens += TikTokenUtils.tokens(encoding, message.getContent().toString());
        } else {
            List items = ListUtils.castList((Object)message.getContent(), ChatMessage.ContentItem.class);
            if (Preconditions.isNotBlank((Object)items)) {
                for (ChatMessage.ContentItem item : items) {
                    if (item.getType() == ChatMessage.ContentType.TEXT) {
                        tokens += TikTokenUtils.tokens(encoding, item.getText());
                        continue;
                    }
                    if (item.getType() != ChatMessage.ContentType.IMAGE_URL) continue;
                    ChatMessage.ImageUrl imageUrl = item.getImageUrl();
                    if (imageUrl.getDetail() == ChatMessage.ImageUrlDetail.LOW) {
                        tokens += 85;
                        continue;
                    }
                    if (imageUrl.getDetail() != ChatMessage.ImageUrlDetail.HIGH) continue;
                    tokens += 85;
                    int width = 0;
                    int height = 0;
                    if (imageUrl.getUrl().startsWith("f")) {
                        Base64.Decoder decoder = Base64.getDecoder();
                        try {
                            String b64 = imageUrl.getUrl();
                            b64 = b64.substring(b64.indexOf(";base64,") + 8);
                            b64 = b64.substring(0, b64.length() - 1);
                            byte[] bytes = decoder.decode(b64);
                            ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes);
                            BufferedImage bi = ImageIO.read(inputStream);
                            if (Preconditions.isNotBlank((Object)inputStream)) {
                                inputStream.close();
                            }
                            width = bi.getWidth();
                            height = bi.getHeight();
                        }
                        catch (Exception e) {
                            log.error("image to base64 error", (Throwable)e);
                        }
                    } else {
                        try {
                            BufferedImage bi = ImageIO.read(new URL(imageUrl.getUrl()));
                            width = bi.getWidth();
                            height = bi.getHeight();
                        }
                        catch (IOException e) {
                            throw new RuntimeException(e);
                        }
                    }
                    int tiles = (int)Math.ceil((double)width / 512.0) * (int)Math.ceil((double)height / 512.0);
                    tokens += 170 * tiles;
                }
            }
        }
        if (Preconditions.isNotBlank((Object)message.getName()) && message.getRole() != ChatMessageRole.TOOL) {
            tokens += TikTokenUtils.tokens(encoding, message.getName()) + 1;
        }
        if (message.getRole() == ChatMessageRole.ASSISTANT && Preconditions.isNotBlank((Object)message.getToolCalls())) {
            for (ToolCall toolCall : message.getToolCalls()) {
                tokens += 6;
                tokens += TikTokenUtils.tokens(encoding, toolCall.getType().value());
                if (toolCall.getType() != Type.FUNCTION) continue;
                if (Preconditions.isNotBlank((Object)toolCall.getFunction().getName())) {
                    tokens += TikTokenUtils.tokens(encoding, toolCall.getFunction().getName());
                }
                if (!Preconditions.isNotBlank((Object)toolCall.getFunction().getArguments())) continue;
                String args = JSONObject.toJSONString((Object)JSONObject.parseObject((String)toolCall.getFunction().getArguments()), (JSONWriter.Feature[])new JSONWriter.Feature[]{JSONWriter.Feature.PrettyFormat});
                args = args.replaceAll("\\t", "");
                tokens += TikTokenUtils.tokens(encoding, args);
            }
            if (message.getToolCalls().size() > 1) {
                tokens += 15;
                tokens -= (message.getToolCalls().size() - 1) * 5 - 1;
            } else {
                tokens -= 2;
            }
        }
        tokens = message.getRole() == ChatMessageRole.TOOL ? (tokens += 2) : (tokens += 3);
        return tokens;
    }

    public static int tokens(String modelName, Object functionCall, List<Tool> tools) {
        Encoding encoding = TikTokenUtils.getEncoding(modelName);
        int sum = 0;
        if (Preconditions.isNotBlank((Object)functionCall) && functionCall instanceof JSONObject) {
            sum += TikTokenUtils.tokens(encoding, functionCall.toString());
        }
        sum += TikTokenUtils.tokens(encoding, FunctionFormat.formatFunctionDefinitions(tools));
        return sum += 9;
    }

    public static String decode(String modelName, List<Integer> encoded) {
        Encoding enc = TikTokenUtils.getEncoding(modelName);
        return enc.decode(encoded);
    }

    public static ModelType getModelTypeByName(String name) {
        if (ChatCompletion.Model.GPT_3_5_TURBO.getName().equals(name) || ChatCompletion.Model.GPT_3_5_TURBO_INSTRUCT.getName().equals(name) || ChatCompletion.Model.GPT_3_5_TURBO_1106.getName().equals(name)) {
            return ModelType.GPT_3_5_TURBO;
        }
        if (ChatCompletion.Model.GPT_4.getName().equals(name) || ChatCompletion.Model.GPT_4_32K.getName().equals(name) || ChatCompletion.Model.GPT_4_1106_PREVIEW.getName().equals(name) || ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName().equals(name) || ChatCompletion.Model.GPT_4_0125_PREVIEW.getName().equals(name)) {
            return ModelType.GPT_4;
        }
        for (ModelType modelType : ModelType.values()) {
            if (!modelType.getName().equals(name)) continue;
            return modelType;
        }
        log.warn("[{}]\u6a21\u578b\u4e0d\u5b58\u5728\u6216\u8005\u6682\u4e0d\u652f\u6301\u8ba1\u7b97tokens", (Object)name);
        return null;
    }

    public static boolean isBlankChar(int c) {
        return Character.isWhitespace(c) || Character.isSpaceChar(c) || c == 65279 || c == 8234 || c == 0 || c == 12644 || c == 10240 || c == 6158;
    }

    public static boolean isBlankChar(char c) {
        return TikTokenUtils.isBlankChar((int)c);
    }

    public static boolean isNotBlank(CharSequence str) {
        return !TikTokenUtils.isBlank(str);
    }

    public static boolean isBlank(CharSequence str) {
        int length;
        if (str != null && (length = str.length()) != 0) {
            for (int i = 0; i < length; ++i) {
                if (TikTokenUtils.isBlankChar(str.charAt(i))) continue;
                return false;
            }
            return true;
        }
        return true;
    }

    static {
        for (ModelType modelType : ModelType.values()) {
            modelMap.put(modelType.getName(), registry.getEncodingForModel(modelType));
        }
        modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_1106.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
        modelMap.put(ChatCompletion.Model.GPT_3_5_TURBO_INSTRUCT.getName(), registry.getEncodingForModel(ModelType.GPT_3_5_TURBO));
        modelMap.put(ChatCompletion.Model.GPT_4_32K.getName(), registry.getEncodingForModel(ModelType.GPT_4));
        modelMap.put(ChatCompletion.Model.GPT_4_1106_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4));
        modelMap.put(ChatCompletion.Model.GPT_4_VISION_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4));
        modelMap.put(ChatCompletion.Model.GPT_4_0125_PREVIEW.getName(), registry.getEncodingForModel(ModelType.GPT_4));
    }
}

