|
|
@ -3,11 +3,11 @@ import { persist } from "zustand/middleware";
|
|
|
|
|
|
|
|
|
|
|
|
import { trimTopic } from "../utils";
|
|
|
|
import { trimTopic } from "../utils";
|
|
|
|
|
|
|
|
|
|
|
|
import Locale from "../locales";
|
|
|
|
import Locale, { getLang } from "../locales";
|
|
|
|
import { showToast } from "../components/ui-lib";
|
|
|
|
import { showToast } from "../components/ui-lib";
|
|
|
|
import { ModelType } from "./config";
|
|
|
|
import { ModelConfig, ModelType } from "./config";
|
|
|
|
import { createEmptyMask, Mask } from "./mask";
|
|
|
|
import { createEmptyMask, Mask } from "./mask";
|
|
|
|
import { StoreKey } from "../constant";
|
|
|
|
import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
|
|
|
|
import { api, RequestMessage } from "../client/api";
|
|
|
|
import { api, RequestMessage } from "../client/api";
|
|
|
|
import { ChatControllerPool } from "../client/controller";
|
|
|
|
import { ChatControllerPool } from "../client/controller";
|
|
|
|
import { prettyObject } from "../utils/format";
|
|
|
|
import { prettyObject } from "../utils/format";
|
|
|
@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) {
|
|
|
|
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
|
|
|
|
return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function fillTemplateWith(input: string, modelConfig: ModelConfig) {
|
|
|
|
|
|
|
|
const vars = {
|
|
|
|
|
|
|
|
model: modelConfig.model,
|
|
|
|
|
|
|
|
time: new Date().toLocaleString(),
|
|
|
|
|
|
|
|
lang: getLang(),
|
|
|
|
|
|
|
|
input: input,
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// must contains {{input}}
|
|
|
|
|
|
|
|
const inputVar = "{{input}}";
|
|
|
|
|
|
|
|
if (!output.includes(inputVar)) {
|
|
|
|
|
|
|
|
output += "\n" + inputVar;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Object.entries(vars).forEach(([name, value]) => {
|
|
|
|
|
|
|
|
output = output.replaceAll(`{{${name}}}`, value);
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return output;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
export const useChatStore = create<ChatStore>()(
|
|
|
|
export const useChatStore = create<ChatStore>()(
|
|
|
|
persist(
|
|
|
|
persist(
|
|
|
|
(set, get) => ({
|
|
|
|
(set, get) => ({
|
|
|
@ -238,9 +261,12 @@ export const useChatStore = create<ChatStore>()(
|
|
|
|
const session = get().currentSession();
|
|
|
|
const session = get().currentSession();
|
|
|
|
const modelConfig = session.mask.modelConfig;
|
|
|
|
const modelConfig = session.mask.modelConfig;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const userContent = fillTemplateWith(content, modelConfig);
|
|
|
|
|
|
|
|
console.log("[User Input] fill with template: ", userContent);
|
|
|
|
|
|
|
|
|
|
|
|
const userMessage: ChatMessage = createMessage({
|
|
|
|
const userMessage: ChatMessage = createMessage({
|
|
|
|
role: "user",
|
|
|
|
role: "user",
|
|
|
|
content,
|
|
|
|
content: userContent,
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
const botMessage: ChatMessage = createMessage({
|
|
|
|
const botMessage: ChatMessage = createMessage({
|
|
|
@ -250,31 +276,22 @@ export const useChatStore = create<ChatStore>()(
|
|
|
|
model: modelConfig.model,
|
|
|
|
model: modelConfig.model,
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
const systemInfo = createMessage({
|
|
|
|
|
|
|
|
role: "system",
|
|
|
|
|
|
|
|
content: `IMPORTANT: You are a virtual assistant powered by the ${
|
|
|
|
|
|
|
|
modelConfig.model
|
|
|
|
|
|
|
|
} model, now time is ${new Date().toLocaleString()}}`,
|
|
|
|
|
|
|
|
id: botMessage.id! + 1,
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// get recent messages
|
|
|
|
// get recent messages
|
|
|
|
const systemMessages = [];
|
|
|
|
|
|
|
|
// if user define a mask with context prompts, wont send system info
|
|
|
|
|
|
|
|
if (session.mask.context.length === 0) {
|
|
|
|
|
|
|
|
systemMessages.push(systemInfo);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const recentMessages = get().getMessagesWithMemory();
|
|
|
|
const recentMessages = get().getMessagesWithMemory();
|
|
|
|
const sendMessages = systemMessages.concat(
|
|
|
|
const sendMessages = recentMessages.concat(userMessage);
|
|
|
|
recentMessages.concat(userMessage),
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
const sessionIndex = get().currentSessionIndex;
|
|
|
|
const sessionIndex = get().currentSessionIndex;
|
|
|
|
const messageIndex = get().currentSession().messages.length + 1;
|
|
|
|
const messageIndex = get().currentSession().messages.length + 1;
|
|
|
|
|
|
|
|
|
|
|
|
// save user's and bot's message
|
|
|
|
// save user's and bot's message
|
|
|
|
get().updateCurrentSession((session) => {
|
|
|
|
get().updateCurrentSession((session) => {
|
|
|
|
session.messages = session.messages.concat([userMessage, botMessage]);
|
|
|
|
const savedUserMessage = {
|
|
|
|
|
|
|
|
...userMessage,
|
|
|
|
|
|
|
|
content,
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
session.messages = session.messages.concat([
|
|
|
|
|
|
|
|
savedUserMessage,
|
|
|
|
|
|
|
|
botMessage,
|
|
|
|
|
|
|
|
]);
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
// make request
|
|
|
|
// make request
|
|
|
@ -350,55 +367,62 @@ export const useChatStore = create<ChatStore>()(
|
|
|
|
getMessagesWithMemory() {
|
|
|
|
getMessagesWithMemory() {
|
|
|
|
const session = get().currentSession();
|
|
|
|
const session = get().currentSession();
|
|
|
|
const modelConfig = session.mask.modelConfig;
|
|
|
|
const modelConfig = session.mask.modelConfig;
|
|
|
|
|
|
|
|
const clearContextIndex = session.clearContextIndex ?? 0;
|
|
|
|
|
|
|
|
const messages = session.messages.slice();
|
|
|
|
|
|
|
|
const totalMessageCount = session.messages.length;
|
|
|
|
|
|
|
|
|
|
|
|
// wont send cleared context messages
|
|
|
|
// in-context prompts
|
|
|
|
const clearedContextMessages = session.messages.slice(
|
|
|
|
const contextPrompts = session.mask.context.slice();
|
|
|
|
session.clearContextIndex ?? 0,
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
const messages = clearedContextMessages.filter((msg) => !msg.isError);
|
|
|
|
|
|
|
|
const n = messages.length;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const context = session.mask.context.slice();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// long term memory
|
|
|
|
// long term memory
|
|
|
|
if (
|
|
|
|
const shouldSendLongTermMemory =
|
|
|
|
modelConfig.sendMemory &&
|
|
|
|
modelConfig.sendMemory &&
|
|
|
|
session.memoryPrompt &&
|
|
|
|
session.memoryPrompt &&
|
|
|
|
session.memoryPrompt.length > 0
|
|
|
|
session.memoryPrompt.length > 0 &&
|
|
|
|
) {
|
|
|
|
session.lastSummarizeIndex <= clearContextIndex;
|
|
|
|
const memoryPrompt = get().getMemoryPrompt();
|
|
|
|
const longTermMemoryPrompts = shouldSendLongTermMemory
|
|
|
|
context.push(memoryPrompt);
|
|
|
|
? [get().getMemoryPrompt()]
|
|
|
|
}
|
|
|
|
: [];
|
|
|
|
|
|
|
|
const longTermMemoryStartIndex = session.lastSummarizeIndex;
|
|
|
|
// get short term and unmemorized long term memory
|
|
|
|
|
|
|
|
const shortTermMemoryMessageIndex = Math.max(
|
|
|
|
// short term memory
|
|
|
|
|
|
|
|
const shortTermMemoryStartIndex = Math.max(
|
|
|
|
0,
|
|
|
|
0,
|
|
|
|
n - modelConfig.historyMessageCount,
|
|
|
|
totalMessageCount - modelConfig.historyMessageCount,
|
|
|
|
);
|
|
|
|
);
|
|
|
|
const longTermMemoryMessageIndex = session.lastSummarizeIndex;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// try to concat history messages
|
|
|
|
// lets concat send messages, including 4 parts:
|
|
|
|
|
|
|
|
// 1. long term memory: summarized memory messages
|
|
|
|
|
|
|
|
// 2. pre-defined in-context prompts
|
|
|
|
|
|
|
|
// 3. short term memory: latest n messages
|
|
|
|
|
|
|
|
// 4. newest input message
|
|
|
|
const memoryStartIndex = Math.min(
|
|
|
|
const memoryStartIndex = Math.min(
|
|
|
|
shortTermMemoryMessageIndex,
|
|
|
|
longTermMemoryStartIndex,
|
|
|
|
longTermMemoryMessageIndex,
|
|
|
|
shortTermMemoryStartIndex,
|
|
|
|
);
|
|
|
|
);
|
|
|
|
const threshold = modelConfig.max_tokens;
|
|
|
|
// and if user has cleared history messages, we should exclude the memory too.
|
|
|
|
|
|
|
|
const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
|
|
|
|
|
|
|
|
const maxTokenThreshold = modelConfig.max_tokens;
|
|
|
|
|
|
|
|
|
|
|
|
// get recent messages as many as possible
|
|
|
|
// get recent messages as much as possible
|
|
|
|
const reversedRecentMessages = [];
|
|
|
|
const reversedRecentMessages = [];
|
|
|
|
for (
|
|
|
|
for (
|
|
|
|
let i = n - 1, count = 0;
|
|
|
|
let i = totalMessageCount - 1, tokenCount = 0;
|
|
|
|
i >= memoryStartIndex && count < threshold;
|
|
|
|
i >= contextStartIndex && tokenCount < maxTokenThreshold;
|
|
|
|
i -= 1
|
|
|
|
i -= 1
|
|
|
|
) {
|
|
|
|
) {
|
|
|
|
const msg = messages[i];
|
|
|
|
const msg = messages[i];
|
|
|
|
if (!msg || msg.isError) continue;
|
|
|
|
if (!msg || msg.isError) continue;
|
|
|
|
count += estimateTokenLength(msg.content);
|
|
|
|
tokenCount += estimateTokenLength(msg.content);
|
|
|
|
reversedRecentMessages.push(msg);
|
|
|
|
reversedRecentMessages.push(msg);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// concat
|
|
|
|
// concat all messages
|
|
|
|
const recentMessages = context.concat(reversedRecentMessages.reverse());
|
|
|
|
const recentMessages = [
|
|
|
|
|
|
|
|
...longTermMemoryPrompts,
|
|
|
|
|
|
|
|
...contextPrompts,
|
|
|
|
|
|
|
|
...reversedRecentMessages.reverse(),
|
|
|
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
|
|
return recentMessages;
|
|
|
|
return recentMessages;
|
|
|
|
},
|
|
|
|
},
|
|
|
|