From ac44c1360a31a23412ae3a29b5ea2d07a7aa7327 Mon Sep 17 00:00:00 2001 From: Alexey Skobkin Date: Fri, 16 Aug 2024 03:07:47 +0300 Subject: [PATCH] Allowing selection of models using ENV. --- Dockerfile | 7 +++++++ bot/bot.go | 17 ++++++++++++----- bot/models.go | 6 ++++++ bot/request_context.go | 8 +++++--- llm/llm.go | 37 ++++++++++++++++++++++++++++++++++--- main.go | 27 +++++++++++++++++++++++---- 6 files changed, 87 insertions(+), 15 deletions(-) create mode 100644 bot/models.go diff --git a/Dockerfile b/Dockerfile index 048399a..57e1e41 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,4 +13,11 @@ WORKDIR /app COPY --from=builder /build/app . +# Do not forget "/v1" in the end +ENV OPENAI_API_BASE_URL="" \ + OPENAI_API_TOKEN="" \ + TELEGRAM_TOKEN="" \ + MODEL_TEXT_REQUEST="llama3.1:8b-instruct-q6_K" \ + MODEL_SUMMARIZE_REQUEST="llama3.1:8b-instruct-q6_K" + CMD ["/app/app"] diff --git a/bot/bot.go b/bot/bot.go index 6810eee..4d7d81f 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -23,16 +23,23 @@ type Bot struct { llm *llm.LlmConnector extractor *extractor.Extractor stats *stats.Stats + models ModelSelection markdownV1Replacer *strings.Replacer } -func NewBot(api *telego.Bot, llm *llm.LlmConnector, extractor *extractor.Extractor) *Bot { +func NewBot( + api *telego.Bot, + llm *llm.LlmConnector, + extractor *extractor.Extractor, + models ModelSelection, +) *Bot { return &Bot{ api: api, llm: llm, extractor: extractor, stats: stats.NewStats(), + models: models, markdownV1Replacer: strings.NewReplacer( // https://core.telegram.org/bots/api#markdown-style @@ -122,7 +129,7 @@ func (b *Bot) inlineHandler(bot *telego.Bot, update telego.Update) { slog.Error("Cannot retrieve an article using extractor", "error", err) } - llmReply, err := b.llm.Summarize(article.Text, llm.ModelLlama3Uncensored ) + llmReply, err := b.llm.Summarize(article.Text, b.models.TextRequestModel) if err != nil { slog.Error("Cannot get reply from LLM connector") @@ -148,7 +155,7 @@ func (b *Bot) inlineHandler(bot *telego.Bot, update telego.Update) { requestContext := createLlmRequestContextFromUpdate(update) - llmReply, err := b.llm.HandleSingleRequest(iq.Query, llm.ModelLlama3Uncensored, requestContext) + llmReply, err := b.llm.HandleSingleRequest(iq.Query, b.models.TextRequestModel, requestContext) if err != nil { slog.Error("Cannot get reply from LLM connector") @@ -194,7 +201,7 @@ func (b *Bot) heyHandler(bot *telego.Bot, update telego.Update) { requestContext := createLlmRequestContextFromUpdate(update) - llmReply, err := b.llm.HandleSingleRequest(userMessage, llm.ModelLlama3Uncensored, requestContext) + llmReply, err := b.llm.HandleSingleRequest(userMessage, b.models.TextRequestModel, requestContext) if err != nil { slog.Error("Cannot get reply from LLM connector") @@ -259,7 +266,7 @@ func (b *Bot) summarizeHandler(bot *telego.Bot, update telego.Update) { slog.Error("Cannot retrieve an article using extractor", "error", err) } - llmReply, err := b.llm.Summarize(article.Text, llm.ModelMistralUncensored) + llmReply, err := b.llm.Summarize(article.Text, b.models.SummarizeModel) if err != nil { slog.Error("Cannot get reply from LLM connector") diff --git a/bot/models.go b/bot/models.go new file mode 100644 index 0000000..93259d0 --- /dev/null +++ b/bot/models.go @@ -0,0 +1,6 @@ +package bot + +type ModelSelection struct { + TextRequestModel string + SummarizeModel string +} diff --git a/bot/request_context.go b/bot/request_context.go index 969abf6..93cba79 100644 --- a/bot/request_context.go +++ b/bot/request_context.go @@ -44,11 +44,13 @@ func createLlmRequestContextFromUpdate(update telego.Update) llm.RequestContext } if !rc.Inline { + // TODO: implement retrieval of chat description chat := message.Chat rc.Chat = llm.ChatContext{ - Title: chat.Title, - Description: chat.Description, - Type: chat.Type, + Title: chat.Title, + // TODO: fill when ChatFullInfo retrieved + //Description: chat.Description, + Type: chat.Type, } } diff --git a/llm/llm.go b/llm/llm.go index 73affe8..748cb8b 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -10,9 +10,6 @@ import ( var ( ErrLlmBackendRequestFailed = errors.New("llm back-end request failed") ErrNoChoices = errors.New("no choices in LLM response") - - ModelMistralUncensored = "dolphin-mistral:7b-v2.8-q4_K_M" - ModelLlama3Uncensored = "dolphin-llama3:8b-v2.9-q4_K_M" ) type LlmConnector struct { @@ -108,3 +105,37 @@ func (l *LlmConnector) Summarize(text string, model string) (string, error) { return resp.Choices[0].Message.Content, nil } + +func (l *LlmConnector) GetModels() []string { + var result []string + + models, err := l.client.ListModels(context.Background()) + if err != nil { + slog.Error("llm: Model list request failed", "error", err) + + return result + } + + slog.Info("Model list retrieved", "models", models) + + for _, model := range models.Models { + result = append(result, model.ID) + } + + return result +} + +func (l *LlmConnector) HasModel(id string) bool { + model, err := l.client.GetModel(context.Background(), id) + if err != nil { + slog.Error("llm: Model request failed", "error", err) + } + + slog.Debug("llm: Returned model", "model", model) + + if model.ID != "" { + return true + } + + return false +} diff --git a/main.go b/main.go index 6611c9a..84d3f24 100644 --- a/main.go +++ b/main.go @@ -12,12 +12,31 @@ import ( ) func main() { - ollamaToken := os.Getenv("OLLAMA_TOKEN") - ollamaBaseUrl := os.Getenv("OLLAMA_BASE_URL") + apiToken := os.Getenv("OPENAI_API_TOKEN") + apiBaseUrl := os.Getenv("OPENAI_API_BASE_URL") + + models := bot.ModelSelection{ + TextRequestModel: os.Getenv("MODEL_TEXT_REQUEST"), + SummarizeModel: os.Getenv("MODEL_SUMMARIZE_REQUEST"), + } + + slog.Info("Selected", "models", models) telegramToken := os.Getenv("TELEGRAM_TOKEN") - llmc := llm.NewConnector(ollamaBaseUrl, ollamaToken) + llmc := llm.NewConnector(apiBaseUrl, apiToken) + + slog.Info("Checking models availability") + + for _, model := range []string{models.TextRequestModel, models.SummarizeModel} { + if !llmc.HasModel(model) { + slog.Error("Model not unavailable", "model", model) + os.Exit(1) + } + } + + slog.Info("All needed models are available") + ext := extractor.NewExtractor() telegramApi, err := tg.NewBot(telegramToken, tg.WithLogger(bot.NewLogger("telego: "))) @@ -26,7 +45,7 @@ func main() { os.Exit(1) } - botService := bot.NewBot(telegramApi, llmc, ext) + botService := bot.NewBot(telegramApi, llmc, ext, models) err = botService.Run() if err != nil {