diff --git a/examples/dataloader/example.go b/examples/dataloader/example.go new file mode 100644 index 00000000..089e1b39 --- /dev/null +++ b/examples/dataloader/example.go @@ -0,0 +1,226 @@ +package dataloaderexample + +import ( + "fmt" + "log" + "time" + + "golang.org/x/net/context" + + "github.com/bigdrum/godataloader" + "github.com/graphql-go/graphql" +) + +var postDB = map[string]*post{ + "1": &post{ + ID: "1", + Content: "Hello 1", + AuthorID: "1", + }, + "2": &post{ + ID: "2", + Content: "Hello 2", + AuthorID: "1", + }, + "3": &post{ + ID: "3", + Content: "Hello 3", + AuthorID: "2", + }, + "4": &post{ + ID: "4", + Content: "Hello 4", + AuthorID: "2", + }, +} + +var userDB = map[string]*user{ + "1": &user{ + ID: "1", + Name: "Mike", + }, + "2": &user{ + ID: "2", + Name: "John", + }, + "3": &user{ + ID: "3", + Name: "Kate", + }, +} + +var loaderKey = struct{}{} + +type loader struct { + postLoader *dataloader.DataLoader + userLoader *dataloader.DataLoader +} + +func newLoader(sch *dataloader.Scheduler) *loader { + return &loader{ + postLoader: dataloader.New(sch, dataloader.Parallel(func(key interface{}) dataloader.Value { + // In practice, we will make remote request (e.g. SQL) to fetch post. + // Here we just fake it. + log.Print("Load post ", key) + time.Sleep(time.Second) + id := key.(string) + return dataloader.NewValue(postDB[id], nil) + })), + userLoader: dataloader.New(sch, func(keys []interface{}) []dataloader.Value { + // In practice, we will make remote request (e.g. SQL) to fetch multiple users. + // Here we just fake it. + log.Print("Batch load users ", keys) + time.Sleep(time.Second) + var ret []dataloader.Value + for _, key := range keys { + id := key.(string) + ret = append(ret, dataloader.NewValue(userDB[id], nil)) + } + return ret + }), + } +} + +type post struct { + ID string `json:"id"` + Content string `json:"content"` + AuthorID string `json:"author_id"` +} + +type user struct { + ID string `json:"id"` + Name string `json:"name"` +} + +func attachNewDataLoader(parent context.Context, sch *dataloader.Scheduler) context.Context { + dl := newLoader(sch) + return context.WithValue(parent, loaderKey, dl) +} + +func getDataLoader(ctx context.Context) *loader { + return ctx.Value(loaderKey).(*loader) +} + +func CreateSchema() graphql.Schema { + userType := graphql.NewObject(graphql.ObjectConfig{ + Name: "User", + Fields: graphql.Fields{ + "id": &graphql.Field{ + Type: graphql.String, + }, + "name": &graphql.Field{ + Type: graphql.String, + }, + }, + }) + + postType := graphql.NewObject(graphql.ObjectConfig{ + Name: "Post", + Fields: graphql.Fields{ + "id": &graphql.Field{ + Type: graphql.String, + }, + "content": &graphql.Field{ + Type: graphql.String, + }, + "author": &graphql.Field{ + Type: userType, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + post := p.Source.(*post) + id := post.AuthorID + dl := getDataLoader(p.Context) + return dl.userLoader.Load(id).Unbox() + }, + }, + }, + }) + + rootQuery := graphql.NewObject(graphql.ObjectConfig{ + Name: "RootQuery", + Fields: graphql.Fields{ + "post": &graphql.Field{ + Type: postType, + Args: graphql.FieldConfigArgument{ + "id": &graphql.ArgumentConfig{ + Type: graphql.String, + }, + }, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + id, ok := p.Args["id"].(string) + if !ok { + return nil, nil + } + dl := getDataLoader(p.Context) + return dl.postLoader.Load(id).Unbox() + }, + }, + "user": &graphql.Field{ + Type: userType, + Args: graphql.FieldConfigArgument{ + "id": &graphql.ArgumentConfig{ + Type: graphql.String, + }, + }, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + id, ok := p.Args["id"].(string) + if !ok { + return nil, nil + } + dl := getDataLoader(p.Context) + return dl.userLoader.Load(id).Unbox() + }, + }, + }}) + + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: rootQuery, + }) + if err != nil { + panic(err) + } + return schema +} + +type dataloaderExecutor struct { + sch *dataloader.Scheduler +} + +func (e *dataloaderExecutor) RunMany(fs []func()) { + if len(fs) == 1 { + fs[0]() + return + } + if len(fs) == 0 { + return + } + + wg := dataloader.NewWaitGroup(e.sch) + for i := range fs { + f := fs[i] + wg.Add(1) + e.sch.Spawn(func() { + defer wg.Done() + f() + }) + } + wg.Wait() +} + +func RunQuery(query string, schema graphql.Schema) *graphql.Result { + var result *graphql.Result + dataloader.RunWithScheduler(func(sch *dataloader.Scheduler) { + executor := dataloaderExecutor{sch} + ctx := attachNewDataLoader(context.Background(), sch) + result = graphql.Do(graphql.Params{ + Schema: schema, + RequestString: query, + Context: ctx, + Executor: &executor, + }) + if len(result.Errors) > 0 { + fmt.Printf("wrong result, unexpected errors: %v", result.Errors) + } + }) + + return result +} diff --git a/examples/dataloader/example_test.go b/examples/dataloader/example_test.go new file mode 100644 index 00000000..3b111685 --- /dev/null +++ b/examples/dataloader/example_test.go @@ -0,0 +1,42 @@ +package dataloaderexample_test + +import ( + "testing" + + "github.com/graphql-go/graphql/examples/dataloader" +) + +func TestQuery(t *testing.T) { + schema := dataloaderexample.CreateSchema() + r := dataloaderexample.RunQuery(`{ + p1_0: post(id: "1") { id author { name }} + p1_1: post(id: "1") { id author { name }} + p1_2: post(id: "1") { id author { name }} + p1_3: post(id: "1") { id author { name }} + p1_4: post(id: "1") { id author { name }} + p1_5: post(id: "1") { id author { name }} + p2_1: post(id: "2") { id author { name }} + p2_2: post(id: "2") { id author { name }} + p2_3: post(id: "2") { id author { name }} + p3_1: post(id: "3") { id author { name }} + p3_2: post(id: "3") { id author { name }} + p3_3: post(id: "3") { id author { name }} + u1_1: user(id: "1") { name } + u1_2: user(id: "1") { name } + u1_3: user(id: "1") { name } + u2_1: user(id: "3") { name } + u2_2: user(id: "3") { name } + u2_3: user(id: "3") { name } + }`, schema) + if len(r.Errors) != 0 { + t.Error(r.Errors) + } + // The above query would produce log like this: + // 2016/07/23 23:49:31 Load post 3 + // 2016/07/23 23:49:31 Load post 1 + // 2016/07/23 23:49:31 Load post 2 + // 2016/07/23 23:49:32 Batch load users [3 1 2] + // Notice the first level post loading is done concurrently without duplicate. + // And the second level user loading is also done in the same fashion, but batched fetch is used instead. + // TODO: Make test actually verify that. +} diff --git a/executor.go b/executor.go index b03c8a01..e99047af 100644 --- a/executor.go +++ b/executor.go @@ -5,12 +5,27 @@ import ( "fmt" "reflect" "strings" + "sync" "github.com/graphql-go/graphql/gqlerrors" "github.com/graphql-go/graphql/language/ast" "golang.org/x/net/context" ) +type Executor interface { + RunMany(f []func()) +} + +type SerialExecutor struct{} + +func (e *SerialExecutor) RunMany(fs []func()) { + for _, f := range fs { + f() + } +} + +var defaultExecutor = &SerialExecutor{} + type ExecuteParams struct { Schema Schema Root interface{} @@ -21,10 +36,15 @@ type ExecuteParams struct { // Context may be provided to pass application-specific per-request // information to resolve functions. Context context.Context + + Executor Executor } func Execute(p ExecuteParams) (result *Result) { result = &Result{} + if p.Executor == nil { + p.Executor = defaultExecutor + } exeContext, err := buildExecutionContext(BuildExecutionCtxParams{ Schema: p.Schema, @@ -35,6 +55,7 @@ func Execute(p ExecuteParams) (result *Result) { Errors: nil, Result: result, Context: p.Context, + Executor: p.Executor, }) if err != nil { @@ -69,6 +90,7 @@ type BuildExecutionCtxParams struct { Errors []gqlerrors.FormattedError Result *Result Context context.Context + Executor Executor } type ExecutionContext struct { Schema Schema @@ -78,6 +100,7 @@ type ExecutionContext struct { VariableValues map[string]interface{} Errors []gqlerrors.FormattedError Context context.Context + Executor Executor } func buildExecutionContext(p BuildExecutionCtxParams) (*ExecutionContext, error) { @@ -124,6 +147,7 @@ func buildExecutionContext(p BuildExecutionCtxParams) (*ExecutionContext, error) eCtx.VariableValues = variableValues eCtx.Errors = p.Errors eCtx.Context = p.Context + eCtx.Executor = p.Executor return eCtx, nil } @@ -247,14 +271,22 @@ func executeFields(p ExecuteFieldsParams) *Result { } finalResults := map[string]interface{}{} + fs := make([]func(), 0, len(p.Fields)) + var resultsMu sync.Mutex for responseName, fieldASTs := range p.Fields { - resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs) - if state.hasNoFieldDefs { - continue - } - finalResults[responseName] = resolved + responseName := responseName + fieldASTs := fieldASTs + fs = append(fs, func() { + resolved, state := resolveField(p.ExecutionContext, p.ParentType, p.Source, fieldASTs) + if state.hasNoFieldDefs { + return + } + resultsMu.Lock() + finalResults[responseName] = resolved + resultsMu.Unlock() + }) } - + p.ExecutionContext.Executor.RunMany(fs) return &Result{ Data: finalResults, Errors: p.ExecutionContext.Errors, @@ -756,12 +788,17 @@ func completeListValue(eCtx *ExecutionContext, returnType *List, fieldASTs []*as } itemType := returnType.OfType - completedResults := []interface{}{} + completedResults := make([]interface{}, resultVal.Len()) + fs := make([]func(), 0, resultVal.Len()) for i := 0; i < resultVal.Len(); i++ { - val := resultVal.Index(i).Interface() - completedItem := completeValueCatchingError(eCtx, itemType, fieldASTs, info, val) - completedResults = append(completedResults, completedItem) - } + i := i + fs = append(fs, func() { + val := resultVal.Index(i).Interface() + completedItem := completeValueCatchingError(eCtx, itemType, fieldASTs, info, val) + completedResults[i] = completedItem + }) + } + eCtx.Executor.RunMany(fs) return completedResults } diff --git a/graphql.go b/graphql.go index af9dd65a..f9a1cd87 100644 --- a/graphql.go +++ b/graphql.go @@ -30,6 +30,10 @@ type Params struct { // Context may be provided to pass application-specific per-request // information to resolve functions. Context context.Context + + // Executor allows to control the behavior of how to perform resolving function that + // can be run concurrently. If not given, they will be executed serially. + Executor Executor } func Do(p Params) *Result { @@ -58,5 +62,6 @@ func Do(p Params) *Result { OperationName: p.OperationName, Args: p.VariableValues, Context: p.Context, + Executor: p.Executor, }) }