Allowing selection of models using ENV.

This commit is contained in:
Alexey Skobkin 2024-08-16 03:07:47 +03:00
parent 82c4d953d4
commit ac44c1360a
No known key found for this signature in database
GPG key ID: 5D5CEF6F221278E7
6 changed files with 87 additions and 15 deletions

View file

@ -13,4 +13,11 @@ WORKDIR /app
COPY --from=builder /build/app . COPY --from=builder /build/app .
# Do not forget "/v1" in the end
ENV OPENAI_API_BASE_URL="" \
OPENAI_API_TOKEN="" \
TELEGRAM_TOKEN="" \
MODEL_TEXT_REQUEST="llama3.1:8b-instruct-q6_K" \
MODEL_SUMMARIZE_REQUEST="llama3.1:8b-instruct-q6_K"
CMD ["/app/app"] CMD ["/app/app"]

View file

@ -23,16 +23,23 @@ type Bot struct {
llm *llm.LlmConnector llm *llm.LlmConnector
extractor *extractor.Extractor extractor *extractor.Extractor
stats *stats.Stats stats *stats.Stats
models ModelSelection
markdownV1Replacer *strings.Replacer markdownV1Replacer *strings.Replacer
} }
func NewBot(api *telego.Bot, llm *llm.LlmConnector, extractor *extractor.Extractor) *Bot { func NewBot(
api *telego.Bot,
llm *llm.LlmConnector,
extractor *extractor.Extractor,
models ModelSelection,
) *Bot {
return &Bot{ return &Bot{
api: api, api: api,
llm: llm, llm: llm,
extractor: extractor, extractor: extractor,
stats: stats.NewStats(), stats: stats.NewStats(),
models: models,
markdownV1Replacer: strings.NewReplacer( markdownV1Replacer: strings.NewReplacer(
// https://core.telegram.org/bots/api#markdown-style // https://core.telegram.org/bots/api#markdown-style
@ -122,7 +129,7 @@ func (b *Bot) inlineHandler(bot *telego.Bot, update telego.Update) {
slog.Error("Cannot retrieve an article using extractor", "error", err) slog.Error("Cannot retrieve an article using extractor", "error", err)
} }
llmReply, err := b.llm.Summarize(article.Text, llm.ModelLlama3Uncensored ) llmReply, err := b.llm.Summarize(article.Text, b.models.TextRequestModel)
if err != nil { if err != nil {
slog.Error("Cannot get reply from LLM connector") slog.Error("Cannot get reply from LLM connector")
@ -148,7 +155,7 @@ func (b *Bot) inlineHandler(bot *telego.Bot, update telego.Update) {
requestContext := createLlmRequestContextFromUpdate(update) requestContext := createLlmRequestContextFromUpdate(update)
llmReply, err := b.llm.HandleSingleRequest(iq.Query, llm.ModelLlama3Uncensored, requestContext) llmReply, err := b.llm.HandleSingleRequest(iq.Query, b.models.TextRequestModel, requestContext)
if err != nil { if err != nil {
slog.Error("Cannot get reply from LLM connector") slog.Error("Cannot get reply from LLM connector")
@ -194,7 +201,7 @@ func (b *Bot) heyHandler(bot *telego.Bot, update telego.Update) {
requestContext := createLlmRequestContextFromUpdate(update) requestContext := createLlmRequestContextFromUpdate(update)
llmReply, err := b.llm.HandleSingleRequest(userMessage, llm.ModelLlama3Uncensored, requestContext) llmReply, err := b.llm.HandleSingleRequest(userMessage, b.models.TextRequestModel, requestContext)
if err != nil { if err != nil {
slog.Error("Cannot get reply from LLM connector") slog.Error("Cannot get reply from LLM connector")
@ -259,7 +266,7 @@ func (b *Bot) summarizeHandler(bot *telego.Bot, update telego.Update) {
slog.Error("Cannot retrieve an article using extractor", "error", err) slog.Error("Cannot retrieve an article using extractor", "error", err)
} }
llmReply, err := b.llm.Summarize(article.Text, llm.ModelMistralUncensored) llmReply, err := b.llm.Summarize(article.Text, b.models.SummarizeModel)
if err != nil { if err != nil {
slog.Error("Cannot get reply from LLM connector") slog.Error("Cannot get reply from LLM connector")

6
bot/models.go Normal file
View file

@ -0,0 +1,6 @@
package bot
type ModelSelection struct {
TextRequestModel string
SummarizeModel string
}

View file

@ -44,11 +44,13 @@ func createLlmRequestContextFromUpdate(update telego.Update) llm.RequestContext
} }
if !rc.Inline { if !rc.Inline {
// TODO: implement retrieval of chat description
chat := message.Chat chat := message.Chat
rc.Chat = llm.ChatContext{ rc.Chat = llm.ChatContext{
Title: chat.Title, Title: chat.Title,
Description: chat.Description, // TODO: fill when ChatFullInfo retrieved
Type: chat.Type, //Description: chat.Description,
Type: chat.Type,
} }
} }

View file

@ -10,9 +10,6 @@ import (
var ( var (
ErrLlmBackendRequestFailed = errors.New("llm back-end request failed") ErrLlmBackendRequestFailed = errors.New("llm back-end request failed")
ErrNoChoices = errors.New("no choices in LLM response") ErrNoChoices = errors.New("no choices in LLM response")
ModelMistralUncensored = "dolphin-mistral:7b-v2.8-q4_K_M"
ModelLlama3Uncensored = "dolphin-llama3:8b-v2.9-q4_K_M"
) )
type LlmConnector struct { type LlmConnector struct {
@ -108,3 +105,37 @@ func (l *LlmConnector) Summarize(text string, model string) (string, error) {
return resp.Choices[0].Message.Content, nil return resp.Choices[0].Message.Content, nil
} }
func (l *LlmConnector) GetModels() []string {
var result []string
models, err := l.client.ListModels(context.Background())
if err != nil {
slog.Error("llm: Model list request failed", "error", err)
return result
}
slog.Info("Model list retrieved", "models", models)
for _, model := range models.Models {
result = append(result, model.ID)
}
return result
}
func (l *LlmConnector) HasModel(id string) bool {
model, err := l.client.GetModel(context.Background(), id)
if err != nil {
slog.Error("llm: Model request failed", "error", err)
}
slog.Debug("llm: Returned model", "model", model)
if model.ID != "" {
return true
}
return false
}

27
main.go
View file

@ -12,12 +12,31 @@ import (
) )
func main() { func main() {
ollamaToken := os.Getenv("OLLAMA_TOKEN") apiToken := os.Getenv("OPENAI_API_TOKEN")
ollamaBaseUrl := os.Getenv("OLLAMA_BASE_URL") apiBaseUrl := os.Getenv("OPENAI_API_BASE_URL")
models := bot.ModelSelection{
TextRequestModel: os.Getenv("MODEL_TEXT_REQUEST"),
SummarizeModel: os.Getenv("MODEL_SUMMARIZE_REQUEST"),
}
slog.Info("Selected", "models", models)
telegramToken := os.Getenv("TELEGRAM_TOKEN") telegramToken := os.Getenv("TELEGRAM_TOKEN")
llmc := llm.NewConnector(ollamaBaseUrl, ollamaToken) llmc := llm.NewConnector(apiBaseUrl, apiToken)
slog.Info("Checking models availability")
for _, model := range []string{models.TextRequestModel, models.SummarizeModel} {
if !llmc.HasModel(model) {
slog.Error("Model not unavailable", "model", model)
os.Exit(1)
}
}
slog.Info("All needed models are available")
ext := extractor.NewExtractor() ext := extractor.NewExtractor()
telegramApi, err := tg.NewBot(telegramToken, tg.WithLogger(bot.NewLogger("telego: "))) telegramApi, err := tg.NewBot(telegramToken, tg.WithLogger(bot.NewLogger("telego: ")))
@ -26,7 +45,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
botService := bot.NewBot(telegramApi, llmc, ext) botService := bot.NewBot(telegramApi, llmc, ext, models)
err = botService.Run() err = botService.Run()
if err != nil { if err != nil {