From a49ce5ab28f8cc5f881ae829a007f5111e1711b5 Mon Sep 17 00:00:00 2001 From: Alexey Skobkin Date: Wed, 6 Nov 2024 19:39:21 +0300 Subject: [PATCH] Using local search to check model existence (fixes #42). --- llm/llm.go | 48 +++++++++++++++++++++++------------------------- main.go | 9 ++++----- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/llm/llm.go b/llm/llm.go index def83fe..19bd3a4 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/sashabaranov/go-openai" "log/slog" + "slices" "strconv" ) @@ -113,36 +114,33 @@ func (l *LlmConnector) Summarize(text string, model string) (string, error) { return resp.Choices[0].Message.Content, nil } -func (l *LlmConnector) GetModels() []string { - var result []string - - models, err := l.client.ListModels(context.Background()) +func (l *LlmConnector) HasAllModels(modelIds []string) (bool, map[string]bool) { + modelList, err := l.client.ListModels(context.Background()) if err != nil { slog.Error("llm: Model list request failed", "error", err) - - return result } - slog.Info("Model list retrieved", "models", models) + slog.Info("llm: Returned model list", "models", modelList) + slog.Info("llm: Checking for requested models", "requested", modelIds) - for _, model := range models.Models { - result = append(result, model.ID) + requestedModelsCount := len(modelIds) + searchResult := make(map[string]bool, requestedModelsCount) + + for _, modelId := range modelIds { + searchResult[modelId] = false } - return result -} - -func (l *LlmConnector) HasModel(id string) bool { - model, err := l.client.GetModel(context.Background(), id) - if err != nil { - slog.Error("llm: Model request failed", "error", err) - } - - slog.Debug("llm: Returned model", "model", model) - - if model.ID != "" { - return true - } - - return false + for _, model := range modelList.Models { + if slices.Contains(modelIds, model.ID) { + searchResult[model.ID] = true + } + } + + for _, v := range searchResult { + if !v { + return false, searchResult + } + } + + return true, searchResult } diff --git a/main.go b/main.go index 84d3f24..e9f23d2 100644 --- a/main.go +++ b/main.go @@ -28,11 +28,10 @@ func main() { slog.Info("Checking models availability") - for _, model := range []string{models.TextRequestModel, models.SummarizeModel} { - if !llmc.HasModel(model) { - slog.Error("Model not unavailable", "model", model) - os.Exit(1) - } + hasAll, searchResult := llmc.HasAllModels([]string{models.TextRequestModel, models.SummarizeModel}) + if !hasAll { + slog.Error("Not all models are available", "result", searchResult) + os.Exit(1) } slog.Info("All needed models are available")