diff --git a/definition.go b/definition.go index b5f0048b..0fbbbd52 100644 --- a/definition.go +++ b/definition.go @@ -534,6 +534,7 @@ func defineFieldMap(ttype Named, fieldMap Fields) (FieldDefinitionMap, error) { Description: field.Description, Type: field.Type, Resolve: field.Resolve, + Subscribe: field.Subscribe, DeprecationReason: field.DeprecationReason, } @@ -606,6 +607,7 @@ type Field struct { Type Output `json:"type"` Args FieldConfigArgument `json:"args"` Resolve FieldResolveFn `json:"-"` + Subscribe FieldResolveFn `json:"-"` DeprecationReason string `json:"deprecationReason"` Description string `json:"description"` } @@ -625,6 +627,7 @@ type FieldDefinition struct { Type Output `json:"type"` Args []*Argument `json:"args"` Resolve FieldResolveFn `json:"-"` + Subscribe FieldResolveFn `json:"-"` DeprecationReason string `json:"deprecationReason"` } diff --git a/go.mod b/go.mod index 399b200d..7e02f765 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module github.com/graphql-go/graphql + +go 1.13 diff --git a/subscription.go b/subscription.go new file mode 100644 index 00000000..ef5d73ef --- /dev/null +++ b/subscription.go @@ -0,0 +1,228 @@ +package graphql + +import ( + "context" + "fmt" + + "github.com/graphql-go/graphql/gqlerrors" + "github.com/graphql-go/graphql/language/parser" + "github.com/graphql-go/graphql/language/source" +) + +// SubscribeParams parameters for subscribing +type SubscribeParams struct { + Schema Schema + RequestString string + RootValue interface{} + // ContextValue context.Context + VariableValues map[string]interface{} + OperationName string + FieldResolver FieldResolveFn + FieldSubscriber FieldResolveFn +} + +// Subscribe performs a subscribe operation on the given query and schema +// To finish a subscription you can simply close the channel from inside the `Subscribe` function +// currently does not support extensions hooks +func Subscribe(p Params) chan *Result { + + source := source.NewSource(&source.Source{ + Body: []byte(p.RequestString), + Name: "GraphQL request", + }) + + // TODO run extensions hooks + + // parse the source + AST, err := parser.Parse(parser.ParseParams{Source: source}) + if err != nil { + + // merge the errors from extensions and the original error from parser + return sendOneResultAndClose(&Result{ + Errors: gqlerrors.FormatErrors(err), + }) + } + + // validate document + validationResult := ValidateDocument(&p.Schema, AST, nil) + + if !validationResult.IsValid { + // run validation finish functions for extensions + return sendOneResultAndClose(&Result{ + Errors: validationResult.Errors, + }) + + } + return ExecuteSubscription(ExecuteParams{ + Schema: p.Schema, + Root: p.RootObject, + AST: AST, + OperationName: p.OperationName, + Args: p.VariableValues, + Context: p.Context, + }) +} + +func sendOneResultAndClose(res *Result) chan *Result { + resultChannel := make(chan *Result, 1) + resultChannel <- res + close(resultChannel) + return resultChannel +} + +// ExecuteSubscription is similar to graphql.Execute but returns a channel instead of a Result +// currently does not support extensions +func ExecuteSubscription(p ExecuteParams) chan *Result { + + if p.Context == nil { + p.Context = context.Background() + } + + var mapSourceToResponse = func(payload interface{}) *Result { + return Execute(ExecuteParams{ + Schema: p.Schema, + Root: payload, + AST: p.AST, + OperationName: p.OperationName, + Args: p.Args, + Context: p.Context, + }) + } + var resultChannel = make(chan *Result) + go func() { + defer close(resultChannel) + defer func() { + if err := recover(); err != nil { + e, ok := err.(error) + if !ok { + return + } + resultChannel <- &Result{ + Errors: gqlerrors.FormatErrors(e), + } + } + return + }() + + exeContext, err := buildExecutionContext(buildExecutionCtxParams{ + Schema: p.Schema, + Root: p.Root, + AST: p.AST, + OperationName: p.OperationName, + Args: p.Args, + Context: p.Context, + }) + + if err != nil { + resultChannel <- &Result{ + Errors: gqlerrors.FormatErrors(err), + } + + return + } + + operationType, err := getOperationRootType(p.Schema, exeContext.Operation) + if err != nil { + resultChannel <- &Result{ + Errors: gqlerrors.FormatErrors(err), + } + + return + } + + fields := collectFields(collectFieldsParams{ + ExeContext: exeContext, + RuntimeType: operationType, + SelectionSet: exeContext.Operation.GetSelectionSet(), + }) + + responseNames := []string{} + for name := range fields { + responseNames = append(responseNames, name) + } + responseName := responseNames[0] + fieldNodes := fields[responseName] + fieldNode := fieldNodes[0] + fieldName := fieldNode.Name.Value + fieldDef := getFieldDef(p.Schema, operationType, fieldName) + + if fieldDef == nil { + resultChannel <- &Result{ + Errors: gqlerrors.FormatErrors(fmt.Errorf("the subscription field %q is not defined", fieldName)), + } + + return + } + + resolveFn := fieldDef.Subscribe + + if resolveFn == nil { + resultChannel <- &Result{ + Errors: gqlerrors.FormatErrors(fmt.Errorf("the subscription function %q is not defined", fieldName)), + } + return + } + fieldPath := &ResponsePath{ + Key: responseName, + } + + args := getArgumentValues(fieldDef.Args, fieldNode.Arguments, exeContext.VariableValues) + info := ResolveInfo{ + FieldName: fieldName, + FieldASTs: fieldNodes, + Path: fieldPath, + ReturnType: fieldDef.Type, + ParentType: operationType, + Schema: p.Schema, + Fragments: exeContext.Fragments, + RootValue: exeContext.Root, + Operation: exeContext.Operation, + VariableValues: exeContext.VariableValues, + } + + fieldResult, err := resolveFn(ResolveParams{ + Source: p.Root, + Args: args, + Info: info, + Context: p.Context, + }) + if err != nil { + resultChannel <- &Result{ + Errors: gqlerrors.FormatErrors(err), + } + + return + } + + if fieldResult == nil { + resultChannel <- &Result{ + Errors: gqlerrors.FormatErrors(fmt.Errorf("no field result")), + } + + return + } + + switch fieldResult.(type) { + case chan interface{}: + sub := fieldResult.(chan interface{}) + for { + select { + case <-p.Context.Done(): + return + + case res, more := <-sub: + if !more { + return + } + resultChannel <- mapSourceToResponse(res) + } + } + default: + resultChannel <- mapSourceToResponse(fieldResult) + return + } + }() + + // return a result channel + return resultChannel +} diff --git a/subscription_test.go b/subscription_test.go new file mode 100644 index 00000000..0a4bebee --- /dev/null +++ b/subscription_test.go @@ -0,0 +1,287 @@ +package graphql_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/testutil" +) + +func TestSchemaSubscribe(t *testing.T) { + + testutil.RunSubscribes(t, []*testutil.TestSubscription{ + { + Name: "subscribe without resolver", + Schema: makeSubscriptionSchema(t, graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "sub_without_resolver": &graphql.Field{ + Type: graphql.String, + Subscribe: makeSubscribeToMapFunction([]map[string]interface{}{ + { + "sub_without_resolver": "a", + }, + { + "sub_without_resolver": "b", + }, + { + "sub_without_resolver": "c", + }, + }), + }, + }, + }), + Query: ` + subscription { + sub_without_resolver + } + `, + ExpectedResults: []testutil.TestResponse{ + {Data: `{ "sub_without_resolver": "a" }`}, + {Data: `{ "sub_without_resolver": "b" }`}, + {Data: `{ "sub_without_resolver": "c" }`}, + }, + }, + { + Name: "subscribe with resolver", + Schema: makeSubscriptionSchema(t, graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "sub_with_resolver": &graphql.Field{ + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return p.Source, nil + }, + Subscribe: makeSubscribeToStringFunction([]string{"a", "b", "c"}), + }, + }, + }), + Query: ` + subscription { + sub_with_resolver + } + `, + ExpectedResults: []testutil.TestResponse{ + {Data: `{ "sub_with_resolver": "a" }`}, + {Data: `{ "sub_with_resolver": "b" }`}, + {Data: `{ "sub_with_resolver": "c" }`}, + }, + }, + { + Name: "receive query validation error", + Schema: makeSubscriptionSchema(t, graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "sub_without_resolver": &graphql.Field{ + Type: graphql.String, + Subscribe: makeSubscribeToStringFunction([]string{"a", "b", "c"}), + }, + }, + }), + Query: ` + subscription { + sub_without_resolver + xxx + } + `, + ExpectedResults: []testutil.TestResponse{ + {Errors: []string{"Cannot query field \"xxx\" on type \"Subscription\"."}}, + }, + }, + { + Name: "panic inside subscribe is recovered", + Schema: makeSubscriptionSchema(t, graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "should_error": &graphql.Field{ + Type: graphql.String, + Subscribe: func(p graphql.ResolveParams) (interface{}, error) { + panic(errors.New("got a panic error")) + }, + }, + }, + }), + Query: ` + subscription { + should_error + } + `, + ExpectedResults: []testutil.TestResponse{ + {Errors: []string{"got a panic error"}}, + }, + }, + { + Name: "subscribe with resolver changes output", + Schema: makeSubscriptionSchema(t, graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "sub_with_resolver": &graphql.Field{ + Type: graphql.String, + Subscribe: makeSubscribeToStringFunction([]string{"a", "b", "c", "d"}), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return fmt.Sprintf("result=%v", p.Source), nil + }, + }, + }, + }), + Query: ` + subscription { + sub_with_resolver + } + `, + ExpectedResults: []testutil.TestResponse{ + {Data: `{ "sub_with_resolver": "result=a" }`}, + {Data: `{ "sub_with_resolver": "result=b" }`}, + {Data: `{ "sub_with_resolver": "result=c" }`}, + {Data: `{ "sub_with_resolver": "result=d" }`}, + }, + }, + { + Name: "subscribe to a nested object", + Schema: makeSubscriptionSchema(t, graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "sub_with_object": &graphql.Field{ + Type: graphql.NewObject(graphql.ObjectConfig{ + Name: "Obj", + Fields: graphql.Fields{ + "field": &graphql.Field{ + Type: graphql.String, + }, + }, + }), + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return p.Source, nil + }, + Subscribe: makeSubscribeToMapFunction([]map[string]interface{}{ + { + "field": "hello", + }, + { + "field": "bye", + }, + { + "field": nil, + }, + }), + }, + }, + }), + Query: ` + subscription { + sub_with_object { + field + } + } + `, + ExpectedResults: []testutil.TestResponse{ + {Data: `{ "sub_with_object": { "field": "hello" } }`}, + {Data: `{ "sub_with_object": { "field": "bye" } }`}, + {Data: `{ "sub_with_object": { "field": null } }`}, + }, + }, + + { + Name: "subscription_resolver_can_error", + Schema: makeSubscriptionSchema(t, graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "should_error": &graphql.Field{ + Type: graphql.String, + Subscribe: func(p graphql.ResolveParams) (interface{}, error) { + return nil, errors.New("got a subscribe error") + }, + }, + }, + }), + Query: ` + subscription { + should_error + } + `, + ExpectedResults: []testutil.TestResponse{ + { + Errors: []string{"got a subscribe error"}, + }, + }, + }, + { + Name: "schema_without_subscribe_errors", + Schema: makeSubscriptionSchema(t, graphql.ObjectConfig{ + Name: "Subscription", + Fields: graphql.Fields{ + "should_error": &graphql.Field{ + Type: graphql.String, + }, + }, + }), + Query: ` + subscription { + should_error + } + `, + ExpectedResults: []testutil.TestResponse{ + { + Errors: []string{"the subscription function \"should_error\" is not defined"}, + }, + }, + }, + }) +} + +func makeSubscribeToStringFunction(elements []string) func(p graphql.ResolveParams) (interface{}, error) { + return func(p graphql.ResolveParams) (interface{}, error) { + c := make(chan interface{}) + go func() { + for _, r := range elements { + select { + case <-p.Context.Done(): + close(c) + return + case c <- r: + } + } + close(c) + }() + return c, nil + } +} + +func makeSubscribeToMapFunction(elements []map[string]interface{}) func(p graphql.ResolveParams) (interface{}, error) { + return func(p graphql.ResolveParams) (interface{}, error) { + c := make(chan interface{}) + go func() { + for _, r := range elements { + select { + case <-p.Context.Done(): + close(c) + return + case c <- r: + } + } + close(c) + }() + return c, nil + } +} + +func makeSubscriptionSchema(t *testing.T, c graphql.ObjectConfig) graphql.Schema { + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: dummyQuery, + Subscription: graphql.NewObject(c), + }) + if err != nil { + t.Errorf("failed to create schema: %v", err) + } + return schema +} + +var dummyQuery = graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + + "hello": &graphql.Field{Type: graphql.String}, + }, +}) diff --git a/testutil/subscription.go b/testutil/subscription.go new file mode 100644 index 00000000..b17c4b65 --- /dev/null +++ b/testutil/subscription.go @@ -0,0 +1,149 @@ +package testutil + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "strconv" + "testing" + + "github.com/graphql-go/graphql" +) + +// TestResponse models the expected response +type TestResponse struct { + Data string + Errors []string +} + +// TestSubscription is a GraphQL test case to be used with RunSubscribe. +type TestSubscription struct { + Name string + Schema graphql.Schema + Query string + OperationName string + Variables map[string]interface{} + ExpectedResults []TestResponse +} + +// RunSubscribes runs the given GraphQL subscription test cases as subtests. +func RunSubscribes(t *testing.T, tests []*TestSubscription) { + for i, test := range tests { + if test.Name == "" { + test.Name = strconv.Itoa(i + 1) + } + + t.Run(test.Name, func(t *testing.T) { + RunSubscribe(t, test) + }) + } +} + +// RunSubscribe runs a single GraphQL subscription test case. +func RunSubscribe(t *testing.T, test *TestSubscription) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := graphql.Subscribe(graphql.Params{ + Context: ctx, + OperationName: test.OperationName, + RequestString: test.Query, + VariableValues: test.Variables, + Schema: test.Schema, + }) + // if err != nil { + // if err.Error() != test.ExpectedErr.Error() { + // t.Fatalf("unexpected error: got %+v, want %+v", err, test.ExpectedErr) + // } + + // return + // } + + var results []*graphql.Result + for res := range c { + t.Log(pretty(res)) + results = append(results, res) + } + + for i, expected := range test.ExpectedResults { + if len(results)-1 < i { + t.Error(errors.New("not enough results, expected results are more than actual results")) + return + } + res := results[i] + + var errs []string + for _, err := range res.Errors { + errs = append(errs, err.Message) + } + checkErrorStrings(t, expected.Errors, errs) + if expected.Data == "" { + continue + } + + got, err := json.MarshalIndent(res.Data, "", " ") + if err != nil { + t.Fatalf("got: invalid JSON: %s; raw: %s", err, got) + } + + if err != nil { + t.Fatal(err) + } + want, err := formatJSON(expected.Data) + if err != nil { + t.Fatalf("got: invalid JSON: %s; raw: %s", err, res.Data) + } + + if !bytes.Equal(got, want) { + t.Logf("got: %s", got) + t.Logf("want: %s", want) + t.Fail() + } + } +} + +func checkErrorStrings(t *testing.T, expected, actual []string) { + expectedCount, actualCount := len(expected), len(actual) + + if expectedCount != actualCount { + t.Fatalf("unexpected number of errors: want `%d`, got `%d`", expectedCount, actualCount) + } + + if expectedCount > 0 { + for i, want := range expected { + got := actual[i] + + if got != want { + t.Fatalf("unexpected error: got `%+v`, want `%+v`", got, want) + } + } + + // Return because we're done checking. + return + } + + for _, err := range actual { + t.Errorf("unexpected error: '%s'", err) + } +} + +func formatJSON(data string) ([]byte, error) { + var v interface{} + if err := json.Unmarshal([]byte(data), &v); err != nil { + return nil, err + } + formatted, err := json.MarshalIndent(v, "", " ") + if err != nil { + return nil, err + } + return formatted, nil +} + +func pretty(x interface{}) string { + got, err := json.MarshalIndent(x, "", " ") + if err != nil { + panic(err) + } + return string(got) +}