Skip to content

Commit

Permalink
Valiadte middleware config when create/update bot, add /middlewares A…
Browse files Browse the repository at this point in the history
…PI to get available middleware info
  • Loading branch information
xwjdsh committed May 17, 2023
1 parent e3486ef commit a26a12e
Show file tree
Hide file tree
Showing 10 changed files with 277 additions and 46 deletions.
59 changes: 40 additions & 19 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,35 +97,35 @@ type GetTurnRequest struct {
type GetTurnResponse Turn

type Bot struct {
ID uint `json:"id"`
Name string `json:"name"`
ChatModel string `json:"chat_model"`
Prompt string `json:"prompt"`
BoundaryPrompt string `json:"boundary_prompt"`
ContextTurnCount int `json:"context_turn_count"`
Temperature float32 `json:"temperature"`
Middleware MiddlewareConfig `json:"middleware"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID uint `json:"id"`
Name string `json:"name"`
ChatModel string `json:"chat_model"`
Prompt string `json:"prompt"`
BoundaryPrompt string `json:"boundary_prompt"`
ContextTurnCount int `json:"context_turn_count"`
Temperature float32 `json:"temperature"`
Middlewares *MiddlewareConfig `json:"middlewares,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}

type Middleware struct {
Name string `json:"name"`
Options map[string]any `json:"options,omitempty"`
Name string `json:"name"`
Options map[string]string `json:"options,omitempty"`
}

type MiddlewareConfig struct {
Items []*Middleware `json:"items,omitempty"`
}

type CreateBotRequest struct {
Name string `json:"name" binding:"required"`
ChatModel string `json:"chat_model" binding:"required"`
Prompt string `json:"prompt"`
BoundaryPrompt string `json:"boundary_prompt"`
Temperature float32 `json:"temperature" binding:"required"`
ContextTurnCount int `json:"context_turn_count" binding:"required"`
Middlewares MiddlewareConfig `json:"middlewares"`
Name string `json:"name" binding:"required"`
ChatModel string `json:"chat_model" binding:"required"`
Prompt string `json:"prompt"`
BoundaryPrompt string `json:"boundary_prompt"`
Temperature float32 `json:"temperature" binding:"required"`
ContextTurnCount int `json:"context_turn_count" binding:"required"`
Middlewares *MiddlewareConfig `json:"middlewares"`
}

type CreateBotResponse Bot
Expand All @@ -140,3 +140,24 @@ type ListModelsResponse struct {
ChatModels []string `json:"chat_models"`
EmbeddingModels []string `json:"embedding_models"`
}

type MiddlewareDescOption struct {
Name string `json:"name"`
Desc string `json:"desc"`
DefaultValue string `json:"default_value,omitempty"`
Required bool `json:"required,omitempty"`

Value any `json:"-"`
ParseValueFunc func(string) (any, error) `json:"-"`
}

type MiddlewareDesc struct {
Name string `json:"name"`
Desc string `json:"desc"`
Options []*MiddlewareDescOption `json:"options"`
}

type ListMiddlewaresResponse struct {
GeneralOptions []*MiddlewareDescOption `json:"general_options"`
Middlewares []*MiddlewareDesc `json:"middlewares"`
}
11 changes: 11 additions & 0 deletions cmd/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/pandodao/botastic/internal/llms"
"github.com/pandodao/botastic/internal/starter"
"github.com/pandodao/botastic/pkg/chanhub"
"github.com/pandodao/botastic/pkg/middleware"
"github.com/pandodao/botastic/state"
"github.com/pandodao/botastic/storage"
"go.uber.org/zap"
Expand All @@ -25,6 +26,12 @@ func provideHttpdStarter(cfgFile string) (starter.Starter, error) {
wire.NewSet(storage.Init),
wire.NewSet(llms.New),
wire.NewSet(chanhub.New),
wire.NewSet(
middleware.NewFetch,
provideMiddlewares,
middleware.New,
wire.Bind(new(httpd.MiddlewareHandler), new(*middleware.Handler)),
),
wire.NewSet(
httpd.New,
httpd.NewHandler,
Expand Down Expand Up @@ -53,3 +60,7 @@ func provideLogger(cfg config.LogConfig) (*zap.Logger, error) {
zapCfg.Level = zap.NewAtomicLevelAt(level)
return zapCfg.Build()
}

func provideMiddlewares(m1 *middleware.Fetch) []middleware.Middleware {
return []middleware.Middleware{m1}
}
14 changes: 11 additions & 3 deletions cmd/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions deploy.env

This file was deleted.

38 changes: 27 additions & 11 deletions internal/httpd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,30 @@ type TurnTransmitter interface {
GetTurnsChan() chan<- *models.Turn
}

type MiddlewareHandler interface {
Middlewares() []*api.MiddlewareDesc
GeneralOptions() []*api.MiddlewareDescOption
ValidateConfig(*api.MiddlewareConfig) error
}

type Handler struct {
logger *zap.Logger
llms *llms.Handler
sh *storage.Handler
hub *chanhub.Hub
turnTransmitter TurnTransmitter
logger *zap.Logger
llms *llms.Handler
sh *storage.Handler
hub *chanhub.Hub
turnTransmitter TurnTransmitter
middlewareHandler MiddlewareHandler
}

func NewHandler(sh *storage.Handler, llms *llms.Handler, hub *chanhub.Hub, turnTransmitter TurnTransmitter, logger *zap.Logger) *Handler {
func NewHandler(sh *storage.Handler, llms *llms.Handler, hub *chanhub.Hub, turnTransmitter TurnTransmitter,
logger *zap.Logger, middlewareHandler MiddlewareHandler) *Handler {
return &Handler{
logger: logger.Named("httpd/handler"),
llms: llms,
sh: sh,
hub: hub,
turnTransmitter: turnTransmitter,
logger: logger.Named("httpd/handler"),
llms: llms,
sh: sh,
hub: hub,
turnTransmitter: turnTransmitter,
middlewareHandler: middlewareHandler,
}
}

Expand All @@ -56,3 +65,10 @@ func (h *Handler) ListModels(c *gin.Context) {
EmbeddingModels: h.llms.EmbeddingModels(),
})
}

func (h *Handler) ListMiddlewares(c *gin.Context) {
h.respData(c, api.ListMiddlewaresResponse{
Middlewares: h.middlewareHandler.Middlewares(),
GeneralOptions: h.middlewareHandler.GeneralOptions(),
})
}
29 changes: 25 additions & 4 deletions internal/httpd/handler_bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ func (h *Handler) CreateBot(c *gin.Context) {
h.respErr(c, http.StatusBadRequest, errors.New("chat model does not exist"))
return
}
if req.Middlewares != nil {
if err := h.middlewareHandler.ValidateConfig(req.Middlewares); err != nil {
h.respErr(c, http.StatusBadRequest, err)
return
}
}

bot := &models.Bot{
Name: req.Name,
Expand All @@ -29,8 +35,12 @@ func (h *Handler) CreateBot(c *gin.Context) {
BoundaryPrompt: req.BoundaryPrompt,
ContextTurnCount: req.ContextTurnCount,
Temperature: req.Temperature,
Middleware: models.MiddlewareConfig(req.Middlewares),
}
if req.Middlewares != nil {
v := models.MiddlewareConfig(*req.Middlewares)
bot.Middlewares = &v
}

if err := h.sh.CreateBot(c, bot); err != nil {
h.respErr(c, http.StatusInternalServerError, err)
return
Expand Down Expand Up @@ -92,16 +102,27 @@ func (h *Handler) UpdateBot(c *gin.Context) {
h.respErr(c, http.StatusBadRequest, errors.New("chat model does not exist"))
return
}
if req.Middlewares != nil {
if err := h.middlewareHandler.ValidateConfig(req.Middlewares); err != nil {
h.respErr(c, http.StatusBadRequest, err)
return
}
}

rowsAffected, err := h.sh.UpdateBot(c, uint(botId), map[string]any{
m := map[string]any{
"name": req.Name,
"chat_model": req.ChatModel,
"prompt": req.Prompt,
"boundary_prompt": req.BoundaryPrompt,
"context_turn_count": req.ContextTurnCount,
"temperature": req.Temperature,
"middleware": models.MiddlewareConfig(req.Middlewares),
})
}
if req.Middlewares != nil {
v := models.MiddlewareConfig(*req.Middlewares)
m["middlewares"] = &v
}

rowsAffected, err := h.sh.UpdateBot(c, uint(botId), m)
if err != nil {
h.respErr(c, http.StatusInternalServerError, err)
return
Expand Down
1 change: 1 addition & 0 deletions internal/httpd/httpd.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func (s *Server) initRoutes() {
v1 := s.engine.Group("/api/v1")
{
v1.GET("/models", h.ListModels)
v1.GET("/middlewares", h.ListMiddlewares)

convs := v1.Group("/conversations")
{
Expand Down
11 changes: 8 additions & 3 deletions models/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,27 @@ type Bot struct {
BoundaryPrompt string `gorm:"type:text"`
ContextTurnCount int
Temperature float32
Middleware MiddlewareConfig `gorm:"type:json"`
Middlewares *MiddlewareConfig `gorm:"type:json"`
}

func (b Bot) API() api.Bot {
return api.Bot{
r := api.Bot{
ID: b.ID,
Name: b.Name,
ChatModel: b.ChatModel,
Prompt: b.Prompt,
BoundaryPrompt: b.BoundaryPrompt,
ContextTurnCount: b.ContextTurnCount,
Temperature: b.Temperature,
Middleware: api.MiddlewareConfig(b.Middleware),
CreatedAt: b.CreatedAt,
UpdatedAt: b.UpdatedAt,
}
if b.Middlewares != nil {
v := api.MiddlewareConfig(*b.Middlewares)
r.Middlewares = &v
}

return r
}

type MiddlewareConfig api.MiddlewareConfig
Expand Down
82 changes: 82 additions & 0 deletions pkg/middleware/fetch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package middleware

import (
"context"
"fmt"
"io"
"net/http"
"net/url"

"github.com/pandodao/botastic/api"
"github.com/pandodao/botastic/models"
)

type Fetch struct{}

func NewFetch() *Fetch {
return &Fetch{}
}

func (m *Fetch) Desc() *api.MiddlewareDesc {
return &api.MiddlewareDesc{
Name: "fetch",
Desc: "fetch middleware will send a GET request to the specified URL and return the response body as string",
Options: []*api.MiddlewareDescOption{
{
Name: "url",
Desc: "URL to fetch",
Required: true,
ParseValueFunc: func(v string) (any, error) {
_, err := url.Parse(v)
return v, err
},
},
},
}
}

func (m *Fetch) Parse(opts map[string]string) (map[string]*api.MiddlewareDescOption, error) {
result := map[string]*api.MiddlewareDescOption{}
desc := m.Desc()
for _, opt := range m.Desc().Options {
if opt.Required && opts[opt.Name] == "" {
return nil, fmt.Errorf("missing required option: %s, middleware: %s", opt.Name, desc.Name)
}

if opt.ParseValueFunc != nil {
v, err := opt.ParseValueFunc(opts[opt.Name])
if err != nil {
return nil, fmt.Errorf("failed to parse option: %s, middleware: %s, err: %w", opt.Name, desc.Name, err)
}
opt.Value = v
}
result[opt.Name] = opt
}

return result, nil
}

func (m *Fetch) Process(ctx context.Context, opts map[string]*api.MiddlewareDescOption, turn *models.Turn) (string, error) {
u := opts["url"].Value.(string)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil)
if err != nil {
return "", err
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}

if resp.StatusCode/100 != 2 {
return "", fmt.Errorf("failed to fetch url: %s, status code: %d, body: %s", u, resp.StatusCode, string(body))
}

return string(body), nil
}

0 comments on commit a26a12e

Please sign in to comment.