From 1f9583cb29d0ac105594ee416e66095b8998d8fd Mon Sep 17 00:00:00 2001 From: Alexey Skobkin Date: Mon, 28 Oct 2024 00:35:35 +0300 Subject: [PATCH] #26 Adding in-memory chat history support. Removing inline queries. Refactoring stats and message processing a bit. Also changing LLM request context building a bit. Also adding alias for summarization and some other small changes. --- bot/bot.go | 159 ++++++++++++++--------------------------- bot/helpers.go | 43 ++++++++--- bot/message_history.go | 90 +++++++++++++++++++++++ bot/middleware.go | 28 ++++++-- bot/request_context.go | 60 +++++++++------- llm/llm.go | 27 +++++-- llm/request_context.go | 32 +++++---- stats/stats.go | 12 ++-- 8 files changed, 279 insertions(+), 172 deletions(-) create mode 100644 bot/message_history.go diff --git a/bot/bot.go b/bot/bot.go index 4d7d81f..204cf7b 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -18,12 +18,20 @@ var ( ErrHandlerInit = errors.New("cannot initialize handler") ) +type BotInfo struct { + Id int64 + Username string + Name string +} + type Bot struct { api *telego.Bot llm *llm.LlmConnector extractor *extractor.Extractor stats *stats.Stats models ModelSelection + history map[int64]*MessageRingBuffer + profile BotInfo markdownV1Replacer *strings.Replacer } @@ -40,6 +48,8 @@ func NewBot( extractor: extractor, stats: stats.NewStats(), models: models, + history: make(map[int64]*MessageRingBuffer), + profile: BotInfo{0, "", ""}, markdownV1Replacer: strings.NewReplacer( // https://core.telegram.org/bots/api#markdown-style @@ -61,6 +71,12 @@ func (b *Bot) Run() error { slog.Info("Running api as", "id", botUser.ID, "username", botUser.Username, "name", botUser.FirstName, "is_bot", botUser.IsBot) + b.profile = BotInfo{ + Id: botUser.ID, + Username: botUser.Username, + Name: botUser.FirstName, + } + updates, err := b.api.UpdatesViaLongPolling(nil) if err != nil { slog.Error("Cannot get update channel", "error", err) @@ -79,133 +95,60 @@ func (b *Bot) Run() error { defer b.api.StopLongPolling() // Middlewares + bh.Use(b.chatHistory) bh.Use(b.chatTypeStatsCounter) // Command handlers + bh.Handle(b.textMessageHandler, th.AnyMessageWithText()) bh.Handle(b.startHandler, th.CommandEqual("start")) - bh.Handle(b.heyHandler, th.CommandEqual("hey")) - bh.Handle(b.summarizeHandler, th.CommandEqual("summarize")) + bh.Handle(b.summarizeHandler, th.Or(th.CommandEqual("summarize"), th.CommandEqual("s"))) bh.Handle(b.statsHandler, th.CommandEqual("stats")) bh.Handle(b.helpHandler, th.CommandEqual("help")) - // Inline query handlers - bh.Handle(b.inlineHandler, th.AnyInlineQuery()) - bh.Start() return nil } -func (b *Bot) inlineHandler(bot *telego.Bot, update telego.Update) { - iq := update.InlineQuery - slog.Info("inline query received", "query", iq.Query) +func (b *Bot) textMessageHandler(bot *telego.Bot, update telego.Update) { + slog.Debug("/any-message") - slog.Debug("query", "query", iq) + message := update.Message - if len(iq.Query) < 3 { - return - } - - b.stats.InlineQuery() - - queryParts := strings.SplitN(iq.Query, " ", 2) - - if len(queryParts) < 1 { - slog.Debug("Empty query. Skipping.") - - return - } - - var response *telego.AnswerInlineQueryParams - - switch isValidAndAllowedUrl(queryParts[0]) { - case true: - slog.Info("Inline /summarize request", "url", queryParts[0]) - - b.stats.SummarizeRequest() - - article, err := b.extractor.GetArticleFromUrl(queryParts[0]) - if err != nil { - slog.Error("Cannot retrieve an article using extractor", "error", err) - } - - llmReply, err := b.llm.Summarize(article.Text, b.models.TextRequestModel) - if err != nil { - slog.Error("Cannot get reply from LLM connector") - - b.trySendInlineQueryError(iq, "LLM request error. Try again later.") - - return - } - - slog.Debug("Got completion. Going to send.", "llm-completion", llmReply) - - response = tu.InlineQuery( - iq.ID, - tu.ResultArticle( - "reply_"+iq.ID, - "Summary for "+queryParts[0], - tu.TextMessage(b.escapeMarkdownV1Symbols(llmReply)).WithParseMode("Markdown"), - ), - ) - case false: - b.stats.HeyRequest() - - slog.Info("Inline /hey request", "text", iq.Query) - - requestContext := createLlmRequestContextFromUpdate(update) - - llmReply, err := b.llm.HandleSingleRequest(iq.Query, b.models.TextRequestModel, requestContext) - if err != nil { - slog.Error("Cannot get reply from LLM connector") - - b.trySendInlineQueryError(iq, "LLM request error. Try again later.") - - return - } - - slog.Debug("Got completion. Going to send.", "llm-completion", llmReply) - - response = tu.InlineQuery( - iq.ID, - tu.ResultArticle( - "reply_"+iq.ID, - "LLM reply to\""+iq.Query+"\"", - tu.TextMessage(b.escapeMarkdownV1Symbols(llmReply)).WithParseMode("Markdown"), - ), - ) - } - - err := bot.AnswerInlineQuery(response) - if err != nil { - slog.Error("Can't answer to inline query", "error", err) - - b.trySendInlineQueryError(iq, "Couldn't send intended reply, sorry") + switch { + // Mentions + case b.isMentionOfMe(update): + slog.Info("/any-message", "type", "mention") + b.processMention(message) + // Replies + case b.isReplyToMe(update): + slog.Info("/any-message", "type", "reply") + b.processMention(message) + // Private chat + case b.isPrivateWithMe(update): + slog.Info("/any-message", "type", "private") + b.processMention(message) + default: + slog.Debug("/any-message", "info", "Message is not mention, reply or private chat. Skipping.") } } -func (b *Bot) heyHandler(bot *telego.Bot, update telego.Update) { - slog.Info("/hey", "message-text", update.Message.Text) +func (b *Bot) processMention(message *telego.Message) { + b.stats.Mention() - b.stats.HeyRequest() + slog.Info("/mention", "chat", message.Chat.ID) - parts := strings.SplitN(update.Message.Text, " ", 2) - userMessage := "Hey!" - if len(parts) == 2 { - userMessage = parts[1] - } - - chatID := tu.ID(update.Message.Chat.ID) + chatID := tu.ID(message.Chat.ID) b.sendTyping(chatID) - requestContext := createLlmRequestContextFromUpdate(update) + requestContext := b.createLlmRequestContextFromMessage(message) - llmReply, err := b.llm.HandleSingleRequest(userMessage, b.models.TextRequestModel, requestContext) + llmReply, err := b.llm.HandleChatMessage(message.Text, b.models.TextRequestModel, requestContext) if err != nil { slog.Error("Cannot get reply from LLM connector") - _, _ = b.api.SendMessage(b.reply(update.Message, tu.Message( + _, _ = b.api.SendMessage(b.reply(message, tu.Message( chatID, "LLM request error. Try again later.", ))) @@ -215,17 +158,21 @@ func (b *Bot) heyHandler(bot *telego.Bot, update telego.Update) { slog.Debug("Got completion. Going to send.", "llm-completion", llmReply) - message := tu.Message( + reply := tu.Message( chatID, b.escapeMarkdownV1Symbols(llmReply), ).WithParseMode("Markdown") - _, err = bot.SendMessage(b.reply(update.Message, message)) + _, err = b.api.SendMessage(b.reply(message, reply)) if err != nil { slog.Error("Can't send reply message", "error", err) - b.trySendReplyError(update.Message) + b.trySendReplyError(message) + + return } + + b.saveBotReplyToHistory(message, llmReply) } func (b *Bot) summarizeHandler(bot *telego.Bot, update telego.Update) { @@ -306,7 +253,9 @@ func (b *Bot) helpHandler(bot *telego.Bot, update telego.Update) { "Instructions:\r\n"+ "/hey - Ask something from LLM\r\n"+ "/summarize - Summarize text from the provided link\r\n"+ - "/help - Show this help", + "/s - Shorter version\r\n"+ + "/help - Show this help\r\n\r\n"+ + "Mention bot or reply to it's message to communicate with it", ))) if err != nil { slog.Error("Cannot send a message", "error", err) diff --git a/bot/helpers.go b/bot/helpers.go index 2fff50d..b4065e6 100644 --- a/bot/helpers.go +++ b/bot/helpers.go @@ -39,19 +39,40 @@ func (b *Bot) trySendReplyError(message *telego.Message) { ))) } -func (b *Bot) trySendInlineQueryError(iq *telego.InlineQuery, text string) { - if iq == nil { - return +func (b *Bot) isMentionOfMe(update telego.Update) bool { + if update.Message == nil { + return false } - _ = b.api.AnswerInlineQuery(tu.InlineQuery( - iq.ID, - tu.ResultArticle( - string("error_"+iq.ID), - "Error: "+text, - tu.TextMessage(text), - ), - )) + return strings.Contains(update.Message.Text, "@"+b.profile.Username) +} + +func (b *Bot) isReplyToMe(update telego.Update) bool { + message := update.Message + + if message == nil { + return false + } + if message.ReplyToMessage == nil { + return false + } + if message.ReplyToMessage.From == nil { + return false + } + + replyToMessage := message.ReplyToMessage + + return replyToMessage != nil && replyToMessage.From.ID == b.profile.Id +} + +func (b *Bot) isPrivateWithMe(update telego.Update) bool { + message := update.Message + + if message == nil { + return false + } + + return message.Chat.Type == telego.ChatTypePrivate } func isValidAndAllowedUrl(text string) bool { diff --git a/bot/message_history.go b/bot/message_history.go new file mode 100644 index 0000000..d7c756a --- /dev/null +++ b/bot/message_history.go @@ -0,0 +1,90 @@ +package bot + +import ( + "github.com/mymmrac/telego" + "log/slog" +) + +const HistoryLength = 50 + +type MessageRingBuffer struct { + messages []Message + capacity int +} + +func NewMessageBuffer(capacity int) *MessageRingBuffer { + return &MessageRingBuffer{ + messages: make([]Message, 0, capacity), + capacity: capacity, + } +} + +func (b *MessageRingBuffer) Push(element Message) { + if len(b.messages) >= b.capacity { + b.messages = b.messages[1:] + } + + b.messages = append(b.messages, element) +} + +func (b *MessageRingBuffer) GetAll() []Message { + return b.messages +} + +type Message struct { + Name string + Text string +} + +func (b *Bot) saveChatMessageToHistory(message *telego.Message) { + chatId := message.Chat.ID + + slog.Info( + "history-message-save", + "chat", chatId, + "from_id", message.From.ID, + "from_name", message.From.FirstName, + "text", message.Text, + ) + + _, ok := b.history[chatId] + if !ok { + b.history[chatId] = NewMessageBuffer(HistoryLength) + } + + b.history[chatId].Push(Message{ + Name: message.From.FirstName, + Text: message.Text, + }) +} + +func (b *Bot) saveBotReplyToHistory(message *telego.Message, reply string) { + chatId := message.Chat.ID + + slog.Info( + "history-reply-save", + "chat", chatId, + "to_id", message.From.ID, + "to_name", message.From.FirstName, + "text", reply, + ) + + _, ok := b.history[chatId] + if !ok { + b.history[chatId] = NewMessageBuffer(HistoryLength) + } + + b.history[chatId].Push(Message{ + Name: b.profile.Username, + Text: reply, + }) +} + +func (b *Bot) getChatHistory(chatId int64) []Message { + _, ok := b.history[chatId] + if !ok { + return make([]Message, 0) + } + + return b.history[chatId].GetAll() +} diff --git a/bot/middleware.go b/bot/middleware.go index a1e2204..32d92e4 100644 --- a/bot/middleware.go +++ b/bot/middleware.go @@ -10,21 +10,41 @@ func (b *Bot) chatTypeStatsCounter(bot *telego.Bot, update telego.Update, next t message := update.Message if message == nil { - slog.Info("chat-type-middleware: update has no message. skipping.") + slog.Info("stats-middleware: update has no message. skipping.") next(bot, update) return } - slog.Info("chat-type-middleware: counting message chat type in stats", "type", message.Chat.Type) - switch message.Chat.Type { case telego.ChatTypeGroup, telego.ChatTypeSupergroup: - b.stats.GroupRequest() + if b.isMentionOfMe(update) || b.isReplyToMe(update) { + slog.Info("stats-middleware: counting message chat type in stats", "type", message.Chat.Type) + b.stats.GroupRequest() + } case telego.ChatTypePrivate: + slog.Info("stats-middleware: counting message chat type in stats", "type", message.Chat.Type) b.stats.PrivateRequest() } next(bot, update) } + +func (b *Bot) chatHistory(bot *telego.Bot, update telego.Update, next telegohandler.Handler) { + message := update.Message + + if message == nil { + slog.Info("chat-history-middleware: update has no message. skipping.") + + next(bot, update) + + return + } + + slog.Info("chat-history-middleware: saving message to history for", "chat_id", message.Chat.ID) + + b.saveChatMessageToHistory(message) + + next(bot, update) +} diff --git a/bot/request_context.go b/bot/request_context.go index 93cba79..a045bef 100644 --- a/bot/request_context.go +++ b/bot/request_context.go @@ -6,33 +6,20 @@ import ( "telegram-ollama-reply-bot/llm" ) -func createLlmRequestContextFromUpdate(update telego.Update) llm.RequestContext { - message := update.Message - iq := update.InlineQuery - +func (b *Bot) createLlmRequestContextFromMessage(message *telego.Message) llm.RequestContext { rc := llm.RequestContext{ - Empty: true, - Inline: false, + Empty: true, } - switch { - case message == nil && iq == nil: + if message == nil { slog.Debug("request context creation problem: no message provided. returning empty context.", "request-context", rc) return rc - case iq != nil: - rc.Inline = true } rc.Empty = false - var user *telego.User - - if rc.Inline { - user = &iq.From - } else { - user = message.From - } + user := message.From if user != nil { rc.User = llm.UserContext{ @@ -43,18 +30,39 @@ func createLlmRequestContextFromUpdate(update telego.Update) llm.RequestContext } } - if !rc.Inline { - // TODO: implement retrieval of chat description - chat := message.Chat - rc.Chat = llm.ChatContext{ - Title: chat.Title, - // TODO: fill when ChatFullInfo retrieved - //Description: chat.Description, - Type: chat.Type, - } + // TODO: implement retrieval of chat description + chat := message.Chat + + history := b.getChatHistory(chat.ID) + + rc.Chat = llm.ChatContext{ + Title: chat.Title, + // TODO: fill when ChatFullInfo retrieved + //Description: chat.Description, + Type: chat.Type, + History: historyToLlmMessages(history), } slog.Debug("request context created", "request-context", rc) return rc } + +func historyToLlmMessages(history []Message) []llm.ChatMessage { + length := len(history) + + if length > 0 { + result := make([]llm.ChatMessage, 0, length) + + for _, msg := range history { + result = append(result, llm.ChatMessage{ + Name: msg.Name, + Text: msg.Text, + }) + } + + return result + } + + return make([]llm.ChatMessage, 0) +} diff --git a/llm/llm.go b/llm/llm.go index 7732264..eb7cc5a 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -5,6 +5,8 @@ import ( "errors" "github.com/sashabaranov/go-openai" "log/slog" + "strconv" + "strings" ) var ( @@ -27,13 +29,15 @@ func NewConnector(baseUrl string, token string) *LlmConnector { } } -func (l *LlmConnector) HandleSingleRequest(text string, model string, requestContext RequestContext) (string, error) { +func (l *LlmConnector) HandleChatMessage(text string, model string, requestContext RequestContext) (string, error) { systemPrompt := "You're a bot in the Telegram chat.\n" + - "You're using a free model called \"" + model + "\".\n" + - "Currently you're not able to access chat history, so each message will be replied from a clean slate." + "You're using a free model called \"" + model + "\".\n\n" + + requestContext.Prompt() - if !requestContext.Empty { - systemPrompt += "\n" + requestContext.Prompt() + historyLength := len(requestContext.Chat.History) + + if historyLength > 0 { + systemPrompt += "\nYou have an access to last " + strconv.Itoa(historyLength) + "messages in this chat." } req := openai.ChatCompletionRequest{ @@ -46,6 +50,15 @@ func (l *LlmConnector) HandleSingleRequest(text string, model string, requestCon }, } + if historyLength > 0 { + for _, msg := range requestContext.Chat.History { + req.Messages = append(req.Messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: msg.Name + ":\n\n" + quoteMessage(msg.Text), + }) + } + } + req.Messages = append(req.Messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: text, @@ -140,3 +153,7 @@ func (l *LlmConnector) HasModel(id string) bool { return false } + +func quoteMessage(text string) string { + return "> " + strings.ReplaceAll(text, "\n", "\n> ") +} diff --git a/llm/request_context.go b/llm/request_context.go index 34fa144..d608d6b 100644 --- a/llm/request_context.go +++ b/llm/request_context.go @@ -1,10 +1,9 @@ package llm type RequestContext struct { - Empty bool - Inline bool - User UserContext - Chat ChatContext + Empty bool + User UserContext + Chat ChatContext } type UserContext struct { @@ -18,6 +17,12 @@ type ChatContext struct { Title string Description string Type string + History []ChatMessage +} + +type ChatMessage struct { + Name string + Text string } func (c RequestContext) Prompt() string { @@ -26,20 +31,17 @@ func (c RequestContext) Prompt() string { } prompt := "" - if !c.Inline { - prompt += "The type of chat you're in is \"" + c.Chat.Type + "\". " - if c.Chat.Title != "" { - prompt += "Chat is called \"" + c.Chat.Title + "\". " - } - if c.Chat.Description != "" { - prompt += "Chat description is \"" + c.Chat.Description + "\". " - } - } else { - prompt += "You're responding to inline query, so you're not in the chat right now. " + prompt += "The type of chat you're in is \"" + c.Chat.Type + "\". " + + if c.Chat.Title != "" { + prompt += "Chat is called \"" + c.Chat.Title + "\". " + } + if c.Chat.Description != "" { + prompt += "Chat description is \"" + c.Chat.Description + "\". " } - prompt += "User profile:" + + prompt += "Profile of the user who mentioned you in the chat:" + "First name: \"" + c.User.FirstName + "\"\n" if c.User.Username != "" { prompt += "Username: @" + c.User.Username + ".\n" diff --git a/stats/stats.go b/stats/stats.go index 733d5c6..1409e68 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -15,7 +15,7 @@ type Stats struct { PrivateRequests uint64 InlineQueries uint64 - HeyRequests uint64 + Mentions uint64 SummarizeRequests uint64 } @@ -27,7 +27,7 @@ func NewStats() *Stats { PrivateRequests: 0, InlineQueries: 0, - HeyRequests: 0, + Mentions: 0, SummarizeRequests: 0, } } @@ -40,7 +40,7 @@ func (s *Stats) MarshalJSON() ([]byte, error) { PrivateRequests uint64 `json:"private_requests"` InlineQueries uint64 `json:"inline_queries"` - HeyRequests uint64 `json:"hey_requests"` + Mentions uint64 `json:"mentions"` SummarizeRequests uint64 `json:"summarize_requests"` }{ Uptime: time.Now().Sub(s.RunningSince).String(), @@ -49,7 +49,7 @@ func (s *Stats) MarshalJSON() ([]byte, error) { PrivateRequests: s.PrivateRequests, InlineQueries: s.InlineQueries, - HeyRequests: s.HeyRequests, + Mentions: s.Mentions, SummarizeRequests: s.SummarizeRequests, }) } @@ -81,10 +81,10 @@ func (s *Stats) PrivateRequest() { s.PrivateRequests++ } -func (s *Stats) HeyRequest() { +func (s *Stats) Mention() { s.mu.Lock() defer s.mu.Unlock() - s.HeyRequests++ + s.Mentions++ } func (s *Stats) SummarizeRequest() { -- 2.43.5