URL scheme whitelist and Inline Queries. #21

Merged
skobkin merged 3 commits from fix_chat_type_middleware_nil_pointer into main 2024-03-12 22:20:06 +00:00
6 changed files with 155 additions and 23 deletions
Showing only changes of commit 7bb5c65d59 - Show all commits

View file

@ -74,18 +74,109 @@ func (b *Bot) Run() error {
// Middlewares // Middlewares
bh.Use(b.chatTypeStatsCounter) bh.Use(b.chatTypeStatsCounter)
// Handlers // Command handlers
bh.Handle(b.startHandler, th.CommandEqual("start")) bh.Handle(b.startHandler, th.CommandEqual("start"))
bh.Handle(b.heyHandler, th.CommandEqual("hey")) bh.Handle(b.heyHandler, th.CommandEqual("hey"))
bh.Handle(b.summarizeHandler, th.CommandEqual("summarize")) bh.Handle(b.summarizeHandler, th.CommandEqual("summarize"))
bh.Handle(b.statsHandler, th.CommandEqual("stats")) bh.Handle(b.statsHandler, th.CommandEqual("stats"))
bh.Handle(b.helpHandler, th.CommandEqual("help")) bh.Handle(b.helpHandler, th.CommandEqual("help"))
// Inline query handlers
bh.Handle(b.inlineHandler, th.AnyInlineQuery())
bh.Start() bh.Start()
return nil return nil
} }
func (b *Bot) inlineHandler(bot *telego.Bot, update telego.Update) {
iq := update.InlineQuery
slog.Info("inline query received", "query", iq.Query)
slog.Debug("query", "query", iq)
if len(iq.Query) < 3 {
return
}
b.stats.InlineQuery()
queryParts := strings.SplitN(iq.Query, " ", 2)
if len(queryParts) < 1 {
slog.Debug("Empty query. Skipping.")
return
}
var response *telego.AnswerInlineQueryParams
switch isValidAndAllowedUrl(queryParts[0]) {
case true:
slog.Info("Inline /summarize request", "url", queryParts[0])
b.stats.SummarizeRequest()
article, err := b.extractor.GetArticleFromUrl(queryParts[0])
if err != nil {
slog.Error("Cannot retrieve an article using extractor", "error", err)
}
llmReply, err := b.llm.Summarize(article.Text, llm.ModelMistralUncensored)
if err != nil {
slog.Error("Cannot get reply from LLM connector")
b.trySendInlineQueryError(iq, "LLM request error. Try again later.")
return
}
slog.Debug("Got completion. Going to send.", "llm-completion", llmReply)
response = tu.InlineQuery(
iq.ID,
tu.ResultArticle(
"reply_"+iq.ID,
"Summary for "+queryParts[0],
tu.TextMessage(b.escapeMarkdownV1Symbols(llmReply)).WithParseMode("Markdown"),
),
)
case false:
b.stats.HeyRequest()
slog.Info("Inline /hey request", "text", iq.Query)
requestContext := createLlmRequestContextFromUpdate(update)
llmReply, err := b.llm.HandleSingleRequest(iq.Query, llm.ModelMistralUncensored, requestContext)
if err != nil {
slog.Error("Cannot get reply from LLM connector")
b.trySendInlineQueryError(iq, "LLM request error. Try again later.")
return
}
slog.Debug("Got completion. Going to send.", "llm-completion", llmReply)
response = tu.InlineQuery(
iq.ID,
tu.ResultArticle(
"reply_"+iq.ID,
"LLM reply to\""+iq.Query+"\"",
tu.TextMessage(b.escapeMarkdownV1Symbols(llmReply)).WithParseMode("Markdown"),
),
)
}
err := bot.AnswerInlineQuery(response)
if err != nil {
slog.Error("Can't answer to inline query", "error", err)
b.trySendInlineQueryError(iq, "Couldn't send intended reply, sorry")
}
}
func (b *Bot) heyHandler(bot *telego.Bot, update telego.Update) { func (b *Bot) heyHandler(bot *telego.Bot, update telego.Update) {
slog.Info("/hey", "message-text", update.Message.Text) slog.Info("/hey", "message-text", update.Message.Text)
@ -101,7 +192,7 @@ func (b *Bot) heyHandler(bot *telego.Bot, update telego.Update) {
b.sendTyping(chatID) b.sendTyping(chatID)
requestContext := b.createLlmRequestContext(update) requestContext := createLlmRequestContextFromUpdate(update)
llmReply, err := b.llm.HandleSingleRequest(userMessage, llm.ModelMistralUncensored, requestContext) llmReply, err := b.llm.HandleSingleRequest(userMessage, llm.ModelMistralUncensored, requestContext)
if err != nil { if err != nil {
@ -115,7 +206,7 @@ func (b *Bot) heyHandler(bot *telego.Bot, update telego.Update) {
return return
} }
slog.Debug("Got completion. Going to send.", "llm-reply", llmReply) slog.Debug("Got completion. Going to send.", "llm-completion", llmReply)
message := tu.Message( message := tu.Message(
chatID, chatID,
@ -139,7 +230,7 @@ func (b *Bot) summarizeHandler(bot *telego.Bot, update telego.Update) {
b.sendTyping(chatID) b.sendTyping(chatID)
args := strings.Split(update.Message.Text, " ") args := strings.SplitN(update.Message.Text, " ", 2)
if len(args) < 2 { if len(args) < 2 {
_, _ = bot.SendMessage(tu.Message( _, _ = bot.SendMessage(tu.Message(
@ -180,7 +271,7 @@ func (b *Bot) summarizeHandler(bot *telego.Bot, update telego.Update) {
return return
} }
slog.Debug("Got completion. Going to send.", "llm-reply", llmReply) slog.Debug("Got completion. Going to send.", "llm-completion", llmReply)
message := tu.Message( message := tu.Message(
chatID, chatID,

View file

@ -2,7 +2,7 @@ package bot
import ( import (
"github.com/mymmrac/telego" "github.com/mymmrac/telego"
"github.com/mymmrac/telego/telegoutil" tu "github.com/mymmrac/telego/telegoutil"
"log/slog" "log/slog"
"net/url" "net/url"
"slices" "slices"
@ -22,7 +22,7 @@ func (b *Bot) reply(originalMessage *telego.Message, newMessage *telego.SendMess
func (b *Bot) sendTyping(chatId telego.ChatID) { func (b *Bot) sendTyping(chatId telego.ChatID) {
slog.Debug("Setting 'typing' chat action") slog.Debug("Setting 'typing' chat action")
err := b.api.SendChatAction(telegoutil.ChatAction(chatId, "typing")) err := b.api.SendChatAction(tu.ChatAction(chatId, "typing"))
if err != nil { if err != nil {
slog.Error("Cannot set chat action", "error", err) slog.Error("Cannot set chat action", "error", err)
} }
@ -33,12 +33,27 @@ func (b *Bot) trySendReplyError(message *telego.Message) {
return return
} }
_, _ = b.api.SendMessage(b.reply(message, telegoutil.Message( _, _ = b.api.SendMessage(b.reply(message, tu.Message(
telegoutil.ID(message.Chat.ID), tu.ID(message.Chat.ID),
"Error occurred while trying to send reply.", "Error occurred while trying to send reply.",
))) )))
} }
func (b *Bot) trySendInlineQueryError(iq *telego.InlineQuery, text string) {
if iq == nil {
return
}
_ = b.api.AnswerInlineQuery(tu.InlineQuery(
iq.ID,
tu.ResultArticle(
string("error_"+iq.ID),
"Error: "+text,
tu.TextMessage(text),
),
))
}
func isValidAndAllowedUrl(text string) bool { func isValidAndAllowedUrl(text string) bool {
u, err := url.ParseRequestURI(text) u, err := url.ParseRequestURI(text)
if err != nil { if err != nil {

View file

@ -6,7 +6,7 @@ import (
"telegram-ollama-reply-bot/llm" "telegram-ollama-reply-bot/llm"
) )
func (b *Bot) createLlmRequestContext(update telego.Update) llm.RequestContext { func createLlmRequestContextFromUpdate(update telego.Update) llm.RequestContext {
message := update.Message message := update.Message
iq := update.InlineQuery iq := update.InlineQuery

View file

@ -30,15 +30,20 @@ func NewConnector(baseUrl string, token string) *LlmConnector {
} }
func (l *LlmConnector) HandleSingleRequest(text string, model string, requestContext RequestContext) (string, error) { func (l *LlmConnector) HandleSingleRequest(text string, model string, requestContext RequestContext) (string, error) {
systemPrompt := "You're a bot in the Telegram chat. " +
"You're using a free model called \"" + model + "\". " +
"You see only messages addressed to you using commands due to privacy settings."
if !requestContext.Empty {
systemPrompt += " " + requestContext.Prompt()
}
req := openai.ChatCompletionRequest{ req := openai.ChatCompletionRequest{
Model: model, Model: model,
Messages: []openai.ChatCompletionMessage{ Messages: []openai.ChatCompletionMessage{
{ {
Role: openai.ChatMessageRoleSystem, Role: openai.ChatMessageRoleSystem,
Content: "You're a bot in the Telegram chat. " + Content: systemPrompt,
"You're using a free model called \"" + model + "\". " +
"You see only messages addressed to you using commands due to privacy settings. " +
requestContext.Prompt(),
}, },
}, },
} }

View file

@ -1,6 +1,8 @@
package llm package llm
type RequestContext struct { type RequestContext struct {
Empty bool
Inline bool
User UserContext User UserContext
Chat ChatContext Chat ChatContext
} }
@ -19,7 +21,13 @@ type ChatContext struct {
} }
func (c RequestContext) Prompt() string { func (c RequestContext) Prompt() string {
prompt := "The type of chat you're in is \"" + c.Chat.Type + "\". " if c.Empty {
return ""
}
prompt := ""
if !c.Inline {
prompt += "The type of chat you're in is \"" + c.Chat.Type + "\". "
if c.Chat.Title != "" { if c.Chat.Title != "" {
prompt += "Chat is called \"" + c.Chat.Title + "\". " prompt += "Chat is called \"" + c.Chat.Title + "\". "
@ -27,6 +35,9 @@ func (c RequestContext) Prompt() string {
if c.Chat.Description != "" { if c.Chat.Description != "" {
prompt += "Chat description is \"" + c.Chat.Description + "\". " prompt += "Chat description is \"" + c.Chat.Description + "\". "
} }
} else {
prompt += "You're responding to inline query, so you're not in the chat right now. "
}
prompt += "According to their profile, first name of the user who wrote you is \"" + c.User.FirstName + "\". " prompt += "According to their profile, first name of the user who wrote you is \"" + c.User.FirstName + "\". "
if c.User.Username != "" { if c.User.Username != "" {

View file

@ -13,6 +13,7 @@ type Stats struct {
GroupRequests uint64 GroupRequests uint64
PrivateRequests uint64 PrivateRequests uint64
InlineQueries uint64
HeyRequests uint64 HeyRequests uint64
SummarizeRequests uint64 SummarizeRequests uint64
@ -24,6 +25,7 @@ func NewStats() *Stats {
GroupRequests: 0, GroupRequests: 0,
PrivateRequests: 0, PrivateRequests: 0,
InlineQueries: 0,
HeyRequests: 0, HeyRequests: 0,
SummarizeRequests: 0, SummarizeRequests: 0,
@ -36,6 +38,7 @@ func (s *Stats) MarshalJSON() ([]byte, error) {
GroupRequests uint64 `json:"group_requests"` GroupRequests uint64 `json:"group_requests"`
PrivateRequests uint64 `json:"private_requests"` PrivateRequests uint64 `json:"private_requests"`
InlineQueries uint64 `json:"inline_queries"`
HeyRequests uint64 `json:"hey_requests"` HeyRequests uint64 `json:"hey_requests"`
SummarizeRequests uint64 `json:"summarize_requests"` SummarizeRequests uint64 `json:"summarize_requests"`
@ -44,6 +47,7 @@ func (s *Stats) MarshalJSON() ([]byte, error) {
GroupRequests: s.GroupRequests, GroupRequests: s.GroupRequests,
PrivateRequests: s.PrivateRequests, PrivateRequests: s.PrivateRequests,
InlineQueries: s.InlineQueries,
HeyRequests: s.HeyRequests, HeyRequests: s.HeyRequests,
SummarizeRequests: s.SummarizeRequests, SummarizeRequests: s.SummarizeRequests,
@ -59,6 +63,12 @@ func (s *Stats) String() string {
return string(data) return string(data)
} }
func (s *Stats) InlineQuery() {
s.mu.Lock()
defer s.mu.Unlock()
s.InlineQueries++
}
func (s *Stats) GroupRequest() { func (s *Stats) GroupRequest() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()