Spaces:
Configuration error
Configuration error
package grpc | |
import ( | |
"context" | |
"fmt" | |
"log" | |
"net" | |
pb "github.com/mudler/LocalAI/pkg/grpc/proto" | |
"google.golang.org/grpc" | |
) | |
// A GRPC Server that allows to run LLM inference. | |
// It is used by the LLMServices to expose the LLM functionalities that are called by the client. | |
// The GRPC Service is general, trying to encompass all the possible LLM options models. | |
// It depends on the real implementer then what can be done or not. | |
// | |
// The server is implemented as a GRPC service, with the following methods: | |
// - Predict: to run the inference with options | |
// - PredictStream: to run the inference with options and stream the results | |
// server is used to implement helloworld.GreeterServer. | |
type server struct { | |
pb.UnimplementedBackendServer | |
llm LLM | |
} | |
func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { | |
return newReply("OK"), nil | |
} | |
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
embeds, err := s.llm.Embeddings(in) | |
if err != nil { | |
return nil, err | |
} | |
return &pb.EmbeddingResult{Embeddings: embeds}, nil | |
} | |
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
err := s.llm.Load(in) | |
if err != nil { | |
return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err | |
} | |
return &pb.Result{Message: "Loading succeeded", Success: true}, nil | |
} | |
func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
result, err := s.llm.Predict(in) | |
return newReply(result), err | |
} | |
func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
err := s.llm.GenerateImage(in) | |
if err != nil { | |
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err | |
} | |
return &pb.Result{Message: "Image generated", Success: true}, nil | |
} | |
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
err := s.llm.TTS(in) | |
if err != nil { | |
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err | |
} | |
return &pb.Result{Message: "TTS audio generated", Success: true}, nil | |
} | |
func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
err := s.llm.SoundGeneration(in) | |
if err != nil { | |
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err | |
} | |
return &pb.Result{Message: "Sound Generation audio generated", Success: true}, nil | |
} | |
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
result, err := s.llm.AudioTranscription(in) | |
if err != nil { | |
return nil, err | |
} | |
tresult := &pb.TranscriptResult{} | |
for _, s := range result.Segments { | |
tks := []int32{} | |
for _, t := range s.Tokens { | |
tks = append(tks, int32(t)) | |
} | |
tresult.Segments = append(tresult.Segments, | |
&pb.TranscriptSegment{ | |
Text: s.Text, | |
Id: int32(s.Id), | |
Start: int64(s.Start), | |
End: int64(s.End), | |
Tokens: tks, | |
}) | |
} | |
tresult.Text = result.Text | |
return tresult, nil | |
} | |
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
resultChan := make(chan string) | |
done := make(chan bool) | |
go func() { | |
for result := range resultChan { | |
stream.Send(newReply(result)) | |
} | |
done <- true | |
}() | |
err := s.llm.PredictStream(in, resultChan) | |
<-done | |
return err | |
} | |
func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
res, err := s.llm.TokenizeString(in) | |
if err != nil { | |
return nil, err | |
} | |
castTokens := make([]int32, len(res.Tokens)) | |
for i, v := range res.Tokens { | |
castTokens[i] = int32(v) | |
} | |
return &pb.TokenizationResponse{ | |
Length: int32(res.Length), | |
Tokens: castTokens, | |
}, err | |
} | |
func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) { | |
res, err := s.llm.Status() | |
if err != nil { | |
return nil, err | |
} | |
return &res, nil | |
} | |
func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Result, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
err := s.llm.StoresSet(in) | |
if err != nil { | |
return &pb.Result{Message: fmt.Sprintf("Error setting entry: %s", err.Error()), Success: false}, err | |
} | |
return &pb.Result{Message: "Set key", Success: true}, nil | |
} | |
func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
err := s.llm.StoresDelete(in) | |
if err != nil { | |
return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err | |
} | |
return &pb.Result{Message: "Deleted key", Success: true}, nil | |
} | |
func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
res, err := s.llm.StoresGet(in) | |
if err != nil { | |
return nil, err | |
} | |
return &res, nil | |
} | |
func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) { | |
if s.llm.Locking() { | |
s.llm.Lock() | |
defer s.llm.Unlock() | |
} | |
res, err := s.llm.StoresFind(in) | |
if err != nil { | |
return nil, err | |
} | |
return &res, nil | |
} | |
func StartServer(address string, model LLM) error { | |
lis, err := net.Listen("tcp", address) | |
if err != nil { | |
return err | |
} | |
s := grpc.NewServer() | |
pb.RegisterBackendServer(s, &server{llm: model}) | |
log.Printf("gRPC Server listening at %v", lis.Addr()) | |
if err := s.Serve(lis); err != nil { | |
return err | |
} | |
return nil | |
} | |
func RunServer(address string, model LLM) (func() error, error) { | |
lis, err := net.Listen("tcp", address) | |
if err != nil { | |
return nil, err | |
} | |
s := grpc.NewServer() | |
pb.RegisterBackendServer(s, &server{llm: model}) | |
log.Printf("gRPC Server listening at %v", lis.Addr()) | |
if err = s.Serve(lis); err != nil { | |
return func() error { | |
return lis.Close() | |
}, err | |
} | |
return func() error { | |
s.GracefulStop() | |
return nil | |
}, nil | |
} | |