From 6fb9e53dfe723dd74104f076d83216062b9be10e Mon Sep 17 00:00:00 2001 From: Aayush Mittal <94428324+mittalaa@users.noreply.github.com> Date: Mon, 15 May 2023 15:42:06 +0530 Subject: [PATCH] Package update (#17) * Support for embedded struct type in resolver * fix bug in slice pop * fix bug while finding field * add 'getFieldCount' to resolve ambiguity * Increase extensions test coverage * Remove duplicate unit tests * rename 'getFieldCount' to 'fieldCount' * add test for ambiguous field panic * add unit tests for embedded struct feature * rename TestEmbedded => TestEmbeddedStruct * Fixes #357 * Actually fix #357 * Print context to panic log * Add Example of Custom Errors Adding example and documentation for how to create custom error implementations which include `extensions` within their `error` payload * Clarify errors for mismatching input implementation Producing clearer error messages when field input arguments are implemented by code: * Which does not match the schema e.g. missing field; or * Function missing struct wrapper for field arguments * Allow `schema` to be omitted when using default root op names * Strip Common Indentation from BlockString Descriptions Multi-line descriptions need to have their common indentation level (which results from indentation of that part of the schema, rather than being intentional for the description text) removed to ensure the descriptions use the correct value, are formatted correctly etc This is to meet the condition documented in the GraphQL spec: https://graphql.github.io/graphql-spec/June2018/#sec-String-Value > Since block strings represent freeform text often used in indented > positions, the string value semantics of a block string excludes > uniform indentation and blank initial and trailing lines via > BlockStringValue(). * Syntax highlighting fixed in README * Add walkthrough Fixed small punctuation and added my walkthrough package * Update README.md * Add support for directives in schema parser * Use operationName from query if missing from POST * Fix SIGSEGV when client subs to multiple fields * bugfix: correctly determine fragment usage In previous versions of this code, this validation would exit when it encountered a fragment legitimately used twice. This bugfix skips the recursion but does not stop progress altogether allowing other fragments to be marked as used. * Limit the number of concurrent list nodes processed It uses the current capacity of the limiter as a hint as this is set based on the maxParallelism field on the schema. * Remove need for WaitGroup * More descriptive error when unmarshaling ID/Time This adds a tiny bit more information to the error messages produced when unmarshaling an input value to an ID or Time fails. * Improve README.md Fixes #307 Add short descriptions for different schema options. Move community examples to wiki. Add companies that use this library. * fix #241 Similar to https://github.com/graph-gophers/graphql-go/pull/407, but adds test cases. * Add comment explaining why we limit concurrency * Issue #299: unclear error message in case of multiline string argument * handle case where interface is type-asserted to same interface * Issue #299: unclear error message in case of multiline strings * Issue #299: unclear error message in case of multiline string argument * Update logic to always check for nil pointer returns * Adding variables parameter for query validations. * Fixed `reflect.Value.Type on zero Value` panic when subscription resolver itself panicks The internal exec Subscribe method had code to deal with subscription resolver panicking when being called. But when such handling happen, the error is attached to the request object and it never checked later on. This leads to some zero checks to fail when we try to extract the type from the resolver's channel since this variable was never set. Doing this creates a second panic which is not handled and make the application die. To fix the issue, we now check if there is errors on the request object before continuing with the rest of the check, if there is errors, it's because a panic occurs and we return the response right away. * Added possibility to customize subscription resolver timeout value The previous value was hard-coded to 1 second. This is problematic for resolver that takes more time than this to return a result. When parsing the schema, it's not possible to pass a custom value for the subscription resolver timeout. Extracted from #317 * Allowed Subscription resolver to return `*QueryError` directly Previously, any error returned by the Subscription resolver was immediately wrapped inside its own `*QueryError` value even if the returned error was already a `*QueryError`. Now, when receiving such types, we use it as-is without wrapping again. * Adding/removing empty lines where needed * DisableIntrospection should not skip __typename for usages of GraphQL union types * Add context to validation tracing Context is needed for tracing to access the current span, in order to add tags to it, or create child spans. As presently defined (without a context), this cannot be done: new spans could be created, but they would not be associated with the existing request trace. OpenTracingTracer implements the new interface (it never implemented the old one). Via this 'extension interface', the tracer configured (or the default tracer) will be used as the validation tracer if: * The tracer implements the (optional) interface; and * A validation tracer isn't provided using the deprecated option What this means is that the deprecated option is _preferred_ as an override. This allows users to migrate in a non-breaking, non-behaviour changing way, until such time as they intentionally remove the use of the deprecated option. For those who are currently using the default tracer, and not supplying a validation tracer, validation will be traced immediately with no change required to configuration options. * Add support for nullable types This allows to differentiate between an omitted value and a null value in an input struct. * Fixed duplicated __typename in response (fixes #369) * Create CHANGELOG.md * Update CHANGELOG.md * ignore JetBrains IDEA and vscode meta directories * expose packer.Unmarshaler interface as graphql.Unmarshaler - add tests for graphql.Time as reference implementation * move packer.Unmarshaler interface to decode.Unmarshaler, so the methods are actually visible * add types package Part of #434 and related to #116 this change adds a new package containing all types used by graphql-go in representing the GraphQL specification. The names used in this package should match the specification as closely as possible. In order to have cohesion, all internal packages that use GraphQL types have been changed to use this new package. This change is large but mostly mechanical. I recommend starting by reading through the `types` package to build familiarity. I'll call out places in the code where I made decisions and what the tradeoffs were. * add getter for the types.Schema field This additive function shouldn't break backward compatibility will allow those who want access to the types to get at an AST version of the `types.Schema` * unused fields * rename to match types * remove unused * use a string and not an Ident for a FieldDefinition's name This was an error. When this field was renamed from schema.Field (to avoid ambiguity) its name field changed to match query.Field (to Ident). This caused a cascade of useless changes that will be rolled back in the next commit * fix compile errors introduced by ab449f07e * merge conflict errors * add location fields to type definitions * Fix dir in readme * coerce float64 to int32 in NullInt and vice versa in NullFloat * errors.Errorf preserves original error similar to fmt.Error * removed test dependency on errors.Is * checkErrors ignores the raw error for purposes of determining if the test passed or failed * Update CHANGELOG.md * internal/exec: assign parent type name to __typename fields * Accepting value Json in parameter of request's body in custom Scalar (#467) Accept JSON value in resolver args * Add option for custom panic handler (#468) Add option for custom panic handler * Tests showing query variables are validated correctly (#470) * README nit -- Move '$' out of cut/paste buffer (#473) Move '$' out of cut/paste buffer * internal/exec/resolvable: include struct field name in errors (#477) * internal/exec/resolvable: include struct field name in errors We were only adding method name, which meant that it was taking an empty string if the resolver was a struct field. This was making the error messages hard to parse as the user can't know which field has the error. Added a check to use the correct variable. * improve test * ci: setup SemaphoreCI v2 (#479) Update Semaphore configuration * Support "Interfaces Implementing Interfaces" (#471) Interface implementing interfaces support https://spec.graphql.org/draft/#sec-Interfaces.Interfaces-Implementing-Interfaces * README.md: Fix build status badge I broke this accidentally when removing the legacy SemaphoreCI integration. * fix golangci lint errors in the codebase (#478) Added a base golangci-config to the codebase to get started. Some more changes are pending, and those checks are commented out in the config. * Improve Sempahore CI (#481) Improve Sempahore CI build * Make some more golang-ci improvements (#483) * graphql.Time unmarshal unix nano time (#486) * validation: fix bug in maxDepth fragment spread logic (#492) * Create codeql-analysis.yml * Add OpenTelemetry Support (#493) Add OpenTelemetry tracer implementation * Improve the Getting Started section * Update README.md * Improve the Getting Started section in the README * Create SECURITY.md * Fix the OTEL tracer package name (#495) * Fix parseObjectDef will terminate when object has bad syntax (#491) (#500) Thank you for your contribution * Fix remove checkNilCase test helper function (#504) * Add graphql.Time example (#508) * Apollo Federation Spec: Fetch service capabilities (#507) Add basic support for Apollo Federation Co-authored-by: Alam Co-authored-by: pavelnikolov * Ignore yarn.lock file * add support for repeatable directives (#502) add support for repeatable directives * Fix example/social code (#510) The `Friends` field had higher priority than the `FriendsResolver` method. This is the reason why the field was renamed to a value, that doesn't match the GraphQL resolver. * Fix lint error (#512) * Refactor trace package (#513) Remove dependency for graphql-go on OpenTracing and OpenTelemetry except where those tracers are explicitly configured for use. * Adding in primitive value validation. (#515) * Update README.md * Update README.md * Improve type assertion method argument validation (require zero) (#516) Improve type assertion method argument validation (require zero) It's tempting to include a context argument (or think it's allowed), but not discover that this will fail until a query is executed. Validating the resolver during schema parsing reduces the chance of inadvertant errors here. Signed-off-by: Evan Owen * Disallow repeat of non repeatable directives (#525) * Disallow repeat of non repeatable directives * Remove unnecessary scallar * Added changes lost after package update * merging old prs * adding gqlerrors support * adding dev message, error code support * Readded Export query name method functionality after package update * Fix: extension initialisation and updated error method to return extension details * updated QueryError Extensions to not emit if empty --------- Signed-off-by: Evan Owen Co-authored-by: Elijah Oyekunle Co-authored-by: Pavel Nikolov Co-authored-by: Pavel Nikolov Co-authored-by: Dorian Thiessen Co-authored-by: Ivan Co-authored-by: David Ackroyd Co-authored-by: pavemaksim Co-authored-by: Tony Ghita Co-authored-by: Zaydek Co-authored-by: Sylvain Cleymans Co-authored-by: will@newrelic.com Co-authored-by: Nicolas Maquet Co-authored-by: Sean Sorrell Co-authored-by: Ryan Slade Co-authored-by: Thorsten Ball Co-authored-by: obei Co-authored-by: Quinn Slack Co-authored-by: suntoucha Co-authored-by: Barry Dutton Co-authored-by: Sebastian Motavita Co-authored-by: Matthieu Vachon Co-authored-by: Epsirom Co-authored-by: David Ackroyd <23301187+dackroyd@users.noreply.github.com> Co-authored-by: Vincent Composieux Co-authored-by: Silvio Ginter Co-authored-by: Sam Ko Co-authored-by: jinleileiking Co-authored-by: Edward Ma Co-authored-by: Matt Ho Co-authored-by: Tony Ghita Co-authored-by: Gustavo Delfim Co-authored-by: John Starich Co-authored-by: Florian Suess Co-authored-by: wejafoo <79415032+wejafoo@users.noreply.github.com> Co-authored-by: Agniva De Sarker Co-authored-by: Steve Gray Co-authored-by: Connor Vanderhook <14183191+cnnrrss@users.noreply.github.com> Co-authored-by: roaris <61813626+roaris@users.noreply.github.com> Co-authored-by: Sulthan Alam <40392850+aeramu@users.noreply.github.com> Co-authored-by: Alam Co-authored-by: speezepearson Co-authored-by: Dallas Phillips Co-authored-by: Evan Owen Co-authored-by: Igor <9917165+ostrea@users.noreply.github.com> Co-authored-by: Amritansh Kumar Co-authored-by: kumaramritansh <105722986+kumaramritansh@users.noreply.github.com> --- .gitignore | 2 + .golangci.yml | 35 + .semaphore/semaphore.yml | 35 + CHANGELOG.md | 10 + README.md | 135 +- SECURITY.md | 17 + decode/decode.go | 13 + errors/errors.go | 64 +- errors/errors_test.go | 55 + errors/panic_handler.go | 18 + errors/panic_handler_test.go | 24 + example/apollo_federation/README.md | 35 + example/apollo_federation/gateway/.gitignore | 2 + example/apollo_federation/gateway/index.js | 20 + .../apollo_federation/gateway/package.json | 14 + .../apollo_federation/subgraph_one/server.go | 34 + .../apollo_federation/subgraph_two/server.go | 34 + example/customerrors/server/server.go | 65 + example/customerrors/starwars.go | 78 + example/scalar_map/server.go | 45 + example/scalar_map/types/map.go | 19 + example/scalar_time/server.go | 31 + example/social/README.md | 4 +- example/social/introspect.json | 65 + example/social/social.go | 79 +- example/starwars/introspect.json | 27 + go.mod | 7 +- go.sum | 33 +- gqltesting/testing.go | 25 +- graphql.go | 170 +- graphql_test.go | 1402 ++++++++++++++++- id.go | 4 +- internal/common/blockstring.go | 103 ++ internal/common/directive.go | 24 +- internal/common/lexer.go | 21 +- internal/common/lexer_test.go | 39 + internal/common/literals.go | 176 +-- internal/common/types.go | 61 +- internal/common/values.go | 61 +- internal/exec/exec.go | 132 +- internal/exec/packer/packer.go | 109 +- internal/exec/resolvable/meta.go | 42 +- internal/exec/resolvable/resolvable.go | 159 +- internal/exec/selected/selected.go | 99 +- internal/exec/subscribe.go | 40 +- internal/query/query.go | 146 +- internal/schema/meta.go | 22 +- internal/schema/schema.go | 593 +++---- internal/schema/schema_internal_test.go | 284 +++- internal/schema/schema_test.go | 877 ++++++++++- internal/validation/testdata/tests.json | 128 +- .../validation/validate_max_depth_test.go | 90 +- internal/validation/validation.go | 365 +++-- internal/validation/validation_test.go | 6 +- introspection/introspection.go | 80 +- introspection_test.go | 4 +- log/log.go | 4 +- nullable_types.go | 166 ++ nullable_types_test.go | 213 +++ relay/relay.go | 2 +- scripts/golangci_install.sh | 407 +++++ subscription_test.go | 274 ++++ subscriptions.go | 16 +- time.go | 17 +- time_test.go | 165 ++ trace/noop/trace.go | 24 + trace/noop/trace_test.go | 22 + trace/opentracing/trace.go | 79 + trace/opentracing/trace_test.go | 22 + trace/otel/trace.go | 91 ++ trace/otel/trace_test.go | 29 + trace/trace.go | 85 +- trace/trace_test.go | 42 + trace/tracer/tracer.go | 34 + trace/validation_trace.go | 17 +- types/argument.go | 44 + types/directive.go | 35 + types/doc.go | 9 + types/enum.go | 32 + types/extension.go | 13 + types/field.go | 39 + types/fragment.go | 51 + types/input.go | 47 + types/interface.go | 25 + types/object.go | 25 + types/query.go | 62 + types/scalar.go | 22 + types/schema.go | 43 + types/types.go | 63 + types/union.go | 24 + types/value.go | 151 ++ types/variable.go | 15 + 92 files changed, 7208 insertions(+), 1463 deletions(-) create mode 100644 .golangci.yml create mode 100644 .semaphore/semaphore.yml create mode 100644 CHANGELOG.md create mode 100644 SECURITY.md create mode 100644 decode/decode.go create mode 100644 errors/errors_test.go create mode 100644 errors/panic_handler.go create mode 100644 errors/panic_handler_test.go create mode 100644 example/apollo_federation/README.md create mode 100644 example/apollo_federation/gateway/.gitignore create mode 100644 example/apollo_federation/gateway/index.js create mode 100644 example/apollo_federation/gateway/package.json create mode 100644 example/apollo_federation/subgraph_one/server.go create mode 100644 example/apollo_federation/subgraph_two/server.go create mode 100644 example/customerrors/server/server.go create mode 100644 example/customerrors/starwars.go create mode 100644 example/scalar_map/server.go create mode 100644 example/scalar_map/types/map.go create mode 100644 example/scalar_time/server.go create mode 100644 internal/common/blockstring.go create mode 100644 nullable_types.go create mode 100644 nullable_types_test.go create mode 100755 scripts/golangci_install.sh create mode 100644 time_test.go create mode 100644 trace/noop/trace.go create mode 100644 trace/noop/trace_test.go create mode 100644 trace/opentracing/trace.go create mode 100644 trace/opentracing/trace_test.go create mode 100644 trace/otel/trace.go create mode 100644 trace/otel/trace_test.go create mode 100644 trace/trace_test.go create mode 100644 trace/tracer/tracer.go create mode 100644 types/argument.go create mode 100644 types/directive.go create mode 100644 types/doc.go create mode 100644 types/enum.go create mode 100644 types/extension.go create mode 100644 types/field.go create mode 100644 types/fragment.go create mode 100644 types/input.go create mode 100644 types/interface.go create mode 100644 types/object.go create mode 100644 types/query.go create mode 100644 types/scalar.go create mode 100644 types/schema.go create mode 100644 types/types.go create mode 100644 types/union.go create mode 100644 types/value.go create mode 100644 types/variable.go diff --git a/.gitignore b/.gitignore index 7b3bcd13b..2fa95abef 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +/.idea +/.vscode /internal/validation/testdata/graphql-js /internal/validation/testdata/node_modules /vendor diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..c6741d58a --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,35 @@ +run: + timeout: 5m + +linters-settings: + gofmt: + simplify: true + govet: + check-shadowing: true + enable-all: true + disable: + - fieldalignment + - deepequalerrors # remove later + +linters: + disable-all: true + enable: + - deadcode + - gofmt + - gosimple + - govet + - ineffassign + - exportloopref + - structcheck + - staticcheck + - unconvert + - unused + - varcheck + - misspell + - goimports + +issues: + exclude-rules: + - linters: + - unused + path: "graphql_test.go" \ No newline at end of file diff --git a/.semaphore/semaphore.yml b/.semaphore/semaphore.yml new file mode 100644 index 000000000..9ea18e45b --- /dev/null +++ b/.semaphore/semaphore.yml @@ -0,0 +1,35 @@ +version: v1.0 +name: Go +agent: + machine: + type: e1-standard-2 + os_image: ubuntu2004 +blocks: + - name: Style Check + task: + jobs: + - name: fmt + commands: + - sem-version go 1.17 + - checkout + - ./scripts/golangci_install.sh -b $(go env GOPATH)/bin v1.42.1 + - export PATH=$(go env GOPATH)/bin:$PATH + - golangci-lint run ./... + + - name: Test & Build + task: + prologue: + commands: + - sem-version go 1.17 + - export PATH=$(go env GOPATH)/bin:$PATH + - checkout + - go version + + jobs: + - name: Test + commands: + - go test ./... + + - name: Build + commands: + - go build -v . diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..4ec3b4282 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,10 @@ +CHANGELOG + +[v1.1.0](https://github.com/tokopedia/graphql-go/releases/tag/v1.1.0) Release v1.1.0 +* [FEATURE] Add types package #437 +* [FEATURE] Expose `packer.Unmarshaler` as `decode.Unmarshaler` to the public #450 +* [FEATURE] Add location fields to type definitions #454 +* [FEATURE] `errors.Errorf` preserves original error similar to `fmt.Errorf` #456 +* [BUGFIX] Fix duplicated __typename in response (fixes #369) #443 + +[v1.0.0](https://github.com/tokopedia/graphql-go/releases/tag/v1.0.0) Initial release diff --git a/README.md b/README.md index 33a30af13..0e0dc73f1 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# graphql-go [![Sourcegraph](https://sourcegraph.com/github.com/tokopedia/graphql-go/-/badge.svg)](https://sourcegraph.com/github.com/tokopedia/graphql-go?badge) [![Build Status](https://semaphoreci.com/api/v1/graph-gophers/graphql-go/branches/master/badge.svg)](https://semaphoreci.com/graph-gophers/graphql-go) [![GoDoc](https://godoc.org/github.com/tokopedia/graphql-go?status.svg)](https://godoc.org/github.com/tokopedia/graphql-go) +# graphql-go [![Sourcegraph](https://sourcegraph.com/github.com/tokopedia/graphql-go/-/badge.svg)](https://sourcegraph.com/github.com/tokopedia/graphql-go?badge) [![Build Status](https://tokopedia.semaphoreci.com/badges/graphql-go/branches/master.svg?style=shields)](https://tokopedia.semaphoreci.com/projects/graphql-go) [![GoDoc](https://godoc.org/github.com/tokopedia/graphql-go?status.svg)](https://godoc.org/github.com/tokopedia/graphql-go)

@@ -11,13 +11,13 @@ safe for production use. - minimal API - support for `context.Context` -- support for the `OpenTracing` standard +- support for the `OpenTelemetry` and `OpenTracing` standards - schema type-checking against resolvers - resolvers are matched to the schema based on method sets (can resolve a GraphQL schema with a Go interface or Go struct). - handles panics in resolvers - parallel execution of resolvers - subscriptions - - [sample WS transport](https://github.com/graph-gophers/graphql-transport-ws) + - [sample WS transport](https://github.com/tokopedia/graphql-transport-ws) ## Roadmap @@ -26,8 +26,9 @@ Feedback is welcome and appreciated. ## (Some) Documentation -### Basic Sample +### Getting started +In order to run a simple GraphQL server locally create a `main.go` file with the following content: ```go package main @@ -45,9 +46,6 @@ func (_ *query) Hello() string { return "Hello, world!" } func main() { s := ` - schema { - query: Query - } type Query { hello: String! } @@ -57,11 +55,12 @@ func main() { log.Fatal(http.ListenAndServe(":8080", nil)) } ``` - -To test: +Then run the file with `go run main.go`. To test: + ```sh -$ curl -XPOST -d '{"query": "{ hello }"}' localhost:8080/query +curl -XPOST -d '{"query": "{ hello }"}' localhost:8080/query ``` +For more realistic usecases check our [examples section](https://github.com/tokopedia/graphql-go/wiki/Examples). ### Resolvers @@ -103,10 +102,118 @@ func (r *helloWorldResolver) Hello(ctx context.Context) (string, error) { } ``` -### Community Examples +### Schema Options + +- `UseStringDescriptions()` enables the usage of double quoted and triple quoted. When this is not enabled, comments are parsed as descriptions instead. +- `UseFieldResolvers()` specifies whether to use struct field resolvers. +- `MaxDepth(n int)` specifies the maximum field nesting depth in a query. The default is 0 which disables max depth checking. +- `MaxParallelism(n int)` specifies the maximum number of resolvers per request allowed to run in parallel. The default is 10. +- `Tracer(tracer trace.Tracer)` is used to trace queries and fields. It defaults to `noop.Tracer`. +- `Logger(logger log.Logger)` is used to log panics during query execution. It defaults to `exec.DefaultLogger`. +- `PanicHandler(panicHandler errors.PanicHandler)` is used to transform panics into errors during query execution. It defaults to `errors.DefaultPanicHandler`. +- `DisableIntrospection()` disables introspection queries. + +### Custom Errors + +Errors returned by resolvers can include custom extensions by implementing the `ResolverError` interface: + +```go +type ResolverError interface { + error + Extensions() map[string]interface{} +} +``` + +Example of a simple custom error: + +```go +type droidNotFoundError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func (e droidNotFoundError) Error() string { + return fmt.Sprintf("error [%s]: %s", e.Code, e.Message) +} + +func (e droidNotFoundError) Extensions() map[string]interface{} { + return map[string]interface{}{ + "code": e.Code, + "message": e.Message, + } +} +``` + +Which could produce a GraphQL error such as: + +```go +{ + "errors": [ + { + "message": "error [NotFound]: This is not the droid you are looking for", + "path": [ + "droid" + ], + "extensions": { + "code": "NotFound", + "message": "This is not the droid you are looking for" + } + } + ], + "data": null +} +``` + +### Tracing + +By default the library uses `noop.Tracer`. If you want to change that you can use the OpenTelemetry or the OpenTracing implementations, respectively: + +```go +// OpenTelemetry tracer +package main + +import ( + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/example/starwars" + otelgraphql "github.com/tokopedia/graphql-go/trace/otel" + "github.com/tokopedia/graphql-go/trace/tracer" +) +// ... +_, err := graphql.ParseSchema(starwars.Schema, nil, graphql.Tracer(otelgraphql.DefaultTracer())) +// ... +``` +Alternatively you can pass an existing trace.Tracer instance: +```go +tr := otel.Tracer("example") +_, err = graphql.ParseSchema(starwars.Schema, nil, graphql.Tracer(&otelgraphql.Tracer{Tracer: tr})) +``` + + +```go +// OpenTracing tracer +package main + +import ( + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/example/starwars" + "github.com/tokopedia/graphql-go/trace/opentracing" + "github.com/tokopedia/graphql-go/trace/tracer" +) +// ... +_, err := graphql.ParseSchema(starwars.Schema, nil, graphql.Tracer(opentracing.Tracer{})) + +// ... +``` + +If you need to implement a custom tracer the library would accept any tracer which implements the interface below: +```go +type Tracer interface { + TraceQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, func([]*errors.QueryError)) + TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, func(*errors.QueryError)) + TraceValidation(context.Context) func([]*errors.QueryError) +} +``` -[tonyghita/graphql-go-example](https://github.com/tonyghita/graphql-go-example) - A more "productionized" version of the Star Wars API example given in this repository. -[deltaskelta/graphql-go-pets-example](https://github.com/deltaskelta/graphql-go-pets-example) - graphql-go resolving against a sqlite database +### [Examples](https://github.com/tokopedia/graphql-go/wiki/Examples) -[OscarYuen/go-graphql-starter](https://github.com/OscarYuen/go-graphql-starter) - a starter application integrated with dataloader, psql and basic authentication diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..f79511db1 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,17 @@ +# Security Policy + +## Supported Versions + +We always try to maintain the library secure and suggest our users to upgrade to the latest stable version. We realize that sometimes this is not possible. + +| Version | Supported | +| ------- | ------------------ | +| 1.x | :white_check_mark: | +| < 1.0 | :x: | + +## MaxDepth +If you are using the `graphql.MaxDepth` schema option, make sure that you upgrade to version v1.3.0 or higher due to a bug causing security vulnerability in earlier versions. + +## Reporting a Vulnerability + +If you find a security vulnerability with this library, please, DO NOT submit a pull request right away. Please, report the issue to @pavelnikolov and/or @tony in the Gophers Slack in a private message. diff --git a/decode/decode.go b/decode/decode.go new file mode 100644 index 000000000..56a9d5b53 --- /dev/null +++ b/decode/decode.go @@ -0,0 +1,13 @@ +package decode + +// Unmarshaler defines the api of Go types mapped to custom GraphQL scalar types +type Unmarshaler interface { + // ImplementsGraphQLType maps the implementing custom Go type + // to the GraphQL scalar type in the schema. + ImplementsGraphQLType(name string) bool + // UnmarshalGraphQL is the custom unmarshaler for the implementing type + // + // This function will be called whenever you use the + // custom GraphQL scalar type as an input + UnmarshalGraphQL(input interface{}) error +} diff --git a/errors/errors.go b/errors/errors.go index a554ddacb..9130c0142 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -9,20 +9,15 @@ type GraphQLError interface { } type QueryError struct { - Message string `json:"message"` - Locations []Location `json:"locations,omitempty"` - Path []interface{} `json:"path,omitempty"` - Rule string `json:"-"` - ResolverError error `json:"-"` - Extensions Extensions `json:"extensions"` + Err error `json:"-"` // Err holds underlying if available + Message string `json:"message"` + Locations []Location `json:"locations,omitempty"` + Path []interface{} `json:"path,omitempty"` + Rule string `json:"-"` + ResolverError error `json:"-"` + Extensions map[string]interface{} `json:"extensions"` } -type Extensions struct { - Code int `json:"code,omitempty"` - DeveloperMessage string `json:"developerMessage,omitempty"` - MoreInfo string `json:"moreInfo,omitempty"` - Timestamp string `json:"timestamp,omitempty"` -} type Location struct { Line int `json:"line"` Column int `json:"column"` @@ -33,8 +28,18 @@ func (a Location) Before(b Location) bool { } func Errorf(format string, a ...interface{}) *QueryError { + // similar to fmt.Errorf, Errorf will wrap the last argument if it is an instance of error + var err error + if n := len(a); n > 0 { + if v, ok := a[n-1].(error); ok { + err = v + } + } + return &QueryError{ - Message: fmt.Sprintf(format, a...), + Err: err, + Message: fmt.Sprintf(format, a...), + Extensions: make(map[string]interface{}), } } @@ -43,42 +48,51 @@ func (err *QueryError) Error() string { return "" } str := fmt.Sprintf("graphql: %s", err.Message) - if err.Extensions.Code != 0 { - str += fmt.Sprintf(" code: %d", err.Extensions.Code) + + if err.Extensions["Code"] != 0 && err.Extensions["Code"] != nil { + str += fmt.Sprintf(" code: %d", err.Extensions["Code"]) } - if err.Extensions.DeveloperMessage != "" { - str += fmt.Sprintf(" developerMessage: %s", err.Extensions.DeveloperMessage) + if err.Extensions["DeveloperMessage"] != "" && err.Extensions["Code"] != nil { + str += fmt.Sprintf(" developerMessage: %s", err.Extensions["DeveloperMessage"]) } - if err.Extensions.MoreInfo != "" { - str += fmt.Sprintf(" moreInfo: %s", err.Extensions.MoreInfo) + if err.Extensions["MoreInfo"] != "" && err.Extensions["Code"] != nil { + str += fmt.Sprintf(" moreInfo: %s", err.Extensions["MoreInfo"]) } - if err.Extensions.Timestamp != "" { - str += fmt.Sprintf(" timestamp: %s", err.Extensions.Timestamp) + if err.Extensions["Timestamp"] != "" && err.Extensions["Code"] != nil { + str += fmt.Sprintf(" timestamp: %s", err.Extensions["Timestamp"]) } + for _, loc := range err.Locations { str += fmt.Sprintf(" (line %d, column %d)", loc.Line, loc.Column) } return str } +func (err *QueryError) Unwrap() error { + if err == nil { + return nil + } + return err.Err +} + var _ error = &QueryError{} func (err *QueryError) AddErrCode(code int) *QueryError { - err.Extensions.Code = code + err.Extensions["Code"] = code return err } func (err *QueryError) AddDevMsg(msg string) *QueryError { - err.Extensions.DeveloperMessage = msg + err.Extensions["DeveloperMessage"] = msg return err } func (err *QueryError) AddMoreInfo(moreInfo string) *QueryError { - err.Extensions.MoreInfo = moreInfo + err.Extensions["MoreInfo"] = moreInfo return err } func (err *QueryError) AddErrTimestamp(errTime string) *QueryError { - err.Extensions.Timestamp = errTime + err.Extensions["Timestamp"] = errTime return err } diff --git a/errors/errors_test.go b/errors/errors_test.go new file mode 100644 index 000000000..85e6ad6a2 --- /dev/null +++ b/errors/errors_test.go @@ -0,0 +1,55 @@ +package errors + +import ( + "io" + "testing" +) + +// Is is simplified facsimile of the go 1.13 errors.Is to ensure QueryError is compatible +func Is(err, target error) bool { + for err != nil { + if target == err { + return true + } + + switch e := err.(type) { + case interface{ Unwrap() error }: + err = e.Unwrap() + default: + break + } + } + return false +} + +func TestErrorf(t *testing.T) { + cause := io.EOF + + t.Run("wrap error", func(t *testing.T) { + err := Errorf("boom: %v", cause) + if !Is(err, cause) { + t.Fatalf("expected errors.Is to return true") + } + }) + + t.Run("handles nil", func(t *testing.T) { + var err *QueryError + if Is(err, cause) { + t.Fatalf("expected errors.Is to return false") + } + }) + + t.Run("handle no arguments", func(t *testing.T) { + err := Errorf("boom") + if Is(err, cause) { + t.Fatalf("expected errors.Is to return false") + } + }) + + t.Run("handle non-error argument arguments", func(t *testing.T) { + err := Errorf("boom: %v", "shaka") + if Is(err, cause) { + t.Fatalf("expected errors.Is to return false") + } + }) +} diff --git a/errors/panic_handler.go b/errors/panic_handler.go new file mode 100644 index 000000000..5446c2a9c --- /dev/null +++ b/errors/panic_handler.go @@ -0,0 +1,18 @@ +package errors + +import ( + "context" +) + +// PanicHandler is the interface used to create custom panic errors that occur during query execution +type PanicHandler interface { + MakePanicError(ctx context.Context, value interface{}) *QueryError +} + +// DefaultPanicHandler is the default PanicHandler +type DefaultPanicHandler struct{} + +// MakePanicError creates a new QueryError from a panic that occurred during execution +func (h *DefaultPanicHandler) MakePanicError(ctx context.Context, value interface{}) *QueryError { + return Errorf("panic occurred: %v", value) +} diff --git a/errors/panic_handler_test.go b/errors/panic_handler_test.go new file mode 100644 index 000000000..d82085bc0 --- /dev/null +++ b/errors/panic_handler_test.go @@ -0,0 +1,24 @@ +package errors + +import ( + "context" + "testing" +) + +func TestDefaultPanicHandler(t *testing.T) { + handler := &DefaultPanicHandler{} + qErr := handler.MakePanicError(context.Background(), "foo") + if qErr == nil { + t.Fatal("Panic error must not be nil") + } + const ( + expectedMessage = "panic occurred: foo" + expectedError = "graphql: " + expectedMessage + ) + if qErr.Error() != expectedError { + t.Errorf("Unexpected panic error message: %q != %q", qErr.Error(), expectedError) + } + if qErr.Message != expectedMessage { + t.Errorf("Unexpected panic QueryError.Message: %q != %q", qErr.Message, expectedMessage) + } +} diff --git a/example/apollo_federation/README.md b/example/apollo_federation/README.md new file mode 100644 index 000000000..04bc15bd9 --- /dev/null +++ b/example/apollo_federation/README.md @@ -0,0 +1,35 @@ +# Apollo Federation + +A simple example of integration with apollo federation as subgraph. Tested with Go v1.18, Node.js v16.14.2 and yarn 1.22.18. + +To run this server + +`go run ./example/apollo_federation/subgraph_one/server.go` + +`go run ./example/apollo_federation/subgraph_two/server.go` + +`cd example/apollo_federation/gateway` + +`yarn start` + +and go to localhost:4000 to interact + +Execute the query: + +``` +query { + hello + hi +} +``` + +and you should see a result similar to this: + +```json +{ + "data": { + "hello": "Hello from subgraph one!", + "hi": "Hi from subgraph two!" + } +} +``` diff --git a/example/apollo_federation/gateway/.gitignore b/example/apollo_federation/gateway/.gitignore new file mode 100644 index 000000000..5add9449b --- /dev/null +++ b/example/apollo_federation/gateway/.gitignore @@ -0,0 +1,2 @@ +/node_modules +/yarn.lock diff --git a/example/apollo_federation/gateway/index.js b/example/apollo_federation/gateway/index.js new file mode 100644 index 000000000..f46b00a10 --- /dev/null +++ b/example/apollo_federation/gateway/index.js @@ -0,0 +1,20 @@ +const { ApolloServer } = require('apollo-server') +const { ApolloGateway, IntrospectAndCompose } = require('@apollo/gateway'); + +const gateway = new ApolloGateway({ + supergraphSdl: new IntrospectAndCompose({ + subgraphs: [ + { name: 'one', url: 'http://localhost:4001/query' }, + { name: 'two', url: 'http://localhost:4002/query' }, + ], + }), +}); + +const server = new ApolloServer({ + gateway, + subscriptions: false, +}); + +server.listen().then(({ url }) => { + console.log(`Server ready at ${url}`); +}); diff --git a/example/apollo_federation/gateway/package.json b/example/apollo_federation/gateway/package.json new file mode 100644 index 000000000..b2e5af494 --- /dev/null +++ b/example/apollo_federation/gateway/package.json @@ -0,0 +1,14 @@ +{ + "name": "apollo-federation-gateway", + "version": "1.0.0", + "description": "Graphql Federation", + "main": "index.js", + "scripts": { + "start": "node index.js" + }, + "dependencies": { + "@apollo/gateway": "^0.49.0", + "apollo-server": "^2.21.1", + "graphql": "^15.5.0" + } +} diff --git a/example/apollo_federation/subgraph_one/server.go b/example/apollo_federation/subgraph_one/server.go new file mode 100644 index 000000000..ce119d371 --- /dev/null +++ b/example/apollo_federation/subgraph_one/server.go @@ -0,0 +1,34 @@ +package main + +import ( + "log" + "net/http" + + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/relay" +) + +var schema = ` + schema { + query: Query + } + + type Query { + hello: String! + } +` + +type resolver struct{} + +func (r *resolver) Hello() string { + return "Hello from subgraph one!" +} + +func main() { + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers(), graphql.MaxParallelism(20)} + schema := graphql.MustParseSchema(schema, &resolver{}, opts...) + + http.Handle("/query", &relay.Handler{Schema: schema}) + + log.Fatal(http.ListenAndServe(":4001", nil)) +} diff --git a/example/apollo_federation/subgraph_two/server.go b/example/apollo_federation/subgraph_two/server.go new file mode 100644 index 000000000..2b641ccfb --- /dev/null +++ b/example/apollo_federation/subgraph_two/server.go @@ -0,0 +1,34 @@ +package main + +import ( + "log" + "net/http" + + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/relay" +) + +var schema = ` + schema { + query: Query + } + + type Query { + hi: String! + } +` + +type resolver struct{} + +func (r *resolver) Hi() string { + return "Hi from subgraph two!" +} + +func main() { + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers(), graphql.MaxParallelism(20)} + schema := graphql.MustParseSchema(schema, &resolver{}, opts...) + + http.Handle("/query", &relay.Handler{Schema: schema}) + + log.Fatal(http.ListenAndServe(":4002", nil)) +} diff --git a/example/customerrors/server/server.go b/example/customerrors/server/server.go new file mode 100644 index 000000000..ca2442cff --- /dev/null +++ b/example/customerrors/server/server.go @@ -0,0 +1,65 @@ +package main + +import ( + "log" + "net/http" + + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/example/customerrors" + "github.com/tokopedia/graphql-go/relay" +) + +var schema *graphql.Schema + +func init() { + schema = graphql.MustParseSchema(customerrors.Schema, &customerrors.Resolver{}) +} + +func main() { + http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(page) + })) + + http.Handle("/query", &relay.Handler{Schema: schema}) + + log.Fatal(http.ListenAndServe(":8080", nil)) +} + +var page = []byte(` + + + + + + + + + + + +
Loading...
+ + + +`) diff --git a/example/customerrors/starwars.go b/example/customerrors/starwars.go new file mode 100644 index 000000000..ee74bd369 --- /dev/null +++ b/example/customerrors/starwars.go @@ -0,0 +1,78 @@ +package customerrors + +import ( + "fmt" + + "github.com/tokopedia/graphql-go" +) + +var Schema = ` + schema { + query: Query + } + type Query { + droid(id: ID!): Droid! + } + # An autonomous mechanical character in the Star Wars universe + type Droid { + # The ID of the droid + id: ID! + # What others call this droid + name: String! + } +` + +type droid struct { + ID graphql.ID + Name string +} + +var droids = []*droid{ + {ID: "2000", Name: "C-3PO"}, + {ID: "2001", Name: "R2-D2"}, +} + +var droidData = make(map[graphql.ID]*droid) + +func init() { + for _, d := range droids { + droidData[d.ID] = d + } +} + +type Resolver struct{} + +func (r *Resolver) Droid(args struct{ ID graphql.ID }) (*droidResolver, error) { + if d := droidData[args.ID]; d != nil { + return &droidResolver{d: d}, nil + } + return nil, &droidNotFoundError{Code: "NotFound", Message: "This is not the droid you are looking for"} +} + +type droidResolver struct { + d *droid +} + +func (r *droidResolver) ID() graphql.ID { + return r.d.ID +} + +func (r *droidResolver) Name() string { + return r.d.Name +} + +type droidNotFoundError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +func (e droidNotFoundError) Error() string { + return fmt.Sprintf("error [%s]: %s", e.Code, e.Message) +} + +func (e droidNotFoundError) Extensions() map[string]interface{} { + return map[string]interface{}{ + "code": e.Code, + "message": e.Message, + } +} diff --git a/example/scalar_map/server.go b/example/scalar_map/server.go new file mode 100644 index 000000000..e55e16cdd --- /dev/null +++ b/example/scalar_map/server.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + "log" + "net/http" + + graphql "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/example/scalar_map/types" + "github.com/tokopedia/graphql-go/relay" +) + +type Args struct { + Name string + Data types.Map +} + +type mutation struct{} + +func (_ *mutation) Hello(args Args) string { + + fmt.Println(args) + + return "Args accept!" +} + +func main() { + s := ` + scalar Map + + type Query {} + + type Mutation { + hello( + name: String! + data: Map! + ): String! + } + ` + schema := graphql.MustParseSchema(s, &mutation{}) + http.Handle("/query", &relay.Handler{Schema: schema}) + + log.Println("Listen in port :8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/example/scalar_map/types/map.go b/example/scalar_map/types/map.go new file mode 100644 index 000000000..22dcf4f70 --- /dev/null +++ b/example/scalar_map/types/map.go @@ -0,0 +1,19 @@ +package types + +import "fmt" + +type Map map[string]interface{} + +func (Map) ImplementsGraphQLType(name string) bool { + return name == "Map" +} + +func (j *Map) UnmarshalGraphQL(input interface{}) error { + json, ok := input.(map[string]interface{}) + if !ok { + return fmt.Errorf("wrong type") + } + + *j = json + return nil +} diff --git a/example/scalar_time/server.go b/example/scalar_time/server.go new file mode 100644 index 000000000..4715df8cc --- /dev/null +++ b/example/scalar_time/server.go @@ -0,0 +1,31 @@ +package main + +import ( + "log" + "net/http" + "time" + + graphql "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/relay" +) + +type query struct{} + +func (_ *query) CurrentTime() graphql.Time { + return graphql.Time{Time: time.Now()} +} + +func main() { + s := ` + scalar Time + + type Query { + currentTime: Time! + } + ` + schema := graphql.MustParseSchema(s, &query{}) + http.Handle("/query", &relay.Handler{Schema: schema}) + + log.Println("Listen in port :8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/example/social/README.md b/example/social/README.md index 5ab316fd7..d2cf8dbdf 100644 --- a/example/social/README.md +++ b/example/social/README.md @@ -4,6 +4,6 @@ A simple example of how to use struct fields as resolvers instead of methods. To run this server -`go run ./example/field-resolvers/server/server.go` +`go run ./example/social/server/server.go` -and go to localhost:9011 to interact \ No newline at end of file +and go to localhost:9011 to interact diff --git a/example/social/introspect.json b/example/social/introspect.json index 88c4c00bc..f344b600a 100644 --- a/example/social/introspect.json +++ b/example/social/introspect.json @@ -214,6 +214,39 @@ "name": "Pagination", "possibleTypes": null }, + { + "description": null, + "enumValues": null, + "fields": [ + { + "args": [], + "deprecationReason": null, + "description": null, + "isDeprecated": false, + "name": "name", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + } + ], + "inputFields": null, + "interfaces": null, + "kind": "INTERFACE", + "name": "Person", + "possibleTypes": [ + { + "kind": "OBJECT", + "name": "User", + "ofType": null + } + ] + }, { "description": null, "enumValues": null, @@ -545,12 +578,44 @@ "kind": "INTERFACE", "name": "Admin", "ofType": null + }, + { + "kind": "INTERFACE", + "name": "Person", + "ofType": null } ], "kind": "OBJECT", "name": "User", "possibleTypes": null }, + { + "description": null, + "enumValues": null, + "fields": [ + { + "args": [], + "deprecationReason": null, + "description": null, + "isDeprecated": false, + "name": "sdl", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + } + ], + "inputFields": null, + "interfaces": [], + "kind": "OBJECT", + "name": "_Service", + "possibleTypes": null + }, { "description": "A Directive provides a way to describe alternate runtime execution and type validation behavior in a GraphQL document.\n\nIn some cases, you need to provide options to alter GraphQL's execution behavior\nin ways field arguments will not suffice, such as conditionally including or\nskipping a field. Directives provide this by describing additional information\nto the executor.", "enumValues": null, diff --git a/example/social/social.go b/example/social/social.go index 4182dff5e..5758fa8df 100644 --- a/example/social/social.go +++ b/example/social/social.go @@ -27,9 +27,13 @@ const Schema = ` role: Role! } + interface Person { + name: String! + } + scalar Time - type User implements Admin { + type User implements Admin & Person { id: ID! name: String! email: String! @@ -64,6 +68,15 @@ type admin interface { Role() string } +type adminResolver struct { + admin +} + +func (r *adminResolver) ToUser() (*user, bool) { + n, ok := r.admin.(user) + return &n, ok +} + type searchResult struct { result interface{} } @@ -73,15 +86,19 @@ func (r *searchResult) ToUser() (*user, bool) { return res, ok } +type contact struct { + Email string + Phone string +} + type user struct { - IDField string - NameField string - RoleField string - Email string - Phone string - Address *[]string - Friends *[]*user - CreatedAt graphql.Time + IDField string + NameField string + RoleField string + Address *[]string + FriendsField *[]*user + CreatedAt graphql.Time + contact } func (u user) ID() graphql.ID { @@ -96,9 +113,9 @@ func (u user) Role() string { return u.RoleField } -func (u user) FriendsResolver(args struct{ Page *page }) (*[]*user, error) { +func (u user) Friends(args struct{ Page *page }) (*[]*user, error) { var from int - numFriends := len(*u.Friends) + numFriends := len(*u.FriendsField) to := numFriends if args.Page != nil { @@ -116,7 +133,7 @@ func (u user) FriendsResolver(args struct{ Page *page }) (*[]*user, error) { } } - friends := (*u.Friends)[from:to] + friends := (*u.FriendsField)[from:to] return &friends, nil } @@ -126,47 +143,55 @@ var users = []*user{ IDField: "0x01", NameField: "Albus Dumbledore", RoleField: "ADMIN", - Email: "Albus@hogwarts.com", - Phone: "000-000-0000", Address: &[]string{"Office @ Hogwarts", "where Horcruxes are"}, CreatedAt: graphql.Time{Time: time.Now()}, + contact: contact{ + Email: "Albus@hogwarts.com", + Phone: "000-000-0000", + }, }, { IDField: "0x02", NameField: "Harry Potter", RoleField: "USER", - Email: "harry@hogwarts.com", - Phone: "000-000-0001", Address: &[]string{"123 dorm room @ Hogwarts", "456 random place"}, CreatedAt: graphql.Time{Time: time.Now()}, + contact: contact{ + Email: "harry@hogwarts.com", + Phone: "000-000-0001", + }, }, { IDField: "0x03", NameField: "Hermione Granger", RoleField: "USER", - Email: "hermione@hogwarts.com", - Phone: "000-000-0011", Address: &[]string{"233 dorm room @ Hogwarts", "786 @ random place"}, CreatedAt: graphql.Time{Time: time.Now()}, + contact: contact{ + Email: "hermione@hogwarts.com", + Phone: "000-000-0011", + }, }, { IDField: "0x04", NameField: "Ronald Weasley", RoleField: "USER", - Email: "ronald@hogwarts.com", - Phone: "000-000-0111", Address: &[]string{"411 dorm room @ Hogwarts", "981 @ random place"}, CreatedAt: graphql.Time{Time: time.Now()}, + contact: contact{ + Email: "ronald@hogwarts.com", + Phone: "000-000-0111", + }, }, } var usersMap = make(map[string]*user) func init() { - users[0].Friends = &[]*user{users[1]} - users[1].Friends = &[]*user{users[0], users[2], users[3]} - users[2].Friends = &[]*user{users[1], users[3]} - users[3].Friends = &[]*user{users[1], users[2]} + users[0].FriendsField = &[]*user{users[1]} + users[1].FriendsField = &[]*user{users[0], users[2], users[3]} + users[2].FriendsField = &[]*user{users[1], users[3]} + users[3].FriendsField = &[]*user{users[1], users[2]} for _, usr := range users { usersMap[usr.IDField] = usr } @@ -177,14 +202,14 @@ type Resolver struct{} func (r *Resolver) Admin(ctx context.Context, args struct { ID string Role string -}) (admin, error) { +}) (*adminResolver, error) { if usr, ok := usersMap[args.ID]; ok { if usr.RoleField == args.Role { - return *usr, nil + return &adminResolver{*usr}, nil } } err := fmt.Errorf("user with id=%s and role=%s does not exist", args.ID, args.Role) - return user{}, err + return nil, err } func (r *Resolver) User(ctx context.Context, args struct{ Id string }) (user, error) { diff --git a/example/starwars/introspect.json b/example/starwars/introspect.json index 2b955ee9c..89ed5a3f2 100644 --- a/example/starwars/introspect.json +++ b/example/starwars/introspect.json @@ -1231,6 +1231,33 @@ "name": "String", "possibleTypes": null }, + { + "description": null, + "enumValues": null, + "fields": [ + { + "args": [], + "deprecationReason": null, + "description": null, + "isDeprecated": false, + "name": "sdl", + "type": { + "kind": "NON_NULL", + "name": null, + "ofType": { + "kind": "SCALAR", + "name": "String", + "ofType": null + } + } + } + ], + "inputFields": null, + "interfaces": [], + "kind": "OBJECT", + "name": "_Service", + "possibleTypes": null + }, { "description": "A Directive provides a way to describe alternate runtime execution and type validation behavior in a GraphQL document.\n\nIn some cases, you need to provide options to alter GraphQL's execution behavior\nin ways field arguments will not suffice, such as conditionally including or\nskipping a field. Directives provide this by describing additional information\nto the executor.", "enumValues": null, diff --git a/go.mod b/go.mod index 1c77f7a91..d50d8319a 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,9 @@ module github.com/tokopedia/graphql-go +go 1.13 + require ( - github.com/graph-gophers/graphql-go v0.0.0-20190724201507-010347b5f9e6 // indirect - github.com/opentracing/opentracing-go v1.1.0 + github.com/opentracing/opentracing-go v1.2.0 + go.opentelemetry.io/otel v1.6.3 + go.opentelemetry.io/otel/trace v1.6.3 ) diff --git a/go.sum b/go.sum index 049bbf669..b987a5d21 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,27 @@ -github.com/graph-gophers/graphql-go v0.0.0-20190724201507-010347b5f9e6 h1:9WiNlI9Cds5S5YITwRpRs8edNaq0nxTEymhDW20A1QE= -github.com/graph-gophers/graphql-go v0.0.0-20190724201507-010347b5f9e6/go.mod h1:Au3iQ8DvDis8hZ4q2OzRcaKYlAsPt+fYvib5q4nIqu4= -github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= -github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/tokopedia/graphql-go v1.1.1 h1:Rreg5sSgFQklU8w+dXQNLY7Hh+NDk9GRTOXMhDR8T5w= -github.com/tokopedia/graphql-go v1.1.1/go.mod h1:2iiWM8Ad38HEvlsZyKExjEeC9y0tDmIwQQeKnReZ4Fw= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/otel v1.6.3 h1:FLOfo8f9JzFVFVyU+MSRJc2HdEAXQgm7pIv2uFKRSZE= +go.opentelemetry.io/otel v1.6.3/go.mod h1:7BgNga5fNlF/iZjG06hM3yofffp0ofKCDwSXx1GC4dI= +go.opentelemetry.io/otel/trace v1.6.3 h1:IqN4L+5b0mPNjdXIiZ90Ni4Bl5BRkDQywePLWemd9bc= +go.opentelemetry.io/otel/trace v1.6.3/go.mod h1:GNJQusJlUgZl9/TQBPKU/Y/ty+0iVB5fjhKeJGZPGFs= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/gqltesting/testing.go b/gqltesting/testing.go index 6dbc5bd23..0755a9071 100644 --- a/gqltesting/testing.go +++ b/gqltesting/testing.go @@ -23,6 +23,7 @@ type Test struct { Variables map[string]interface{} ExpectedResult string ExpectedErrors []*errors.QueryError + RawResponse bool } // RunTests runs the given GraphQL test cases as subtests. @@ -57,10 +58,22 @@ func RunTest(t *testing.T, test *Test) { } // Verify JSON to avoid red herring errors. - got, err := formatJSON(result.Data) - if err != nil { - t.Fatalf("got: invalid JSON: %s", err) + var got []byte + + if test.RawResponse { + value, err := result.Data.MarshalJSON() + if err != nil { + t.Fatalf("got: unable to marshal JSON response: %s", err) + } + got = value + } else { + value, err := formatJSON(result.Data) + if err != nil { + t.Fatalf("got: invalid JSON: %s", err) + } + got = value } + want, err := formatJSON([]byte(test.ExpectedResult)) if err != nil { t.Fatalf("want: invalid JSON: %s", err) @@ -89,6 +102,12 @@ func checkErrors(t *testing.T, want, got []*errors.QueryError) { sortErrors(want) sortErrors(got) + // Clear the underlying error before the DeepEqual check. It's too + // much to ask the tester to include the raw failing error. + for _, err := range got { + err.Err = nil + } + if !reflect.DeepEqual(got, want) { t.Fatalf("unexpected error: got %+v, want %+v", got, want) } diff --git a/graphql.go b/graphql.go index 42ece77b4..9a1643a41 100644 --- a/graphql.go +++ b/graphql.go @@ -4,7 +4,7 @@ import ( "context" "encoding/json" "fmt" - "reflect" + "time" "github.com/tokopedia/graphql-go/errors" "github.com/tokopedia/graphql-go/internal/common" @@ -16,7 +16,9 @@ import ( "github.com/tokopedia/graphql-go/internal/validation" "github.com/tokopedia/graphql-go/introspection" "github.com/tokopedia/graphql-go/log" - "github.com/tokopedia/graphql-go/trace" + "github.com/tokopedia/graphql-go/trace/noop" + "github.com/tokopedia/graphql-go/trace/tracer" + "github.com/tokopedia/graphql-go/types" ) // ParseSchema parses a GraphQL schema and attaches the given root resolver. It returns an error if @@ -24,17 +26,28 @@ import ( // resolver, then the schema can not be executed, but it may be inspected (e.g. with ToJSON). func ParseSchema(schemaString string, resolver interface{}, opts ...SchemaOpt) (*Schema, error) { s := &Schema{ - schema: schema.New(), - maxParallelism: 10, - tracer: trace.OpenTracingTracer{}, - validationTracer: trace.NoopValidationTracer{}, - logger: &log.DefaultLogger{}, + schema: schema.New(), + maxParallelism: 10, + tracer: noop.Tracer{}, + logger: &log.DefaultLogger{}, + panicHandler: &errors.DefaultPanicHandler{}, } for _, opt := range opts { opt(s) } - if err := s.schema.Parse(schemaString, s.useStringDescriptions); err != nil { + if s.validationTracer == nil { + if t, ok := s.tracer.(tracer.ValidationTracer); ok { + s.validationTracer = t + } else { + s.validationTracer = &validationBridgingTracer{tracer: tracer.LegacyNoopValidationTracer{}} //nolint:staticcheck + } + } + + if err := schema.Parse(s.schema, schemaString, s.useStringDescriptions); err != nil { + return nil, err + } + if err := s.validateSchema(); err != nil { return nil, err } @@ -58,16 +71,22 @@ func MustParseSchema(schemaString string, resolver interface{}, opts ...SchemaOp // Schema represents a GraphQL schema with an optional resolver. type Schema struct { - schema *schema.Schema + schema *types.Schema res *resolvable.Schema - maxDepth int - maxParallelism int - tracer trace.Tracer - validationTracer trace.ValidationTracer - logger log.Logger - useStringDescriptions bool - disableIntrospection bool + maxDepth int + maxParallelism int + tracer tracer.Tracer + validationTracer tracer.ValidationTracer + logger log.Logger + panicHandler errors.PanicHandler + useStringDescriptions bool + disableIntrospection bool + subscribeResolverTimeout time.Duration +} + +func (s *Schema) ASTSchema() *types.Schema { + return s.schema } // SchemaOpt is an option to pass to ParseSchema or MustParseSchema. @@ -104,17 +123,18 @@ func MaxParallelism(n int) SchemaOpt { } } -// Tracer is used to trace queries and fields. It defaults to trace.OpenTracingTracer. -func Tracer(tracer trace.Tracer) SchemaOpt { +// Tracer is used to trace queries and fields. It defaults to tracer.Noop. +func Tracer(t tracer.Tracer) SchemaOpt { return func(s *Schema) { - s.tracer = tracer + s.tracer = t } } -// ValidationTracer is used to trace validation errors. It defaults to trace.NoopValidationTracer. -func ValidationTracer(tracer trace.ValidationTracer) SchemaOpt { +// ValidationTracer is used to trace validation errors. It defaults to tracer.LegacyNoopValidationTracer. +// Deprecated: context is needed to support tracing correctly. Use a Tracer which implements tracer.ValidationTracer. +func ValidationTracer(tracer tracer.LegacyValidationTracer) SchemaOpt { //nolint:staticcheck return func(s *Schema) { - s.validationTracer = tracer + s.validationTracer = &validationBridgingTracer{tracer: tracer} } } @@ -125,6 +145,14 @@ func Logger(logger log.Logger) SchemaOpt { } } +// PanicHandler is used to customize the panic errors during query execution. +// It defaults to errors.DefaultPanicHandler. +func PanicHandler(panicHandler errors.PanicHandler) SchemaOpt { + return func(s *Schema) { + s.panicHandler = panicHandler + } +} + // DisableIntrospection disables introspection queries. func DisableIntrospection() SchemaOpt { return func(s *Schema) { @@ -132,6 +160,15 @@ func DisableIntrospection() SchemaOpt { } } +// SubscribeResolverTimeout is an option to control the amount of time +// we allow for a single subscribe message resolver to complete it's job +// before it times out and returns an error to the subscriber. +func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt { + return func(s *Schema) { + s.subscribeResolverTimeout = timeout + } +} + // Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or // it may be further processed to a custom response type, for example to include custom error data. // Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107 @@ -141,21 +178,28 @@ type Response struct { Extensions map[string]interface{} `json:"extensions,omitempty"` } -// Validate validates the given query with the schema. -func (s *Schema) Validate(queryString string, variables map[string]interface{}) ([]string, bool, []*errors.QueryError) { - var queries []string +//Validate validates the given query with the schema. +func (s *Schema) Validate(queryString string) ([]string, bool, []*errors.QueryError) { + return s.ValidateWithVariables(queryString, nil) +} + +// ValidateWithVariables validates the given query with the schema and the input variables. +func (s *Schema) ValidateWithVariables(queryString string, variables map[string]interface{}) ([]string, bool, []*errors.QueryError) { + var queries []string doc, qErr := query.Parse(queryString) if qErr != nil { return queries, true, []*errors.QueryError{qErr} } - for _, op := range doc.Operations{ - for _, sel := range op.Selections{ - query, ok := sel.(*query.Field) + + for _, op := range doc.Operations { + for _, sel := range op.Selections { + query, ok := sel.(*types.Field) if ok { queries = append(queries, query.Name.Name) } } } + return queries, false, validation.Validate(s.schema, doc, variables, s.maxDepth) } @@ -163,7 +207,7 @@ func (s *Schema) Validate(queryString string, variables map[string]interface{}) // without a resolver. If the context get cancelled, no further resolvers will be called and a // the context error will be returned as soon as possible (not immediately). func (s *Schema) Exec(ctx context.Context, queryString string, operationName string, variables map[string]interface{}) *Response { - if s.res.Resolver == (reflect.Value{}) { + if !s.res.Resolver.IsValid() { panic("schema created without resolver, can not exec") } return s.exec(ctx, queryString, operationName, variables, s.res) @@ -177,11 +221,10 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str return &Response{Errors: []*errors.QueryError{qErr}} } - validationFinish := s.validationTracer.TraceValidation() + validationFinish := s.validationTracer.TraceValidation(ctx) errs := validation.Validate(s.schema, doc, variables, s.maxDepth) validationFinish(errs) if len(errs) != 0 { - for _, err := range errs { if err.Rule != "VariablesOfCorrectType" && err.Rule != "" { anyOtherValidationError = true @@ -197,9 +240,20 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str return &Response{Errors: []*errors.QueryError{errors.Errorf("%s", err)}} } + // If the optional "operationName" POST parameter is not provided then + // use the query's operation name for improved tracing. + if operationName == "" { + operationName = op.Name.Name + } + // Subscriptions are not valid in Exec. Use schema.Subscribe() instead. if op.Type == query.Subscription { - return &Response{Errors: []*errors.QueryError{&errors.QueryError{Message: "graphql-ws protocol header is missing"}}} + return &Response{Errors: []*errors.QueryError{{Message: "graphql-ws protocol header is missing"}}} + } + if op.Type == query.Mutation { + if _, ok := s.schema.EntryPoints["mutation"]; !ok { + return &Response{Errors: []*errors.QueryError{{Message: "no mutations are offered by the schema"}}} + } } // Fill in variables with the defaults from the operation @@ -208,7 +262,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str } for _, v := range op.Vars { if _, ok := variables[v.Name.Name]; !ok && v.Default != nil { - variables[v.Name.Name] = v.Default.Value(nil) + variables[v.Name.Name] = v.Default.Deserialize(nil) } } @@ -219,9 +273,10 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str Schema: s.schema, DisableIntrospection: s.disableIntrospection, }, - Limiter: make(chan struct{}, s.maxParallelism), - Tracer: s.tracer, - Logger: s.logger, + Limiter: make(chan struct{}, s.maxParallelism), + Tracer: s.tracer, + Logger: s.logger, + PanicHandler: s.panicHandler, } varTypes := make(map[string]*introspection.Type) for _, v := range op.Vars { @@ -241,7 +296,48 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str } } -func getOperation(document *query.Document, operationName string) (*query.Operation, error) { +func (s *Schema) validateSchema() error { + // https://graphql.github.io/graphql-spec/June2018/#sec-Root-Operation-Types + // > The query root operation type must be provided and must be an Object type. + if err := validateRootOp(s.schema, "query", true); err != nil { + return err + } + // > The mutation root operation type is optional; if it is not provided, the service does not support mutations. + // > If it is provided, it must be an Object type. + if err := validateRootOp(s.schema, "mutation", false); err != nil { + return err + } + // > Similarly, the subscription root operation type is also optional; if it is not provided, the service does not + // > support subscriptions. If it is provided, it must be an Object type. + if err := validateRootOp(s.schema, "subscription", false); err != nil { + return err + } + return nil +} + +type validationBridgingTracer struct { + tracer tracer.LegacyValidationTracer //nolint:staticcheck +} + +func (t *validationBridgingTracer) TraceValidation(context.Context) func([]*errors.QueryError) { + return t.tracer.TraceValidation() +} + +func validateRootOp(s *types.Schema, name string, mandatory bool) error { + t, ok := s.EntryPoints[name] + if !ok { + if mandatory { + return fmt.Errorf("root operation %q must be defined", name) + } + return nil + } + if t.Kind() != "OBJECT" { + return fmt.Errorf("root operation %q must be an OBJECT", name) + } + return nil +} + +func getOperation(document *types.ExecutableDefinition, operationName string) (*types.OperationDefinition, error) { if len(document.Operations) == 0 { return nil, fmt.Errorf("no operations in query document") } diff --git a/graphql_test.go b/graphql_test.go index 60f34c51f..f9671c4fa 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "testing" "time" @@ -11,6 +12,8 @@ import ( gqlerrors "github.com/tokopedia/graphql-go/errors" "github.com/tokopedia/graphql-go/example/starwars" "github.com/tokopedia/graphql-go/gqltesting" + "github.com/tokopedia/graphql-go/introspection" + "github.com/tokopedia/graphql-go/trace/tracer" ) type helloWorldResolver1 struct{} @@ -313,6 +316,236 @@ func TestHelloSnakeArguments(t *testing.T) { }) } +func TestRootOperations_invalidSchema(t *testing.T) { + type args struct { + Schema string + } + type want struct { + Error string + } + testTable := map[string]struct { + Args args + Want want + }{ + "Empty schema": { + Want: want{Error: `root operation "query" must be defined`}, + }, + "Query declared by schema, but type not present": { + Args: args{ + Schema: ` + schema { + query: Query + } + `, + }, + Want: want{Error: `graphql: type "Query" not found`}, + }, + "Query as incorrect type": { + Args: args{ + Schema: ` + schema { + query: String + } + `, + }, + Want: want{Error: `root operation "query" must be an OBJECT`}, + }, + "Query with custom name, schema omitted": { + Args: args{ + Schema: ` + type QueryType { + hello: String! + } + `, + }, + Want: want{Error: `root operation "query" must be defined`}, + }, + "Mutation as incorrect type": { + Args: args{ + Schema: ` + schema { + query: Query + mutation: String + } + type Query { + thing: String + } + `, + }, + Want: want{Error: `root operation "mutation" must be an OBJECT`}, + }, + "Mutation declared by schema, but type not present": { + Args: args{ + Schema: ` + schema { + query: Query + mutation: Mutation + } + type Query { + hello: String! + } + `, + }, + Want: want{Error: `graphql: type "Mutation" not found`}, + }, + } + + for name, tt := range testTable { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + + _, err := graphql.ParseSchema(tt.Args.Schema, nil) + if err == nil || err.Error() != tt.Want.Error { + t.Logf("got: %v", err) + t.Logf("want: %s", tt.Want.Error) + t.Fail() + } + }) + } +} + +func TestRootOperations_validSchema(t *testing.T) { + type resolver struct { + helloSaidResolver + helloWorldResolver1 + theNumberResolver + } + gqltesting.RunTests(t, []*gqltesting.Test{ + { + // Query only, default name with `schema` omitted + Schema: graphql.MustParseSchema(` + type Query { + hello: String! + } + `, &resolver{}), + Query: `{ hello }`, + ExpectedResult: `{"hello": "Hello world!"}`, + }, + { + // Query only, default name with `schema` present + Schema: graphql.MustParseSchema(` + schema { + query: Query + } + type Query { + hello: String! + } + `, &resolver{}), + Query: `{ hello }`, + ExpectedResult: `{"hello": "Hello world!"}`, + }, + { + // Query only, custom name + Schema: graphql.MustParseSchema(` + schema { + query: QueryType + } + type QueryType { + hello: String! + } + `, &resolver{}), + Query: `{ hello }`, + ExpectedResult: `{"hello": "Hello world!"}`, + }, + { + // Query+Mutation, default names with `schema` omitted + Schema: graphql.MustParseSchema(` + type Query { + hello: String! + } + type Mutation { + changeTheNumber(newNumber: Int!): ChangedNumber! + } + type ChangedNumber { + theNumber: Int! + } + `, &resolver{}), + Query: ` + mutation { + changeTheNumber(newNumber: 1) { + theNumber + } + } + `, + ExpectedResult: `{"changeTheNumber": {"theNumber": 1}}`, + }, + { + // Query+Mutation, custom names + Schema: graphql.MustParseSchema(` + schema { + query: QueryType + mutation: MutationType + } + type QueryType { + hello: String! + } + type MutationType { + changeTheNumber(newNumber: Int!): ChangedNumber! + } + type ChangedNumber { + theNumber: Int! + } + `, &resolver{}), + Query: ` + mutation { + changeTheNumber(newNumber: 1) { + theNumber + } + } + `, + ExpectedResult: `{"changeTheNumber": {"theNumber": 1}}`, + }, + { + // Mutation with custom name, schema omitted + Schema: graphql.MustParseSchema(` + type Query { + hello: String! + } + type MutationType { + changeTheNumber(newNumber: Int!): ChangedNumber! + } + type ChangedNumber { + theNumber: Int! + } + `, &resolver{}), + Query: ` + mutation { + changeTheNumber(newNumber: 1) { + theNumber + } + } + `, + ExpectedErrors: []*gqlerrors.QueryError{{Message: "no mutations are offered by the schema"}}, + }, + { + // Explicit schema without mutation field + Schema: graphql.MustParseSchema(` + schema { + query: Query + } + type Query { + hello: String! + } + type Mutation { + changeTheNumber(newNumber: Int!): ChangedNumber! + } + type ChangedNumber { + theNumber: Int! + } + `, &resolver{}), + Query: ` + mutation { + changeTheNumber(newNumber: 1) { + theNumber + } + } + `, + ExpectedErrors: []*gqlerrors.QueryError{{Message: "no mutations are offered by the schema"}}, + }, + }) +} + func TestBasic(t *testing.T) { gqltesting.RunTests(t, []*gqltesting.Test{ { @@ -351,6 +584,88 @@ func TestBasic(t *testing.T) { }) } +type testEmbeddedStructResolver struct{} + +func (_ *testEmbeddedStructResolver) Course() courseResolver { + return courseResolver{ + CourseMeta: CourseMeta{ + Name: "Biology", + Timestamps: Timestamps{CreatedAt: "yesterday", UpdatedAt: "today"}, + }, + Instructor: Instructor{Name: "Socrates"}, + } +} + +type courseResolver struct { + CourseMeta + Instructor Instructor +} + +type CourseMeta struct { + Name string + Timestamps +} + +type Instructor struct { + Name string +} + +type Timestamps struct { + CreatedAt string + UpdatedAt string +} + +func TestEmbeddedStruct(t *testing.T) { + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + schema { + query: Query + } + + type Query { + course: Course! + } + + type Course { + name: String! + createdAt: String! + updatedAt: String! + instructor: Instructor! + } + + type Instructor { + name: String! + } + `, &testEmbeddedStructResolver{}, graphql.UseFieldResolvers()), + Query: ` + { + course{ + name + createdAt + updatedAt + instructor { + name + } + } + } + `, + ExpectedResult: ` + { + "course": { + "name": "Biology", + "createdAt": "yesterday", + "updatedAt": "today", + "instructor": { + "name":"Socrates" + } + } + } + `, + }, + }) +} + type testNilInterfaceResolver struct{} func (r *testNilInterfaceResolver) A() interface{ Z() int32 } { @@ -400,7 +715,7 @@ func TestNilInterface(t *testing.T) { } `, ExpectedErrors: []*gqlerrors.QueryError{ - &gqlerrors.QueryError{ + { Message: "x", Path: []interface{}{"b"}, ResolverError: errors.New("x"), @@ -438,7 +753,7 @@ func TestErrorPropagationInLists(t *testing.T) { null `, ExpectedErrors: []*gqlerrors.QueryError{ - &gqlerrors.QueryError{ + { Message: droidNotFoundError.Error(), Path: []interface{}{"findDroids", 1, "name"}, ResolverError: droidNotFoundError, @@ -480,7 +795,7 @@ func TestErrorPropagationInLists(t *testing.T) { } `, ExpectedErrors: []*gqlerrors.QueryError{ - &gqlerrors.QueryError{ + { Message: droidNotFoundError.Error(), Path: []interface{}{"findDroids", 1, "name"}, ResolverError: droidNotFoundError, @@ -514,7 +829,7 @@ func TestErrorPropagationInLists(t *testing.T) { } `, ExpectedErrors: []*gqlerrors.QueryError{ - &gqlerrors.QueryError{ + { Message: `graphql: got nil for non-null "Droid"`, Path: []interface{}{"findNilDroids", 1}, }, @@ -591,7 +906,7 @@ func TestErrorPropagationInLists(t *testing.T) { } `, ExpectedErrors: []*gqlerrors.QueryError{ - &gqlerrors.QueryError{ + { Message: quoteError.Error(), ResolverError: quoteError, Path: []interface{}{"findDroids", 0, "quotes"}, @@ -626,12 +941,12 @@ func TestErrorPropagationInLists(t *testing.T) { } `, ExpectedErrors: []*gqlerrors.QueryError{ - &gqlerrors.QueryError{ + { Message: quoteError.Error(), ResolverError: quoteError, Path: []interface{}{"findNilDroids", 0, "quotes"}, }, - &gqlerrors.QueryError{ + { Message: `graphql: got nil for non-null "Droid"`, Path: []interface{}{"findNilDroids", 1}, }, @@ -670,7 +985,7 @@ func TestErrorWithExtensions(t *testing.T) { null `, ExpectedErrors: []*gqlerrors.QueryError{ - &gqlerrors.QueryError{ + { Message: droidNotFoundError.Error(), Path: []interface{}{"FindDroid"}, ResolverError: droidNotFoundError, @@ -706,7 +1021,7 @@ func TestErrorWithNoExtensions(t *testing.T) { null `, ExpectedErrors: []*gqlerrors.QueryError{ - &gqlerrors.QueryError{ + { Message: err.Error(), Path: []interface{}{"DismissVader"}, ResolverError: err, @@ -1391,6 +1706,77 @@ func TestInlineFragments(t *testing.T) { } `, }, + + { + Schema: starwarsSchema, + Query: ` + query CharacterSearch { + search(text: "C-3PO") { + ... on Character { + name + } + } + } + `, + ExpectedResult: ` + { + "search": [ + { + "name": "C-3PO" + } + ] + } + `, + }, + + { + Schema: starwarsSchema, + Query: ` + query CharacterSearch { + hero { + ... on Character { + ... on Human { + name + } + ... on Droid { + name + } + } + } + } + `, + ExpectedResult: ` + { + "hero": { + "name": "R2-D2" + } + } + `, + }, + + { + Schema: socialSchema, + Query: ` + query { + admin(id: "0x01") { + ... on User { + email + } + ... on Person { + name + } + } + } + `, + ExpectedResult: ` + { + "admin": { + "email": "Albus@hogwarts.com", + "name": "Albus Dumbledore" + } + } + `, + }, }) } @@ -1453,22 +1839,46 @@ func TestTypeName(t *testing.T) { } `, }, - }) -} -func TestConnections(t *testing.T) { - gqltesting.RunTests(t, []*gqltesting.Test{ { Schema: starwarsSchema, Query: ` { hero { + __typename name - friendsConnection { - totalCount - pageInfo { - startCursor - endCursor + ... on Character { + ...Droid + name + __typename + } + } + } + + fragment Droid on Droid { + name + __typename + } + `, + RawResponse: true, + ExpectedResult: `{"hero":{"__typename":"Droid","name":"R2-D2"}}`, + }, + }) +} + +func TestConnections(t *testing.T) { + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: starwarsSchema, + Query: ` + { + hero { + name + friendsConnection { + totalCount + pageInfo { + startCursor + endCursor hasNextPage } edges { @@ -1734,6 +2144,7 @@ func TestIntrospection(t *testing.T) { { "name": "SearchResult" }, { "name": "Starship" }, { "name": "String" }, + { "name": "_Service" }, { "name": "__Directive" }, { "name": "__DirectiveLocation" }, { "name": "__EnumValue" }, @@ -2231,6 +2642,44 @@ func TestIntrospectionDisableIntrospection(t *testing.T) { } `, }, + + { + Schema: starwarsSchemaNoIntrospection, + Query: ` + { + search(text: "an") { + __typename + ... on Human { + name + } + ... on Droid { + name + } + ... on Starship { + name + } + } + } + `, + ExpectedResult: ` + { + "search": [ + { + "__typename": "Human", + "name": "Han Solo" + }, + { + "__typename": "Human", + "name": "Leia Organa" + }, + { + "__typename": "Starship", + "name": "TIE Advanced x1" + } + ] + } + `, + }, }) } @@ -2614,6 +3063,205 @@ func TestInput(t *testing.T) { }) } +type inputArgumentsHello struct{} + +type inputArgumentsScalarMismatch1 struct{} + +type inputArgumentsScalarMismatch2 struct{} + +type inputArgumentsObjectMismatch1 struct{} + +type inputArgumentsObjectMismatch2 struct{} + +type inputArgumentsObjectMismatch3 struct{} + +type fieldNameMismatch struct{} + +type helloInput struct { + Name string +} + +type helloOutput struct { + Name string +} + +func (*fieldNameMismatch) Hello() helloOutput { + return helloOutput{} +} + +type helloInputMismatch struct { + World string +} + +func (r *inputArgumentsHello) Hello(args struct{ Input *helloInput }) string { + return "Hello " + args.Input.Name + "!" +} + +func (r *inputArgumentsScalarMismatch1) Hello(name string) string { + return "Hello " + name + "!" +} + +func (r *inputArgumentsScalarMismatch2) Hello(args struct{ World string }) string { + return "Hello " + args.World + "!" +} + +func (r *inputArgumentsObjectMismatch1) Hello(in helloInput) string { + return "Hello " + in.Name + "!" +} + +func (r *inputArgumentsObjectMismatch2) Hello(args struct{ Input *helloInputMismatch }) string { + return "Hello " + args.Input.World + "!" +} + +func (r *inputArgumentsObjectMismatch3) Hello(args struct{ Input *struct{ Thing string } }) string { + return "Hello " + args.Input.Thing + "!" +} + +func TestInputArguments_failSchemaParsing(t *testing.T) { + type args struct { + Resolver interface{} + Schema string + Opts []graphql.SchemaOpt + } + type want struct { + Error string + } + testTable := map[string]struct { + Args args + Want want + }{ + "Non-input type used with field arguments": { + Args: args{ + Resolver: &inputArgumentsHello{}, + Schema: ` + schema { + query: Query + } + type Query { + hello(input: HelloInput): String! + } + type HelloInput { + name: String + } + `, + }, + Want: want{Error: "field \"Input\": type of kind OBJECT can not be used as input\n\tused by (*graphql_test.inputArgumentsHello).Hello"}, + }, + "Missing Args Wrapper for scalar input": { + Args: args{ + Resolver: &inputArgumentsScalarMismatch1{}, + Schema: ` + schema { + query: Query + } + type Query { + hello(name: String): String! + } + input HelloInput { + name: String + } + `, + }, + Want: want{Error: "expected struct or pointer to struct, got string (hint: missing `args struct { ... }` wrapper for field arguments?)\n\tused by (*graphql_test.inputArgumentsScalarMismatch1).Hello"}, + }, + "Mismatching field name for scalar input": { + Args: args{ + Resolver: &inputArgumentsScalarMismatch2{}, + Schema: ` + schema { + query: Query + } + type Query { + hello(name: String): String! + } + `, + }, + Want: want{Error: "struct { World string } does not define field \"name\" (hint: missing `args struct { ... }` wrapper for field arguments, or missing field on input struct)\n\tused by (*graphql_test.inputArgumentsScalarMismatch2).Hello"}, + }, + "Missing Args Wrapper for Input type": { + Args: args{ + Resolver: &inputArgumentsObjectMismatch1{}, + Schema: ` + schema { + query: Query + } + type Query { + hello(input: HelloInput): String! + } + input HelloInput { + name: String + } + `, + }, + Want: want{Error: "graphql_test.helloInput does not define field \"input\" (hint: missing `args struct { ... }` wrapper for field arguments, or missing field on input struct)\n\tused by (*graphql_test.inputArgumentsObjectMismatch1).Hello"}, + }, + "Input struct missing field": { + Args: args{ + Resolver: &inputArgumentsObjectMismatch2{}, + Schema: ` + schema { + query: Query + } + type Query { + hello(input: HelloInput): String! + } + input HelloInput { + name: String + } + `, + }, + Want: want{Error: "field \"Input\": *graphql_test.helloInputMismatch does not define field \"name\" (hint: missing `args struct { ... }` wrapper for field arguments, or missing field on input struct)\n\tused by (*graphql_test.inputArgumentsObjectMismatch2).Hello"}, + }, + "Inline Input struct missing field": { + Args: args{ + Resolver: &inputArgumentsObjectMismatch3{}, + Schema: ` + schema { + query: Query + } + type Query { + hello(input: HelloInput): String! + } + input HelloInput { + name: String + } + `, + }, + Want: want{Error: "field \"Input\": *struct { Thing string } does not define field \"name\" (hint: missing `args struct { ... }` wrapper for field arguments, or missing field on input struct)\n\tused by (*graphql_test.inputArgumentsObjectMismatch3).Hello"}, + }, + "Struct field name inclusion": { + Args: args{ + Resolver: &fieldNameMismatch{}, + Opts: []graphql.SchemaOpt{graphql.UseFieldResolvers()}, + Schema: ` + type Query { + hello(): HelloOutput! + } + type HelloOutput { + name: Int + } + `, + }, + Want: want{Error: "string is not a pointer\n\tused by (graphql_test.helloOutput).Name\n\tused by (*graphql_test.fieldNameMismatch).Hello"}, + }, + } + + for name, tt := range testTable { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + + _, err := graphql.ParseSchema(tt.Args.Schema, tt.Args.Resolver, tt.Args.Opts...) + if err == nil || err.Error() != tt.Want.Error { + t.Log("Schema parsing error mismatch") + t.Logf("got: %s", err) + t.Logf("exp: %s", tt.Want.Error) + t.Fail() + } + }) + } +} + func TestComposedFragments(t *testing.T) { gqltesting.RunTests(t, []*gqltesting.Test{ { @@ -3018,6 +3666,148 @@ func TestErrorPropagation(t *testing.T) { }) } +type assertionResolver struct{} + +func (r *assertionResolver) ToHuman() (*struct{ Name string }, bool) { + return &struct{ Name string }{Name: "Luke Skywalker"}, true +} + +type assertionQueryResolver struct{} + +func (*assertionQueryResolver) Character() *assertionResolver { + return &assertionResolver{} +} + +type badAssertionResolver struct{} + +func (r *badAssertionResolver) ToHuman(ctx context.Context) (*struct{ Name string }, bool) { + return &struct{ Name string }{Name: "Luke Skywalker"}, true +} + +type badAssertionQueryResolver struct{} + +func (*badAssertionQueryResolver) Character() *badAssertionResolver { + return &badAssertionResolver{} +} + +func TestTypeAssertions(t *testing.T) { + assertionSchema := ` + schema { + query: Query + } + + type Query { + character: Character! + } + + type Human { + name: String! + } + + union Character = Human + ` + query := ` + query { + character { + ... on Human { + name + } + } + } + ` + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(assertionSchema, &assertionQueryResolver{}, graphql.UseFieldResolvers()), + Query: query, + ExpectedResult: ` + { + "character": { + "name": "Luke Skywalker" + } + } + `, + }, + }) +} + +func TestPanicTypeAssertionArguments(t *testing.T) { + panicMessage := `*graphql_test.badAssertionResolver does not resolve "Character": method "ToHuman" should't have any arguments + used by (*graphql_test.badAssertionQueryResolver).Character` + + defer func() { + r := recover() + if r == nil { + t.Fatal("expected schema parse to panic") + } + + if r.(error).Error() != panicMessage { + t.Logf("got: %s", r) + t.Logf("want: %s", panicMessage) + t.Fail() + } + }() + + schema := ` + schema { + query: Query + } + + type Query { + character: Character! + } + + type Human { + name: String! + } + + union Character = Human + ` + graphql.MustParseSchema(schema, &badAssertionQueryResolver{}, graphql.UseFieldResolvers()) +} + +type ambiguousResolver struct { + Name string // ambiguous + University +} + +type University struct { + Name string // ambiguous +} + +func TestPanicAmbiguity(t *testing.T) { + panicMessage := `*graphql_test.ambiguousResolver does not resolve "Query": ambiguous field "name"` + + defer func() { + r := recover() + if r == nil { + t.Fatal("expected schema parse to panic") + } + + if r.(error).Error() != panicMessage { + t.Logf("got: %s", r) + t.Logf("want: %s", panicMessage) + t.Fail() + } + }() + + schema := ` + schema { + query: Query + } + + type Query { + name: String! + university: University! + } + + type University { + name: String! + } + ` + graphql.MustParseSchema(schema, &ambiguousResolver{}, graphql.UseFieldResolvers()) +} + func TestSchema_Exec_without_resolver(t *testing.T) { t.Parallel() @@ -3080,16 +3870,22 @@ func (r *subscriptionsInExecResolver) AppUpdated() <-chan string { } func TestSubscriptions_In_Exec(t *testing.T) { + r := &struct { + *helloResolver + *subscriptionsInExecResolver + }{ + helloResolver: &helloResolver{}, + subscriptionsInExecResolver: &subscriptionsInExecResolver{}, + } gqltesting.RunTest(t, &gqltesting.Test{ Schema: graphql.MustParseSchema(` - schema { - subscription: Subscription + type Query { + hello: String! } - type Subscription { appUpdated : String! } - `, &subscriptionsInExecResolver{}), + `, r), Query: ` subscription { appUpdated @@ -3102,3 +3898,565 @@ func TestSubscriptions_In_Exec(t *testing.T) { }, }) } + +type nilPointerReturnValue struct{} + +func (r *nilPointerReturnValue) Value() *string { + return nil +} + +type nilPointerReturnResolver struct{} + +func (r *nilPointerReturnResolver) PointerReturn() *nilPointerReturnValue { + return &nilPointerReturnValue{} +} + +func TestPointerReturnForNonNull(t *testing.T) { + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(` + type Query { + pointerReturn: PointerReturnValue + } + + type PointerReturnValue { + value: Hello! + } + enum Hello { + WORLD + } + `, &nilPointerReturnResolver{}), + Query: ` + query { + pointerReturn { + value + } + } + `, + ExpectedResult: ` + { + "pointerReturn": null + } + `, + ExpectedErrors: []*gqlerrors.QueryError{ + { + Message: `graphql: got nil for non-null "Hello"`, + Path: []interface{}{"pointerReturn", "value"}, + }, + }, + }, + }) +} + +type nullableInput struct { + String graphql.NullString + Int graphql.NullInt + Bool graphql.NullBool + Time graphql.NullTime + Float graphql.NullFloat +} + +type nullableResult struct { + String string + Int string + Bool string + Time string + Float string +} + +type nullableResolver struct { +} + +func (r *nullableResolver) TestNullables(args struct { + Input *nullableInput +}) nullableResult { + var res nullableResult + if args.Input.String.Set { + if args.Input.String.Value == nil { + res.String = "" + } else { + res.String = *args.Input.String.Value + } + } + + if args.Input.Int.Set { + if args.Input.Int.Value == nil { + res.Int = "" + } else { + res.Int = fmt.Sprintf("%d", *args.Input.Int.Value) + } + } + + if args.Input.Float.Set { + if args.Input.Float.Value == nil { + res.Float = "" + } else { + res.Float = fmt.Sprintf("%.2f", *args.Input.Float.Value) + } + } + + if args.Input.Bool.Set { + if args.Input.Bool.Value == nil { + res.Bool = "" + } else { + res.Bool = fmt.Sprintf("%t", *args.Input.Bool.Value) + } + } + + if args.Input.Time.Set { + if args.Input.Time.Value == nil { + res.Time = "" + } else { + res.Time = args.Input.Time.Value.Format(time.RFC3339) + } + } + + return res +} + +func TestNullable(t *testing.T) { + schema := ` + scalar Time + + input MyInput { + string: String + int: Int + float: Float + bool: Boolean + time: Time + } + + type Result { + string: String! + int: String! + float: String! + bool: String! + time: String! + } + + type Query { + testNullables(input: MyInput): Result! + } + ` + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(schema, &nullableResolver{}, graphql.UseFieldResolvers()), + Query: ` + query { + testNullables(input: { + string: "test" + int: 1234 + float: 42.42 + bool: true + time: "2021-01-02T15:04:05Z" + }) { + string + int + float + bool + time + } + } + `, + ExpectedResult: ` + { + "testNullables": { + "string": "test", + "int": "1234", + "float": "42.42", + "bool": "true", + "time": "2021-01-02T15:04:05Z" + } + } + `, + }, + { + Schema: graphql.MustParseSchema(schema, &nullableResolver{}, graphql.UseFieldResolvers()), + Query: ` + query { + testNullables(input: { + string: null + int: null + float: null + bool: null + time: null + }) { + string + int + float + bool + time + } + } + `, + ExpectedResult: ` + { + "testNullables": { + "string": "", + "int": "", + "float": "", + "bool": "", + "time": "" + } + } + `, + }, + { + Schema: graphql.MustParseSchema(schema, &nullableResolver{}, graphql.UseFieldResolvers()), + Query: ` + query { + testNullables(input: {}) { + string + int + float + bool + time + } + } + `, + ExpectedResult: ` + { + "testNullables": { + "string": "", + "int": "", + "float": "", + "bool": "", + "time": "" + } + } + `, + }, + }) +} + +type testTracer struct { + mu *sync.Mutex + fields []fieldTrace + queries []queryTrace +} + +type fieldTrace struct { + label string + typeName string + fieldName string + isTrivial bool + args map[string]interface{} + err *gqlerrors.QueryError +} + +type queryTrace struct { + document string + opName string + variables map[string]interface{} + varTypes map[string]*introspection.Type + errors []*gqlerrors.QueryError +} + +func (t *testTracer) TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, func(*gqlerrors.QueryError)) { + return ctx, func(qe *gqlerrors.QueryError) { + t.mu.Lock() + defer t.mu.Unlock() + + ft := fieldTrace{ + label: label, + typeName: typeName, + fieldName: fieldName, + isTrivial: trivial, + args: args, + err: qe, + } + + t.fields = append(t.fields, ft) + } +} + +func (t *testTracer) TraceQuery(ctx context.Context, document string, opName string, vars map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, func([]*gqlerrors.QueryError)) { + return ctx, func(qe []*gqlerrors.QueryError) { + t.mu.Lock() + defer t.mu.Unlock() + + qt := queryTrace{ + document: document, + opName: opName, + variables: vars, + varTypes: varTypes, + errors: qe, + } + + t.queries = append(t.queries, qt) + } +} + +var _ tracer.Tracer = (*testTracer)(nil) + +func TestTracer(t *testing.T) { + t.Parallel() + + tt := &testTracer{mu: &sync.Mutex{}} + + schema, err := graphql.ParseSchema(starwars.Schema, &starwars.Resolver{}, graphql.Tracer(tt)) + if err != nil { + t.Fatalf("graphql.ParseSchema: %s", err) + } + + ctx := context.Background() + doc := ` + query TestTracer($id: ID!) { + HanSolo: human(id: $id) { + __typename + name + } + } + ` + opName := "TestTracer" + variables := map[string]interface{}{ + "id": "1002", + } + + _ = schema.Exec(ctx, doc, opName, variables) + + tt.mu.Lock() + defer tt.mu.Unlock() + + if len(tt.queries) != 1 { + t.Fatalf("expected one query trace, but got %d: %#v", len(tt.queries), tt.queries) + } + + qt := tt.queries[0] + if qt.document != doc { + t.Errorf("mismatched query trace document:\nwant: %q\ngot : %q", doc, qt.document) + } + if qt.opName != opName { + t.Errorf("mismated query trace operationName:\nwant: %q\ngot : %q", opName, qt.opName) + } + + expectedFieldTraces := []fieldTrace{ + {fieldName: "human", typeName: "Query"}, + {fieldName: "__typename", typeName: "Human"}, + {fieldName: "name", typeName: "Human"}, + } + + checkFieldTraces(t, expectedFieldTraces, tt.fields) +} + +func checkFieldTraces(t *testing.T, want, have []fieldTrace) { + if len(want) != len(have) { + t.Errorf("mismatched field traces: expected %d but got %d: %#v", len(want), len(have), have) + } + + type comparison struct { + want fieldTrace + have fieldTrace + } + + m := map[string]comparison{} + + for _, ft := range want { + m[ft.fieldName] = comparison{want: ft} + } + + for _, ft := range have { + c := m[ft.fieldName] + c.have = ft + m[ft.fieldName] = c + } + + for _, c := range m { + if err := stringsEqual(c.want.fieldName, c.have.fieldName); err != "" { + t.Error("mismatched field name:", err) + } + if err := stringsEqual(c.want.typeName, c.have.typeName); err != "" { + t.Error("mismatched field parent type:", err) + } + } +} + +func stringsEqual(want, have string) string { + if want != have { + return fmt.Sprintf("mismatched values:\nwant: %q\nhave: %q", want, have) + } + + return "" +} + +type queryVarResolver struct{} +type filterArgs struct { + Required string + Optional *string +} +type filterSearchResults struct { + Match *string +} + +func (r *queryVarResolver) Search(ctx context.Context, args *struct{ Filter filterArgs }) []filterSearchResults { + return []filterSearchResults{} +} + +func TestQueryVariablesValidation(t *testing.T) { + gqltesting.RunTests(t, []*gqltesting.Test{{ + Schema: graphql.MustParseSchema(` + input SearchFilter { + required: String! + optional: String + } + + type SearchResults { + match: String + } + + type Query { + search(filter: SearchFilter!): [SearchResults!]! + }`, &queryVarResolver{}, graphql.UseFieldResolvers()), + Query: ` + query { + search(filter: {}) { + match + } + }`, + ExpectedErrors: []*gqlerrors.QueryError{{ + Message: "Argument \"filter\" has invalid value {}.\nIn field \"required\": Expected \"String!\", found null.", + Locations: []gqlerrors.Location{{Line: 3, Column: 27}}, + Rule: "ArgumentsOfCorrectType", + }}, + }, { + Schema: graphql.MustParseSchema(` + input SearchFilter { + required: String! + optional: String + } + + type SearchResults { + match: String + } + + type Query { + search(filter: SearchFilter!): [SearchResults!]! + }`, &queryVarResolver{}, graphql.UseFieldResolvers()), + Query: ` + query q($filter: SearchFilter!) { + search(filter: $filter) { + match + } + }`, + Variables: map[string]interface{}{"filter": map[string]interface{}{}}, + ExpectedErrors: []*gqlerrors.QueryError{{ + Message: "Variable \"required\" has invalid value null.\nExpected type \"String!\", found null.", + Locations: []gqlerrors.Location{{Line: 3, Column: 5}}, + Rule: "VariablesOfCorrectType", + }}, + }}) +} + +type interfaceImplementingInterfaceResolver struct{} +type interfaceImplementingInterfaceExample struct { + A string + B string + C bool +} + +func (r *interfaceImplementingInterfaceResolver) Hey() *interfaceImplementingInterfaceExample { + return &interfaceImplementingInterfaceExample{ + A: "testing", + B: "test", + C: true, + } +} + +func TestInterfaceImplementingInterface(t *testing.T) { + gqltesting.RunTests(t, []*gqltesting.Test{{ + Schema: graphql.MustParseSchema(` + interface A { + a: String! + } + interface B implements A { + a: String! + b: String! + } + interface C implements B & A { + a: String! + b: String! + c: Boolean! + } + type ABC implements C { + a: String! + b: String! + c: Boolean! + } + type Query { + hey: ABC + }`, &interfaceImplementingInterfaceResolver{}, graphql.UseFieldResolvers(), graphql.UseFieldResolvers()), + Query: `query {hey { a b c }}`, + ExpectedResult: ` + { + "hey": { + "a": "testing", + "b": "test", + "c": true + } + } + `, + }}) +} + +func TestCircularFragmentMaxDepth(t *testing.T) { + withMaxDepth := graphql.MustParseSchema(starwars.Schema, &starwars.Resolver{}, graphql.MaxDepth(2)) + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: withMaxDepth, + Query: ` + query { + ...X + } + + fragment X on Query { + ...Y + } + fragment Y on Query { + ...X + } + `, + ExpectedErrors: []*gqlerrors.QueryError{{ + Message: `Cannot spread fragment "X" within itself via Y.`, + Rule: "NoFragmentCycles", + Locations: []gqlerrors.Location{ + {Line: 7, Column: 20}, + {Line: 10, Column: 20}, + }, + }}, + }, + }) +} + +func TestQueryService(t *testing.T) { + t.Parallel() + + schemaString := ` + schema { + query: Query + } + + type Query { + hello: String! + }` + + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: graphql.MustParseSchema(schemaString, &helloWorldResolver1{}), + Query: ` + { + _service{ + sdl + } + } + `, + ExpectedResult: ` + { + "_service": { + "sdl": "\n\tschema {\n\t\tquery: Query\n\t}\n\n\ttype Query {\n\t\thello: String!\n\t}" + } + } + `, + }, + }) +} diff --git a/id.go b/id.go index 52771c413..80bdac906 100644 --- a/id.go +++ b/id.go @@ -1,7 +1,7 @@ package graphql import ( - "errors" + "fmt" "strconv" ) @@ -20,7 +20,7 @@ func (id *ID) UnmarshalGraphQL(input interface{}) error { case int32: *id = ID(strconv.Itoa(int(input))) default: - err = errors.New("wrong type") + err = fmt.Errorf("wrong type for ID: %T", input) } return err } diff --git a/internal/common/blockstring.go b/internal/common/blockstring.go new file mode 100644 index 000000000..1f7fe8133 --- /dev/null +++ b/internal/common/blockstring.go @@ -0,0 +1,103 @@ +// MIT License +// +// Copyright (c) 2019 GraphQL Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// This implementation has been adapted from the graphql-js reference implementation +// https://github.com/graphql/graphql-js/blob/5eb7c4ded7ceb83ac742149cbe0dae07a8af9a30/src/language/blockString.js +// which is released under the MIT License above. + +package common + +import ( + "strings" +) + +// Produces the value of a block string from its parsed raw value, similar to +// CoffeeScript's block string, Python's docstring trim or Ruby's strip_heredoc. +// +// This implements the GraphQL spec's BlockStringValue() static algorithm. +func blockString(raw string) string { + lines := strings.Split(raw, "\n") + + // Remove common indentation from all lines except the first (which has none) + ind := blockStringIndentation(lines) + if ind > 0 { + for i := 1; i < len(lines); i++ { + l := lines[i] + if len(l) < ind { + lines[i] = "" + continue + } + lines[i] = l[ind:] + } + } + + // Remove leading and trailing blank lines + trimStart := 0 + for i := 0; i < len(lines) && isBlank(lines[i]); i++ { + trimStart++ + } + lines = lines[trimStart:] + trimEnd := 0 + for i := len(lines) - 1; i > 0 && isBlank(lines[i]); i-- { + trimEnd++ + } + lines = lines[:len(lines)-trimEnd] + + return strings.Join(lines, "\n") +} + +func blockStringIndentation(lines []string) int { + var commonIndent *int + for i := 1; i < len(lines); i++ { + l := lines[i] + indent := leadingWhitespace(l) + if indent == len(l) { + // don't consider blank/empty lines + continue + } + if indent == 0 { + return 0 + } + if commonIndent == nil || indent < *commonIndent { + commonIndent = &indent + } + } + if commonIndent == nil { + return 0 + } + return *commonIndent +} + +func isBlank(s string) bool { + return len(s) == 0 || leadingWhitespace(s) == len(s) +} + +func leadingWhitespace(s string) int { + i := 0 + for _, r := range s { + if r != '\t' && r != ' ' { + break + } + i++ + } + return i +} diff --git a/internal/common/directive.go b/internal/common/directive.go index 62dca47f8..8d647b0f3 100644 --- a/internal/common/directive.go +++ b/internal/common/directive.go @@ -1,32 +1,18 @@ package common -type Directive struct { - Name Ident - Args ArgumentList -} +import "github.com/tokopedia/graphql-go/types" -func ParseDirectives(l *Lexer) DirectiveList { - var directives DirectiveList +func ParseDirectives(l *Lexer) types.DirectiveList { + var directives types.DirectiveList for l.Peek() == '@' { l.ConsumeToken('@') - d := &Directive{} + d := &types.Directive{} d.Name = l.ConsumeIdentWithLoc() d.Name.Loc.Column-- if l.Peek() == '(' { - d.Args = ParseArguments(l) + d.Arguments = ParseArgumentList(l) } directives = append(directives, d) } return directives } - -type DirectiveList []*Directive - -func (l DirectiveList) Get(name string) *Directive { - for _, d := range l { - if d.Name.Name == name { - return d - } - } - return nil -} diff --git a/internal/common/lexer.go b/internal/common/lexer.go index 6771e3a3b..957e598f4 100644 --- a/internal/common/lexer.go +++ b/internal/common/lexer.go @@ -8,6 +8,7 @@ import ( "text/scanner" "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/types" ) type syntaxError string @@ -30,7 +31,10 @@ func NewLexer(s string, useStringDescriptions bool) *Lexer { } sc.Init(strings.NewReader(s)) - return &Lexer{sc: sc, useStringDescriptions: useStringDescriptions} + l := Lexer{sc: sc, useStringDescriptions: useStringDescriptions} + l.sc.Error = l.CatchScannerError + + return &l } func (l *Lexer) CatchSyntaxError(f func()) (errRes *errors.QueryError) { @@ -115,11 +119,11 @@ func (l *Lexer) ConsumeIdent() string { return name } -func (l *Lexer) ConsumeIdentWithLoc() Ident { +func (l *Lexer) ConsumeIdentWithLoc() types.Ident { loc := l.Location() name := l.sc.TokenText() l.ConsumeToken(scanner.Ident) - return Ident{name, loc} + return types.Ident{Name: name, Loc: loc} } func (l *Lexer) ConsumeKeyword(keyword string) { @@ -129,8 +133,8 @@ func (l *Lexer) ConsumeKeyword(keyword string) { l.ConsumeWhitespace() } -func (l *Lexer) ConsumeLiteral() *BasicLit { - lit := &BasicLit{Type: l.next, Text: l.sc.TokenText()} +func (l *Lexer) ConsumeLiteral() *types.PrimitiveValue { + lit := &types.PrimitiveValue{Type: l.next, Text: l.sc.TokenText()} l.ConsumeWhitespace() return lit } @@ -184,8 +188,7 @@ func (l *Lexer) consumeTripleQuoteComment() string { } val := buf.String() val = val[:len(val)-numQuotes] - val = strings.TrimSpace(val) - return val + return blockString(val) } func (l *Lexer) consumeStringComment() string { @@ -220,3 +223,7 @@ func (l *Lexer) consumeComment() { l.comment.WriteRune(next) } } + +func (l *Lexer) CatchScannerError(s *scanner.Scanner, msg string) { + l.SyntaxError(msg) +} diff --git a/internal/common/lexer_test.go b/internal/common/lexer_test.go index 997c73158..321fb7b21 100644 --- a/internal/common/lexer_test.go +++ b/internal/common/lexer_test.go @@ -93,3 +93,42 @@ func TestConsume(t *testing.T) { }) } } + +var multilineStringTests = []consumeTestCase{ + { + description: "Oneline strings are okay", + definition: `"Hello World"`, + expected: "", + failureExpected: false, + useStringDescriptions: true, + }, + { + description: "Multiline strings are not allowed", + definition: `"Hello + World"`, + expected: `graphql: syntax error: literal not terminated (line 1, column 1)`, + failureExpected: true, + useStringDescriptions: true, + }, +} + +func TestMultilineString(t *testing.T) { + for _, test := range multilineStringTests { + t.Run(test.description, func(t *testing.T) { + lex := common.NewLexer(test.definition, test.useStringDescriptions) + + err := lex.CatchSyntaxError(func() { lex.ConsumeWhitespace() }) + if test.failureExpected && err == nil { + t.Fatalf("Test '%s' should fail", test.description) + } else if test.failureExpected && err != nil { + if test.expected != err.Error() { + t.Fatalf("Test '%s' failed with wrong error: '%s'. Error should be: '%s'", test.description, err.Error(), test.expected) + } + } + + if !test.failureExpected && err != nil { + t.Fatalf("Test '%s' failed with error: '%s'", test.description, err.Error()) + } + }) + } +} diff --git a/internal/common/literals.go b/internal/common/literals.go index d1483239c..aac20fe92 100644 --- a/internal/common/literals.go +++ b/internal/common/literals.go @@ -1,170 +1,12 @@ package common import ( - "strconv" - "strings" "text/scanner" - "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/types" ) -type Literal interface { - Value(vars map[string]interface{}) interface{} - String() string - Location() errors.Location -} - -type BasicLit struct { - Type rune - Text string - Loc errors.Location -} - -func (lit *BasicLit) Value(vars map[string]interface{}) interface{} { - switch lit.Type { - case scanner.Int: - value, err := strconv.ParseInt(lit.Text, 10, 32) - if err != nil { - // check if it is out of range error. - // which probably mean that the input use int64 data type - // as needed by scalar.Int64 data type of tokopedia/gqlserver - if strings.Contains(err.Error(), strconv.ErrRange.Error()) { - val64, err := strconv.ParseInt(lit.Text, 10, 64) - if err != nil { - panic(err) - } - return int64(val64) - } - panic(err) - } - return int32(value) - - case scanner.Float: - value, err := strconv.ParseFloat(lit.Text, 64) - if err != nil { - panic(err) - } - return value - - case scanner.String: - value, err := strconv.Unquote(lit.Text) - if err != nil { - panic(err) - } - return value - - case scanner.Ident: - switch lit.Text { - case "true": - return true - case "false": - return false - default: - return lit.Text - } - - default: - panic("invalid literal") - } -} - -func (lit *BasicLit) String() string { - return lit.Text -} - -func (lit *BasicLit) Location() errors.Location { - return lit.Loc -} - -type ListLit struct { - Entries []Literal - Loc errors.Location -} - -func (lit *ListLit) Value(vars map[string]interface{}) interface{} { - entries := make([]interface{}, len(lit.Entries)) - for i, entry := range lit.Entries { - entries[i] = entry.Value(vars) - } - return entries -} - -func (lit *ListLit) String() string { - entries := make([]string, len(lit.Entries)) - for i, entry := range lit.Entries { - entries[i] = entry.String() - } - return "[" + strings.Join(entries, ", ") + "]" -} - -func (lit *ListLit) Location() errors.Location { - return lit.Loc -} - -type ObjectLit struct { - Fields []*ObjectLitField - Loc errors.Location -} - -type ObjectLitField struct { - Name Ident - Value Literal -} - -func (lit *ObjectLit) Value(vars map[string]interface{}) interface{} { - fields := make(map[string]interface{}, len(lit.Fields)) - for _, f := range lit.Fields { - fields[f.Name.Name] = f.Value.Value(vars) - } - return fields -} - -func (lit *ObjectLit) String() string { - entries := make([]string, 0, len(lit.Fields)) - for _, f := range lit.Fields { - entries = append(entries, f.Name.Name+": "+f.Value.String()) - } - return "{" + strings.Join(entries, ", ") + "}" -} - -func (lit *ObjectLit) Location() errors.Location { - return lit.Loc -} - -type NullLit struct { - Loc errors.Location -} - -func (lit *NullLit) Value(vars map[string]interface{}) interface{} { - return nil -} - -func (lit *NullLit) String() string { - return "null" -} - -func (lit *NullLit) Location() errors.Location { - return lit.Loc -} - -type Variable struct { - Name string - Loc errors.Location -} - -func (v Variable) Value(vars map[string]interface{}) interface{} { - return vars[v.Name] -} - -func (v Variable) String() string { - return "$" + v.Name -} - -func (v *Variable) Location() errors.Location { - return v.Loc -} - -func ParseLiteral(l *Lexer, constOnly bool) Literal { +func ParseLiteral(l *Lexer, constOnly bool) types.Value { loc := l.Location() switch l.Peek() { case '$': @@ -173,12 +15,12 @@ func ParseLiteral(l *Lexer, constOnly bool) Literal { panic("unreachable") } l.ConsumeToken('$') - return &Variable{l.ConsumeIdent(), loc} + return &types.Variable{Name: l.ConsumeIdent(), Loc: loc} case scanner.Int, scanner.Float, scanner.String, scanner.Ident: lit := l.ConsumeLiteral() if lit.Type == scanner.Ident && lit.Text == "null" { - return &NullLit{loc} + return &types.NullValue{Loc: loc} } lit.Loc = loc return lit @@ -190,24 +32,24 @@ func ParseLiteral(l *Lexer, constOnly bool) Literal { return lit case '[': l.ConsumeToken('[') - var list []Literal + var list []types.Value for l.Peek() != ']' { list = append(list, ParseLiteral(l, constOnly)) } l.ConsumeToken(']') - return &ListLit{list, loc} + return &types.ListValue{Values: list, Loc: loc} case '{': l.ConsumeToken('{') - var fields []*ObjectLitField + var fields []*types.ObjectField for l.Peek() != '}' { name := l.ConsumeIdentWithLoc() l.ConsumeToken(':') value := ParseLiteral(l, constOnly) - fields = append(fields, &ObjectLitField{name, value}) + fields = append(fields, &types.ObjectField{Name: name, Value: value}) } l.ConsumeToken('}') - return &ObjectLit{fields, loc} + return &types.ObjectValue{Fields: fields, Loc: loc} default: l.SyntaxError("invalid value") diff --git a/internal/common/types.go b/internal/common/types.go index 9ceb83bd4..d1b5ef9fd 100644 --- a/internal/common/types.go +++ b/internal/common/types.go @@ -2,70 +2,57 @@ package common import ( "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/types" ) -type Type interface { - Kind() string - String() string -} - -type List struct { - OfType Type -} - -type NonNull struct { - OfType Type -} - -type TypeName struct { - Ident -} - -func (*List) Kind() string { return "LIST" } -func (*NonNull) Kind() string { return "NON_NULL" } -func (*TypeName) Kind() string { panic("TypeName needs to be resolved to actual type") } - -func (t *List) String() string { return "[" + t.OfType.String() + "]" } -func (t *NonNull) String() string { return t.OfType.String() + "!" } -func (*TypeName) String() string { panic("TypeName needs to be resolved to actual type") } - -func ParseType(l *Lexer) Type { +func ParseType(l *Lexer) types.Type { t := parseNullType(l) if l.Peek() == '!' { l.ConsumeToken('!') - return &NonNull{OfType: t} + return &types.NonNull{OfType: t} } return t } -func parseNullType(l *Lexer) Type { +func parseNullType(l *Lexer) types.Type { if l.Peek() == '[' { l.ConsumeToken('[') ofType := ParseType(l) l.ConsumeToken(']') - return &List{OfType: ofType} + return &types.List{OfType: ofType} } - return &TypeName{Ident: l.ConsumeIdentWithLoc()} + return &types.TypeName{Ident: l.ConsumeIdentWithLoc()} } -type Resolver func(name string) Type +type Resolver func(name string) types.Type -func ResolveType(t Type, resolver Resolver) (Type, *errors.QueryError) { +// ResolveType attempts to resolve a type's name against a resolving function. +// This function is used when one needs to check if a TypeName exists in the resolver (typically a Schema). +// +// In the example below, ResolveType would be used to check if the resolving function +// returns a valid type for Dimension: +// +// type Profile { +// picture(dimensions: Dimension): Url +// } +// +// ResolveType recursively unwraps List and NonNull types until a NamedType is reached. +func ResolveType(t types.Type, resolver Resolver) (types.Type, *errors.QueryError) { switch t := t.(type) { - case *List: + case *types.List: ofType, err := ResolveType(t.OfType, resolver) if err != nil { return nil, err } - return &List{OfType: ofType}, nil - case *NonNull: + return &types.List{OfType: ofType}, nil + case *types.NonNull: ofType, err := ResolveType(t.OfType, resolver) if err != nil { return nil, err } - return &NonNull{OfType: ofType}, nil - case *TypeName: + return &types.NonNull{OfType: ofType}, nil + case *types.TypeName: refT := resolver(t.Name) if refT == nil { err := errors.Errorf("Unknown type %q.", t.Name) diff --git a/internal/common/values.go b/internal/common/values.go index 9345790dc..580fbfaa1 100644 --- a/internal/common/values.go +++ b/internal/common/values.go @@ -1,32 +1,11 @@ package common import ( - "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/types" ) -// http://facebook.github.io/graphql/draft/#InputValueDefinition -type InputValue struct { - Name Ident - Type Type - Default Literal - Desc string - Loc errors.Location - TypeLoc errors.Location -} - -type InputValueList []*InputValue - -func (l InputValueList) Get(name string) *InputValue { - for _, v := range l { - if v.Name.Name == name { - return v - } - } - return nil -} - -func ParseInputValue(l *Lexer) *InputValue { - p := &InputValue{} +func ParseInputValue(l *Lexer) *types.InputValueDefinition { + p := &types.InputValueDefinition{} p.Loc = l.Location() p.Desc = l.DescComment() p.Name = l.ConsumeIdentWithLoc() @@ -37,41 +16,21 @@ func ParseInputValue(l *Lexer) *InputValue { l.ConsumeToken('=') p.Default = ParseLiteral(l, true) } + p.Directives = ParseDirectives(l) return p } -type Argument struct { - Name Ident - Value Literal -} - -type ArgumentList []Argument - -func (l ArgumentList) Get(name string) (Literal, bool) { - for _, arg := range l { - if arg.Name.Name == name { - return arg.Value, true - } - } - return nil, false -} - -func (l ArgumentList) MustGet(name string) Literal { - value, ok := l.Get(name) - if !ok { - panic("argument not found") - } - return value -} - -func ParseArguments(l *Lexer) ArgumentList { - var args ArgumentList +func ParseArgumentList(l *Lexer) types.ArgumentList { + var args types.ArgumentList l.ConsumeToken('(') for l.Peek() != ')' { name := l.ConsumeIdentWithLoc() l.ConsumeToken(':') value := ParseLiteral(l, false) - args = append(args, Argument{Name: name, Value: value}) + args = append(args, &types.Argument{ + Name: name, + Value: value, + }) } l.ConsumeToken(')') return args diff --git a/internal/exec/exec.go b/internal/exec/exec.go index 07e4c7ed1..3561b26dd 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -4,32 +4,33 @@ import ( "bytes" "context" "encoding/json" - errlib "errors" "fmt" "reflect" "sync" + "time" "github.com/tokopedia/graphql-go/errors" - "github.com/tokopedia/graphql-go/internal/common" "github.com/tokopedia/graphql-go/internal/exec/resolvable" "github.com/tokopedia/graphql-go/internal/exec/selected" "github.com/tokopedia/graphql-go/internal/query" - "github.com/tokopedia/graphql-go/internal/schema" "github.com/tokopedia/graphql-go/log" - "github.com/tokopedia/graphql-go/trace" + "github.com/tokopedia/graphql-go/trace/tracer" + "github.com/tokopedia/graphql-go/types" ) type Request struct { selected.Request - Limiter chan struct{} - Tracer trace.Tracer - Logger log.Logger + Limiter chan struct{} + Tracer tracer.Tracer + Logger log.Logger + PanicHandler errors.PanicHandler + SubscribeResolverTimeout time.Duration } func (r *Request) handlePanic(ctx context.Context) { if value := recover(); value != nil { r.Logger.LogPanic(ctx, value) - r.AddError(makePanicError(value)) + r.AddError(r.PanicHandler.MakePanicError(ctx, value)) } } @@ -37,11 +38,7 @@ type extensionser interface { Extensions() map[string]interface{} } -func makePanicError(value interface{}) *errors.QueryError { - return errors.Errorf("graphql: panic occurred: %v", value) -} - -func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *query.Operation) ([]byte, []*errors.QueryError) { +func (r *Request) Execute(ctx context.Context, s *resolvable.Schema, op *types.OperationDefinition) ([]byte, []*errors.QueryError) { var out bytes.Buffer func() { defer r.handlePanic(ctx) @@ -97,7 +94,7 @@ func (r *Request) execSelections(ctx context.Context, sels []selected.Selection, // If a non-nullable child resolved to null, an error was added to the // "errors" list in the response, so this field resolves to null. // If this field is non-nullable, the error is propagated to its parent. - if _, ok := f.field.Type.(*common.NonNull); ok && resolvedToNull(f.out) { + if _, ok := f.field.Type.(*types.NonNull); ok && resolvedToNull(f.out) { out.Reset() out.Write([]byte("null")) return @@ -128,12 +125,22 @@ func collectFieldsToResolve(sels []selected.Selection, s *resolvable.Schema, res field.sels = append(field.sels, sel.Sels...) case *selected.TypenameField: - sf := &selected.SchemaField{ - Field: s.Meta.FieldTypename, - Alias: sel.Alias, - FixedResult: reflect.ValueOf(typeOf(sel, resolver)), + _, ok := fieldByAlias[sel.Alias] + if !ok { + res := reflect.ValueOf(typeOf(sel, resolver)) + f := s.FieldTypename + f.TypeName = res.String() + + sf := &selected.SchemaField{ + Field: f, + Alias: sel.Alias, + FixedResult: res, + } + + field := &fieldToExec{field: sf, resolver: resolver} + *fields = append(*fields, field) + fieldByAlias[sel.Alias] = field } - *fields = append(*fields, &fieldToExec{field: sf, resolver: resolver}) case *selected.TypeAssertion: out := resolver.Method(sel.MethodIndex).Call(nil) @@ -178,7 +185,7 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f defer func() { if panicValue := recover(); panicValue != nil { r.Logger.LogPanic(ctx, panicValue) - err = makePanicError(panicValue) + err = r.PanicHandler.MakePanicError(ctx, panicValue) err.Path = path.toSlice() } }() @@ -201,21 +208,16 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f if f.field.ArgsPacker != nil { in = append(in, f.field.PackedArgs) } - callOut := f.resolver.Method(f.field.MethodIndex).Call(in) + callOut := res.Method(f.field.MethodIndex).Call(in) result = callOut[0] if f.field.HasError && !callOut[1].IsNil() { - graphQLErr, ok := callOut[1].Interface().(errors.GraphQLError) - if ok { - extnErr := graphQLErr.PrepareExtErr() - extnErr.Path = path.toSlice() - extnErr.ResolverError = errlib.New(extnErr.Message) - return extnErr - } - resolverErr := callOut[1].Interface().(error) err := errors.Errorf("%s", resolverErr) err.Path = path.toSlice() err.ResolverError = resolverErr + if ex, ok := callOut[1].Interface().(extensionser); ok { + err.Extensions = ex.Extensions() + } return err } } else { @@ -223,7 +225,7 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f if res.Kind() == reflect.Ptr { res = res.Elem() } - result = res.Field(f.field.FieldIndex) + result = res.FieldByIndex(f.field.FieldIndex) } return nil }() @@ -243,41 +245,40 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f r.execSelectionSet(traceCtx, f.sels, f.field.Type, path, s, result, f.out) } -func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selection, typ common.Type, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { +func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selection, typ types.Type, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { t, nonNull := unwrapNonNull(typ) - switch t := t.(type) { - case *schema.Object, *schema.Interface, *schema.Union: - // a reflect.Value of a nil interface will show up as an Invalid value - if resolver.Kind() == reflect.Invalid || ((resolver.Kind() == reflect.Ptr || resolver.Kind() == reflect.Interface) && resolver.IsNil()) { - // If a field of a non-null type resolves to null (either because the - // function to resolve the field returned null or because an error occurred), - // add an error to the "errors" list in the response. - if nonNull { - err := errors.Errorf("graphql: got nil for non-null %q", t) - err.Path = path.toSlice() - r.AddError(err) - } - out.WriteString("null") - return + + // a reflect.Value of a nil interface will show up as an Invalid value + if resolver.Kind() == reflect.Invalid || ((resolver.Kind() == reflect.Ptr || resolver.Kind() == reflect.Interface) && resolver.IsNil()) { + // If a field of a non-null type resolves to null (either because the + // function to resolve the field returned null or because an error occurred), + // add an error to the "errors" list in the response. + if nonNull { + err := errors.Errorf("graphql: got nil for non-null %q", t) + err.Path = path.toSlice() + r.AddError(err) } + out.WriteString("null") + return + } + switch t.(type) { + case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: r.execSelections(ctx, sels, path, s, resolver, out, false) return } - if !nonNull { - if resolver.IsNil() { - out.WriteString("null") - return - } + // Any pointers or interfaces at this point should be non-nil, so we can get the actual value of them + // for serialization + if resolver.Kind() == reflect.Ptr || resolver.Kind() == reflect.Interface { resolver = resolver.Elem() } switch t := t.(type) { - case *common.List: + case *types.List: r.execList(ctx, sels, t, path, s, resolver, out) - case *schema.Scalar: + case *types.ScalarTypeDefinition: v := resolver.Interface() data, err := json.Marshal(v) if err != nil { @@ -285,15 +286,15 @@ func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selectio } out.Write(data) - case *schema.Enum: + case *types.EnumTypeDefinition: var stringer fmt.Stringer = resolver if s, ok := resolver.Interface().(fmt.Stringer); ok { stringer = s } name := stringer.String() var valid bool - for _, v := range t.Values { - if v.Name == name { + for _, v := range t.EnumValuesDefinition { + if v.EnumValue == name { valid = true break } @@ -314,28 +315,33 @@ func (r *Request) execSelectionSet(ctx context.Context, sels []selected.Selectio } } -func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ *common.List, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { +func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ *types.List, path *pathSegment, s *resolvable.Schema, resolver reflect.Value, out *bytes.Buffer) { l := resolver.Len() entryouts := make([]bytes.Buffer, l) if selected.HasAsyncSel(sels) { - var wg sync.WaitGroup - wg.Add(l) + // Limit the number of concurrent goroutines spawned as it can lead to large + // memory spikes for large lists. + concurrency := cap(r.Limiter) + sem := make(chan struct{}, concurrency) for i := 0; i < l; i++ { + sem <- struct{}{} go func(i int) { - defer wg.Done() + defer func() { <-sem }() defer r.handlePanic(ctx) r.execSelectionSet(ctx, sels, typ.OfType, &pathSegment{path, i}, s, resolver.Index(i), &entryouts[i]) }(i) } - wg.Wait() + for i := 0; i < concurrency; i++ { + sem <- struct{}{} + } } else { for i := 0; i < l; i++ { r.execSelectionSet(ctx, sels, typ.OfType, &pathSegment{path, i}, s, resolver.Index(i), &entryouts[i]) } } - _, listOfNonNull := typ.OfType.(*common.NonNull) + _, listOfNonNull := typ.OfType.(*types.NonNull) out.WriteByte('[') for i, entryout := range entryouts { @@ -355,8 +361,8 @@ func (r *Request) execList(ctx context.Context, sels []selected.Selection, typ * out.WriteByte(']') } -func unwrapNonNull(t common.Type) (common.Type, bool) { - if nn, ok := t.(*common.NonNull); ok { +func unwrapNonNull(t types.Type) (types.Type, bool) { + if nn, ok := t.(*types.NonNull); ok { return nn.OfType, true } return t, false diff --git a/internal/exec/packer/packer.go b/internal/exec/packer/packer.go index ff5a2ce74..6686bccc6 100644 --- a/internal/exec/packer/packer.go +++ b/internal/exec/packer/packer.go @@ -6,9 +6,9 @@ import ( "reflect" "strings" + "github.com/tokopedia/graphql-go/decode" "github.com/tokopedia/graphql-go/errors" - "github.com/tokopedia/graphql-go/internal/common" - "github.com/tokopedia/graphql-go/internal/schema" + "github.com/tokopedia/graphql-go/types" ) type packer interface { @@ -21,7 +21,7 @@ type Builder struct { } type typePair struct { - graphQLType common.Type + graphQLType types.Type resolverType reflect.Type } @@ -47,7 +47,7 @@ func (b *Builder) Finish() error { p.defaultStruct = reflect.New(p.structType).Elem() for _, f := range p.fields { if defaultVal := f.field.Default; defaultVal != nil { - v, err := f.fieldPacker.Pack(defaultVal.Value(nil)) + v, err := f.fieldPacker.Pack(defaultVal.Deserialize(nil)) if err != nil { return err } @@ -59,7 +59,7 @@ func (b *Builder) Finish() error { return nil } -func (b *Builder) assignPacker(target *packer, schemaType common.Type, reflectType reflect.Type) error { +func (b *Builder) assignPacker(target *packer, schemaType types.Type, reflectType reflect.Type) error { k := typePair{schemaType, reflectType} ref, ok := b.packerMap[k] if !ok { @@ -75,34 +75,47 @@ func (b *Builder) assignPacker(target *packer, schemaType common.Type, reflectTy return nil } -func (b *Builder) makePacker(schemaType common.Type, reflectType reflect.Type) (packer, error) { +func (b *Builder) makePacker(schemaType types.Type, reflectType reflect.Type) (packer, error) { t, nonNull := unwrapNonNull(schemaType) if !nonNull { - if reflectType.Kind() != reflect.Ptr { - return nil, fmt.Errorf("%s is not a pointer", reflectType) - } - elemType := reflectType.Elem() - addPtr := true - if _, ok := t.(*schema.InputObject); ok { - elemType = reflectType // keep pointer for input objects - addPtr = false - } - elem, err := b.makeNonNullPacker(t, elemType) - if err != nil { - return nil, err + if reflectType.Kind() == reflect.Ptr { + elemType := reflectType.Elem() + addPtr := true + if _, ok := t.(*types.InputObject); ok { + elemType = reflectType // keep pointer for input objects + addPtr = false + } + elem, err := b.makeNonNullPacker(t, elemType) + if err != nil { + return nil, err + } + return &nullPacker{ + elemPacker: elem, + valueType: reflectType, + addPtr: addPtr, + }, nil + } else if isNullable(reflectType) { + elemType := reflectType + addPtr := false + elem, err := b.makeNonNullPacker(t, elemType) + if err != nil { + return nil, err + } + return &nullPacker{ + elemPacker: elem, + valueType: reflectType, + addPtr: addPtr, + }, nil + } else { + return nil, fmt.Errorf("%s is not a pointer or a nullable type", reflectType) } - return &nullPacker{ - elemPacker: elem, - valueType: reflectType, - addPtr: addPtr, - }, nil } return b.makeNonNullPacker(t, reflectType) } -func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect.Type) (packer, error) { - if u, ok := reflect.New(reflectType).Interface().(Unmarshaler); ok { +func (b *Builder) makeNonNullPacker(schemaType types.Type, reflectType reflect.Type) (packer, error) { + if u, ok := reflect.New(reflectType).Interface().(decode.Unmarshaler); ok { if !u.ImplementsGraphQLType(schemaType.String()) { return nil, fmt.Errorf("can not unmarshal %s into %s", schemaType, reflectType) } @@ -112,12 +125,12 @@ func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect. } switch t := schemaType.(type) { - case *schema.Scalar: + case *types.ScalarTypeDefinition: return &ValuePacker{ ValueType: reflectType, }, nil - case *schema.Enum: + case *types.EnumTypeDefinition: if reflectType.Kind() != reflect.String { return nil, fmt.Errorf("wrong type, expected %s", reflect.String) } @@ -125,14 +138,14 @@ func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect. ValueType: reflectType, }, nil - case *schema.InputObject: + case *types.InputObject: e, err := b.MakeStructPacker(t.Values, reflectType) if err != nil { return nil, err } return e, nil - case *common.List: + case *types.List: if reflectType.Kind() != reflect.Slice { return nil, fmt.Errorf("expected slice, got %s", reflectType) } @@ -144,7 +157,7 @@ func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect. } return p, nil - case *schema.Object, *schema.Interface, *schema.Union: + case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: return nil, fmt.Errorf("type of kind %s can not be used as input", t.Kind()) default: @@ -152,7 +165,7 @@ func (b *Builder) makeNonNullPacker(schemaType common.Type, reflectType reflect. } } -func (b *Builder) MakeStructPacker(values common.InputValueList, typ reflect.Type) (*StructPacker, error) { +func (b *Builder) MakeStructPacker(values []*types.InputValueDefinition, typ reflect.Type) (*StructPacker, error) { structType := typ usePtr := false if typ.Kind() == reflect.Ptr { @@ -160,7 +173,7 @@ func (b *Builder) MakeStructPacker(values common.InputValueList, typ reflect.Typ usePtr = true } if structType.Kind() != reflect.Struct { - return nil, fmt.Errorf("expected struct or pointer to struct, got %s", typ) + return nil, fmt.Errorf("expected struct or pointer to struct, got %s (hint: missing `args struct { ... }` wrapper for field arguments?)", typ) } var fields []*structPackerField @@ -172,7 +185,7 @@ func (b *Builder) MakeStructPacker(values common.InputValueList, typ reflect.Typ sf, ok := structType.FieldByNameFunc(fx) if !ok { - return nil, fmt.Errorf("missing argument %q", v.Name) + return nil, fmt.Errorf("%s does not define field %q (hint: missing `args struct { ... }` wrapper for field arguments, or missing field on input struct)", typ, v.Name.Name) } if sf.PkgPath != "" { return nil, fmt.Errorf("field %q must be exported", sf.Name) @@ -182,7 +195,7 @@ func (b *Builder) MakeStructPacker(values common.InputValueList, typ reflect.Typ ft := v.Type if v.Default != nil { ft, _ = unwrapNonNull(ft) - ft = &common.NonNull{OfType: ft} + ft = &types.NonNull{OfType: ft} } if err := b.assignPacker(&fe.fieldPacker, ft, sf.Type); err != nil { @@ -209,7 +222,7 @@ type StructPacker struct { } type structPackerField struct { - field *common.InputValue + field *types.InputValueDefinition fieldIndex []int fieldPacker packer } @@ -266,7 +279,7 @@ type nullPacker struct { } func (p *nullPacker) Pack(value interface{}) (reflect.Value, error) { - if value == nil { + if value == nil && !isNullable(p.valueType) { return reflect.Zero(p.valueType), nil } @@ -305,22 +318,17 @@ type unmarshalerPacker struct { } func (p *unmarshalerPacker) Pack(value interface{}) (reflect.Value, error) { - if value == nil { + if value == nil && !isNullable(p.ValueType) { return reflect.Value{}, errors.Errorf("got null for non-null") } v := reflect.New(p.ValueType) - if err := v.Interface().(Unmarshaler).UnmarshalGraphQL(value); err != nil { + if err := v.Interface().(decode.Unmarshaler).UnmarshalGraphQL(value); err != nil { return reflect.Value{}, err } return v.Elem(), nil } -type Unmarshaler interface { - ImplementsGraphQLType(name string) bool - UnmarshalGraphQL(input interface{}) error -} - func unmarshalInput(typ reflect.Type, input interface{}) (interface{}, error) { if reflect.TypeOf(input) == typ { return input, nil @@ -359,8 +367,8 @@ func unmarshalInput(typ reflect.Type, input interface{}) (interface{}, error) { return nil, fmt.Errorf("incompatible type") } -func unwrapNonNull(t common.Type) (common.Type, bool) { - if nn, ok := t.(*common.NonNull); ok { +func unwrapNonNull(t types.Type) (types.Type, bool) { + if nn, ok := t.(*types.NonNull); ok { return nn.OfType, true } return t, false @@ -369,3 +377,14 @@ func unwrapNonNull(t common.Type) (common.Type, bool) { func stripUnderscore(s string) string { return strings.Replace(s, "_", "", -1) } + +// NullUnmarshaller is an unmarshaller that can handle a nil input +type NullUnmarshaller interface { + decode.Unmarshaler + Nullable() +} + +func isNullable(t reflect.Type) bool { + _, ok := reflect.New(t).Interface().(NullUnmarshaller) + return ok +} diff --git a/internal/exec/resolvable/meta.go b/internal/exec/resolvable/meta.go index da7ebe568..c8e8c1dbe 100644 --- a/internal/exec/resolvable/meta.go +++ b/internal/exec/resolvable/meta.go @@ -1,12 +1,10 @@ package resolvable import ( - "fmt" "reflect" - "github.com/tokopedia/graphql-go/internal/common" - "github.com/tokopedia/graphql-go/internal/schema" "github.com/tokopedia/graphql-go/introspection" + "github.com/tokopedia/graphql-go/types" ) // Meta defines the details of the metadata schema for introspection. @@ -14,59 +12,77 @@ type Meta struct { FieldSchema Field FieldType Field FieldTypename Field + FieldService Field Schema *Object Type *Object + Service *Object } -func newMeta(s *schema.Schema) *Meta { +func newMeta(s *types.Schema) *Meta { var err error b := newBuilder(s) - metaSchema := s.Types["__Schema"].(*schema.Object) + metaSchema := s.Types["__Schema"].(*types.ObjectTypeDefinition) so, err := b.makeObjectExec(metaSchema.Name, metaSchema.Fields, nil, false, reflect.TypeOf(&introspection.Schema{})) if err != nil { panic(err) } - metaType := s.Types["__Type"].(*schema.Object) + metaType := s.Types["__Type"].(*types.ObjectTypeDefinition) t, err := b.makeObjectExec(metaType.Name, metaType.Fields, nil, false, reflect.TypeOf(&introspection.Type{})) if err != nil { panic(err) } + metaService := s.Types["_Service"].(*types.ObjectTypeDefinition) + sv, err := b.makeObjectExec(metaService.Name, metaService.Fields, nil, false, reflect.TypeOf(&introspection.Service{})) + if err != nil { + panic(err) + } + if err := b.finish(); err != nil { panic(err) } fieldTypename := Field{ - Field: schema.Field{ + FieldDefinition: types.FieldDefinition{ Name: "__typename", - Type: &common.NonNull{OfType: s.Types["String"]}, + Type: &types.NonNull{OfType: s.Types["String"]}, }, - TraceLabel: fmt.Sprintf("GraphQL field: __typename"), + TraceLabel: "GraphQL field: __typename", } fieldSchema := Field{ - Field: schema.Field{ + FieldDefinition: types.FieldDefinition{ Name: "__schema", Type: s.Types["__Schema"], }, - TraceLabel: fmt.Sprintf("GraphQL field: __schema"), + TraceLabel: "GraphQL field: __schema", } fieldType := Field{ - Field: schema.Field{ + FieldDefinition: types.FieldDefinition{ Name: "__type", Type: s.Types["__Type"], }, - TraceLabel: fmt.Sprintf("GraphQL field: __type"), + TraceLabel: "GraphQL field: __type", + } + + fieldService := Field{ + FieldDefinition: types.FieldDefinition{ + Name: "_service", + Type: s.Types["_Service"], + }, + TraceLabel: "GraphQL field: _service", } return &Meta{ FieldSchema: fieldSchema, FieldTypename: fieldTypename, FieldType: fieldType, + FieldService: fieldService, Schema: so, Type: t, + Service: sv, } } diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go index 70229d06e..0fe7e1c5e 100644 --- a/internal/exec/resolvable/resolvable.go +++ b/internal/exec/resolvable/resolvable.go @@ -3,18 +3,18 @@ package resolvable import ( "context" "fmt" + "github.com/tokopedia/graphql-go/errors" "reflect" "strings" - "github.com/tokopedia/graphql-go/errors" - "github.com/tokopedia/graphql-go/internal/common" + "github.com/tokopedia/graphql-go/decode" "github.com/tokopedia/graphql-go/internal/exec/packer" - "github.com/tokopedia/graphql-go/internal/schema" + "github.com/tokopedia/graphql-go/types" ) type Schema struct { *Meta - schema.Schema + types.Schema Query Resolvable Mutation Resolvable Subscription Resolvable @@ -32,10 +32,10 @@ type Object struct { } type Field struct { - schema.Field + types.FieldDefinition TypeName string MethodIndex int - FieldIndex int + FieldIndex []int HasContext bool HasError bool ArgsPacker *packer.StructPacker @@ -44,7 +44,7 @@ type Field struct { } func (f *Field) UseMethodResolver() bool { - return f.FieldIndex == -1 + return len(f.FieldIndex) == 0 } type TypeAssertion struct { @@ -62,7 +62,7 @@ func (*Object) isResolvable() {} func (*List) isResolvable() {} func (*Scalar) isResolvable() {} -func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) { +func ApplyResolver(s *types.Schema, resolver interface{}) (*Schema, error) { if resolver == nil { return &Schema{Meta: newMeta(s), Schema: *s}, nil } @@ -104,13 +104,13 @@ func ApplyResolver(s *schema.Schema, resolver interface{}) (*Schema, error) { } type execBuilder struct { - schema *schema.Schema + schema *types.Schema resMap map[typePair]*resMapEntry packerBuilder *packer.Builder } type typePair struct { - graphQLType common.Type + graphQLType types.Type resolverType reflect.Type } @@ -119,7 +119,7 @@ type resMapEntry struct { targets []*Resolvable } -func newBuilder(s *schema.Schema) *execBuilder { +func newBuilder(s *types.Schema) *execBuilder { return &execBuilder{ schema: s, resMap: make(map[typePair]*resMapEntry), @@ -137,7 +137,7 @@ func (b *execBuilder) finish() error { return b.packerBuilder.Finish() } -func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType reflect.Type) error { +func (b *execBuilder) assignExec(target *Resolvable, t types.Type, resolverType reflect.Type) error { k := typePair{t, resolverType} ref, ok := b.resMap[k] if !ok { @@ -153,19 +153,19 @@ func (b *execBuilder) assignExec(target *Resolvable, t common.Type, resolverType return nil } -func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolvable, error) { +func (b *execBuilder) makeExec(t types.Type, resolverType reflect.Type) (Resolvable, error) { var nonNull bool t, nonNull = unwrapNonNull(t) switch t := t.(type) { - case *schema.Object: + case *types.ObjectTypeDefinition: return b.makeObjectExec(t.Name, t.Fields, nil, nonNull, resolverType) - case *schema.Interface: + case *types.InterfaceTypeDefinition: return b.makeObjectExec(t.Name, t.Fields, t.PossibleTypes, nonNull, resolverType) - case *schema.Union: - return b.makeObjectExec(t.Name, nil, t.PossibleTypes, nonNull, resolverType) + case *types.Union: + return b.makeObjectExec(t.Name, nil, t.UnionMemberTypes, nonNull, resolverType) } if !nonNull { @@ -176,13 +176,13 @@ func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolv } switch t := t.(type) { - case *schema.Scalar: + case *types.ScalarTypeDefinition: return makeScalarExec(t, resolverType) - case *schema.Enum: + case *types.EnumTypeDefinition: return &Scalar{}, nil - case *common.List: + case *types.List: if resolverType.Kind() != reflect.Slice { return nil, fmt.Errorf("%s is not a slice", resolverType) } @@ -197,7 +197,7 @@ func (b *execBuilder) makeExec(t common.Type, resolverType reflect.Type) (Resolv } } -func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, error) { +func makeScalarExec(t *types.ScalarTypeDefinition, resolverType reflect.Type) (Resolvable, error) { implementsType := false switch r := reflect.New(resolverType).Interface().(type) { case *int32: @@ -208,16 +208,17 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er implementsType = t.Name == "String" case *bool: implementsType = t.Name == "Boolean" - case packer.Unmarshaler: + case decode.Unmarshaler: implementsType = r.ImplementsGraphQLType(t.Name) } + if !implementsType { return nil, fmt.Errorf("can not use %s as %s", resolverType, t.Name) } return &Scalar{}, nil } -func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object, +func (b *execBuilder) makeObjectExec(typeName string, fields types.FieldsDefinition, possibleTypes []*types.ObjectTypeDefinition, nonNull bool, resolverType reflect.Type) (*Object, error) { if !nonNull { if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface { @@ -229,13 +230,17 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p Fields := make(map[string]*Field) rt := unwrapPtr(resolverType) + fieldsCount := fieldCount(rt, map[string]int{}) for _, f := range fields { - fieldIndex := -1 + var fieldIndex []int methodIndex := findMethod(resolverType, f.Name) if b.schema.UseFieldResolvers && methodIndex == -1 { - fieldIndex = findField(rt, f.Name) + if fieldsCount[strings.ToLower(stripUnderscore(f.Name))] > 1 { + return nil, fmt.Errorf("%s does not resolve %q: ambiguous field %q", resolverType, typeName, f.Name) + } + fieldIndex = findField(rt, f.Name, []int{}) } - if methodIndex == -1 && fieldIndex == -1 { + if methodIndex == -1 && len(fieldIndex) == 0 { hint := "" if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 { hint = " (hint: the method exists on the pointer type)" @@ -248,11 +253,17 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p if methodIndex != -1 { m = resolverType.Method(methodIndex) } else { - sf = rt.Field(fieldIndex) + sf = rt.FieldByIndex(fieldIndex) } fe, err := b.makeFieldExec(typeName, f, m, sf, methodIndex, fieldIndex, methodHasReceiver) if err != nil { - return nil, fmt.Errorf("%s\n\treturned by (%s).%s", err, resolverType, m.Name) + var resolverName string + if methodIndex != -1 { + resolverName = m.Name + } else { + resolverName = sf.Name + } + return nil, fmt.Errorf("%s\n\tused by (%s).%s", err, resolverType, resolverName) } Fields[f.Name] = fe } @@ -267,7 +278,15 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p if methodIndex == -1 { return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name, impl.Name) } - if resolverType.Method(methodIndex).Type.NumOut() != 2 { + m := resolverType.Method(methodIndex) + expectedIn := 0 + if methodHasReceiver { + expectedIn = 1 + } + if m.Type.NumIn() != expectedIn { + return nil, fmt.Errorf("%s does not resolve %q: method %q should't have any arguments", resolverType, typeName, "To"+impl.Name) + } + if m.Type.NumOut() != 2 { return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name) } a := &TypeAssertion{ @@ -291,8 +310,8 @@ var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() var extnErrorInterfaceType = reflect.TypeOf((*errors.GraphQLError)(nil)).Elem() -func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, sf reflect.StructField, - methodIndex, fieldIndex int, methodHasReceiver bool) (*Field, error) { +func (b *execBuilder) makeFieldExec(typeName string, f *types.FieldDefinition, m reflect.Method, sf reflect.StructField, + methodIndex int, fieldIndex []int, methodHasReceiver bool) (*Field, error) { var argsPacker *packer.StructPacker var hasError bool @@ -313,20 +332,20 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect. in = in[1:] } - if len(f.Args) > 0 { + if len(f.Arguments) > 0 { if len(in) == 0 { - return nil, fmt.Errorf("must have parameter for field arguments") + return nil, fmt.Errorf("must have `args struct { ... }` argument for field arguments") } var err error - argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0]) + argsPacker, err = b.packerBuilder.MakeStructPacker(f.Arguments, in[0]) if err != nil { return nil, err } in = in[1:] } - if len(in) > 0 { - return nil, fmt.Errorf("too many parameters") + if len(in) > 0 { + return nil, fmt.Errorf("too many arguments") } maxNumOfReturns := 2 @@ -340,27 +359,28 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect. hasError = m.Type.NumOut() == maxNumOfReturns if hasError { - if m.Type.Out(1) != errorType && !m.Type.Out(1).Implements(extnErrorInterfaceType) { - return nil, fmt.Errorf(`must have "error" or implements errors.GraphQLError interface as its second return value but the type is %v`, m.Type.Out(1)) + if m.Type.Out(maxNumOfReturns-1) != errorType && !m.Type.Out(1).Implements(extnErrorInterfaceType) { + return nil, fmt.Errorf(`must have "error" as its last return value`) } } } fe := &Field{ - Field: *f, - TypeName: typeName, - MethodIndex: methodIndex, - FieldIndex: fieldIndex, - HasContext: hasContext, - ArgsPacker: argsPacker, - HasError: hasError, - TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name), + FieldDefinition: *f, + TypeName: typeName, + MethodIndex: methodIndex, + FieldIndex: fieldIndex, + HasContext: hasContext, + ArgsPacker: argsPacker, + HasError: hasError, + TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name), } var out reflect.Type if methodIndex != -1 { out = m.Type.Out(0) - if typeName == "Subscription" && out.Kind() == reflect.Chan { + sub, ok := b.schema.EntryPoints["subscription"] + if ok && typeName == sub.TypeName() && out.Kind() == reflect.Chan { out = m.Type.Out(0).Elem() } } else { @@ -382,17 +402,50 @@ func findMethod(t reflect.Type, name string) int { return -1 } -func findField(t reflect.Type, name string) int { +func findField(t reflect.Type, name string, index []int) []int { for i := 0; i < t.NumField(); i++ { - if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Field(i).Name)) { - return i + field := t.Field(i) + + if field.Type.Kind() == reflect.Struct && field.Anonymous { + newIndex := findField(field.Type, name, []int{i}) + if len(newIndex) > 1 { + return append(index, newIndex...) + } + } + + if strings.EqualFold(stripUnderscore(name), stripUnderscore(field.Name)) { + return append(index, i) } } - return -1 + + return index +} + +// fieldCount helps resolve ambiguity when more than one embedded struct contains fields with the same name. +func fieldCount(t reflect.Type, count map[string]int) map[string]int { + if t.Kind() != reflect.Struct { + return nil + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + fieldName := strings.ToLower(stripUnderscore(field.Name)) + + if field.Type.Kind() == reflect.Struct && field.Anonymous { + count = fieldCount(field.Type, count) + } else { + if _, ok := count[fieldName]; !ok { + count[fieldName] = 0 + } + count[fieldName]++ + } + } + + return count } -func unwrapNonNull(t common.Type) (common.Type, bool) { - if nn, ok := t.(*common.NonNull); ok { +func unwrapNonNull(t types.Type) (types.Type, bool) { + if nn, ok := t.(*types.NonNull); ok { return nn.OfType, true } return t, false diff --git a/internal/exec/selected/selected.go b/internal/exec/selected/selected.go index b638bd942..d0f897e95 100644 --- a/internal/exec/selected/selected.go +++ b/internal/exec/selected/selected.go @@ -6,17 +6,16 @@ import ( "sync" "github.com/tokopedia/graphql-go/errors" - "github.com/tokopedia/graphql-go/internal/common" "github.com/tokopedia/graphql-go/internal/exec/packer" "github.com/tokopedia/graphql-go/internal/exec/resolvable" "github.com/tokopedia/graphql-go/internal/query" - "github.com/tokopedia/graphql-go/internal/schema" "github.com/tokopedia/graphql-go/introspection" + "github.com/tokopedia/graphql-go/types" ) type Request struct { - Schema *schema.Schema - Doc *query.Document + Schema *types.Schema + Doc *types.ExecutableDefinition Vars map[string]interface{} Mu sync.Mutex Errs []*errors.QueryError @@ -29,7 +28,7 @@ func (r *Request) AddError(err *errors.QueryError) { r.Mu.Unlock() } -func ApplyOperation(r *Request, s *resolvable.Schema, op *query.Operation) []Selection { +func ApplyOperation(r *Request, s *resolvable.Schema, op *types.OperationDefinition) []Selection { var obj *resolvable.Object switch op.Type { case query.Query: @@ -70,10 +69,10 @@ func (*SchemaField) isSelection() {} func (*TypeAssertion) isSelection() {} func (*TypenameField) isSelection() {} -func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, sels []query.Selection) (flattenedSels []Selection) { +func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, sels []types.Selection) (flattenedSels []Selection) { for _, sel := range sels { switch sel := sel.(type) { - case *query.Field: + case *types.Field: field := sel if skipByDirective(r, field.Directives) { continue @@ -81,19 +80,19 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s switch field.Name.Name { case "__typename": - if !r.DisableIntrospection { - flattenedSels = append(flattenedSels, &TypenameField{ - Object: *e, - Alias: field.Alias.Name, - }) - } + // __typename is available even though r.DisableIntrospection == true + // because it is necessary when using union types and interfaces: https://graphql.org/learn/schema/#union-types + flattenedSels = append(flattenedSels, &TypenameField{ + Object: *e, + Alias: field.Alias.Name, + }) case "__schema": if !r.DisableIntrospection { flattenedSels = append(flattenedSels, &SchemaField{ Field: s.Meta.FieldSchema, Alias: field.Alias.Name, - Sels: applySelectionSet(r, s, s.Meta.Schema, field.Selections), + Sels: applySelectionSet(r, s, s.Meta.Schema, field.SelectionSet), Async: true, FixedResult: reflect.ValueOf(introspection.WrapSchema(r.Schema)), }) @@ -102,7 +101,7 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s case "__type": if !r.DisableIntrospection { p := packer.ValuePacker{ValueType: reflect.TypeOf("")} - v, err := p.Pack(field.Arguments.MustGet("name").Value(r.Vars)) + v, err := p.Pack(field.Arguments.MustGet("name").Deserialize(r.Vars)) if err != nil { r.AddError(errors.Errorf("%s", err)) return nil @@ -116,12 +115,23 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s flattenedSels = append(flattenedSels, &SchemaField{ Field: s.Meta.FieldType, Alias: field.Alias.Name, - Sels: applySelectionSet(r, s, s.Meta.Type, field.Selections), + Sels: applySelectionSet(r, s, s.Meta.Type, field.SelectionSet), Async: true, FixedResult: reflect.ValueOf(introspection.WrapType(t)), }) } + case "_service": + if !r.DisableIntrospection { + flattenedSels = append(flattenedSels, &SchemaField{ + Field: s.Meta.FieldService, + Alias: field.Alias.Name, + Sels: applySelectionSet(r, s, s.Meta.Service, field.SelectionSet), + Async: true, + FixedResult: reflect.ValueOf(introspection.WrapService(r.Schema)), + }) + } + default: fe := e.Fields[field.Name.Name] @@ -130,7 +140,7 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s if fe.ArgsPacker != nil { args = make(map[string]interface{}) for _, arg := range field.Arguments { - args[arg.Name.Name] = arg.Value.Value(r.Vars) + args[arg.Name.Name] = arg.Value.Deserialize(r.Vars) } var err error packedArgs, err = fe.ArgsPacker.Pack(args) @@ -140,7 +150,7 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s } } - fieldSels := applyField(r, s, fe.ValueExec, field.Selections) + fieldSels := applyField(r, s, fe.ValueExec, field.SelectionSet) flattenedSels = append(flattenedSels, &SchemaField{ Field: *fe, Alias: field.Alias.Name, @@ -151,14 +161,14 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s }) } - case *query.InlineFragment: + case *types.InlineFragment: frag := sel if skipByDirective(r, frag.Directives) { continue } flattenedSels = append(flattenedSels, applyFragment(r, s, e, &frag.Fragment)...) - case *query.FragmentSpread: + case *types.FragmentSpread: spread := sel if skipByDirective(r, spread.Directives) { continue @@ -172,22 +182,45 @@ func applySelectionSet(r *Request, s *resolvable.Schema, e *resolvable.Object, s return } -func applyFragment(r *Request, s *resolvable.Schema, e *resolvable.Object, frag *query.Fragment) []Selection { - if frag.On.Name != "" && frag.On.Name != e.Name { - a, ok := e.TypeAssertions[frag.On.Name] - if !ok { - panic(fmt.Errorf("%q does not implement %q", frag.On, e.Name)) // TODO proper error handling +func applyFragment(r *Request, s *resolvable.Schema, e *resolvable.Object, frag *types.Fragment) []Selection { + if frag.On.Name != e.Name { + t := r.Schema.Resolve(frag.On.Name) + face, ok := t.(*types.InterfaceTypeDefinition) + if !ok && frag.On.Name != "" { + a, ok2 := e.TypeAssertions[frag.On.Name] + if !ok2 { + panic(fmt.Errorf("%q does not implement %q", frag.On, e.Name)) // TODO proper error handling + } + + return []Selection{&TypeAssertion{ + TypeAssertion: *a, + Sels: applySelectionSet(r, s, a.TypeExec.(*resolvable.Object), frag.Selections), + }} } + if ok && len(face.PossibleTypes) > 0 { + sels := []Selection{} + for _, t := range face.PossibleTypes { + if t.Name == e.Name { + return applySelectionSet(r, s, e, frag.Selections) + } - return []Selection{&TypeAssertion{ - TypeAssertion: *a, - Sels: applySelectionSet(r, s, a.TypeExec.(*resolvable.Object), frag.Selections), - }} + if a, ok := e.TypeAssertions[t.Name]; ok { + sels = append(sels, &TypeAssertion{ + TypeAssertion: *a, + Sels: applySelectionSet(r, s, a.TypeExec.(*resolvable.Object), frag.Selections), + }) + } + } + if len(sels) == 0 { + panic(fmt.Errorf("%q does not implement %q", e.Name, frag.On)) // TODO proper error handling + } + return sels + } } return applySelectionSet(r, s, e, frag.Selections) } -func applyField(r *Request, s *resolvable.Schema, e resolvable.Resolvable, sels []query.Selection) []Selection { +func applyField(r *Request, s *resolvable.Schema, e resolvable.Resolvable, sels []types.Selection) []Selection { switch e := e.(type) { case *resolvable.Object: return applySelectionSet(r, s, e, sels) @@ -200,10 +233,10 @@ func applyField(r *Request, s *resolvable.Schema, e resolvable.Resolvable, sels } } -func skipByDirective(r *Request, directives common.DirectiveList) bool { +func skipByDirective(r *Request, directives types.DirectiveList) bool { if d := directives.Get("skip"); d != nil { p := packer.ValuePacker{ValueType: reflect.TypeOf(false)} - v, err := p.Pack(d.Args.MustGet("if").Value(r.Vars)) + v, err := p.Pack(d.Arguments.MustGet("if").Deserialize(r.Vars)) if err != nil { r.AddError(errors.Errorf("%s", err)) } @@ -214,7 +247,7 @@ func skipByDirective(r *Request, directives common.DirectiveList) bool { if d := directives.Get("include"); d != nil { p := packer.ValuePacker{ValueType: reflect.TypeOf(false)} - v, err := p.Pack(d.Args.MustGet("if").Value(r.Vars)) + v, err := p.Pack(d.Arguments.MustGet("if").Deserialize(r.Vars)) if err != nil { r.AddError(errors.Errorf("%s", err)) } diff --git a/internal/exec/subscribe.go b/internal/exec/subscribe.go index 10d8aede0..625710ee4 100644 --- a/internal/exec/subscribe.go +++ b/internal/exec/subscribe.go @@ -9,10 +9,9 @@ import ( "time" "github.com/tokopedia/graphql-go/errors" - "github.com/tokopedia/graphql-go/internal/common" "github.com/tokopedia/graphql-go/internal/exec/resolvable" "github.com/tokopedia/graphql-go/internal/exec/selected" - "github.com/tokopedia/graphql-go/internal/query" + "github.com/tokopedia/graphql-go/types" ) type Response struct { @@ -20,7 +19,7 @@ type Response struct { Errors []*errors.QueryError } -func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query.Operation) <-chan *Response { +func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *types.OperationDefinition) <-chan *Response { var result reflect.Value var f *fieldToExec var err *errors.QueryError @@ -49,14 +48,29 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query result = callOut[0] if f.field.HasError && !callOut[1].IsNil() { - resolverErr := callOut[1].Interface().(error) - err = errors.Errorf("%s", resolverErr) - err.ResolverError = resolverErr + switch resolverErr := callOut[1].Interface().(type) { + case *errors.QueryError: + err = resolverErr + case error: + err = errors.Errorf("%s", resolverErr) + err.ResolverError = resolverErr + default: + panic(fmt.Errorf("can only deal with *QueryError and error types, got %T", resolverErr)) + } } }() + // Handles the case where the locally executed func above panicked + if len(r.Request.Errs) > 0 { + return sendAndReturnClosed(&Response{Errors: r.Request.Errs}) + } + + if f == nil { + return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}}) + } + if err != nil { - if _, nonNullChild := f.field.Type.(*common.NonNull); nonNullChild { + if _, nonNullChild := f.field.Type.(*types.NonNull); nonNullChild { return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}}) } return sendAndReturnClosed(&Response{Data: []byte(fmt.Sprintf(`{"%s":null}`, f.field.Alias)), Errors: []*errors.QueryError{err}}) @@ -68,7 +82,7 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query c := make(chan *Response) // TODO: handle resolver nil channel better? - if result == reflect.Zero(result.Type()) { + if result.IsZero() { close(c) return c } @@ -111,8 +125,12 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query } var out bytes.Buffer func() { - // TODO: configurable timeout - subCtx, cancel := context.WithTimeout(ctx, time.Second) + timeout := r.SubscribeResolverTimeout + if timeout == 0 { + timeout = time.Second + } + + subCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() // resolve response @@ -123,7 +141,7 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *query subR.execSelectionSet(subCtx, f.sels, f.field.Type, &pathSegment{nil, f.field.Alias}, s, resp, &buf) propagateChildError := false - if _, nonNullChild := f.field.Type.(*common.NonNull); nonNullChild && resolvedToNull(&buf) { + if _, nonNullChild := f.field.Type.(*types.NonNull); nonNullChild && resolvedToNull(&buf) { propagateChildError = true } diff --git a/internal/query/query.go b/internal/query/query.go index 95585795b..0d270ebc2 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -6,113 +6,35 @@ import ( "github.com/tokopedia/graphql-go/errors" "github.com/tokopedia/graphql-go/internal/common" + "github.com/tokopedia/graphql-go/types" ) -type Document struct { - Operations OperationList - Fragments FragmentList -} - -type OperationList []*Operation - -func (l OperationList) Get(name string) *Operation { - for _, f := range l { - if f.Name.Name == name { - return f - } - } - return nil -} - -type FragmentList []*FragmentDecl - -func (l FragmentList) Get(name string) *FragmentDecl { - for _, f := range l { - if f.Name.Name == name { - return f - } - } - return nil -} - -type Operation struct { - Type OperationType - Name common.Ident - Vars common.InputValueList - Selections []Selection - Directives common.DirectiveList - Loc errors.Location -} - -type OperationType string - const ( - Query OperationType = "QUERY" - Mutation = "MUTATION" - Subscription = "SUBSCRIPTION" + Query types.OperationType = "QUERY" + Mutation types.OperationType = "MUTATION" + Subscription types.OperationType = "SUBSCRIPTION" ) -type Fragment struct { - On common.TypeName - Selections []Selection -} - -type FragmentDecl struct { - Fragment - Name common.Ident - Directives common.DirectiveList - Loc errors.Location -} - -type Selection interface { - isSelection() -} - -type Field struct { - Alias common.Ident - Name common.Ident - Arguments common.ArgumentList - Directives common.DirectiveList - Selections []Selection - SelectionSetLoc errors.Location -} - -type InlineFragment struct { - Fragment - Directives common.DirectiveList - Loc errors.Location -} - -type FragmentSpread struct { - Name common.Ident - Directives common.DirectiveList - Loc errors.Location -} - -func (Field) isSelection() {} -func (InlineFragment) isSelection() {} -func (FragmentSpread) isSelection() {} - -func Parse(queryString string) (*Document, *errors.QueryError) { +func Parse(queryString string) (*types.ExecutableDefinition, *errors.QueryError) { l := common.NewLexer(queryString, false) - var doc *Document - err := l.CatchSyntaxError(func() { doc = parseDocument(l) }) + var execDef *types.ExecutableDefinition + err := l.CatchSyntaxError(func() { execDef = parseExecutableDefinition(l) }) if err != nil { return nil, err } - return doc, nil + return execDef, nil } -func parseDocument(l *common.Lexer) *Document { - d := &Document{} +func parseExecutableDefinition(l *common.Lexer) *types.ExecutableDefinition { + ed := &types.ExecutableDefinition{} l.ConsumeWhitespace() for l.Peek() != scanner.EOF { if l.Peek() == '{' { - op := &Operation{Type: Query, Loc: l.Location()} + op := &types.OperationDefinition{Type: Query, Loc: l.Location()} op.Selections = parseSelectionSet(l) - d.Operations = append(d.Operations, op) + ed.Operations = append(ed.Operations, op) continue } @@ -121,28 +43,28 @@ func parseDocument(l *common.Lexer) *Document { case "query": op := parseOperation(l, Query) op.Loc = loc - d.Operations = append(d.Operations, op) + ed.Operations = append(ed.Operations, op) case "mutation": - d.Operations = append(d.Operations, parseOperation(l, Mutation)) + ed.Operations = append(ed.Operations, parseOperation(l, Mutation)) case "subscription": - d.Operations = append(d.Operations, parseOperation(l, Subscription)) + ed.Operations = append(ed.Operations, parseOperation(l, Subscription)) case "fragment": frag := parseFragment(l) frag.Loc = loc - d.Fragments = append(d.Fragments, frag) + ed.Fragments = append(ed.Fragments, frag) default: l.SyntaxError(fmt.Sprintf(`unexpected %q, expecting "fragment"`, x)) } } - return d + return ed } -func parseOperation(l *common.Lexer, opType OperationType) *Operation { - op := &Operation{Type: opType} +func parseOperation(l *common.Lexer, opType types.OperationType) *types.OperationDefinition { + op := &types.OperationDefinition{Type: opType} op.Name.Loc = l.Location() if l.Peek() == scanner.Ident { op.Name = l.ConsumeIdentWithLoc() @@ -163,18 +85,18 @@ func parseOperation(l *common.Lexer, opType OperationType) *Operation { return op } -func parseFragment(l *common.Lexer) *FragmentDecl { - f := &FragmentDecl{} +func parseFragment(l *common.Lexer) *types.FragmentDefinition { + f := &types.FragmentDefinition{} f.Name = l.ConsumeIdentWithLoc() l.ConsumeKeyword("on") - f.On = common.TypeName{Ident: l.ConsumeIdentWithLoc()} + f.On = types.TypeName{Ident: l.ConsumeIdentWithLoc()} f.Directives = common.ParseDirectives(l) f.Selections = parseSelectionSet(l) return f } -func parseSelectionSet(l *common.Lexer) []Selection { - var sels []Selection +func parseSelectionSet(l *common.Lexer) []types.Selection { + var sels []types.Selection l.ConsumeToken('{') for l.Peek() != '}' { sels = append(sels, parseSelection(l)) @@ -183,15 +105,15 @@ func parseSelectionSet(l *common.Lexer) []Selection { return sels } -func parseSelection(l *common.Lexer) Selection { +func parseSelection(l *common.Lexer) types.Selection { if l.Peek() == '.' { return parseSpread(l) } - return parseField(l) + return parseFieldDef(l) } -func parseField(l *common.Lexer) *Field { - f := &Field{} +func parseFieldDef(l *common.Lexer) *types.Field { + f := &types.Field{} f.Alias = l.ConsumeIdentWithLoc() f.Name = f.Alias if l.Peek() == ':' { @@ -199,34 +121,34 @@ func parseField(l *common.Lexer) *Field { f.Name = l.ConsumeIdentWithLoc() } if l.Peek() == '(' { - f.Arguments = common.ParseArguments(l) + f.Arguments = common.ParseArgumentList(l) } f.Directives = common.ParseDirectives(l) if l.Peek() == '{' { f.SelectionSetLoc = l.Location() - f.Selections = parseSelectionSet(l) + f.SelectionSet = parseSelectionSet(l) } return f } -func parseSpread(l *common.Lexer) Selection { +func parseSpread(l *common.Lexer) types.Selection { loc := l.Location() l.ConsumeToken('.') l.ConsumeToken('.') l.ConsumeToken('.') - f := &InlineFragment{Loc: loc} + f := &types.InlineFragment{Loc: loc} if l.Peek() == scanner.Ident { ident := l.ConsumeIdentWithLoc() if ident.Name != "on" { - fs := &FragmentSpread{ + fs := &types.FragmentSpread{ Name: ident, Loc: loc, } fs.Directives = common.ParseDirectives(l) return fs } - f.On = common.TypeName{Ident: l.ConsumeIdentWithLoc()} + f.On = types.TypeName{Ident: l.ConsumeIdentWithLoc()} } f.Directives = common.ParseDirectives(l) f.Selections = parseSelectionSet(l) diff --git a/internal/schema/meta.go b/internal/schema/meta.go index 2e3118301..a19715d5c 100644 --- a/internal/schema/meta.go +++ b/internal/schema/meta.go @@ -1,17 +1,23 @@ package schema +import ( + "github.com/tokopedia/graphql-go/types" +) + func init() { _ = newMeta() } // newMeta initializes an instance of the meta Schema. -func newMeta() *Schema { - s := &Schema{ - entryPointNames: make(map[string]string), - Types: make(map[string]NamedType), - Directives: make(map[string]*DirectiveDecl), +func newMeta() *types.Schema { + s := &types.Schema{ + EntryPointNames: make(map[string]string), + Types: make(map[string]types.NamedType), + Directives: make(map[string]*types.DirectiveDefinition), } - if err := s.Parse(metaSrc, false); err != nil { + + err := Parse(s, metaSrc, false) + if err != nil { panic(err) } return s @@ -194,4 +200,8 @@ var metaSrc = ` # Indicates this type is a non-null. ` + "`" + `ofType` + "`" + ` is a valid field. NON_NULL } + + type _Service { + sdl: String! + } ` diff --git a/internal/schema/schema.go b/internal/schema/schema.go index b98ec42b4..5a7e0308c 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -6,247 +6,15 @@ import ( "github.com/tokopedia/graphql-go/errors" "github.com/tokopedia/graphql-go/internal/common" + "github.com/tokopedia/graphql-go/types" ) -// Schema represents a GraphQL service's collective type system capabilities. -// A schema is defined in terms of the types and directives it supports as well as the root -// operation types for each kind of operation: `query`, `mutation`, and `subscription`. -// -// For a more formal definition, read the relevant section in the specification: -// -// http://facebook.github.io/graphql/draft/#sec-Schema -type Schema struct { - // EntryPoints determines the place in the type system where `query`, `mutation`, and - // `subscription` operations begin. - // - // http://facebook.github.io/graphql/draft/#sec-Root-Operation-Types - // - // NOTE: The specification refers to this concept as "Root Operation Types". - // TODO: Rename the `EntryPoints` field to `RootOperationTypes` to align with spec terminology. - EntryPoints map[string]NamedType - - // Types are the fundamental unit of any GraphQL schema. - // There are six kinds of named types, and two wrapping types. - // - // http://facebook.github.io/graphql/draft/#sec-Types - Types map[string]NamedType - - // TODO: Type extensions? - // http://facebook.github.io/graphql/draft/#sec-Type-Extensions - - // Directives are used to annotate various parts of a GraphQL document as an indicator that they - // should be evaluated differently by a validator, executor, or client tool such as a code - // generator. - // - // http://facebook.github.io/graphql/draft/#sec-Type-System.Directives - Directives map[string]*DirectiveDecl - - UseFieldResolvers bool - - entryPointNames map[string]string - objects []*Object - unions []*Union - enums []*Enum - extensions []*Extension -} - -// Resolve a named type in the schema by its name. -func (s *Schema) Resolve(name string) common.Type { - return s.Types[name] -} - -// NamedType represents a type with a name. -// -// http://facebook.github.io/graphql/draft/#NamedType -type NamedType interface { - common.Type - TypeName() string - Description() string -} - -// Scalar types represent primitive leaf values (e.g. a string or an integer) in a GraphQL type -// system. -// -// GraphQL responses take the form of a hierarchical tree; the leaves on these trees are GraphQL -// scalars. -// -// http://facebook.github.io/graphql/draft/#sec-Scalars -type Scalar struct { - Name string - Desc string - // TODO: Add a list of directives? -} - -// Object types represent a list of named fields, each of which yield a value of a specific type. -// -// GraphQL queries are hierarchical and composed, describing a tree of information. -// While Scalar types describe the leaf values of these hierarchical types, Objects describe the -// intermediate levels. -// -// http://facebook.github.io/graphql/draft/#sec-Objects -type Object struct { - Name string - Interfaces []*Interface - Fields FieldList - Desc string - // TODO: Add a list of directives? - - interfaceNames []string -} - -// Interface types represent a list of named fields and their arguments. -// -// GraphQL objects can then implement these interfaces which requires that the object type will -// define all fields defined by those interfaces. -// -// http://facebook.github.io/graphql/draft/#sec-Interfaces -type Interface struct { - Name string - PossibleTypes []*Object - Fields FieldList // NOTE: the spec refers to this as `FieldsDefinition`. - Desc string - // TODO: Add a list of directives? -} - -// Union types represent objects that could be one of a list of GraphQL object types, but provides no -// guaranteed fields between those types. -// -// They also differ from interfaces in that object types declare what interfaces they implement, but -// are not aware of what unions contain them. -// -// http://facebook.github.io/graphql/draft/#sec-Unions -type Union struct { - Name string - PossibleTypes []*Object // NOTE: the spec refers to this as `UnionMemberTypes`. - Desc string - // TODO: Add a list of directives? - - typeNames []string -} - -// Enum types describe a set of possible values. -// -// Like scalar types, Enum types also represent leaf values in a GraphQL type system. -// -// http://facebook.github.io/graphql/draft/#sec-Enums -type Enum struct { - Name string - Values []*EnumValue // NOTE: the spec refers to this as `EnumValuesDefinition`. - Desc string - // TODO: Add a list of directives? -} - -// EnumValue types are unique values that may be serialized as a string: the name of the -// represented value. -// -// http://facebook.github.io/graphql/draft/#EnumValueDefinition -type EnumValue struct { - Name string - Directives common.DirectiveList - Desc string - // TODO: Add a list of directives? -} - -// InputObject types define a set of input fields; the input fields are either scalars, enums, or -// other input objects. -// -// This allows arguments to accept arbitrarily complex structs. -// -// http://facebook.github.io/graphql/draft/#sec-Input-Objects -type InputObject struct { - Name string - Desc string - Values common.InputValueList - // TODO: Add a list of directives? -} - -// Extension type defines a GraphQL type extension. -// Schemas, Objects, Inputs and Scalars can be extended. -// -// https://facebook.github.io/graphql/draft/#sec-Type-System-Extensions -type Extension struct { - Type NamedType - // TODO: Add a list of directives -} - -// FieldsList is a list of an Object's Fields. -// -// http://facebook.github.io/graphql/draft/#FieldsDefinition -type FieldList []*Field - -// Get iterates over the field list, returning a pointer-to-Field when the field name matches the -// provided `name` argument. -// Returns nil when no field was found by that name. -func (l FieldList) Get(name string) *Field { - for _, f := range l { - if f.Name == name { - return f - } - } - return nil -} - -// Names returns a string slice of the field names in the FieldList. -func (l FieldList) Names() []string { - names := make([]string, len(l)) - for i, f := range l { - names[i] = f.Name - } - return names -} - -// http://facebook.github.io/graphql/draft/#sec-Type-System.Directives -type DirectiveDecl struct { - Name string - Desc string - Locs []string - Args common.InputValueList -} - -func (*Scalar) Kind() string { return "SCALAR" } -func (*Object) Kind() string { return "OBJECT" } -func (*Interface) Kind() string { return "INTERFACE" } -func (*Union) Kind() string { return "UNION" } -func (*Enum) Kind() string { return "ENUM" } -func (*InputObject) Kind() string { return "INPUT_OBJECT" } - -func (t *Scalar) String() string { return t.Name } -func (t *Object) String() string { return t.Name } -func (t *Interface) String() string { return t.Name } -func (t *Union) String() string { return t.Name } -func (t *Enum) String() string { return t.Name } -func (t *InputObject) String() string { return t.Name } - -func (t *Scalar) TypeName() string { return t.Name } -func (t *Object) TypeName() string { return t.Name } -func (t *Interface) TypeName() string { return t.Name } -func (t *Union) TypeName() string { return t.Name } -func (t *Enum) TypeName() string { return t.Name } -func (t *InputObject) TypeName() string { return t.Name } - -func (t *Scalar) Description() string { return t.Desc } -func (t *Object) Description() string { return t.Desc } -func (t *Interface) Description() string { return t.Desc } -func (t *Union) Description() string { return t.Desc } -func (t *Enum) Description() string { return t.Desc } -func (t *InputObject) Description() string { return t.Desc } - -// Field is a conceptual function which yields values. -// http://facebook.github.io/graphql/draft/#FieldDefinition -type Field struct { - Name string - Args common.InputValueList // NOTE: the spec refers to this as `ArgumentsDefinition`. - Type common.Type - Directives common.DirectiveList - Desc string -} - // New initializes an instance of Schema. -func New() *Schema { - s := &Schema{ - entryPointNames: make(map[string]string), - Types: make(map[string]NamedType), - Directives: make(map[string]*DirectiveDecl), +func New() *types.Schema { + s := &types.Schema{ + EntryPointNames: make(map[string]string), + Types: make(map[string]types.NamedType), + Directives: make(map[string]*types.DirectiveDefinition), } m := newMeta() for n, t := range m.Types { @@ -258,10 +26,8 @@ func New() *Schema { return s } -// Parse the schema string. -func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { +func Parse(s *types.Schema, schemaString string, useStringDescriptions bool) error { l := common.NewLexer(schemaString, useStringDescriptions) - err := l.CatchSyntaxError(func() { parseSchema(s, l) }) if err != nil { return err @@ -277,7 +43,7 @@ func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { } } for _, d := range s.Directives { - for _, arg := range d.Args { + for _, arg := range d.Arguments { t, err := common.ResolveType(arg.Type, s.Resolve) if err != nil { return err @@ -286,25 +52,73 @@ func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { } } - s.EntryPoints = make(map[string]NamedType) - for key, name := range s.entryPointNames { + // https://graphql.github.io/graphql-spec/June2018/#sec-Root-Operation-Types + // > While any type can be the root operation type for a GraphQL operation, the type system definition language can + // > omit the schema definition when the query, mutation, and subscription root types are named Query, Mutation, + // > and Subscription respectively. + if len(s.EntryPointNames) == 0 { + if _, ok := s.Types["Query"]; ok { + s.EntryPointNames["query"] = "Query" + } + if _, ok := s.Types["Mutation"]; ok { + s.EntryPointNames["mutation"] = "Mutation" + } + if _, ok := s.Types["Subscription"]; ok { + s.EntryPointNames["subscription"] = "Subscription" + } + } + s.EntryPoints = make(map[string]types.NamedType) + for key, name := range s.EntryPointNames { t, ok := s.Types[name] if !ok { - if !ok { - return errors.Errorf("type %q not found", name) - } + return errors.Errorf("type %q not found", name) } s.EntryPoints[key] = t } - for _, obj := range s.objects { - obj.Interfaces = make([]*Interface, len(obj.interfaceNames)) - for i, intfName := range obj.interfaceNames { + // Interface types need validation: https://spec.graphql.org/draft/#sec-Interfaces.Interfaces-Implementing-Interfaces + for _, typeDef := range s.Types { + switch t := typeDef.(type) { + case *types.InterfaceTypeDefinition: + for i, implements := range t.Interfaces { + typ, ok := s.Types[implements.Name] + if !ok { + return errors.Errorf("interface %q not found", implements) + } + inteface, ok := typ.(*types.InterfaceTypeDefinition) + if !ok { + return errors.Errorf("type %q is not an interface", inteface) + } + + for _, f := range inteface.Fields.Names() { + if t.Fields.Get(f) == nil { + return errors.Errorf("interface %q expects field %q but %q does not provide it", inteface.Name, f, t.Name) + } + } + + t.Interfaces[i] = inteface + } + default: + continue + } + } + + for _, obj := range s.Objects { + obj.Interfaces = make([]*types.InterfaceTypeDefinition, len(obj.InterfaceNames)) + if err := resolveDirectives(s, obj.Directives, "OBJECT"); err != nil { + return err + } + for _, field := range obj.Fields { + if err := resolveDirectives(s, field.Directives, "FIELD_DEFINITION"); err != nil { + return err + } + } + for i, intfName := range obj.InterfaceNames { t, ok := s.Types[intfName] if !ok { return errors.Errorf("interface %q not found", intfName) } - intf, ok := t.(*Interface) + intf, ok := t.(*types.InterfaceTypeDefinition) if !ok { return errors.Errorf("type %q is not an interface", intfName) } @@ -318,34 +132,48 @@ func (s *Schema) Parse(schemaString string, useStringDescriptions bool) error { } } - for _, union := range s.unions { - union.PossibleTypes = make([]*Object, len(union.typeNames)) - for i, name := range union.typeNames { + for _, union := range s.Unions { + if err := resolveDirectives(s, union.Directives, "UNION"); err != nil { + return err + } + union.UnionMemberTypes = make([]*types.ObjectTypeDefinition, len(union.TypeNames)) + for i, name := range union.TypeNames { t, ok := s.Types[name] if !ok { return errors.Errorf("object type %q not found", name) } - obj, ok := t.(*Object) + obj, ok := t.(*types.ObjectTypeDefinition) if !ok { return errors.Errorf("type %q is not an object", name) } - union.PossibleTypes[i] = obj + union.UnionMemberTypes[i] = obj } } - for _, enum := range s.enums { - for _, value := range enum.Values { - if err := resolveDirectives(s, value.Directives); err != nil { + for _, enum := range s.Enums { + if err := resolveDirectives(s, enum.Directives, "ENUM"); err != nil { + return err + } + for _, value := range enum.EnumValuesDefinition { + if err := resolveDirectives(s, value.Directives, "ENUM_VALUE"); err != nil { return err } } } + s.SchemaString = schemaString + return nil } -func mergeExtensions(s *Schema) error { - for _, ext := range s.extensions { +func ParseSchema(schemaString string, useStringDescriptions bool) (*types.Schema, error) { + s := New() + err := Parse(s, schemaString, useStringDescriptions) + return s, err +} + +func mergeExtensions(s *types.Schema) error { + for _, ext := range s.Extensions { typ := s.Types[ext.Type.TypeName()] if typ == nil { return fmt.Errorf("trying to extend unknown type %q", ext.Type.TypeName()) @@ -356,8 +184,8 @@ func mergeExtensions(s *Schema) error { } switch og := typ.(type) { - case *Object: - e := ext.Type.(*Object) + case *types.ObjectTypeDefinition: + e := ext.Type.(*types.ObjectTypeDefinition) for _, field := range e.Fields { if og.Fields.Get(field.Name) != nil { @@ -366,17 +194,17 @@ func mergeExtensions(s *Schema) error { } og.Fields = append(og.Fields, e.Fields...) - for _, en := range e.interfaceNames { - for _, on := range og.interfaceNames { + for _, en := range e.InterfaceNames { + for _, on := range og.InterfaceNames { if on == en { return fmt.Errorf("interface %q implemented in the extension is already implemented in %q", on, og.Name) } } } - og.interfaceNames = append(og.interfaceNames, e.interfaceNames...) + og.InterfaceNames = append(og.InterfaceNames, e.InterfaceNames...) - case *InputObject: - e := ext.Type.(*InputObject) + case *types.InputObject: + e := ext.Type.(*types.InputObject) for _, field := range e.Values { if og.Values.Get(field.Name.Name) != nil { @@ -385,8 +213,8 @@ func mergeExtensions(s *Schema) error { } og.Values = append(og.Values, e.Values...) - case *Interface: - e := ext.Type.(*Interface) + case *types.InterfaceTypeDefinition: + e := ext.Type.(*types.InterfaceTypeDefinition) for _, field := range e.Fields { if og.Fields.Get(field.Name) != nil { @@ -395,29 +223,29 @@ func mergeExtensions(s *Schema) error { } og.Fields = append(og.Fields, e.Fields...) - case *Union: - e := ext.Type.(*Union) + case *types.Union: + e := ext.Type.(*types.Union) - for _, en := range e.typeNames { - for _, on := range og.typeNames { + for _, en := range e.TypeNames { + for _, on := range og.TypeNames { if on == en { return fmt.Errorf("union type %q already declared in %q", on, og.Name) } } } - og.typeNames = append(og.typeNames, e.typeNames...) + og.TypeNames = append(og.TypeNames, e.TypeNames...) - case *Enum: - e := ext.Type.(*Enum) + case *types.EnumTypeDefinition: + e := ext.Type.(*types.EnumTypeDefinition) - for _, en := range e.Values { - for _, on := range og.Values { - if on.Name == en.Name { - return fmt.Errorf("enum value %q already declared in %q", on.Name, og.Name) + for _, en := range e.EnumValuesDefinition { + for _, on := range og.EnumValuesDefinition { + if on.EnumValue == en.EnumValue { + return fmt.Errorf("enum value %q already declared in %q", on.EnumValue, og.Name) } } } - og.Values = append(og.Values, e.Values...) + og.EnumValuesDefinition = append(og.EnumValuesDefinition, e.EnumValuesDefinition...) default: return fmt.Errorf(`unexpected %q, expecting "schema", "type", "enum", "interface", "union" or "input"`, og.TypeName()) } @@ -426,21 +254,21 @@ func mergeExtensions(s *Schema) error { return nil } -func resolveNamedType(s *Schema, t NamedType) error { +func resolveNamedType(s *types.Schema, t types.NamedType) error { switch t := t.(type) { - case *Object: + case *types.ObjectTypeDefinition: for _, f := range t.Fields { if err := resolveField(s, f); err != nil { return err } } - case *Interface: + case *types.InterfaceTypeDefinition: for _, f := range t.Fields { if err := resolveField(s, f); err != nil { return err } } - case *InputObject: + case *types.InputObject: if err := resolveInputObject(s, t.Values); err != nil { return err } @@ -448,40 +276,59 @@ func resolveNamedType(s *Schema, t NamedType) error { return nil } -func resolveField(s *Schema, f *Field) error { +func resolveField(s *types.Schema, f *types.FieldDefinition) error { t, err := common.ResolveType(f.Type, s.Resolve) if err != nil { return err } f.Type = t - if err := resolveDirectives(s, f.Directives); err != nil { + if err := resolveDirectives(s, f.Directives, "FIELD_DEFINITION"); err != nil { return err } - return resolveInputObject(s, f.Args) + return resolveInputObject(s, f.Arguments) } -func resolveDirectives(s *Schema, directives common.DirectiveList) error { +func resolveDirectives(s *types.Schema, directives types.DirectiveList, loc string) error { + alreadySeenNonRepeatable := make(map[string]struct{}) for _, d := range directives { dirName := d.Name.Name dd, ok := s.Directives[dirName] if !ok { return errors.Errorf("directive %q not found", dirName) } - for _, arg := range d.Args { - if dd.Args.Get(arg.Name.Name) == nil { + validLoc := false + for _, l := range dd.Locations { + if l == loc { + validLoc = true + break + } + } + if !validLoc { + return errors.Errorf("invalid location %q for directive %q (must be one of %v)", loc, dirName, dd.Locations) + } + for _, arg := range d.Arguments { + if dd.Arguments.Get(arg.Name.Name) == nil { return errors.Errorf("invalid argument %q for directive %q", arg.Name.Name, dirName) } } - for _, arg := range dd.Args { - if _, ok := d.Args.Get(arg.Name.Name); !ok { - d.Args = append(d.Args, common.Argument{Name: arg.Name, Value: arg.Default}) + for _, arg := range dd.Arguments { + if _, ok := d.Arguments.Get(arg.Name.Name); !ok { + d.Arguments = append(d.Arguments, &types.Argument{Name: arg.Name, Value: arg.Default}) } } + + if dd.Repeatable { + continue + } + if _, seen := alreadySeenNonRepeatable[dirName]; seen { + return errors.Errorf(`non repeatable directive %q can not be repeated. Consider adding "repeatable".`, dirName) + } + alreadySeenNonRepeatable[dirName] = struct{}{} } return nil } -func resolveInputObject(s *Schema, values common.InputValueList) error { +func resolveInputObject(s *types.Schema, values types.ArgumentsDefinition) error { for _, v := range values { t, err := common.ResolveType(v.Type, s.Resolve) if err != nil { @@ -492,7 +339,7 @@ func resolveInputObject(s *Schema, values common.InputValueList) error { return nil } -func parseSchema(s *Schema, l *common.Lexer) { +func parseSchema(s *types.Schema, l *common.Lexer) { l.ConsumeWhitespace() for l.Peek() != scanner.EOF { @@ -502,10 +349,11 @@ func parseSchema(s *Schema, l *common.Lexer) { case "schema": l.ConsumeToken('{') for l.Peek() != '}' { + name := l.ConsumeIdent() l.ConsumeToken(':') typ := l.ConsumeIdent() - s.entryPointNames[name] = typ + s.EntryPointNames[name] = typ } l.ConsumeToken('}') @@ -513,7 +361,7 @@ func parseSchema(s *Schema, l *common.Lexer) { obj := parseObjectDef(l) obj.Desc = desc s.Types[obj.Name] = obj - s.objects = append(s.objects, obj) + s.Objects = append(s.Objects, obj) case "interface": iface := parseInterfaceDef(l) @@ -524,13 +372,13 @@ func parseSchema(s *Schema, l *common.Lexer) { union := parseUnionDef(l) union.Desc = desc s.Types[union.Name] = union - s.unions = append(s.unions, union) + s.Unions = append(s.Unions, union) case "enum": enum := parseEnumDef(l) enum.Desc = desc s.Types[enum.Name] = enum - s.enums = append(s.enums, enum) + s.Enums = append(s.Enums, enum) case "input": input := parseInputDef(l) @@ -538,8 +386,10 @@ func parseSchema(s *Schema, l *common.Lexer) { s.Types[input.Name] = input case "scalar": + loc := l.Location() name := l.ConsumeIdent() - s.Types[name] = &Scalar{Name: name, Desc: desc} + directives := common.ParseDirectives(l) + s.Types[name] = &types.ScalarTypeDefinition{Name: name, Desc: desc, Directives: directives, Loc: loc} case "directive": directive := parseDirectiveDef(l) @@ -556,30 +406,55 @@ func parseSchema(s *Schema, l *common.Lexer) { } } -func parseObjectDef(l *common.Lexer) *Object { - object := &Object{Name: l.ConsumeIdent()} +func parseObjectDef(l *common.Lexer) *types.ObjectTypeDefinition { + object := &types.ObjectTypeDefinition{Loc: l.Location(), Name: l.ConsumeIdent()} + + for { + if l.Peek() == '{' { + break + } + + if l.Peek() == '@' { + object.Directives = common.ParseDirectives(l) + continue + } + + if l.Peek() != scanner.Ident { + break + } - if l.Peek() == scanner.Ident { l.ConsumeKeyword("implements") - for l.Peek() != '{' { + for l.Peek() != '{' && l.Peek() != '@' { if l.Peek() == '&' { l.ConsumeToken('&') } - object.interfaceNames = append(object.interfaceNames, l.ConsumeIdent()) + object.InterfaceNames = append(object.InterfaceNames, l.ConsumeIdent()) } } - l.ConsumeToken('{') object.Fields = parseFieldsDef(l) l.ConsumeToken('}') return object + } -func parseInterfaceDef(l *common.Lexer) *Interface { - i := &Interface{Name: l.ConsumeIdent()} +func parseInterfaceDef(l *common.Lexer) *types.InterfaceTypeDefinition { + i := &types.InterfaceTypeDefinition{Loc: l.Location(), Name: l.ConsumeIdent()} + + if l.Peek() == scanner.Ident { + l.ConsumeKeyword("implements") + i.Interfaces = append(i.Interfaces, &types.InterfaceTypeDefinition{Name: l.ConsumeIdent()}) + + for l.Peek() == '&' { + l.ConsumeToken('&') + i.Interfaces = append(i.Interfaces, &types.InterfaceTypeDefinition{Name: l.ConsumeIdent()}) + } + } + + i.Directives = common.ParseDirectives(l) l.ConsumeToken('{') i.Fields = parseFieldsDef(l) @@ -588,22 +463,25 @@ func parseInterfaceDef(l *common.Lexer) *Interface { return i } -func parseUnionDef(l *common.Lexer) *Union { - union := &Union{Name: l.ConsumeIdent()} +func parseUnionDef(l *common.Lexer) *types.Union { + union := &types.Union{Loc: l.Location(), Name: l.ConsumeIdent()} + union.Directives = common.ParseDirectives(l) l.ConsumeToken('=') - union.typeNames = []string{l.ConsumeIdent()} + union.TypeNames = []string{l.ConsumeIdent()} for l.Peek() == '|' { l.ConsumeToken('|') - union.typeNames = append(union.typeNames, l.ConsumeIdent()) + union.TypeNames = append(union.TypeNames, l.ConsumeIdent()) } return union } -func parseInputDef(l *common.Lexer) *InputObject { - i := &InputObject{} +func parseInputDef(l *common.Lexer) *types.InputObject { + i := &types.InputObject{} + i.Loc = l.Location() i.Name = l.ConsumeIdent() + i.Directives = common.ParseDirectives(l) l.ConsumeToken('{') for l.Peek() != '}' { i.Values = append(i.Values, common.ParseInputValue(l)) @@ -612,41 +490,54 @@ func parseInputDef(l *common.Lexer) *InputObject { return i } -func parseEnumDef(l *common.Lexer) *Enum { - enum := &Enum{Name: l.ConsumeIdent()} +func parseEnumDef(l *common.Lexer) *types.EnumTypeDefinition { + enum := &types.EnumTypeDefinition{Loc: l.Location(), Name: l.ConsumeIdent()} + enum.Directives = common.ParseDirectives(l) l.ConsumeToken('{') for l.Peek() != '}' { - v := &EnumValue{ + v := &types.EnumValueDefinition{ Desc: l.DescComment(), - Name: l.ConsumeIdent(), + Loc: l.Location(), + EnumValue: l.ConsumeIdent(), Directives: common.ParseDirectives(l), } - enum.Values = append(enum.Values, v) + enum.EnumValuesDefinition = append(enum.EnumValuesDefinition, v) } l.ConsumeToken('}') return enum } - -func parseDirectiveDef(l *common.Lexer) *DirectiveDecl { +func parseDirectiveDef(l *common.Lexer) *types.DirectiveDefinition { l.ConsumeToken('@') - d := &DirectiveDecl{Name: l.ConsumeIdent()} + loc := l.Location() + d := &types.DirectiveDefinition{Name: l.ConsumeIdent(), Loc: loc} if l.Peek() == '(' { l.ConsumeToken('(') for l.Peek() != ')' { v := common.ParseInputValue(l) - d.Args = append(d.Args, v) + d.Arguments = append(d.Arguments, v) } l.ConsumeToken(')') } - l.ConsumeKeyword("on") + switch x := l.ConsumeIdent(); x { + case "on": + // no-op; Go doesn't fallthrough by default + case "repeatable": + d.Repeatable = true + l.ConsumeKeyword("on") + default: + l.SyntaxError(fmt.Sprintf(`unexpected %q, expecting "on" or "repeatable"`, x)) + } for { loc := l.ConsumeIdent() - d.Locs = append(d.Locs, loc) + if _, ok := legalDirectiveLocationNames[loc]; !ok { + l.SyntaxError(fmt.Sprintf("%q is not a legal directive location (options: %v)", loc, legalDirectiveLocationNames)) + } + d.Locations = append(d.Locations, loc) if l.Peek() != '|' { break } @@ -655,7 +546,8 @@ func parseDirectiveDef(l *common.Lexer) *DirectiveDecl { return d } -func parseExtension(s *Schema, l *common.Lexer) { +func parseExtension(s *types.Schema, l *common.Lexer) { + loc := l.Location() switch x := l.ConsumeIdent(); x { case "schema": l.ConsumeToken('{') @@ -663,46 +555,47 @@ func parseExtension(s *Schema, l *common.Lexer) { name := l.ConsumeIdent() l.ConsumeToken(':') typ := l.ConsumeIdent() - s.entryPointNames[name] = typ + s.EntryPointNames[name] = typ } l.ConsumeToken('}') case "type": obj := parseObjectDef(l) - s.extensions = append(s.extensions, &Extension{Type: obj}) + s.Extensions = append(s.Extensions, &types.Extension{Type: obj, Loc: loc}) case "interface": iface := parseInterfaceDef(l) - s.extensions = append(s.extensions, &Extension{Type: iface}) + s.Extensions = append(s.Extensions, &types.Extension{Type: iface, Loc: loc}) case "union": union := parseUnionDef(l) - s.extensions = append(s.extensions, &Extension{Type: union}) + s.Extensions = append(s.Extensions, &types.Extension{Type: union, Loc: loc}) case "enum": enum := parseEnumDef(l) - s.extensions = append(s.extensions, &Extension{Type: enum}) + s.Extensions = append(s.Extensions, &types.Extension{Type: enum, Loc: loc}) case "input": input := parseInputDef(l) - s.extensions = append(s.extensions, &Extension{Type: input}) + s.Extensions = append(s.Extensions, &types.Extension{Type: input, Loc: loc}) default: - // TODO: Add Scalar when adding directives + // TODO: Add ScalarTypeDefinition when adding directives l.SyntaxError(fmt.Sprintf(`unexpected %q, expecting "schema", "type", "enum", "interface", "union" or "input"`, x)) } } -func parseFieldsDef(l *common.Lexer) FieldList { - var fields FieldList +func parseFieldsDef(l *common.Lexer) types.FieldsDefinition { + var fields types.FieldsDefinition for l.Peek() != '}' { - f := &Field{} + f := &types.FieldDefinition{} f.Desc = l.DescComment() + f.Loc = l.Location() f.Name = l.ConsumeIdent() if l.Peek() == '(' { l.ConsumeToken('(') for l.Peek() != ')' { - f.Args = append(f.Args, common.ParseInputValue(l)) + f.Arguments = append(f.Arguments, common.ParseInputValue(l)) } l.ConsumeToken(')') } @@ -713,3 +606,25 @@ func parseFieldsDef(l *common.Lexer) FieldList { } return fields } + +var legalDirectiveLocationNames = map[string]struct{}{ + "SCHEMA": {}, + "SCALAR": {}, + "OBJECT": {}, + "FIELD_DEFINITION": {}, + "ARGUMENT_DEFINITION": {}, + "INTERFACE": {}, + "UNION": {}, + "ENUM": {}, + "ENUM_VALUE": {}, + "INPUT_OBJECT": {}, + "INPUT_FIELD_DEFINITION": {}, + "QUERY": {}, + "MUTATION": {}, + "SUBSCRIPTION": {}, + "FIELD": {}, + "FRAGMENT_DEFINITION": {}, + "FRAGMENT_SPREAD": {}, + "INLINE_FRAGMENT": {}, + "VARIABLE_DEFINITION": {}, +} diff --git a/internal/schema/schema_internal_test.go b/internal/schema/schema_internal_test.go index 26cb64498..6bc1e57e8 100644 --- a/internal/schema/schema_internal_test.go +++ b/internal/schema/schema_internal_test.go @@ -1,29 +1,34 @@ package schema import ( + "reflect" "testing" "github.com/tokopedia/graphql-go/errors" "github.com/tokopedia/graphql-go/internal/common" + "github.com/tokopedia/graphql-go/types" ) func TestParseInterfaceDef(t *testing.T) { type testCase struct { description string definition string - expected *Interface + expected *types.InterfaceTypeDefinition err *errors.QueryError } tests := []testCase{{ description: "Parses simple interface", definition: "Greeting { field: String }", - expected: &Interface{Name: "Greeting", Fields: []*Field{{Name: "field"}}}, + expected: &types.InterfaceTypeDefinition{ + Name: "Greeting", + Loc: errors.Location{Line: 1, Column: 1}, + Fields: types.FieldsDefinition{&types.FieldDefinition{Name: "field"}}}, }} for _, test := range tests { t.Run(test.description, func(t *testing.T) { - var actual *Interface + var actual *types.InterfaceTypeDefinition lex := setup(t, test.definition) parse := func() { actual = parseInterfaceDef(lex) } @@ -41,31 +46,31 @@ func TestParseObjectDef(t *testing.T) { type testCase struct { description string definition string - expected *Object + expected *types.ObjectTypeDefinition err *errors.QueryError } tests := []testCase{{ description: "Parses type inheriting single interface", definition: "Hello implements World { field: String }", - expected: &Object{Name: "Hello", interfaceNames: []string{"World"}}, + expected: &types.ObjectTypeDefinition{Name: "Hello", Loc: errors.Location{Line: 1, Column: 1}, InterfaceNames: []string{"World"}}, }, { description: "Parses type inheriting multiple interfaces", definition: "Hello implements Wo & rld { field: String }", - expected: &Object{Name: "Hello", interfaceNames: []string{"Wo", "rld"}}, + expected: &types.ObjectTypeDefinition{Name: "Hello", Loc: errors.Location{Line: 1, Column: 1}, InterfaceNames: []string{"Wo", "rld"}}, }, { description: "Parses type inheriting multiple interfaces with leading ampersand", definition: "Hello implements & Wo & rld { field: String }", - expected: &Object{Name: "Hello", interfaceNames: []string{"Wo", "rld"}}, + expected: &types.ObjectTypeDefinition{Name: "Hello", Loc: errors.Location{Line: 1, Column: 1}, InterfaceNames: []string{"Wo", "rld"}}, }, { description: "Allows legacy SDL interfaces", definition: "Hello implements Wo, rld { field: String }", - expected: &Object{Name: "Hello", interfaceNames: []string{"Wo", "rld"}}, + expected: &types.ObjectTypeDefinition{Name: "Hello", Loc: errors.Location{Line: 1, Column: 1}, InterfaceNames: []string{"Wo", "rld"}}, }} for _, test := range tests { t.Run(test.description, func(t *testing.T) { - var actual *Object + var actual *types.ObjectTypeDefinition lex := setup(t, test.definition) parse := func() { actual = parseObjectDef(lex) } @@ -77,6 +82,224 @@ func TestParseObjectDef(t *testing.T) { } } +func TestParseUnionDef(t *testing.T) { + type testCase struct { + description string + definition string + expected *types.Union + err *errors.QueryError + } + + tests := []testCase{ + { + description: "Parses a union", + definition: "Foo = Bar | Qux | Quux", + expected: &types.Union{ + Name: "Foo", + TypeNames: []string{"Bar", "Qux", "Quux"}, + Loc: errors.Location{Line: 1, Column: 1}, + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + var actual *types.Union + lex := setup(t, test.definition) + + parse := func() { actual = parseUnionDef(lex) } + err := lex.CatchSyntaxError(parse) + + compareErrors(t, test.err, err) + compareUnions(t, test.expected, actual) + }) + } +} + +func TestParseEnumDef(t *testing.T) { + type testCase struct { + description string + definition string + expected *types.EnumTypeDefinition + err *errors.QueryError + } + + tests := []testCase{ + { + description: "parses EnumTypeDefinition on single line", + definition: "Foo { BAR QUX }", + expected: &types.EnumTypeDefinition{ + Name: "Foo", + EnumValuesDefinition: []*types.EnumValueDefinition{ + { + EnumValue: "BAR", + Loc: errors.Location{Line: 1, Column: 7}, + }, + { + EnumValue: "QUX", + Loc: errors.Location{Line: 1, Column: 11}, + }, + }, + Loc: errors.Location{Line: 1, Column: 1}, + }, + }, + { + description: "parses EnumtypeDefinition with new lines", + definition: `Foo { + BAR + QUX + }`, + expected: &types.EnumTypeDefinition{ + Name: "Foo", + EnumValuesDefinition: []*types.EnumValueDefinition{ + { + EnumValue: "BAR", + Loc: errors.Location{Line: 2, Column: 5}, + }, + { + EnumValue: "QUX", + Loc: errors.Location{Line: 3, Column: 5}, + }, + }, + Loc: errors.Location{Line: 1, Column: 1}, + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + var actual *types.EnumTypeDefinition + lex := setup(t, test.definition) + + parse := func() { actual = parseEnumDef(lex) } + err := lex.CatchSyntaxError(parse) + + compareErrors(t, test.err, err) + compareEnumTypeDefs(t, test.expected, actual) + }) + } +} + +func TestParseDirectiveDef(t *testing.T) { + type testCase struct { + description string + definition string + expected *types.DirectiveDefinition + err *errors.QueryError + } + + tests := []*testCase{ + { + description: "parses DirectiveDefinition", + definition: "@Foo on FIELD", + expected: &types.DirectiveDefinition{ + Name: "Foo", + Loc: errors.Location{Line: 1, Column: 2}, + Locations: []string{"FIELD"}, + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + var actual *types.DirectiveDefinition + lex := setup(t, test.definition) + + parse := func() { actual = parseDirectiveDef(lex) } + err := lex.CatchSyntaxError(parse) + + compareErrors(t, test.err, err) + compareDirectiveDefinitions(t, test.expected, actual) + }) + } +} + +func TestParseInputDef(t *testing.T) { + type testCase struct { + description string + definition string + expected *types.InputObject + err *errors.QueryError + } + + tests := []testCase{ + { + description: "parses an input object type definition", + definition: "Foo { qux: String }", + expected: &types.InputObject{ + Name: "Foo", + Values: nil, + Loc: errors.Location{Line: 1, Column: 1}, + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + var actual *types.InputObject + lex := setup(t, test.definition) + + parse := func() { actual = parseInputDef(lex) } + err := lex.CatchSyntaxError(parse) + + compareErrors(t, test.err, err) + compareInputObjectTypeDefinition(t, test.expected, actual) + }) + } +} + +func compareDirectiveDefinitions(t *testing.T, expected *types.DirectiveDefinition, actual *types.DirectiveDefinition) { + t.Helper() + + if expected.Name != actual.Name { + t.Fatalf("wrong DirectiveDefinition name: want %q, got %q", expected.Name, actual.Name) + } + + if !reflect.DeepEqual(expected.Locations, actual.Locations) { + t.Errorf("wrong DirectiveDefinition locations: want %v, got %v", expected.Locations, actual.Locations) + } + + compareLoc(t, "DirectiveDefinition", expected.Loc, actual.Loc) +} + +func compareInputObjectTypeDefinition(t *testing.T, expected, actual *types.InputObject) { + t.Helper() + + if expected.Name != actual.Name { + t.Fatalf("wrong InputObject name: want %q, got %q", expected.Name, actual.Name) + } + + compareLoc(t, "InputObjectTypeDefinition", expected.Loc, actual.Loc) +} + +func compareEnumTypeDefs(t *testing.T, expected, actual *types.EnumTypeDefinition) { + t.Helper() + + if expected.Name != actual.Name { + t.Fatalf("wrong EnumTypeDefinition name: want %q, got %q", expected.Name, actual.Name) + } + + compareLoc(t, "EnumValueDefinition", expected.Loc, actual.Loc) + + for i, definition := range expected.EnumValuesDefinition { + expectedValue, expectedLoc := definition.EnumValue, definition.Loc + actualDef := actual.EnumValuesDefinition[i] + + if expectedValue != actualDef.EnumValue { + t.Fatalf("wrong EnumValue: want %q, got %q", expectedValue, actualDef.EnumValue) + } + + compareLoc(t, "EnumValue "+expectedValue, expectedLoc, actualDef.Loc) + } +} + +func compareLoc(t *testing.T, typeName string, expected, actual errors.Location) { + t.Helper() + if expected != actual { + t.Errorf("wrong location on %s: want %v, got %v", typeName, expected, actual) + } +} + func compareErrors(t *testing.T, expected, actual *errors.QueryError) { t.Helper() @@ -95,23 +318,15 @@ func compareErrors(t *testing.T, expected, actual *errors.QueryError) { } } -func compareInterfaces(t *testing.T, expected, actual *Interface) { +func compareInterfaces(t *testing.T, expected, actual *types.InterfaceTypeDefinition) { t.Helper() - // TODO: We can probably extract this switch statement into its own function. - switch { - case expected == nil && actual == nil: - return - case expected == nil && actual != nil: - t.Fatalf("wanted nil, got an unexpected result: %#v", actual) - case expected != nil && actual == nil: - t.Fatalf("wanted non-nil result, got nil") - } - if expected.Name != actual.Name { t.Errorf("wrong interface name: want %q, got %q", expected.Name, actual.Name) } + compareLoc(t, "InterfaceTypeDefinition", expected.Loc, actual.Loc) + if len(expected.Fields) != len(actual.Fields) { t.Fatalf("wanted %d field definitions, got %d", len(expected.Fields), len(actual.Fields)) } @@ -123,32 +338,35 @@ func compareInterfaces(t *testing.T, expected, actual *Interface) { } } -func compareObjects(t *testing.T, expected, actual *Object) { +func compareUnions(t *testing.T, expected, actual *types.Union) { t.Helper() - switch { - case expected == nil && expected == actual: - return - case expected == nil && actual != nil: - t.Fatalf("wanted nil, got an unexpected result: %#v", actual) - case expected != nil && actual == nil: - t.Fatalf("wanted non-nil result, got nil") + if expected.Name != actual.Name { + t.Errorf("wrong object name: want %q, got %q", expected.Name, actual.Name) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("wrong type names: want %v, got %v", expected.TypeNames, actual.TypeNames) } +} + +func compareObjects(t *testing.T, expected, actual *types.ObjectTypeDefinition) { + t.Helper() if expected.Name != actual.Name { t.Errorf("wrong object name: want %q, got %q", expected.Name, actual.Name) } - if len(expected.interfaceNames) != len(actual.interfaceNames) { + if len(expected.InterfaceNames) != len(actual.InterfaceNames) { t.Fatalf( "wrong number of interface names: want %s, got %s", - expected.interfaceNames, - actual.interfaceNames, + expected.InterfaceNames, + actual.InterfaceNames, ) } - for i, expectedName := range expected.interfaceNames { - actualName := actual.interfaceNames[i] + for i, expectedName := range expected.InterfaceNames { + actualName := actual.InterfaceNames[i] if expectedName != actualName { t.Errorf("wrong interface name: want %q, got %q", expectedName, actualName) } diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 94a4776a8..01974ff42 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -2,9 +2,11 @@ package schema_test import ( "fmt" + "strings" "testing" "github.com/tokopedia/graphql-go/internal/schema" + "github.com/tokopedia/graphql-go/types" ) func TestParse(t *testing.T) { @@ -13,14 +15,14 @@ func TestParse(t *testing.T) { sdl string useStringDescriptions bool validateError func(err error) error - validateSchema func(s *schema.Schema) error + validateSchema func(s *types.Schema) error }{ { name: "Parses interface definition", sdl: "interface Greeting { message: String! }", - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Greeting" - typ, ok := s.Types[typeName].(*schema.Interface) + typ, ok := s.Types[typeName].(*types.InterfaceTypeDefinition) if !ok { return fmt.Errorf("interface %q not found", typeName) } @@ -60,9 +62,9 @@ func TestParse(t *testing.T) { field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -73,7 +75,7 @@ func TestParse(t *testing.T) { }, }, { - name: "Parses type with multi-line description string", + name: "Parses type with simple multi-line 'BlockString' description", sdl: ` """ Multi-line description. @@ -82,9 +84,9 @@ func TestParse(t *testing.T) { field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -95,7 +97,120 @@ func TestParse(t *testing.T) { }, }, { - name: "Parses type with multi-line description and ignores comments", + name: "Parses type with empty multi-line 'BlockString' description", + sdl: ` + """ + """ + type Type { + field: String + }`, + useStringDescriptions: true, + validateSchema: func(s *types.Schema) error { + const typeName = "Type" + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", typeName) + } + if want, have := "", typ.Description(); want != have { + return fmt.Errorf("invalid description: want %q, have %q", want, have) + } + return nil + }, + }, + { + name: "Parses type with multi-line 'BlockString' description", + sdl: ` + """ + First line of the description. + + Second line of the description. + + query { + code { + example + } + } + + Notes: + + * First note + * Second note + """ + type Type { + field: String + }`, + useStringDescriptions: true, + validateSchema: func(s *types.Schema) error { + const typeName = "Type" + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", typeName) + } + want := "First line of the description.\n\nSecond line of the description.\n\n\tquery {\n\t\tcode {\n\t\t\texample\n\t\t}\n\t}\n\nNotes:\n\n * First note\n * Second note" + if have := typ.Description(); want != have { + return fmt.Errorf("invalid description: want %q, have %q", want, have) + } + return nil + }, + }, + { + name: "Parses type with un-indented multi-line 'BlockString' description", + sdl: ` + """ +First line of the description. + +Second line of the description. + """ + type Type { + field: String + }`, + useStringDescriptions: true, + validateSchema: func(s *types.Schema) error { + const typeName = "Type" + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", typeName) + } + want := "First line of the description.\n\nSecond line of the description." + if have := typ.Description(); want != have { + return fmt.Errorf("invalid description: want %q, have %q", want, have) + } + return nil + }, + }, + { + name: "Parses type with space-indented multi-line 'BlockString' description", + sdl: ` + """ + First line of the description. + + Second line of the description. + + query { + code { + example + } + } + """ + type Type { + field: String + }`, + useStringDescriptions: true, + validateSchema: func(s *types.Schema) error { + const typeName = "Type" + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", typeName) + } + want := "First line of the description.\n\nSecond line of the description.\n\n query {\n code {\n example\n }\n }" + if have := typ.Description(); want != have { + return fmt.Errorf("invalid description: want %q, have %q", want, have) + } + return nil + }, + }, + { + name: "Parses type with multi-line 'BlockString' description and ignores comments", sdl: ` """ Multi-line description with ignored comments. @@ -105,9 +220,9 @@ func TestParse(t *testing.T) { field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { const typeName = "Type" - typ, ok := s.Types[typeName].(*schema.Object) + typ, ok := s.Types[typeName].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", typeName) } @@ -117,6 +232,19 @@ func TestParse(t *testing.T) { return nil }, }, + { + name: "Parses type invalid syntax", + sdl: ` + type U = T + `, + validateError: func(err error) error { + msg := `graphql: syntax error: unexpected "=", expecting "{" (line 2, column 11)` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, { name: "Description is correctly parsed for non-described types", sdl: ` @@ -126,7 +254,7 @@ func TestParse(t *testing.T) { field: String }`, useStringDescriptions: true, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { typ, ok := s.Types["Type"] if !ok { return fmt.Errorf("type %q not found", "Type") @@ -147,7 +275,7 @@ func TestParse(t *testing.T) { type Type { field: String }`, - validateSchema: func(s *schema.Schema) error { + validateSchema: func(s *types.Schema) error { typ, ok := s.Types["MyInt"] if !ok { return fmt.Errorf("scalar %q not found", "MyInt") @@ -166,7 +294,47 @@ func TestParse(t *testing.T) { }, }, { - name: "Type extension works correctly", + name: "Default Root schema", + sdl: ` + type Query { + hello: String! + } + type Mutation { + concat(a: String!, b: String!): String! + } + `, + validateSchema: func(s *types.Schema) error { + typq, ok := s.Types["Query"].(*types.ObjectTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", "Query") + } + helloField := typq.Fields.Get("hello") + if helloField == nil { + return fmt.Errorf("field %q not found", "hello") + } + if helloField.Type.String() != "String!" { + return fmt.Errorf("field %q has an invalid type: %q", "hello", helloField.Type.String()) + } + + typm, ok := s.Types["Mutation"].(*types.ObjectTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", "Mutation") + } + concatField := typm.Fields.Get("concat") + if concatField == nil { + return fmt.Errorf("field %q not found", "concat") + } + if concatField.Type.String() != "String!" { + return fmt.Errorf("field %q has an invalid type: %q", "concat", concatField.Type.String()) + } + if len(concatField.Arguments) != 2 || concatField.Arguments[0] == nil || concatField.Arguments[1] == nil || concatField.Arguments[0].Type.String() != "String!" || concatField.Arguments[1].Type.String() != "String!" { + return fmt.Errorf("field %q has an invalid args: %+v", "concat", concatField.Arguments) + } + return nil + }, + }, + { + name: "Extend type", sdl: ` type Query { hello: String! @@ -175,8 +343,8 @@ func TestParse(t *testing.T) { extend type Query { world: String! }`, - validateSchema: func(s *schema.Schema) error { - typ, ok := s.Types["Query"].(*schema.Object) + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Query"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Query") } @@ -200,7 +368,7 @@ func TestParse(t *testing.T) { }, }, { - name: "Schema extension works correctly", + name: "Extend schema", sdl: ` schema { query: Query @@ -215,8 +383,8 @@ func TestParse(t *testing.T) { concat(a: String!, b: String!): String! } `, - validateSchema: func(s *schema.Schema) error { - typq, ok := s.Types["Query"].(*schema.Object) + validateSchema: func(s *types.Schema) error { + typq, ok := s.Types["Query"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Query") } @@ -228,7 +396,7 @@ func TestParse(t *testing.T) { return fmt.Errorf("field %q has an invalid type: %q", "hello", helloField.Type.String()) } - typm, ok := s.Types["Mutation"].(*schema.Object) + typm, ok := s.Types["Mutation"].(*types.ObjectTypeDefinition) if !ok { return fmt.Errorf("type %q not found", "Mutation") } @@ -239,21 +407,544 @@ func TestParse(t *testing.T) { if concatField.Type.String() != "String!" { return fmt.Errorf("field %q has an invalid type: %q", "concat", concatField.Type.String()) } - if len(concatField.Args) != 2 || concatField.Args[0] == nil || concatField.Args[1] == nil || concatField.Args[0].Type.String() != "String!" || concatField.Args[1].Type.String() != "String!" { - return fmt.Errorf("field %q has an invalid args: %+v", "concat", concatField.Args) + if len(concatField.Arguments) != 2 || concatField.Arguments[0] == nil || concatField.Arguments[1] == nil || concatField.Arguments[0].Type.String() != "String!" || concatField.Arguments[1].Type.String() != "String!" { + return fmt.Errorf("field %q has an invalid args: %+v", "concat", concatField.Arguments) + } + return nil + }, + }, + { + name: "Extend type with interface implementation", + sdl: ` + interface Named { + name: String! + } + type Product { + id: ID! + } + extend type Product implements Named { + name: String! + }`, + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Product"].(*types.ObjectTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", "Product") + } + idField := typ.Fields.Get("id") + if idField == nil { + return fmt.Errorf("field %q not found", "id") + } + if idField.Type.String() != "ID!" { + return fmt.Errorf("field %q has an invalid type: %q", "id", idField.Type.String()) + } + nameField := typ.Fields.Get("name") + if nameField == nil { + return fmt.Errorf("field %q not found", "name") + } + if nameField.Type.String() != "String!" { + return fmt.Errorf("field %q has an invalid type: %q", "name", nameField.Type.String()) + } + + ifc, ok := s.Types["Named"].(*types.InterfaceTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", "Named") + } + nameField = ifc.Fields.Get("name") + if nameField == nil { + return fmt.Errorf("field %q not found", "name") + } + if nameField.Type.String() != "String!" { + return fmt.Errorf("field %q has an invalid type: %q", "name", nameField.Type.String()) + } + return nil + }, + }, + { + name: "Extend union type", + sdl: ` + type Named { + name: String! + } + type Numbered { + num: Int! + } + union Item = Named | Numbered + type Coloured { + Colour: String! + } + extend union Item = Coloured + `, + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Item"].(*types.Union) + if !ok { + return fmt.Errorf("type %q not found", "Item") + } + if len(typ.UnionMemberTypes) != 3 { + return fmt.Errorf("Expected 3 possible types, but instead got %d types", len(typ.UnionMemberTypes)) + } + posible := map[string]struct{}{ + "Coloured": {}, + "Named": {}, + "Numbered": {}, + } + for _, pt := range typ.UnionMemberTypes { + if _, ok := posible[pt.Name]; !ok { + return fmt.Errorf("Unexpected possible type %q", pt.Name) + } + } + return nil + }, + }, + { + name: "Extend enum type", + sdl: ` + enum Currencies{ + AUD + USD + EUR + } + extend enum Currencies { + BGN + GBP + } + `, + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Currencies"].(*types.EnumTypeDefinition) + if !ok { + return fmt.Errorf("enum %q not found", "Currencies") + } + if len(typ.EnumValuesDefinition) != 5 { + return fmt.Errorf("Expected 5 enum values, but instead got %d types", len(typ.EnumValuesDefinition)) + } + posible := map[string]struct{}{ + "AUD": {}, + "USD": {}, + "EUR": {}, + "BGN": {}, + "GBP": {}, + } + for _, v := range typ.EnumValuesDefinition { + if _, ok := posible[v.EnumValue]; !ok { + return fmt.Errorf("Unexpected enum value %q", v.EnumValue) + } + } + return nil + }, + }, + { + name: "Extend incompatible type", + sdl: ` + type Query { + hello: String! + } + + extend interface Query { + name: String! + }`, + validateError: func(err error) error { + msg := `trying to extend type "OBJECT" with type "INTERFACE"` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Extend type already implements an interface", + sdl: ` + interface Named { + name: String! + } + type Product implements Named { + id: ID! + name: String! + } + extend type Product implements Named { + }`, + validateError: func(err error) error { + msg := `interface "Named" implemented in the extension is already implemented in "Product"` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Extend union already contains type", + sdl: ` + type Named { + name: String! + } + type Numbered { + num: Int! + } + union Item = Named | Numbered + type Coloured { + Colour: String! + } + extend union Item = Coloured | Named + `, + validateError: func(err error) error { + msg := `union type "Named" already declared in "Item"` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Extend union contains type", + sdl: ` + type Named { + name: String! + } + type Numbered { + num: Int! + } + union Item = Named | Numbered + + type Coloured { + Colour: String! + } + + extend union Item = Coloured + `, + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Item"].(*types.Union) + if !ok { + return fmt.Errorf("type %q not found", "Item") + } + if len(typ.UnionMemberTypes) != 3 { + return fmt.Errorf("Expected 3 possible types, but instead got %d types", len(typ.UnionMemberTypes)) + } + posible := map[string]struct{}{ + "Coloured": {}, + "Named": {}, + "Numbered": {}, + } + for _, pt := range typ.UnionMemberTypes { + if _, ok := posible[pt.Name]; !ok { + return fmt.Errorf("Unexpected possible type %q", pt.Name) + } + } + return nil + }, + }, + { + name: "Extend input", + sdl: ` + input Product { + id: ID! + name: String! + } + extend input Product { + category: Category! + tags: [String!]! = ["sale", "shoes"] + } + input Category { + id: ID! + name: String! + } + `, + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Product"].(*types.InputObject) + if !ok { + return fmt.Errorf("type %q not found", "Product") + } + if len(typ.Values) != 4 { + return fmt.Errorf("Expected 4 fields, but instead got %d types", len(typ.Values)) + } + posible := map[string]struct{}{ + "id": {}, + "name": {}, + "category": {}, + "tags": {}, + } + for _, pt := range typ.Values { + if _, ok := posible[pt.Name.Name]; !ok { + return fmt.Errorf("Unexpected possible type %q", pt.Name) + } + } + categoryField := typ.Values.Get("category") + if categoryField == nil { + return fmt.Errorf("field %q not found", "category") + } + if categoryField.Type.String() != "Category!" { + return fmt.Errorf("expected type %q, but got %q", "Category!", categoryField.Type.String()) + } + if categoryField.Type.Kind() != "NON_NULL" { + return fmt.Errorf("expected kind %q, but got %q", "NON_NULL", categoryField.Type.Kind()) + } + return nil + }, + }, + { + name: "Extend enum value already exists", + sdl: ` + enum Currencies{ + AUD + USD + EUR + } + extend enum Currencies { + AUD + }`, + validateError: func(err error) error { + msg := `enum value "AUD" already declared in "Currencies"` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Extend input field already exists", + sdl: ` + input Product{ + name: String! + } + extend input Product { + name: String! + }`, + validateError: func(err error) error { + msg := `extended field {"name" {'\x06' '\x05'}} already exists` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Extend field already exists", + sdl: ` + interface Named { + name: String! + } + type Product implements Named { + id: ID! + name: String! + } + extend type Product { + name: String! + }`, + validateError: func(err error) error { + msg := `extended field "name" already exists` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Extend interface type", + sdl: ` + interface Product { + id: ID! + name: String! + } + extend interface Product { + category: String! + } + `, + validateSchema: func(s *types.Schema) error { + typ, ok := s.Types["Product"].(*types.InterfaceTypeDefinition) + if !ok { + return fmt.Errorf("type %q not found", "Product") + } + if len(typ.Fields) != 3 { + return fmt.Errorf("Expected 3 fields, but instead got %d types", len(typ.Fields)) + } + fields := map[string]struct{}{ + "id": {}, + "name": {}, + "category": {}, + } + for _, f := range typ.Fields { + if _, ok := fields[f.Name]; !ok { + return fmt.Errorf("Unexpected field %q", f.Name) + } + } + return nil + }, + }, + { + name: "Extend unknown type", + sdl: ` + extend type User { + name: String! + } + `, + validateError: func(err error) error { + msg := `trying to extend unknown type "User"` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Extend invalid syntax", + sdl: ` + extend invalid Node { + id: ID! + } + `, + validateError: func(err error) error { + msg := `graphql: syntax error: unexpected "invalid", expecting "schema", "type", "enum", "interface", "union" or "input" (line 2, column 19)` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Parses directives", + sdl: ` + directive @objectdirective on OBJECT + directive @fielddirective on FIELD_DEFINITION + directive @enumdirective on ENUM + directive @uniondirective on UNION + directive @directive on SCALAR + | OBJECT + | FIELD_DEFINITION + | ARGUMENT_DEFINITION + | INTERFACE + | UNION + | ENUM + | ENUM_VALUE + | INPUT_OBJECT + | INPUT_FIELD_DEFINITION + directive @repeatabledirective repeatable on SCALAR + + interface NamedEntity @directive { name: String } + + scalar Time @directive + + type Photo @objectdirective { + id: ID! @deprecated @fielddirective + } + + type Person implements NamedEntity @objectdirective { + name: String + } + + enum Direction @enumdirective { + NORTH @deprecated + EAST + SOUTH + WEST + } + + union Union @uniondirective = Photo | Person + + scalar Mass @repeatabledirective @repeatabledirective + `, + validateSchema: func(s *types.Schema) error { + namedEntityDirectives := s.Types["NamedEntity"].(*types.InterfaceTypeDefinition).Directives + if len(namedEntityDirectives) != 1 || namedEntityDirectives[0].Name.Name != "directive" { + return fmt.Errorf("missing directive on NamedEntity interface, expected @directive but got %v", namedEntityDirectives) + } + + timeDirectives := s.Types["Time"].(*types.ScalarTypeDefinition).Directives + if len(timeDirectives) != 1 || timeDirectives[0].Name.Name != "directive" { + return fmt.Errorf("missing directive on Time scalar, expected @directive but got %v", timeDirectives) + } + + photo := s.Types["Photo"].(*types.ObjectTypeDefinition) + photoDirectives := photo.Directives + if len(photoDirectives) != 1 || photoDirectives[0].Name.Name != "objectdirective" { + return fmt.Errorf("missing directive on Time scalar, expected @objectdirective but got %v", photoDirectives) + } + if len(photo.Fields.Get("id").Directives) != 2 { + return fmt.Errorf("expected Photo.id to have 2 directives but got %v", photoDirectives) + } + + directionDirectives := s.Types["Direction"].(*types.EnumTypeDefinition).Directives + if len(directionDirectives) != 1 || directionDirectives[0].Name.Name != "enumdirective" { + return fmt.Errorf("missing directive on Direction enum, expected @enumdirective but got %v", directionDirectives) + } + + unionDirectives := s.Types["Union"].(*types.Union).Directives + if len(unionDirectives) != 1 || unionDirectives[0].Name.Name != "uniondirective" { + return fmt.Errorf("missing directive on Union union, expected @uniondirective but got %v", unionDirectives) + } + + massDirectives := s.Types["Mass"].(*types.ScalarTypeDefinition).Directives + if len(massDirectives) != 2 || massDirectives[0].Name.Name != "repeatabledirective" || massDirectives[1].Name.Name != "repeatabledirective" { + return fmt.Errorf("missing directive on Repeatable scalar, expected @repeatabledirective @repeatabledirective but got %v", massDirectives) + } + return nil + }, + }, + { + name: "Sets Directive.Repeatable if `repeatable` keyword is given", + sdl: ` + directive @nonrepeatabledirective on SCALAR + directive @repeatabledirective repeatable on SCALAR + `, + validateSchema: func(s *types.Schema) error { + if dir := s.Directives["nonrepeatabledirective"]; dir.Repeatable { + return fmt.Errorf("did not expect directive to be repeatable: %v", dir) + } + if dir := s.Directives["repeatabledirective"]; !dir.Repeatable { + return fmt.Errorf("expected directive to be repeatable: %v", dir) + } + return nil + }, + }, + { + name: "Directive definition does not allow double-`repeatable`", + sdl: ` + directive @mydirective repeatable repeatable SCALAR + scalar MyScalar @mydirective + `, + validateError: func(err error) error { + msg := `graphql: syntax error: unexpected "repeatable", expecting "on" (line 2, column 38)` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + { + name: "Directive definition does not allow double-`on` instead of `repeatable on`", + sdl: ` + directive @mydirective on on SCALAR + scalar MyScalar @mydirective + `, + validateError: func(err error) error { + prefix := `graphql: syntax error: "on" is not a legal directive location` + if err == nil || !strings.HasPrefix(err.Error(), prefix) { + return fmt.Errorf("expected error starting with %q, but got %q", prefix, err) + } + return nil + }, + }, + { + name: "Disallow repeat of a directive if it is not `repeatable`", + sdl: ` + directive @nonrepeatabledirective on FIELD_DEFINITION + type Foo { + bar: String @nonrepeatabledirective @nonrepeatabledirective + } + `, + validateError: func(err error) error { + prefix := `graphql: non repeatable directive "nonrepeatabledirective" can not be repeated. Consider adding "repeatable"` + if err == nil || !strings.HasPrefix(err.Error(), prefix) { + return fmt.Errorf("expected error starting with %q, but got %q", prefix, err) } return nil }, }, } { t.Run(test.name, func(t *testing.T) { - s := schema.New() - if err := s.Parse(test.sdl, test.useStringDescriptions); err != nil { + s, err := schema.ParseSchema(test.sdl, test.useStringDescriptions) + if err != nil { if test.validateError == nil { t.Fatal(err) } - if err := test.validateError(err); err != nil { - t.Fatal(err) + if err2 := test.validateError(err); err2 != nil { + t.Fatal(err2) } } if test.validateSchema != nil { @@ -264,3 +955,137 @@ func TestParse(t *testing.T) { }) } } + +func TestInterfaceImplementsInterface(t *testing.T) { + for _, tt := range []struct { + name string + sdl string + useStringDescriptions bool + validateError func(err error) error + validateSchema func(s *types.Schema) error + }{ + { + name: "Parses interface implementing other interface", + sdl: ` + interface Foo { + field: String! + } + interface Bar implements Foo { + field: String! + } + `, + validateSchema: func(s *types.Schema) error { + const implementedInterfaceName = "Bar" + typ, ok := s.Types[implementedInterfaceName].(*types.InterfaceTypeDefinition) + if !ok { + return fmt.Errorf("interface %q not found", implementedInterfaceName) + } + if len(typ.Fields) != 1 { + return fmt.Errorf("invalid number of fields: want %d, have %d", 1, len(typ.Fields)) + } + const fieldName = "field" + + if typ.Fields[0].Name != fieldName { + return fmt.Errorf("field %q not found", fieldName) + } + + if len(typ.Interfaces) != 1 { + return fmt.Errorf("invalid number of implementing interfaces found on %q: want %d, have %d", implementedInterfaceName, 1, len(typ.Interfaces)) + } + + const implementingInterfaceName = "Foo" + if typ.Interfaces[0].Name != implementingInterfaceName { + return fmt.Errorf("interface %q not found", implementingInterfaceName) + } + + return nil + }, + }, + { + name: "Parses interface transitively implementing an interface that implements an interface", + sdl: ` + interface Foo { + field: String! + } + interface Bar implements Foo { + field: String! + } + interface Baz implements Bar & Foo { + field: String! + } + `, + validateSchema: func(s *types.Schema) error { + const implementedInterfaceName = "Baz" + typ, ok := s.Types[implementedInterfaceName].(*types.InterfaceTypeDefinition) + if !ok { + return fmt.Errorf("interface %q not found", implementedInterfaceName) + } + if len(typ.Fields) != 1 { + return fmt.Errorf("invalid number of fields: want %d, have %d", 1, len(typ.Fields)) + } + const fieldName = "field" + + if typ.Fields[0].Name != fieldName { + return fmt.Errorf("field %q not found", fieldName) + } + + if len(typ.Interfaces) != 2 { + return fmt.Errorf("invalid number of implementing interfaces found on %q: want %d, have %d", implementedInterfaceName, 2, len(typ.Interfaces)) + } + + const firstImplementingInterfaceName = "Bar" + if typ.Interfaces[0].Name != firstImplementingInterfaceName { + return fmt.Errorf("first interface %q not found", firstImplementingInterfaceName) + } + + const secondImplementingInterfaceName = "Foo" + if typ.Interfaces[1].Name != secondImplementingInterfaceName { + return fmt.Errorf("second interface %q not found", secondImplementingInterfaceName) + } + + return nil + }, + }, + { + name: "Transitively implemented interfaces must also be defined on an implementing type or interface", + sdl: ` + interface A { + message: String! + } + interface B implements A { + message: String! + name: String! + } + interface C implements B { + message: String! + name: String! + hug: Boolean! + } + `, + validateError: func(err error) error { + msg := `graphql: interface "C" must explicitly implement transitive interface "A"` + if err == nil || err.Error() != msg { + return fmt.Errorf("expected error %q, but got %q", msg, err) + } + return nil + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + s, err := schema.ParseSchema(tt.sdl, tt.useStringDescriptions) + if err != nil { + if tt.validateError == nil { + t.Fatal(err) + } + if err2 := tt.validateError(err); err2 != nil { + t.Fatal(err2) + } + } + if tt.validateSchema != nil { + if err2 := tt.validateSchema(s); err2 != nil { + t.Fatal(err2) + } + } + }) + } +} diff --git a/internal/validation/testdata/tests.json b/internal/validation/testdata/tests.json index 46df80d3f..88211ca60 100644 --- a/internal/validation/testdata/tests.json +++ b/internal/validation/testdata/tests.json @@ -1457,6 +1457,125 @@ "query": "\n query Foo($a: String, $b: String, $c: String) {\n field(a: $a) {\n field(b: $b) {\n field(c: $c)\n }\n }\n }\n ", "errors": [] }, + { + "name": "Validate: No invalid default String variable values", + "rule": "DefaultValuesOfCorrectType", + "schema": 0, + "query": "\n query Foo($a: String = -\"\") {\n field(a: $a)\n }\n ", + "errors": [ + { + "message": "Variable \"$a\" of type \"String\" has invalid default value -\"\".\nExpected type \"String\", found -\"\".", + "locations": [ + { + "line": 2, + "column": 30 + } + ] + } + ] + }, + { + "name": "Validate: No invalid default Int variable values/bad input", + "rule": "DefaultValuesOfCorrectType", + "schema": 0, + "query": "\n query Foo($a: Int = -\"\") {\n field(a: $a)\n }\n ", + "errors": [ + { + "message": "Variable \"$a\" of type \"Int\" has invalid default value -\"\".\nExpected type \"Int\", found -\"\".", + "locations": [ + { + "line": 2, + "column": 27 + } + ] + } + ] + }, + { + "name": "Validate: No invalid default Int variable values/value out of range", + "rule": "DefaultValuesOfCorrectType", + "schema": 0, + "query": "\n query Foo($a: Int = -2147483649) {\n field(a: $a)\n }\n ", + "errors": [ + { + "message": "Variable \"$a\" of type \"Int\" has invalid default value -2147483649.\nExpected type \"Int\", found -2147483649.", + "locations": [ + { + "line": 2, + "column": 27 + } + ] + } + ] + }, + { + "name": "Validate: No invalid default Float variable values", + "rule": "DefaultValuesOfCorrectType", + "schema": 0, + "query": "\n query Foo($a: Float = -\"\") {\n field(a: $a)\n }\n ", + "errors": [ + { + "message": "Variable \"$a\" of type \"Float\" has invalid default value -\"\".\nExpected type \"Float\", found -\"\".", + "locations": [ + { + "line": 2, + "column": 29 + } + ] + } + ] + }, + { + "name": "Validate: No invalid default Float variable values/value out of range", + "rule": "DefaultValuesOfCorrectType", + "schema": 0, + "query": "\n query Foo($a: Float = 1.8e+308) {\n field(a: $a)\n }\n ", + "errors": [ + { + "message": "Variable \"$a\" of type \"Float\" has invalid default value 1.8e+308.\nExpected type \"Float\", found 1.8e+308.", + "locations": [ + { + "line": 2, + "column": 29 + } + ] + } + ] + }, + { + "name": "Validate: No invalid default Boolean variable values", + "rule": "DefaultValuesOfCorrectType", + "schema": 0, + "query": "\n query Foo($a: Boolean = \"false\") {\n field(a: $a)\n }\n ", + "errors": [ + { + "message": "Variable \"$a\" of type \"Boolean\" has invalid default value \"false\".\nExpected type \"Boolean\", found \"false\".", + "locations": [ + { + "line": 2, + "column": 31 + } + ] + } + ] + }, + { + "name": "Validate: No invalid default ID variable values", + "rule": "DefaultValuesOfCorrectType", + "schema": 0, + "query": "\n query Foo($a: ID = false) {\n field(a: $a)\n }\n ", + "errors": [ + { + "message": "Variable \"$a\" of type \"ID\" has invalid default value false.\nExpected type \"ID\", found false.", + "locations": [ + { + "line": 2, + "column": 26 + } + ] + } + ] + }, { "name": "Validate: No unused variables/uses all variables deeply in inline fragments", "rule": "NoUnusedVariables", @@ -1464,6 +1583,13 @@ "query": "\n query Foo($a: String, $b: String, $c: String) {\n ... on Type {\n field(a: $a) {\n field(b: $b) {\n ... on Type {\n field(c: $c)\n }\n }\n }\n }\n }\n ", "errors": [] }, + { + "name": "Validate: fragments are used even when they are nested", + "rule": "NoUnusedFragments", + "schema": 1, + "query": "\n query Foo() {\n ...StringFragment\n stringBox {\n ...StringFragment\n ...StringFragmentPrime\n}\n}\n\n\n fragment StringFragment on StringBox {\n scalar\n}\n\n fragment StringFragmentPrime on StringBox {\n unrelatedField\n}\n", + "errors": [] + }, { "name": "Validate: No unused variables/uses all variables in fragments", "rule": "NoUnusedVariables", @@ -3853,4 +3979,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/internal/validation/validate_max_depth_test.go b/internal/validation/validate_max_depth_test.go index b31523feb..65c96cf8d 100644 --- a/internal/validation/validate_max_depth_test.go +++ b/internal/validation/validate_max_depth_test.go @@ -5,6 +5,7 @@ import ( "github.com/tokopedia/graphql-go/internal/query" "github.com/tokopedia/graphql-go/internal/schema" + "github.com/tokopedia/graphql-go/types" ) const ( @@ -33,6 +34,7 @@ const ( id: ID! name: String! friends: [Character] + enemies: [Character] appearsIn: [Episode]! } @@ -42,12 +44,15 @@ const ( JEDI } - type Starship {} + type Starship { + id: ID! + } type Human implements Character { id: ID! name: String! friends: [Character] + enemies: [Character] appearsIn: [Episode]! starships: [Starship] totalCredits: Int @@ -57,6 +62,7 @@ const ( id: ID! name: String! friends: [Character] + enemies: [Character] appearsIn: [Episode]! primaryFunction: String }` @@ -70,7 +76,7 @@ type maxDepthTestCase struct { expectedErrors []string } -func (tc maxDepthTestCase) Run(t *testing.T, s *schema.Schema) { +func (tc maxDepthTestCase) Run(t *testing.T, s *types.Schema) { t.Run(tc.name, func(t *testing.T) { doc, qErr := query.Parse(tc.query) if qErr != nil { @@ -103,9 +109,7 @@ func (tc maxDepthTestCase) Run(t *testing.T, s *schema.Schema) { } func TestMaxDepth(t *testing.T) { - s := schema.New() - - err := s.Parse(simpleSchema, false) + s, err := schema.ParseSchema(simpleSchema, false) if err != nil { t.Fatal(err) } @@ -179,9 +183,7 @@ func TestMaxDepth(t *testing.T) { } func TestMaxDepthInlineFragments(t *testing.T) { - s := schema.New() - - err := s.Parse(interfaceSimple, false) + s, err := schema.ParseSchema(interfaceSimple, false) if err != nil { t.Fatal(err) } @@ -228,9 +230,7 @@ func TestMaxDepthInlineFragments(t *testing.T) { } func TestMaxDepthFragmentSpreads(t *testing.T) { - s := schema.New() - - err := s.Parse(interfaceSimple, false) + s, err := schema.ParseSchema(interfaceSimple, false) if err != nil { t.Fatal(err) } @@ -309,15 +309,71 @@ func TestMaxDepthFragmentSpreads(t *testing.T) { depth: 6, failure: true, }, + { + name: "spreadAtDifferentDepths", + query: ` + fragment character on Character { + name # depth + 0 + friends { # depth + 0 + name # depth + 1 + } + } + + query laterDepthValidated { + ...character # depth 1 (+1) + enemies { # depth 1 + friends { # depth 2 + ...character # depth 2 (+1), should error! + } + } + } + `, + depth: 2, + failure: true, + }, + { + name: "spreadAtSameDepth", + query: ` + fragment character on Character { + name # depth + 0 + friends { # depth + 0 + name # depth + 1 + } + } + query { + characters { # depth 1 + friends { # depth 2 + ...character # depth 3 (+1) + } + enemies { # depth 2 + ...character # depth 3 (+1) + } + } + } + `, + depth: 4, + }, + { + name: "fragmentCycle", + query: ` + fragment X on Query { ...Y } + fragment Y on Query { ...Z } + fragment Z on Query { ...X } + + query { + ...X + } + `, + depth: 10, + failure: true, + }, } { tc.Run(t, s) } } func TestMaxDepthUnknownFragmentSpreads(t *testing.T) { - s := schema.New() - - err := s.Parse(interfaceSimple, false) + s, err := schema.ParseSchema(interfaceSimple, false) if err != nil { t.Fatal(err) } @@ -350,9 +406,7 @@ func TestMaxDepthUnknownFragmentSpreads(t *testing.T) { } func TestMaxDepthValidation(t *testing.T) { - s := schema.New() - - err := s.Parse(interfaceSimple, false) + s, err := schema.ParseSchema(interfaceSimple, false) if err != nil { t.Fatal(err) } @@ -440,7 +494,7 @@ func TestMaxDepthValidation(t *testing.T) { opc := &opContext{context: context, ops: doc.Operations} - actual := validateMaxDepth(opc, op.Selections, 1) + actual := validateMaxDepth(opc, op.Selections, nil, 1) if actual != tc.expected { t.Errorf("expected %t, actual %t", tc.expected, actual) } diff --git a/internal/validation/validation.go b/internal/validation/validation.go index 2429c2dce..5ee18b480 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -11,25 +11,27 @@ import ( "github.com/tokopedia/graphql-go/errors" "github.com/tokopedia/graphql-go/internal/common" "github.com/tokopedia/graphql-go/internal/query" - "github.com/tokopedia/graphql-go/internal/schema" + "github.com/tokopedia/graphql-go/types" ) -type varSet map[*common.InputValue]struct{} +type varSet map[*types.InputValueDefinition]struct{} -type selectionPair struct{ a, b query.Selection } +type selectionPair struct{ a, b types.Selection } + +type nameSet map[string]errors.Location type fieldInfo struct { - sf *schema.Field - parent schema.NamedType + sf *types.FieldDefinition + parent types.NamedType } type context struct { - schema *schema.Schema - doc *query.Document + schema *types.Schema + doc *types.ExecutableDefinition errs []*errors.QueryError - opErrs map[*query.Operation][]*errors.QueryError - usedVars map[*query.Operation]varSet - fieldMap map[*query.Field]fieldInfo + opErrs map[*types.OperationDefinition][]*errors.QueryError + usedVars map[*types.OperationDefinition]varSet + fieldMap map[*types.Field]fieldInfo overlapValidated map[selectionPair]struct{} maxDepth int } @@ -48,33 +50,33 @@ func (c *context) addErrMultiLoc(locs []errors.Location, rule string, format str type opContext struct { *context - ops []*query.Operation + ops []*types.OperationDefinition } -func newContext(s *schema.Schema, doc *query.Document, maxDepth int) *context { +func newContext(s *types.Schema, doc *types.ExecutableDefinition, maxDepth int) *context { return &context{ schema: s, doc: doc, - opErrs: make(map[*query.Operation][]*errors.QueryError), - usedVars: make(map[*query.Operation]varSet), - fieldMap: make(map[*query.Field]fieldInfo), + opErrs: make(map[*types.OperationDefinition][]*errors.QueryError), + usedVars: make(map[*types.OperationDefinition]varSet), + fieldMap: make(map[*types.Field]fieldInfo), overlapValidated: make(map[selectionPair]struct{}), maxDepth: maxDepth, } } -func Validate(s *schema.Schema, doc *query.Document, variables map[string]interface{}, maxDepth int) []*errors.QueryError { +func Validate(s *types.Schema, doc *types.ExecutableDefinition, variables map[string]interface{}, maxDepth int) []*errors.QueryError { c := newContext(s, doc, maxDepth) opNames := make(nameSet) - fragUsedBy := make(map[*query.FragmentDecl][]*query.Operation) + fragUsedBy := make(map[*types.FragmentDefinition][]*types.OperationDefinition) for _, op := range doc.Operations { c.usedVars[op] = make(varSet) - opc := &opContext{c, []*query.Operation{op}} + opc := &opContext{c, []*types.OperationDefinition{op}} // Check if max depth is exceeded, if it's set. If max depth is exceeded, // don't continue to validate the document and exit early. - if validateMaxDepth(opc, op.Selections, 1) { + if validateMaxDepth(opc, op.Selections, nil, 1) { return c.errs } @@ -101,7 +103,7 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf validateLiteral(opc, v.Default) if t != nil { - if nn, ok := t.(*common.NonNull); ok { + if nn, ok := t.(*types.NonNull); ok { c.addErr(v.Default.Location(), "DefaultValuesOfCorrectType", "Variable %q of type %q is required and will not use the default value. Perhaps you meant to use type %q.", "$"+v.Name.Name, t, nn.OfType) } @@ -112,7 +114,7 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf } } - var entryPoint schema.NamedType + var entryPoint types.NamedType switch op.Type { case query.Query: entryPoint = s.EntryPoints["query"] @@ -126,7 +128,7 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf validateSelectionSet(opc, op.Selections, entryPoint) - fragUsed := make(map[*query.FragmentDecl]struct{}) + fragUsed := make(map[*types.FragmentDefinition]struct{}) markUsedFragments(c, op.Selections, fragUsed) for frag := range fragUsed { fragUsedBy[frag] = append(fragUsedBy[frag], op) @@ -134,7 +136,7 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf } fragNames := make(nameSet) - fragVisited := make(map[*query.FragmentDecl]struct{}) + fragVisited := make(map[*types.FragmentDefinition]struct{}) for _, frag := range doc.Fragments { opc := &opContext{c, fragUsedBy[frag]} @@ -179,15 +181,15 @@ func Validate(s *schema.Schema, doc *query.Document, variables map[string]interf return c.errs } -func validateValue(c *opContext, v *common.InputValue, val interface{}, t common.Type) { +func validateValue(c *opContext, v *types.InputValueDefinition, val interface{}, t types.Type) { switch t := t.(type) { - case *common.NonNull: + case *types.NonNull: if val == nil { c.addErr(v.Loc, "VariablesOfCorrectType", "Variable \"%s\" has invalid value null.\nExpected type \"%s\", found null.", v.Name.Name, t) return } validateValue(c, v, val, t.OfType) - case *common.List: + case *types.List: if val == nil { return } @@ -200,7 +202,7 @@ func validateValue(c *opContext, v *common.InputValue, val interface{}, t common for _, elem := range vv { validateValue(c, v, elem, t.OfType) } - case *schema.Enum: + case *types.EnumTypeDefinition: if val == nil { return } @@ -209,13 +211,13 @@ func validateValue(c *opContext, v *common.InputValue, val interface{}, t common c.addErr(v.Loc, "VariablesOfCorrectType", "Variable \"%s\" has invalid type %T.\nExpected type \"%s\", found %v.", v.Name.Name, val, t, val) return } - for _, option := range t.Values { - if option.Name == e { + for _, option := range t.EnumValuesDefinition { + if option.EnumValue == e { return } } c.addErr(v.Loc, "VariablesOfCorrectType", "Variable \"%s\" has invalid value %s.\nExpected type \"%s\", found %s.", v.Name.Name, e, t, e) - case *schema.InputObject: + case *types.InputObject: if val == nil { return } @@ -233,28 +235,35 @@ func validateValue(c *opContext, v *common.InputValue, val interface{}, t common // validates the query doesn't go deeper than maxDepth (if set). Returns whether // or not query validated max depth to avoid excessive recursion. -func validateMaxDepth(c *opContext, sels []query.Selection, depth int) bool { +// +// The visited map is necessary to ensure that max depth validation does not get stuck in cyclical +// fragment spreads. +func validateMaxDepth(c *opContext, sels []types.Selection, visited map[*types.FragmentDefinition]struct{}, depth int) bool { // maxDepth checking is turned off when maxDepth is 0 if c.maxDepth == 0 { return false } exceededMaxDepth := false + if visited == nil { + visited = map[*types.FragmentDefinition]struct{}{} + } for _, sel := range sels { switch sel := sel.(type) { - case *query.Field: + case *types.Field: if depth > c.maxDepth { exceededMaxDepth = true c.addErr(sel.Alias.Loc, "MaxDepthExceeded", "Field %q has depth %d that exceeds max depth %d", sel.Name.Name, depth, c.maxDepth) continue } - exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.Selections, depth+1) - case *query.InlineFragment: + exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.SelectionSet, visited, depth+1) + + case *types.InlineFragment: // Depth is not checked because inline fragments resolve to other fields which are checked. // Depth is not incremented because inline fragments have the same depth as neighboring fields - exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.Selections, depth) - case *query.FragmentSpread: + exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.Selections, visited, depth) + case *types.FragmentSpread: // Depth is not checked because fragments resolve to other fields which are checked. frag := c.doc.Fragments.Get(sel.Name.Name) if frag == nil { @@ -262,15 +271,22 @@ func validateMaxDepth(c *opContext, sels []query.Selection, depth int) bool { c.addErr(sel.Loc, "MaxDepthEvaluationError", "Unknown fragment %q. Unable to evaluate depth.", sel.Name.Name) continue } + + if _, ok := visited[frag]; ok { + // we've already seen this fragment, don't check depth again. + continue + } + visited[frag] = struct{}{} + // Depth is not incremented because fragments have the same depth as surrounding fields - exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, frag.Selections, depth) + exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, frag.Selections, visited, depth) } } return exceededMaxDepth } -func validateSelectionSet(c *opContext, sels []query.Selection, t schema.NamedType) { +func validateSelectionSet(c *opContext, sels []types.Selection, t types.NamedType) { for _, sel := range sels { validateSelection(c, sel, t) } @@ -282,35 +298,40 @@ func validateSelectionSet(c *opContext, sels []query.Selection, t schema.NamedTy } } -func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { +func validateSelection(c *opContext, sel types.Selection, t types.NamedType) { switch sel := sel.(type) { - case *query.Field: + case *types.Field: validateDirectives(c, "FIELD", sel.Directives) fieldName := sel.Name.Name - var f *schema.Field + var f *types.FieldDefinition switch fieldName { case "__typename": - f = &schema.Field{ + f = &types.FieldDefinition{ Name: "__typename", Type: c.schema.Types["String"], } case "__schema": - f = &schema.Field{ + f = &types.FieldDefinition{ Name: "__schema", Type: c.schema.Types["__Schema"], } case "__type": - f = &schema.Field{ + f = &types.FieldDefinition{ Name: "__type", - Args: common.InputValueList{ - &common.InputValue{ - Name: common.Ident{Name: "name"}, - Type: &common.NonNull{OfType: c.schema.Types["String"]}, + Arguments: types.ArgumentsDefinition{ + &types.InputValueDefinition{ + Name: types.Ident{Name: "name"}, + Type: &types.NonNull{OfType: c.schema.Types["String"]}, }, }, Type: c.schema.Types["__Type"], } + case "_service": + f = &types.FieldDefinition{ + Name: "_service", + Type: c.schema.Types["_Service"], + } default: f = fields(t).Get(fieldName) if f == nil && t != nil { @@ -322,28 +343,28 @@ func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { validateArgumentLiterals(c, sel.Arguments) if f != nil { - validateArgumentTypes(c, sel.Arguments, f.Args, sel.Alias.Loc, + validateArgumentTypes(c, sel.Arguments, f.Arguments, sel.Alias.Loc, func() string { return fmt.Sprintf("field %q of type %q", fieldName, t) }, func() string { return fmt.Sprintf("Field %q", fieldName) }, ) } - var ft common.Type + var ft types.Type if f != nil { ft = f.Type sf := hasSubfields(ft) - if sf && sel.Selections == nil { + if sf && sel.SelectionSet == nil { c.addErr(sel.Alias.Loc, "ScalarLeafs", "Field %q of type %q must have a selection of subfields. Did you mean \"%s { ... }\"?", fieldName, ft, fieldName) } - if !sf && sel.Selections != nil { + if !sf && sel.SelectionSet != nil { c.addErr(sel.SelectionSetLoc, "ScalarLeafs", "Field %q must not have a selection since type %q has no subfields.", fieldName, ft) } } - if sel.Selections != nil { - validateSelectionSet(c, sel.Selections, unwrapType(ft)) + if sel.SelectionSet != nil { + validateSelectionSet(c, sel.SelectionSet, unwrapType(ft)) } - case *query.InlineFragment: + case *types.InlineFragment: validateDirectives(c, "INLINE_FRAGMENT", sel.Directives) if sel.On.Name != "" { fragTyp := unwrapType(resolveType(c.context, &sel.On)) @@ -359,7 +380,7 @@ func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { } validateSelectionSet(c, sel.Selections, unwrapType(t)) - case *query.FragmentSpread: + case *types.FragmentSpread: validateDirectives(c, "FRAGMENT_SPREAD", sel.Directives) frag := c.doc.Fragments.Get(sel.Name.Name) if frag == nil { @@ -376,7 +397,7 @@ func validateSelection(c *opContext, sel query.Selection, t schema.NamedType) { } } -func compatible(a, b common.Type) bool { +func compatible(a, b types.Type) bool { for _, pta := range possibleTypes(a) { for _, ptb := range possibleTypes(b) { if pta == ptb { @@ -387,39 +408,40 @@ func compatible(a, b common.Type) bool { return false } -func possibleTypes(t common.Type) []*schema.Object { +func possibleTypes(t types.Type) []*types.ObjectTypeDefinition { switch t := t.(type) { - case *schema.Object: - return []*schema.Object{t} - case *schema.Interface: - return t.PossibleTypes - case *schema.Union: + case *types.ObjectTypeDefinition: + return []*types.ObjectTypeDefinition{t} + case *types.InterfaceTypeDefinition: return t.PossibleTypes + case *types.Union: + return t.UnionMemberTypes default: return nil } } -func markUsedFragments(c *context, sels []query.Selection, fragUsed map[*query.FragmentDecl]struct{}) { +func markUsedFragments(c *context, sels []types.Selection, fragUsed map[*types.FragmentDefinition]struct{}) { for _, sel := range sels { switch sel := sel.(type) { - case *query.Field: - if sel.Selections != nil { - markUsedFragments(c, sel.Selections, fragUsed) + case *types.Field: + if sel.SelectionSet != nil { + markUsedFragments(c, sel.SelectionSet, fragUsed) } - case *query.InlineFragment: + case *types.InlineFragment: markUsedFragments(c, sel.Selections, fragUsed) - case *query.FragmentSpread: + case *types.FragmentSpread: frag := c.doc.Fragments.Get(sel.Name.Name) if frag == nil { return } if _, ok := fragUsed[frag]; ok { - return + continue } + fragUsed[frag] = struct{}{} markUsedFragments(c, frag.Selections, fragUsed) @@ -429,23 +451,23 @@ func markUsedFragments(c *context, sels []query.Selection, fragUsed map[*query.F } } -func detectFragmentCycle(c *context, sels []query.Selection, fragVisited map[*query.FragmentDecl]struct{}, spreadPath []*query.FragmentSpread, spreadPathIndex map[string]int) { +func detectFragmentCycle(c *context, sels []types.Selection, fragVisited map[*types.FragmentDefinition]struct{}, spreadPath []*types.FragmentSpread, spreadPathIndex map[string]int) { for _, sel := range sels { detectFragmentCycleSel(c, sel, fragVisited, spreadPath, spreadPathIndex) } } -func detectFragmentCycleSel(c *context, sel query.Selection, fragVisited map[*query.FragmentDecl]struct{}, spreadPath []*query.FragmentSpread, spreadPathIndex map[string]int) { +func detectFragmentCycleSel(c *context, sel types.Selection, fragVisited map[*types.FragmentDefinition]struct{}, spreadPath []*types.FragmentSpread, spreadPathIndex map[string]int) { switch sel := sel.(type) { - case *query.Field: - if sel.Selections != nil { - detectFragmentCycle(c, sel.Selections, fragVisited, spreadPath, spreadPathIndex) + case *types.Field: + if sel.SelectionSet != nil { + detectFragmentCycle(c, sel.SelectionSet, fragVisited, spreadPath, spreadPathIndex) } - case *query.InlineFragment: + case *types.InlineFragment: detectFragmentCycle(c, sel.Selections, fragVisited, spreadPath, spreadPathIndex) - case *query.FragmentSpread: + case *types.FragmentSpread: frag := c.doc.Fragments.Get(sel.Name.Name) if frag == nil { return @@ -485,7 +507,7 @@ func detectFragmentCycleSel(c *context, sel query.Selection, fragVisited map[*qu } } -func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs *[]errors.Location) { +func (c *context) validateOverlap(a, b types.Selection, reasons *[]string, locs *[]errors.Location) { if a == b { return } @@ -497,9 +519,9 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs c.overlapValidated[selectionPair{b, a}] = struct{}{} switch a := a.(type) { - case *query.Field: + case *types.Field: switch b := b.(type) { - case *query.Field: + case *types.Field: if b.Alias.Loc.Before(a.Alias.Loc) { a, b = b, a } @@ -515,12 +537,12 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs *locs = append(*locs, locs2...) } - case *query.InlineFragment: + case *types.InlineFragment: for _, sel := range b.Selections { c.validateOverlap(a, sel, reasons, locs) } - case *query.FragmentSpread: + case *types.FragmentSpread: if frag := c.doc.Fragments.Get(b.Name.Name); frag != nil { for _, sel := range frag.Selections { c.validateOverlap(a, sel, reasons, locs) @@ -531,12 +553,12 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs panic("unreachable") } - case *query.InlineFragment: + case *types.InlineFragment: for _, sel := range a.Selections { c.validateOverlap(sel, b, reasons, locs) } - case *query.FragmentSpread: + case *types.FragmentSpread: if frag := c.doc.Fragments.Get(a.Name.Name); frag != nil { for _, sel := range frag.Selections { c.validateOverlap(sel, b, reasons, locs) @@ -548,7 +570,7 @@ func (c *context) validateOverlap(a, b query.Selection, reasons *[]string, locs } } -func (c *context) validateFieldOverlap(a, b *query.Field) ([]string, []errors.Location) { +func (c *context) validateFieldOverlap(a, b *types.Field) ([]string, []errors.Location) { if a.Alias.Name != b.Alias.Name { return nil, nil } @@ -575,49 +597,49 @@ func (c *context) validateFieldOverlap(a, b *query.Field) ([]string, []errors.Lo var reasons []string var locs []errors.Location - for _, a2 := range a.Selections { - for _, b2 := range b.Selections { + for _, a2 := range a.SelectionSet { + for _, b2 := range b.SelectionSet { c.validateOverlap(a2, b2, &reasons, &locs) } } return reasons, locs } -func argumentsConflict(a, b common.ArgumentList) bool { +func argumentsConflict(a, b types.ArgumentList) bool { if len(a) != len(b) { return true } for _, argA := range a { valB, ok := b.Get(argA.Name.Name) - if !ok || !reflect.DeepEqual(argA.Value.Value(nil), valB.Value(nil)) { + if !ok || !reflect.DeepEqual(argA.Value.Deserialize(nil), valB.Deserialize(nil)) { return true } } return false } -func fields(t common.Type) schema.FieldList { +func fields(t types.Type) types.FieldsDefinition { switch t := t.(type) { - case *schema.Object: + case *types.ObjectTypeDefinition: return t.Fields - case *schema.Interface: + case *types.InterfaceTypeDefinition: return t.Fields default: return nil } } -func unwrapType(t common.Type) schema.NamedType { +func unwrapType(t types.Type) types.NamedType { if t == nil { return nil } for { switch t2 := t.(type) { - case schema.NamedType: + case types.NamedType: return t2 - case *common.List: + case *types.List: t = t2.OfType - case *common.NonNull: + case *types.NonNull: t = t2.OfType default: panic("unreachable") @@ -625,7 +647,7 @@ func unwrapType(t common.Type) schema.NamedType { } } -func resolveType(c *context, t common.Type) common.Type { +func resolveType(c *context, t types.Type) types.Type { t2, err := common.ResolveType(t, c.schema.Resolve) if err != nil { c.errs = append(c.errs, err) @@ -633,7 +655,7 @@ func resolveType(c *context, t common.Type) common.Type { return t2 } -func validateDirectives(c *opContext, loc string, directives common.DirectiveList) { +func validateDirectives(c *opContext, loc string, directives types.DirectiveList) { directiveNames := make(nameSet) for _, d := range directives { dirName := d.Name.Name @@ -641,7 +663,7 @@ func validateDirectives(c *opContext, loc string, directives common.DirectiveLis return fmt.Sprintf("The directive %q can only be used once at this location.", dirName) }) - validateArgumentLiterals(c, d.Args) + validateArgumentLiterals(c, d.Arguments) dd, ok := c.schema.Directives[dirName] if !ok { @@ -650,7 +672,7 @@ func validateDirectives(c *opContext, loc string, directives common.DirectiveLis } locOK := false - for _, allowedLoc := range dd.Locs { + for _, allowedLoc := range dd.Locations { if loc == allowedLoc { locOK = true break @@ -660,22 +682,20 @@ func validateDirectives(c *opContext, loc string, directives common.DirectiveLis c.addErr(d.Name.Loc, "KnownDirectives", "Directive %q may not be used on %s.", dirName, loc) } - validateArgumentTypes(c, d.Args, dd.Args, d.Name.Loc, + validateArgumentTypes(c, d.Arguments, dd.Arguments, d.Name.Loc, func() string { return fmt.Sprintf("directive %q", "@"+dirName) }, func() string { return fmt.Sprintf("Directive %q", "@"+dirName) }, ) } } -type nameSet map[string]errors.Location - -func validateName(c *context, set nameSet, name common.Ident, rule string, kind string) { +func validateName(c *context, set nameSet, name types.Ident, rule string, kind string) { validateNameCustomMsg(c, set, name, rule, func() string { return fmt.Sprintf("There can be only one %s named %q.", kind, name.Name) }) } -func validateNameCustomMsg(c *context, set nameSet, name common.Ident, rule string, msg func() string) { +func validateNameCustomMsg(c *context, set nameSet, name types.Ident, rule string, msg func() string) { if loc, ok := set[name.Name]; ok { c.addErrMultiLoc([]errors.Location{loc, name.Loc}, rule, msg()) return @@ -683,7 +703,7 @@ func validateNameCustomMsg(c *context, set nameSet, name common.Ident, rule stri set[name.Name] = name.Loc } -func validateArgumentTypes(c *opContext, args common.ArgumentList, argDecls common.InputValueList, loc errors.Location, owner1, owner2 func() string) { +func validateArgumentTypes(c *opContext, args types.ArgumentList, argDecls types.ArgumentsDefinition, loc errors.Location, owner1, owner2 func() string) { for _, selArg := range args { arg := argDecls.Get(selArg.Name.Name) if arg == nil { @@ -696,7 +716,7 @@ func validateArgumentTypes(c *opContext, args common.ArgumentList, argDecls comm } } for _, decl := range argDecls { - if _, ok := decl.Type.(*common.NonNull); ok { + if _, ok := decl.Type.(*types.NonNull); ok { if _, ok := args.Get(decl.Name.Name); !ok { c.addErr(loc, "ProvidedNonNullArguments", "%s argument %q of type %q is required but not provided.", owner2(), decl.Name.Name, decl.Type) } @@ -704,7 +724,7 @@ func validateArgumentTypes(c *opContext, args common.ArgumentList, argDecls comm } } -func validateArgumentLiterals(c *opContext, args common.ArgumentList) { +func validateArgumentLiterals(c *opContext, args types.ArgumentList) { argNames := make(nameSet) for _, arg := range args { validateName(c.context, argNames, arg.Name, "UniqueArgumentNames", "argument") @@ -712,19 +732,19 @@ func validateArgumentLiterals(c *opContext, args common.ArgumentList) { } } -func validateLiteral(c *opContext, l common.Literal) { +func validateLiteral(c *opContext, l types.Value) { switch l := l.(type) { - case *common.ObjectLit: + case *types.ObjectValue: fieldNames := make(nameSet) for _, f := range l.Fields { validateName(c.context, fieldNames, f.Name, "UniqueInputFieldNames", "input field") validateLiteral(c, f.Value) } - case *common.ListLit: - for _, entry := range l.Entries { + case *types.ListValue: + for _, entry := range l.Values { validateLiteral(c, entry) } - case *common.Variable: + case *types.Variable: for _, op := range c.ops { v := op.Vars.Get(l.Name) if v == nil { @@ -745,13 +765,13 @@ func validateLiteral(c *opContext, l common.Literal) { } } -func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, string) { - if v, ok := v.(*common.Variable); ok { +func validateValueType(c *opContext, v types.Value, t types.Type) (bool, string) { + if v, ok := v.(*types.Variable); ok { for _, op := range c.ops { if v2 := op.Vars.Get(v.Name); v2 != nil { t2, err := common.ResolveType(v2.Type, c.schema.Resolve) - if _, ok := t2.(*common.NonNull); !ok && v2.Default != nil { - t2 = &common.NonNull{OfType: t2} + if _, ok := t2.(*types.NonNull); !ok && v2.Default != nil { + t2 = &types.NonNull{OfType: t2} } if err == nil && !typeCanBeUsedAs(t2, t) { c.addErrMultiLoc([]errors.Location{v2.Loc, v.Loc}, "VariablesInAllowedPosition", "Variable %q of type %q used in position expecting type %q.", "$"+v.Name, t2, t) @@ -761,7 +781,7 @@ func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, str return true, "" } - if nn, ok := t.(*common.NonNull); ok { + if nn, ok := t.(*types.NonNull); ok { if isNull(v) { return false, fmt.Sprintf("Expected %q, found null.", t) } @@ -772,27 +792,29 @@ func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, str } switch t := t.(type) { - case *schema.Scalar, *schema.Enum: - if lit, ok := v.(*common.BasicLit); ok { + case *types.ScalarTypeDefinition, *types.EnumTypeDefinition: + if lit, ok := v.(*types.PrimitiveValue); ok { if validateBasicLit(lit, t) { return true, "" } + return false, fmt.Sprintf("Expected type %q, found %s.", t, v) } + return true, "" - case *common.List: - list, ok := v.(*common.ListLit) + case *types.List: + list, ok := v.(*types.ListValue) if !ok { return validateValueType(c, v, t.OfType) // single value instead of list } - for i, entry := range list.Entries { + for i, entry := range list.Values { if ok, reason := validateValueType(c, entry, t.OfType); !ok { return false, fmt.Sprintf("In element #%d: %s", i, reason) } } return true, "" - case *schema.InputObject: - v, ok := v.(*common.ObjectLit) + case *types.InputObject: + v, ok := v.(*types.ObjectValue) if !ok { return false, fmt.Sprintf("Expected %q, found not an object.", t) } @@ -815,7 +837,7 @@ func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, str } } if !found { - if _, ok := iv.Type.(*common.NonNull); ok && iv.Default == nil { + if _, ok := iv.Type.(*types.NonNull); ok && iv.Default == nil { return false, fmt.Sprintf("In field %q: Expected %q, found null.", iv.Name.Name, iv.Type) } } @@ -826,38 +848,34 @@ func validateValueType(c *opContext, v common.Literal, t common.Type) (bool, str return false, fmt.Sprintf("Expected type %q, found %s.", t, v) } -func validateBasicLit(v *common.BasicLit, t common.Type) bool { +func validateBasicLit(v *types.PrimitiveValue, t types.Type) bool { switch t := t.(type) { - case *schema.Scalar: + case *types.ScalarTypeDefinition: switch t.Name { case "Int": if v.Type != scanner.Int { return false } - f, err := strconv.ParseFloat(v.Text, 64) - if err != nil { - panic(err) - } - return f >= math.MinInt32 && f <= math.MaxInt32 + return validateBuiltInScalar(v.Text, "Int") case "Float": - return v.Type == scanner.Int || v.Type == scanner.Float + return (v.Type == scanner.Int || v.Type == scanner.Float) && validateBuiltInScalar(v.Text, "Float") case "String": - return v.Type == scanner.String + return v.Type == scanner.String && validateBuiltInScalar(v.Text, "String") case "Boolean": - return v.Type == scanner.Ident && (v.Text == "true" || v.Text == "false") + return v.Type == scanner.Ident && validateBuiltInScalar(v.Text, "Boolean") case "ID": - return v.Type == scanner.Int || v.Type == scanner.String + return (v.Type == scanner.Int && validateBuiltInScalar(v.Text, "Int")) || (v.Type == scanner.String && validateBuiltInScalar(v.Text, "String")) default: //TODO: Type-check against expected type by Unmarshalling return true } - case *schema.Enum: + case *types.EnumTypeDefinition: if v.Type != scanner.Ident { return false } - for _, option := range t.Values { - if option.Name == v.Text { + for _, option := range t.EnumValuesDefinition { + if option.EnumValue == v.Text { return true } } @@ -867,44 +885,65 @@ func validateBasicLit(v *common.BasicLit, t common.Type) bool { return false } -func canBeFragment(t common.Type) bool { +func validateBuiltInScalar(v string, n string) bool { + switch n { + case "Int": + f, err := strconv.ParseFloat(v, 64) + if err != nil { + return false + } + return f >= math.MinInt32 && f <= math.MaxInt32 + case "Float": + f, fe := strconv.ParseFloat(v, 64) + return fe == nil && f >= math.SmallestNonzeroFloat64 && f <= math.MaxFloat64 + case "String": + vl := len(v) + return vl >= 2 && v[0] == '"' && v[vl-1] == '"' + case "Boolean": + return v == "true" || v == "false" + default: + return false + } +} + +func canBeFragment(t types.Type) bool { switch t.(type) { - case *schema.Object, *schema.Interface, *schema.Union: + case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: return true default: return false } } -func canBeInput(t common.Type) bool { +func canBeInput(t types.Type) bool { switch t := t.(type) { - case *schema.InputObject, *schema.Scalar, *schema.Enum: + case *types.InputObject, *types.ScalarTypeDefinition, *types.EnumTypeDefinition: return true - case *common.List: + case *types.List: return canBeInput(t.OfType) - case *common.NonNull: + case *types.NonNull: return canBeInput(t.OfType) default: return false } } -func hasSubfields(t common.Type) bool { +func hasSubfields(t types.Type) bool { switch t := t.(type) { - case *schema.Object, *schema.Interface, *schema.Union: + case *types.ObjectTypeDefinition, *types.InterfaceTypeDefinition, *types.Union: return true - case *common.List: + case *types.List: return hasSubfields(t.OfType) - case *common.NonNull: + case *types.NonNull: return hasSubfields(t.OfType) default: return false } } -func isLeaf(t common.Type) bool { +func isLeaf(t types.Type) bool { switch t.(type) { - case *schema.Scalar, *schema.Enum: + case *types.ScalarTypeDefinition, *types.EnumTypeDefinition: return true default: return false @@ -912,19 +951,19 @@ func isLeaf(t common.Type) bool { } func isNull(lit interface{}) bool { - _, ok := lit.(*common.NullLit) + _, ok := lit.(*types.NullValue) return ok } -func typesCompatible(a, b common.Type) bool { - al, aIsList := a.(*common.List) - bl, bIsList := b.(*common.List) +func typesCompatible(a, b types.Type) bool { + al, aIsList := a.(*types.List) + bl, bIsList := b.(*types.List) if aIsList || bIsList { return aIsList && bIsList && typesCompatible(al.OfType, bl.OfType) } - ann, aIsNN := a.(*common.NonNull) - bnn, bIsNN := b.(*common.NonNull) + ann, aIsNN := a.(*types.NonNull) + bnn, bIsNN := b.(*types.NonNull) if aIsNN || bIsNN { return aIsNN && bIsNN && typesCompatible(ann.OfType, bnn.OfType) } @@ -936,13 +975,13 @@ func typesCompatible(a, b common.Type) bool { return true } -func typeCanBeUsedAs(t, as common.Type) bool { - nnT, okT := t.(*common.NonNull) +func typeCanBeUsedAs(t, as types.Type) bool { + nnT, okT := t.(*types.NonNull) if okT { t = nnT.OfType } - nnAs, okAs := as.(*common.NonNull) + nnAs, okAs := as.(*types.NonNull) if okAs { as = nnAs.OfType if !okT { @@ -954,8 +993,8 @@ func typeCanBeUsedAs(t, as common.Type) bool { return true } - if lT, ok := t.(*common.List); ok { - if lAs, ok := as.(*common.List); ok { + if lT, ok := t.(*types.List); ok { + if lAs, ok := as.(*types.List); ok { return typeCanBeUsedAs(lT.OfType, lAs.OfType) } } diff --git a/internal/validation/validation_test.go b/internal/validation/validation_test.go index b50555697..e74eb4b82 100644 --- a/internal/validation/validation_test.go +++ b/internal/validation/validation_test.go @@ -12,6 +12,7 @@ import ( "github.com/tokopedia/graphql-go/internal/query" "github.com/tokopedia/graphql-go/internal/schema" "github.com/tokopedia/graphql-go/internal/validation" + "github.com/tokopedia/graphql-go/types" ) type Test struct { @@ -37,10 +38,11 @@ func TestValidate(t *testing.T) { t.Fatal(err) } - schemas := make([]*schema.Schema, len(testData.Schemas)) + schemas := make([]*types.Schema, len(testData.Schemas)) for i, schemaStr := range testData.Schemas { schemas[i] = schema.New() - if err := schemas[i].Parse(schemaStr, false); err != nil { + err := schema.Parse(schemas[i], schemaStr, false) + if err != nil { t.Fatal(err) } } diff --git a/introspection/introspection.go b/introspection/introspection.go index 20aa091b1..7c292a62f 100644 --- a/introspection/introspection.go +++ b/introspection/introspection.go @@ -3,16 +3,15 @@ package introspection import ( "sort" - "github.com/tokopedia/graphql-go/internal/common" - "github.com/tokopedia/graphql-go/internal/schema" + "github.com/tokopedia/graphql-go/types" ) type Schema struct { - schema *schema.Schema + schema *types.Schema } // WrapSchema is only used internally. -func WrapSchema(schema *schema.Schema) *Schema { +func WrapSchema(schema *types.Schema) *Schema { return &Schema{schema} } @@ -69,11 +68,11 @@ func (r *Schema) SubscriptionType() *Type { } type Type struct { - typ common.Type + typ types.Type } // WrapType is only used internally. -func WrapType(typ common.Type) *Type { +func WrapType(typ types.Type) *Type { return &Type{typ} } @@ -82,7 +81,7 @@ func (r *Type) Kind() string { } func (r *Type) Name() *string { - if named, ok := r.typ.(schema.NamedType); ok { + if named, ok := r.typ.(types.NamedType); ok { name := named.TypeName() return &name } @@ -90,7 +89,7 @@ func (r *Type) Name() *string { } func (r *Type) Description() *string { - if named, ok := r.typ.(schema.NamedType); ok { + if named, ok := r.typ.(types.NamedType); ok { desc := named.Description() if desc == "" { return nil @@ -101,11 +100,11 @@ func (r *Type) Description() *string { } func (r *Type) Fields(args *struct{ IncludeDeprecated bool }) *[]*Field { - var fields schema.FieldList + var fields types.FieldsDefinition switch t := r.typ.(type) { - case *schema.Object: + case *types.ObjectTypeDefinition: fields = t.Fields - case *schema.Interface: + case *types.InterfaceTypeDefinition: fields = t.Fields default: return nil @@ -114,14 +113,14 @@ func (r *Type) Fields(args *struct{ IncludeDeprecated bool }) *[]*Field { var l []*Field for _, f := range fields { if d := f.Directives.Get("deprecated"); d == nil || args.IncludeDeprecated { - l = append(l, &Field{f}) + l = append(l, &Field{field: f}) } } return &l } func (r *Type) Interfaces() *[]*Type { - t, ok := r.typ.(*schema.Object) + t, ok := r.typ.(*types.ObjectTypeDefinition) if !ok { return nil } @@ -134,12 +133,12 @@ func (r *Type) Interfaces() *[]*Type { } func (r *Type) PossibleTypes() *[]*Type { - var possibleTypes []*schema.Object + var possibleTypes []*types.ObjectTypeDefinition switch t := r.typ.(type) { - case *schema.Interface: - possibleTypes = t.PossibleTypes - case *schema.Union: + case *types.InterfaceTypeDefinition: possibleTypes = t.PossibleTypes + case *types.Union: + possibleTypes = t.UnionMemberTypes default: return nil } @@ -152,13 +151,13 @@ func (r *Type) PossibleTypes() *[]*Type { } func (r *Type) EnumValues(args *struct{ IncludeDeprecated bool }) *[]*EnumValue { - t, ok := r.typ.(*schema.Enum) + t, ok := r.typ.(*types.EnumTypeDefinition) if !ok { return nil } var l []*EnumValue - for _, v := range t.Values { + for _, v := range t.EnumValuesDefinition { if d := v.Directives.Get("deprecated"); d == nil || args.IncludeDeprecated { l = append(l, &EnumValue{v}) } @@ -167,7 +166,7 @@ func (r *Type) EnumValues(args *struct{ IncludeDeprecated bool }) *[]*EnumValue } func (r *Type) InputFields() *[]*InputValue { - t, ok := r.typ.(*schema.InputObject) + t, ok := r.typ.(*types.InputObject) if !ok { return nil } @@ -181,9 +180,9 @@ func (r *Type) InputFields() *[]*InputValue { func (r *Type) OfType() *Type { switch t := r.typ.(type) { - case *common.List: + case *types.List: return &Type{t.OfType} - case *common.NonNull: + case *types.NonNull: return &Type{t.OfType} default: return nil @@ -191,7 +190,7 @@ func (r *Type) OfType() *Type { } type Field struct { - field *schema.Field + field *types.FieldDefinition } func (r *Field) Name() string { @@ -206,8 +205,8 @@ func (r *Field) Description() *string { } func (r *Field) Args() []*InputValue { - l := make([]*InputValue, len(r.field.Args)) - for i, v := range r.field.Args { + l := make([]*InputValue, len(r.field.Arguments)) + for i, v := range r.field.Arguments { l[i] = &InputValue{v} } return l @@ -226,12 +225,12 @@ func (r *Field) DeprecationReason() *string { if d == nil { return nil } - reason := d.Args.MustGet("reason").Value(nil).(string) + reason := d.Arguments.MustGet("reason").Deserialize(nil).(string) return &reason } type InputValue struct { - value *common.InputValue + value *types.InputValueDefinition } func (r *InputValue) Name() string { @@ -258,11 +257,11 @@ func (r *InputValue) DefaultValue() *string { } type EnumValue struct { - value *schema.EnumValue + value *types.EnumValueDefinition } func (r *EnumValue) Name() string { - return r.value.Name + return r.value.EnumValue } func (r *EnumValue) Description() *string { @@ -281,12 +280,12 @@ func (r *EnumValue) DeprecationReason() *string { if d == nil { return nil } - reason := d.Args.MustGet("reason").Value(nil).(string) + reason := d.Arguments.MustGet("reason").Deserialize(nil).(string) return &reason } type Directive struct { - directive *schema.DirectiveDecl + directive *types.DirectiveDefinition } func (r *Directive) Name() string { @@ -301,13 +300,26 @@ func (r *Directive) Description() *string { } func (r *Directive) Locations() []string { - return r.directive.Locs + return r.directive.Locations } func (r *Directive) Args() []*InputValue { - l := make([]*InputValue, len(r.directive.Args)) - for i, v := range r.directive.Args { + l := make([]*InputValue, len(r.directive.Arguments)) + for i, v := range r.directive.Arguments { l[i] = &InputValue{v} } return l } + +type Service struct { + schema *types.Schema +} + +// WrapService is only used internally. +func WrapService(schema *types.Schema) *Service { + return &Service{schema} +} + +func (r *Service) SDL() string { + return r.schema.SchemaString +} diff --git a/introspection_test.go b/introspection_test.go index e66e0fdc2..c3e13b16b 100644 --- a/introspection_test.go +++ b/introspection_test.go @@ -11,6 +11,8 @@ import ( "github.com/tokopedia/graphql-go/example/starwars" ) +var socialSchema = graphql.MustParseSchema(social.Schema, &social.Resolver{}, graphql.UseFieldResolvers()) + func TestSchema_ToJSON(t *testing.T) { t.Parallel() @@ -27,7 +29,7 @@ func TestSchema_ToJSON(t *testing.T) { }{ { Name: "Social Schema", - Args: args{Schema: graphql.MustParseSchema(social.Schema, &social.Resolver{}, graphql.UseFieldResolvers())}, + Args: args{Schema: socialSchema}, Want: want{JSON: mustReadFile("example/social/introspect.json")}, }, { diff --git a/log/log.go b/log/log.go index 25569af7c..bdada8742 100644 --- a/log/log.go +++ b/log/log.go @@ -15,9 +15,9 @@ type Logger interface { type DefaultLogger struct{} // LogPanic is used to log recovered panic values that occur during query execution -func (l *DefaultLogger) LogPanic(_ context.Context, value interface{}) { +func (l *DefaultLogger) LogPanic(ctx context.Context, value interface{}) { const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] - log.Printf("graphql: panic occurred: %v\n%s", value, buf) + log.Printf("graphql: panic occurred: %v\n%s\ncontext: %v", value, buf, ctx) } diff --git a/nullable_types.go b/nullable_types.go new file mode 100644 index 000000000..fa5bbfd62 --- /dev/null +++ b/nullable_types.go @@ -0,0 +1,166 @@ +package graphql + +import ( + "fmt" + "math" +) + +// NullString is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullString struct { + Value *string + Set bool +} + +func (NullString) ImplementsGraphQLType(name string) bool { + return name == "String" +} + +func (s *NullString) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + switch v := input.(type) { + case string: + s.Value = &v + return nil + default: + return fmt.Errorf("wrong type for String: %T", v) + } +} + +func (s *NullString) Nullable() {} + +// NullBool is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullBool struct { + Value *bool + Set bool +} + +func (NullBool) ImplementsGraphQLType(name string) bool { + return name == "Boolean" +} + +func (s *NullBool) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + switch v := input.(type) { + case bool: + s.Value = &v + return nil + default: + return fmt.Errorf("wrong type for Boolean: %T", v) + } +} + +func (s *NullBool) Nullable() {} + +// NullInt is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullInt struct { + Value *int32 + Set bool +} + +func (NullInt) ImplementsGraphQLType(name string) bool { + return name == "Int" +} + +func (s *NullInt) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + switch v := input.(type) { + case int32: + s.Value = &v + return nil + case float64: + coerced := int32(v) + if v < math.MinInt32 || v > math.MaxInt32 || float64(coerced) != v { + return fmt.Errorf("not a 32-bit integer") + } + s.Value = &coerced + return nil + default: + return fmt.Errorf("wrong type for Int: %T", v) + } +} + +func (s *NullInt) Nullable() {} + +// NullFloat is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullFloat struct { + Value *float64 + Set bool +} + +func (NullFloat) ImplementsGraphQLType(name string) bool { + return name == "Float" +} + +func (s *NullFloat) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + switch v := input.(type) { + case float64: + s.Value = &v + return nil + case int32: + coerced := float64(v) + s.Value = &coerced + return nil + case int: + coerced := float64(v) + s.Value = &coerced + return nil + default: + return fmt.Errorf("wrong type for Float: %T", v) + } +} + +func (s *NullFloat) Nullable() {} + +// NullTime is a string that can be null. Use it in input structs to +// differentiate a value explicitly set to null from an omitted value. +// When the value is defined (either null or a value) Set is true. +type NullTime struct { + Value *Time + Set bool +} + +func (NullTime) ImplementsGraphQLType(name string) bool { + return name == "Time" +} + +func (s *NullTime) UnmarshalGraphQL(input interface{}) error { + s.Set = true + + if input == nil { + return nil + } + + s.Value = new(Time) + return s.Value.UnmarshalGraphQL(input) +} + +func (s *NullTime) Nullable() {} diff --git a/nullable_types_test.go b/nullable_types_test.go new file mode 100644 index 000000000..cb4a055d0 --- /dev/null +++ b/nullable_types_test.go @@ -0,0 +1,213 @@ +package graphql_test + +import ( + "math" + "testing" + + "github.com/tokopedia/graphql-go/decode" +) + +func TestNullInt_ImplementsUnmarshaler(t *testing.T) { + defer func() { + if err := recover(); err != nil { + t.Error(err) + } + }() + + // assert *NullInt implements decode.Unmarshaler interface + var _ decode.Unmarshaler = (*NullInt)(nil) +} + +func TestNullInt_UnmarshalGraphQL(t *testing.T) { + type args struct { + input interface{} + } + + a := float64(math.MaxInt32 + 1) + b := float64(math.MinInt32 - 1) + c := 1234.6 + good := int32(1234) + ref := NullInt{ + Value: &good, + Set: true, + } + + t.Run("invalid", func(t *testing.T) { + tests := []struct { + name string + args args + wantErr string + }{ + { + name: "boolean", + args: args{input: true}, + wantErr: "wrong type for Int: bool", + }, + { + name: "int32 out of range (+)", + args: args{ + input: a, + }, + wantErr: "not a 32-bit integer", + }, + { + name: "int32 out of range (-)", + args: args{ + input: b, + }, + wantErr: "not a 32-bit integer", + }, + { + name: "non-integer", + args: args{ + input: c, + }, + wantErr: "not a 32-bit integer", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gt := new(NullInt) + if err := gt.UnmarshalGraphQL(tt.args.input); err != nil { + if err.Error() != tt.wantErr { + t.Errorf("UnmarshalGraphQL() error = %v, want = %s", err, tt.wantErr) + } + + return + } + + t.Error("UnmarshalGraphQL() expected error not raised") + }) + } + }) + + tests := []struct { + name string + args args + wantEq NullInt + }{ + { + name: "int32", + args: args{ + input: good, + }, + wantEq: ref, + }, + { + name: "float64", + args: args{ + input: float64(good), + }, + wantEq: ref, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gt := new(NullInt) + if err := gt.UnmarshalGraphQL(tt.args.input); err != nil { + t.Errorf("UnmarshalGraphQL() error = %v", err) + return + } + + if *gt.Value != *tt.wantEq.Value { + t.Errorf("UnmarshalGraphQL() got = %v, want = %v", *gt.Value, *tt.wantEq.Value) + } + }) + } +} + +func TestNullFloat_ImplementsUnmarshaler(t *testing.T) { + defer func() { + if err := recover(); err != nil { + t.Error(err) + } + }() + + // assert *NullFloat implements decode.Unmarshaler interface + var _ decode.Unmarshaler = (*NullFloat)(nil) +} + +func TestNullFloat_UnmarshalGraphQL(t *testing.T) { + type args struct { + input interface{} + } + + good := float64(1234) + ref := NullFloat{ + Value: &good, + Set: true, + } + + t.Run("invalid", func(t *testing.T) { + tests := []struct { + name string + args args + wantErr string + }{ + { + name: "boolean", + args: args{input: true}, + wantErr: "wrong type for Float: bool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gt := new(NullFloat) + if err := gt.UnmarshalGraphQL(tt.args.input); err != nil { + if err.Error() != tt.wantErr { + t.Errorf("UnmarshalGraphQL() error = %v, want = %s", err, tt.wantErr) + } + + return + } + + t.Error("UnmarshalGraphQL() expected error not raised") + }) + } + }) + + tests := []struct { + name string + args args + wantEq NullFloat + }{ + { + name: "int", + args: args{ + input: int(good), + }, + wantEq: ref, + }, + { + name: "int32", + args: args{ + input: int32(good), + }, + wantEq: ref, + }, + { + name: "float64", + args: args{ + input: good, + }, + wantEq: ref, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gt := new(NullFloat) + if err := gt.UnmarshalGraphQL(tt.args.input); err != nil { + t.Errorf("UnmarshalGraphQL() error = %v", err) + return + } + + if *gt.Value != *tt.wantEq.Value { + t.Errorf("UnmarshalGraphQL() got = %v, want = %v", *gt.Value, *tt.wantEq.Value) + } + }) + } +} diff --git a/relay/relay.go b/relay/relay.go index cf104b242..7e0b71a62 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -40,7 +40,7 @@ func UnmarshalSpec(id graphql.ID, v interface{}) error { if i == -1 { return errors.New("invalid graphql.ID") } - return json.Unmarshal([]byte(s[i+1:]), v) + return json.Unmarshal(s[i+1:], v) } type Handler struct { diff --git a/scripts/golangci_install.sh b/scripts/golangci_install.sh new file mode 100755 index 000000000..bad6bb552 --- /dev/null +++ b/scripts/golangci_install.sh @@ -0,0 +1,407 @@ +#!/bin/sh +set -e +# Code generated by godownloader. DO NOT EDIT. +# + +usage() { + this=$1 + cat </dev/null +} +echoerr() { + echo "$@" 1>&2 +} +log_prefix() { + echo "$0" +} +_logp=6 +log_set_priority() { + _logp="$1" +} +log_priority() { + if test -z "$1"; then + echo "$_logp" + return + fi + [ "$1" -le "$_logp" ] +} +log_tag() { + case $1 in + 0) echo "emerg" ;; + 1) echo "alert" ;; + 2) echo "crit" ;; + 3) echo "err" ;; + 4) echo "warning" ;; + 5) echo "notice" ;; + 6) echo "info" ;; + 7) echo "debug" ;; + *) echo "$1" ;; + esac +} +log_debug() { + log_priority 7 || return 0 + echoerr "$(log_prefix)" "$(log_tag 7)" "$@" +} +log_info() { + log_priority 6 || return 0 + echoerr "$(log_prefix)" "$(log_tag 6)" "$@" +} +log_err() { + log_priority 3 || return 0 + echoerr "$(log_prefix)" "$(log_tag 3)" "$@" +} +log_crit() { + log_priority 2 || return 0 + echoerr "$(log_prefix)" "$(log_tag 2)" "$@" +} +uname_os() { + os=$(uname -s | tr '[:upper:]' '[:lower:]') + case "$os" in + cygwin_nt*) os="windows" ;; + mingw*) os="windows" ;; + msys_nt*) os="windows" ;; + esac + echo "$os" +} +uname_arch() { + arch=$(uname -m) + case $arch in + x86_64) arch="amd64" ;; + x86) arch="386" ;; + i686) arch="386" ;; + i386) arch="386" ;; + aarch64) arch="arm64" ;; + armv5*) arch="armv5" ;; + armv6*) arch="armv6" ;; + armv7*) arch="armv7" ;; + esac + echo ${arch} +} +uname_os_check() { + os=$(uname_os) + case "$os" in + darwin) return 0 ;; + dragonfly) return 0 ;; + freebsd) return 0 ;; + linux) return 0 ;; + android) return 0 ;; + nacl) return 0 ;; + netbsd) return 0 ;; + openbsd) return 0 ;; + plan9) return 0 ;; + solaris) return 0 ;; + windows) return 0 ;; + esac + log_crit "uname_os_check '$(uname -s)' got converted to '$os' which is not a GOOS value. Please file bug at https://github.com/client9/shlib" + return 1 +} +uname_arch_check() { + arch=$(uname_arch) + case "$arch" in + 386) return 0 ;; + amd64) return 0 ;; + arm64) return 0 ;; + armv5) return 0 ;; + armv6) return 0 ;; + armv7) return 0 ;; + ppc64) return 0 ;; + ppc64le) return 0 ;; + mips) return 0 ;; + mipsle) return 0 ;; + mips64) return 0 ;; + mips64le) return 0 ;; + s390x) return 0 ;; + amd64p32) return 0 ;; + esac + log_crit "uname_arch_check '$(uname -m)' got converted to '$arch' which is not a GOARCH value. Please file bug report at https://github.com/client9/shlib" + return 1 +} +untar() { + tarball=$1 + case "${tarball}" in + *.tar.gz | *.tgz) tar --no-same-owner -xzf "${tarball}" ;; + *.tar) tar --no-same-owner -xf "${tarball}" ;; + *.zip) unzip "${tarball}" ;; + *) + log_err "untar unknown archive format for ${tarball}" + return 1 + ;; + esac +} +http_download_curl() { + local_file=$1 + source_url=$2 + header=$3 + if [ -z "$header" ]; then + code=$(curl -w '%{http_code}' -sL -o "$local_file" "$source_url") + else + code=$(curl -w '%{http_code}' -sL -H "$header" -o "$local_file" "$source_url") + fi + if [ "$code" != "200" ]; then + log_debug "http_download_curl received HTTP status $code" + return 1 + fi + return 0 +} +http_download_wget() { + local_file=$1 + source_url=$2 + header=$3 + if [ -z "$header" ]; then + wget -q -O "$local_file" "$source_url" + else + wget -q --header "$header" -O "$local_file" "$source_url" + fi +} +http_download() { + log_debug "http_download $2" + if is_command curl; then + http_download_curl "$@" + return + elif is_command wget; then + http_download_wget "$@" + return + fi + log_crit "http_download unable to find wget or curl" + return 1 +} +http_copy() { + tmp=$(mktemp) + http_download "${tmp}" "$1" "$2" || return 1 + body=$(cat "$tmp") + rm -f "${tmp}" + echo "$body" +} +github_release() { + owner_repo=$1 + version=$2 + test -z "$version" && version="latest" + giturl="https://github.com/${owner_repo}/releases/${version}" + json=$(http_copy "$giturl" "Accept:application/json") + test -z "$json" && return 1 + version=$(echo "$json" | tr -s '\n' ' ' | sed 's/.*"tag_name":"//' | sed 's/".*//') + test -z "$version" && return 1 + echo "$version" +} +hash_sha256() { + TARGET=${1:-/dev/stdin} + if is_command gsha256sum; then + hash=$(gsha256sum "$TARGET") || return 1 + echo "$hash" | cut -d ' ' -f 1 + elif is_command sha256sum; then + hash=$(sha256sum "$TARGET") || return 1 + echo "$hash" | cut -d ' ' -f 1 + elif is_command shasum; then + hash=$(shasum -a 256 "$TARGET" 2>/dev/null) || return 1 + echo "$hash" | cut -d ' ' -f 1 + elif is_command openssl; then + hash=$(openssl -dst openssl dgst -sha256 "$TARGET") || return 1 + echo "$hash" | cut -d ' ' -f a + else + log_crit "hash_sha256 unable to find command to compute sha-256 hash" + return 1 + fi +} +hash_sha256_verify() { + TARGET=$1 + checksums=$2 + if [ -z "$checksums" ]; then + log_err "hash_sha256_verify checksum file not specified in arg2" + return 1 + fi + BASENAME=${TARGET##*/} + want=$(grep "${BASENAME}" "${checksums}" 2>/dev/null | tr '\t' ' ' | cut -d ' ' -f 1) + if [ -z "$want" ]; then + log_err "hash_sha256_verify unable to find checksum for '${TARGET}' in '${checksums}'" + return 1 + fi + got=$(hash_sha256 "$TARGET") + if [ "$want" != "$got" ]; then + log_err "hash_sha256_verify checksum for '$TARGET' did not verify ${want} vs $got" + return 1 + fi +} +cat /dev/null <= 1e10 { + sec := input / 1e9 + nsec := input - (sec * 1e9) + t.Time = time.Unix(sec, nsec) + } else { + t.Time = time.Unix(input, 0) + } + return nil case float64: t.Time = time.Unix(int64(input), 0) return nil default: - return fmt.Errorf("wrong type") + return fmt.Errorf("wrong type for Time: %T", input) } } diff --git a/time_test.go b/time_test.go new file mode 100644 index 000000000..9bda5c92f --- /dev/null +++ b/time_test.go @@ -0,0 +1,165 @@ +package graphql_test + +import ( + "bytes" + "encoding/json" + "testing" + "time" + + . "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/decode" +) + +func TestTime_ImplementsUnmarshaler(t *testing.T) { + defer func() { + if err := recover(); err != nil { + t.Error(err) + } + }() + + // assert *Time implements decode.Unmarshaler interface + var _ decode.Unmarshaler = (*Time)(nil) +} + +func TestTime_ImplementsGraphQLType(t *testing.T) { + gt := &Time{} + + if gt.ImplementsGraphQLType("foobar") { + t.Error("Type *Time must not claim to implement GraphQL type 'foobar'") + } + + if !gt.ImplementsGraphQLType("Time") { + t.Error("Failed asserting *Time implements GraphQL type Time") + } +} + +func TestTime_MarshalJSON(t *testing.T) { + var err error + var b1, b2 []byte + ref := time.Date(2021, time.April, 20, 12, 3, 23, 551476231, time.UTC) + + if b1, err = json.Marshal(ref); err != nil { + t.Error(err) + return + } + + if b2, err = json.Marshal(Time{Time: ref}); err != nil { + t.Errorf("MarshalJSON() error = %v", err) + return + } + + if !bytes.Equal(b1, b2) { + t.Errorf("MarshalJSON() got = %s, want = %s", b2, b1) + } +} + +func TestTime_UnmarshalGraphQL(t *testing.T) { + type args struct { + input interface{} + } + ref := time.Date(2021, time.April, 20, 12, 3, 23, 551476231, time.UTC) + refZeroNano := time.Unix(ref.Unix(), 0) + + t.Run("invalid", func(t *testing.T) { + tests := []struct { + name string + args args + wantErr string + }{ + { + name: "boolean", + args: args{input: true}, + wantErr: "wrong type for Time: bool", + }, + { + name: "invalid format", + args: args{input: ref.Format(time.ANSIC)}, + wantErr: `parsing time "Tue Apr 20 12:03:23 2021" as "2006-01-02T15:04:05Z07:00": cannot parse "Tue Apr 20 12:03:23 2021" as "2006"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gt := new(Time) + if err := gt.UnmarshalGraphQL(tt.args.input); err != nil { + if err.Error() != tt.wantErr { + t.Errorf("UnmarshalGraphQL() error = %v, want = %s", err, tt.wantErr) + } + + return + } + + t.Error("UnmarshalGraphQL() expected error not raised") + }) + } + }) + + tests := []struct { + name string + args args + wantEq time.Time + }{ + { + name: "time.Time", + args: args{ + input: ref, + }, + wantEq: ref, + }, + { + name: "string", + args: args{ + input: ref.Format(time.RFC3339), + }, + wantEq: refZeroNano, + }, + { + name: "bytes", + args: args{ + input: []byte(ref.Format(time.RFC3339)), + }, + wantEq: refZeroNano, + }, + { + name: "int32", + args: args{ + input: int32(ref.Unix()), + }, + wantEq: refZeroNano, + }, + { + name: "int64", + args: args{ + input: ref.Unix(), + }, + wantEq: refZeroNano, + }, + { + name: "int64-nano", + args: args{ + input: ref.UnixNano(), + }, + wantEq: ref, + }, + { + name: "float64", + args: args{ + input: float64(ref.Unix()), + }, + wantEq: refZeroNano, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gt := &Time{} + if err := gt.UnmarshalGraphQL(tt.args.input); err != nil { + t.Errorf("UnmarshalGraphQL() error = %v", err) + return + } + if !gt.Equal(tt.wantEq) { + t.Errorf("UnmarshalGraphQL() got = %v, want = %v", gt, tt.wantEq) + } + }) + } +} diff --git a/trace/noop/trace.go b/trace/noop/trace.go new file mode 100644 index 000000000..aa122450d --- /dev/null +++ b/trace/noop/trace.go @@ -0,0 +1,24 @@ +// Package noop defines a no-op tracer implementation. +package noop + +import ( + "context" + + "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/introspection" +) + +// Tracer is a no-op tracer that does nothing. +type Tracer struct{} + +func (Tracer) TraceQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, func([]*errors.QueryError)) { + return ctx, func(errs []*errors.QueryError) {} +} + +func (Tracer) TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, func(*errors.QueryError)) { + return ctx, func(err *errors.QueryError) {} +} + +func (Tracer) TraceValidation(context.Context) func([]*errors.QueryError) { + return func(errs []*errors.QueryError) {} +} diff --git a/trace/noop/trace_test.go b/trace/noop/trace_test.go new file mode 100644 index 000000000..6d2e3ef17 --- /dev/null +++ b/trace/noop/trace_test.go @@ -0,0 +1,22 @@ +package noop_test + +import ( + "testing" + + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/example/starwars" + "github.com/tokopedia/graphql-go/trace/noop" + "github.com/tokopedia/graphql-go/trace/tracer" +) + +func TestInterfaceImplementation(t *testing.T) { + var _ tracer.ValidationTracer = &noop.Tracer{} + var _ tracer.Tracer = &noop.Tracer{} +} + +func TestTracerOption(t *testing.T) { + _, err := graphql.ParseSchema(starwars.Schema, nil, graphql.Tracer(noop.Tracer{})) + if err != nil { + t.Fatal(err) + } +} diff --git a/trace/opentracing/trace.go b/trace/opentracing/trace.go new file mode 100644 index 000000000..418d3520f --- /dev/null +++ b/trace/opentracing/trace.go @@ -0,0 +1,79 @@ +package opentracing + +import ( + "context" + "fmt" + + opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" + "github.com/opentracing/opentracing-go/log" + "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/introspection" +) + +// Tracer implements the graphql-go Tracer inteface and creates OpenTracing spans. +type Tracer struct{} + +func (Tracer) TraceQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, func([]*errors.QueryError)) { + span, spanCtx := opentracing.StartSpanFromContext(ctx, "GraphQL request") + span.SetTag("graphql.query", queryString) + + if operationName != "" { + span.SetTag("graphql.operationName", operationName) + } + + if len(variables) != 0 { + span.LogFields(log.Object("graphql.variables", variables)) + } + + return spanCtx, func(errs []*errors.QueryError) { + if len(errs) > 0 { + msg := errs[0].Error() + if len(errs) > 1 { + msg += fmt.Sprintf(" (and %d more errors)", len(errs)-1) + } + ext.Error.Set(span, true) + span.SetTag("graphql.error", msg) + } + span.Finish() + } +} + +func (Tracer) TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, func(*errors.QueryError)) { + if trivial { + return ctx, noop + } + + span, spanCtx := opentracing.StartSpanFromContext(ctx, label) + span.SetTag("graphql.type", typeName) + span.SetTag("graphql.field", fieldName) + for name, value := range args { + span.SetTag("graphql.args."+name, value) + } + + return spanCtx, func(err *errors.QueryError) { + if err != nil { + ext.Error.Set(span, true) + span.SetTag("graphql.error", err.Error()) + } + span.Finish() + } +} + +func (Tracer) TraceValidation(ctx context.Context) func([]*errors.QueryError) { + span, _ := opentracing.StartSpanFromContext(ctx, "Validate Query") + + return func(errs []*errors.QueryError) { + if len(errs) > 0 { + msg := errs[0].Error() + if len(errs) > 1 { + msg += fmt.Sprintf(" (and %d more errors)", len(errs)-1) + } + ext.Error.Set(span, true) + span.SetTag("graphql.error", msg) + } + span.Finish() + } +} + +func noop(*errors.QueryError) {} diff --git a/trace/opentracing/trace_test.go b/trace/opentracing/trace_test.go new file mode 100644 index 000000000..d5bb8ec14 --- /dev/null +++ b/trace/opentracing/trace_test.go @@ -0,0 +1,22 @@ +package opentracing_test + +import ( + "testing" + + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/example/starwars" + "github.com/tokopedia/graphql-go/trace/opentracing" + "github.com/tokopedia/graphql-go/trace/tracer" +) + +func TestInterfaceImplementation(t *testing.T) { + var _ tracer.ValidationTracer = &opentracing.Tracer{} + var _ tracer.Tracer = &opentracing.Tracer{} +} + +func TestTracerOption(t *testing.T) { + _, err := graphql.ParseSchema(starwars.Schema, nil, graphql.Tracer(opentracing.Tracer{})) + if err != nil { + t.Fatal(err) + } +} diff --git a/trace/otel/trace.go b/trace/otel/trace.go new file mode 100644 index 000000000..281d6c2e0 --- /dev/null +++ b/trace/otel/trace.go @@ -0,0 +1,91 @@ +package otel + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + oteltrace "go.opentelemetry.io/otel/trace" + + "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/introspection" +) + +// DefaultTracer creates a tracer using a default name. +func DefaultTracer() *Tracer { + return &Tracer{ + Tracer: otel.Tracer("graphql-go"), + } +} + +// Tracer is an OpenTelemetry implementation for graphql-go. Set the Tracer +// property to your tracer instance as required. +type Tracer struct { + Tracer oteltrace.Tracer +} + +func (t *Tracer) TraceQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, func([]*errors.QueryError)) { + spanCtx, span := t.Tracer.Start(ctx, "GraphQL Request") + + var attributes []attribute.KeyValue + attributes = append(attributes, attribute.String("graphql.query", queryString)) + if operationName != "" { + attributes = append(attributes, attribute.String("graphql.operationName", operationName)) + } + if len(variables) != 0 { + attributes = append(attributes, attribute.String("graphql.variables", fmt.Sprintf("%v", variables))) + } + span.SetAttributes(attributes...) + + return spanCtx, func(errs []*errors.QueryError) { + if len(errs) > 0 { + msg := errs[0].Error() + if len(errs) > 1 { + msg += fmt.Sprintf(" (and %d more errors)", len(errs)-1) + } + + span.SetStatus(codes.Error, msg) + } + span.End() + } +} + +func (t *Tracer) TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, func(*errors.QueryError)) { + if trivial { + return ctx, func(*errors.QueryError) {} + } + + var attributes []attribute.KeyValue + + spanCtx, span := t.Tracer.Start(ctx, fmt.Sprintf("Field: %v", label)) + attributes = append(attributes, attribute.String("graphql.type", typeName)) + attributes = append(attributes, attribute.String("graphql.field", fieldName)) + for name, value := range args { + attributes = append(attributes, attribute.String("graphql.args."+name, fmt.Sprintf("%v", value))) + } + span.SetAttributes(attributes...) + + return spanCtx, func(err *errors.QueryError) { + if err != nil { + span.SetStatus(codes.Error, err.Error()) + } + span.End() + } +} + +func (t *Tracer) TraceValidation(ctx context.Context) func([]*errors.QueryError) { + _, span := t.Tracer.Start(ctx, "GraphQL Validate") + + return func(errs []*errors.QueryError) { + if len(errs) > 0 { + msg := errs[0].Error() + if len(errs) > 1 { + msg += fmt.Sprintf(" (and %d more errors)", len(errs)-1) + } + span.SetStatus(codes.Error, msg) + } + span.End() + } +} diff --git a/trace/otel/trace_test.go b/trace/otel/trace_test.go new file mode 100644 index 000000000..5c2751f68 --- /dev/null +++ b/trace/otel/trace_test.go @@ -0,0 +1,29 @@ +package otel_test + +import ( + "testing" + + "go.opentelemetry.io/otel" + + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/example/starwars" + otelgraphql "github.com/tokopedia/graphql-go/trace/otel" + "github.com/tokopedia/graphql-go/trace/tracer" +) + +func TestInterfaceImplementation(t *testing.T) { + var _ tracer.ValidationTracer = &otelgraphql.Tracer{} + var _ tracer.Tracer = &otelgraphql.Tracer{} +} + +func TestTracerOption(t *testing.T) { + _, err := graphql.ParseSchema(starwars.Schema, nil, graphql.Tracer(otelgraphql.DefaultTracer())) + if err != nil { + t.Fatal(err) + } + + _, err = graphql.ParseSchema(starwars.Schema, nil, graphql.Tracer(&otelgraphql.Tracer{Tracer: otel.Tracer("example")})) + if err != nil { + t.Fatal(err) + } +} diff --git a/trace/trace.go b/trace/trace.go index 8c71196d7..ed84b3b85 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -1,80 +1,25 @@ +// The trace package provides tracing functionality. +// Deprecated: this package has been deprecated. Use package trace/tracer instead. package trace import ( - "context" - "fmt" - "github.com/tokopedia/graphql-go/errors" - "github.com/tokopedia/graphql-go/introspection" - opentracing "github.com/opentracing/opentracing-go" - "github.com/opentracing/opentracing-go/ext" - "github.com/opentracing/opentracing-go/log" + "github.com/tokopedia/graphql-go/trace/noop" + "github.com/tokopedia/graphql-go/trace/opentracing" + "github.com/tokopedia/graphql-go/trace/tracer" ) -type TraceQueryFinishFunc func([]*errors.QueryError) -type TraceFieldFinishFunc func(*errors.QueryError) - -type Tracer interface { - TraceQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, TraceQueryFinishFunc) - TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, TraceFieldFinishFunc) -} - -type OpenTracingTracer struct{} - -func (OpenTracingTracer) TraceQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, TraceQueryFinishFunc) { - span, spanCtx := opentracing.StartSpanFromContext(ctx, "GraphQL request") - span.SetTag("graphql.query", queryString) - - if operationName != "" { - span.SetTag("graphql.operationName", operationName) - } - - if len(variables) != 0 { - span.LogFields(log.Object("graphql.variables", variables)) - } - - return spanCtx, func(errs []*errors.QueryError) { - if len(errs) > 0 { - msg := errs[0].Error() - if len(errs) > 1 { - msg += fmt.Sprintf(" (and %d more errors)", len(errs)-1) - } - ext.Error.Set(span, true) - span.SetTag("graphql.error", msg) - } - span.Finish() - } -} - -func (OpenTracingTracer) TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, TraceFieldFinishFunc) { - if trivial { - return ctx, noop - } - - span, spanCtx := opentracing.StartSpanFromContext(ctx, label) - span.SetTag("graphql.type", typeName) - span.SetTag("graphql.field", fieldName) - for name, value := range args { - span.SetTag("graphql.args."+name, value) - } - - return spanCtx, func(err *errors.QueryError) { - if err != nil { - ext.Error.Set(span, true) - span.SetTag("graphql.error", err.Error()) - } - span.Finish() - } -} +// Deprecated: this type has been deprecated. Use tracer.QueryFinishFunc instead. +type TraceQueryFinishFunc = func([]*errors.QueryError) -func noop(*errors.QueryError) {} +// Deprecated: this type has been deprecated. Use tarcer.FieldFinishFunc instead. +type TraceFieldFinishFunc = func(*errors.QueryError) -type NoopTracer struct{} +// Deprecated: this interface has been deprecated. Use tracer.Tracer instead. +type Tracer = tracer.Tracer -func (NoopTracer) TraceQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, TraceQueryFinishFunc) { - return ctx, func(errs []*errors.QueryError) {} -} +// Deprecated: this type has been deprecated. Use opentracing.Tracer instead. +type OpenTracingTracer = opentracing.Tracer -func (NoopTracer) TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, TraceFieldFinishFunc) { - return ctx, func(err *errors.QueryError) {} -} +// Deprecated: this type has been deprecated. Use noop.Tracer instead. +type NoopTracer = noop.Tracer diff --git a/trace/trace_test.go b/trace/trace_test.go new file mode 100644 index 000000000..73f04ee95 --- /dev/null +++ b/trace/trace_test.go @@ -0,0 +1,42 @@ +package trace_test + +import ( + "testing" + + "github.com/tokopedia/graphql-go" + "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/example/starwars" + "github.com/tokopedia/graphql-go/trace" + "github.com/tokopedia/graphql-go/trace/tracer" +) + +func TestInterfaceImplementation(t *testing.T) { + var _ tracer.ValidationTracer = &trace.OpenTracingTracer{} + var _ tracer.Tracer = &trace.OpenTracingTracer{} + + var _ tracer.ValidationTracer = &trace.NoopTracer{} + var _ tracer.Tracer = &trace.NoopTracer{} +} + +func TestTracerOption(t *testing.T) { + _, err := graphql.ParseSchema(starwars.Schema, nil, graphql.Tracer(trace.OpenTracingTracer{})) + if err != nil { + t.Fatal(err) + } +} + +// MockVlidationTracer is a struct that implements the tracer.LegacyValidationTracer inteface. +type MockValidationTracer struct{} + +func (MockValidationTracer) TraceValidation() func([]*errors.QueryError) { + return func([]*errors.QueryError) {} +} + +func TestValidationTracer(t *testing.T) { + // test the legacy validation tracer interface (validating without using context) to ensure backwards compatibility + vt := MockValidationTracer{} + _, err := graphql.ParseSchema(starwars.Schema, nil, graphql.ValidationTracer(vt)) + if err != nil { + t.Fatal(err) + } +} diff --git a/trace/tracer/tracer.go b/trace/tracer/tracer.go new file mode 100644 index 000000000..e3c347454 --- /dev/null +++ b/trace/tracer/tracer.go @@ -0,0 +1,34 @@ +package tracer + +import ( + "context" + + "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/introspection" +) + +type QueryFinishFunc = func([]*errors.QueryError) +type FieldFinishFunc = func(*errors.QueryError) +type ValidationFinishFunc = func([]*errors.QueryError) + +type Tracer interface { + TraceQuery(ctx context.Context, queryString string, operationName string, variables map[string]interface{}, varTypes map[string]*introspection.Type) (context.Context, QueryFinishFunc) + TraceField(ctx context.Context, label, typeName, fieldName string, trivial bool, args map[string]interface{}) (context.Context, FieldFinishFunc) +} + +type ValidationTracer interface { + TraceValidation(ctx context.Context) ValidationFinishFunc +} + +// Deprecated: use ValidationTracerContext instead. +type LegacyValidationTracer interface { + TraceValidation() func([]*errors.QueryError) +} + +// Deprecated: use a Tracer which implements ValidationTracerContext. +type LegacyNoopValidationTracer struct{} + +// Deprecated: use a Tracer which implements ValidationTracerContext. +func (LegacyNoopValidationTracer) TraceValidation() func([]*errors.QueryError) { + return func(errs []*errors.QueryError) {} +} diff --git a/trace/validation_trace.go b/trace/validation_trace.go index ce3cc156b..8315af0d0 100644 --- a/trace/validation_trace.go +++ b/trace/validation_trace.go @@ -2,16 +2,17 @@ package trace import ( "github.com/tokopedia/graphql-go/errors" + "github.com/tokopedia/graphql-go/trace/tracer" ) -type TraceValidationFinishFunc = TraceQueryFinishFunc +// Deprecated: this type has been deprecated. Use tracer.ValidationFinishFunc instead. +type TraceValidationFinishFunc = func([]*errors.QueryError) -type ValidationTracer interface { - TraceValidation() TraceValidationFinishFunc -} +// Deprecated: use ValidationTracerContext. +type ValidationTracer = tracer.LegacyValidationTracer //nolint:staticcheck -type NoopValidationTracer struct{} +// Deprecated: this type has been deprecated. Use tracer.ValidationTracer instead. +type ValidationTracerContext = tracer.ValidationTracer -func (NoopValidationTracer) TraceValidation() TraceValidationFinishFunc { - return func(errs []*errors.QueryError) {} -} +// Deprecated: use a tracer that implements ValidationTracerContext. +type NoopValidationTracer = tracer.LegacyNoopValidationTracer //nolint:staticcheck diff --git a/types/argument.go b/types/argument.go new file mode 100644 index 000000000..b2681a284 --- /dev/null +++ b/types/argument.go @@ -0,0 +1,44 @@ +package types + +// Argument is a representation of the GraphQL Argument. +// +// https://spec.graphql.org/draft/#sec-Language.Arguments +type Argument struct { + Name Ident + Value Value +} + +// ArgumentList is a collection of GraphQL Arguments. +type ArgumentList []*Argument + +// Returns a Value in the ArgumentList by name. +func (l ArgumentList) Get(name string) (Value, bool) { + for _, arg := range l { + if arg.Name.Name == name { + return arg.Value, true + } + } + return nil, false +} + +// MustGet returns a Value in the ArgumentList by name. +// MustGet will panic if the argument name is not found in the ArgumentList. +func (l ArgumentList) MustGet(name string) Value { + value, ok := l.Get(name) + if !ok { + panic("argument not found") + } + return value +} + +type ArgumentsDefinition []*InputValueDefinition + +// Get returns an InputValueDefinition in the ArgumentsDefinition by name or nil if not found. +func (a ArgumentsDefinition) Get(name string) *InputValueDefinition { + for _, inputValue := range a { + if inputValue.Name.Name == name { + return inputValue + } + } + return nil +} diff --git a/types/directive.go b/types/directive.go new file mode 100644 index 000000000..33e9d3ebd --- /dev/null +++ b/types/directive.go @@ -0,0 +1,35 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// Directive is a representation of the GraphQL Directive. +// +// http://spec.graphql.org/draft/#sec-Language.Directives +type Directive struct { + Name Ident + Arguments ArgumentList +} + +// DirectiveDefinition is a representation of the GraphQL DirectiveDefinition. +// +// http://spec.graphql.org/draft/#sec-Type-System.Directives +type DirectiveDefinition struct { + Name string + Desc string + Repeatable bool + Locations []string + Arguments ArgumentsDefinition + Loc errors.Location +} + +type DirectiveList []*Directive + +// Returns the Directive in the DirectiveList by name or nil if not found. +func (l DirectiveList) Get(name string) *Directive { + for _, d := range l { + if d.Name.Name == name { + return d + } + } + return nil +} diff --git a/types/doc.go b/types/doc.go new file mode 100644 index 000000000..87caa60b8 --- /dev/null +++ b/types/doc.go @@ -0,0 +1,9 @@ +/* + Package types represents all types from the GraphQL specification in code. + + + The names of the Go types, whenever possible, match 1:1 with the names from + the specification. + +*/ +package types diff --git a/types/enum.go b/types/enum.go new file mode 100644 index 000000000..a7be40261 --- /dev/null +++ b/types/enum.go @@ -0,0 +1,32 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// EnumTypeDefinition defines a set of possible enum values. +// +// Like scalar types, an EnumTypeDefinition also represents a leaf value in a GraphQL type system. +// +// http://spec.graphql.org/draft/#sec-Enums +type EnumTypeDefinition struct { + Name string + EnumValuesDefinition []*EnumValueDefinition + Desc string + Directives DirectiveList + Loc errors.Location +} + +// EnumValueDefinition are unique values that may be serialized as a string: the name of the +// represented value. +// +// http://spec.graphql.org/draft/#EnumValueDefinition +type EnumValueDefinition struct { + EnumValue string + Directives DirectiveList + Desc string + Loc errors.Location +} + +func (*EnumTypeDefinition) Kind() string { return "ENUM" } +func (t *EnumTypeDefinition) String() string { return t.Name } +func (t *EnumTypeDefinition) TypeName() string { return t.Name } +func (t *EnumTypeDefinition) Description() string { return t.Desc } diff --git a/types/extension.go b/types/extension.go new file mode 100644 index 000000000..e78540097 --- /dev/null +++ b/types/extension.go @@ -0,0 +1,13 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// Extension type defines a GraphQL type extension. +// Schemas, Objects, Inputs and Scalars can be extended. +// +// https://spec.graphql.org/draft/#sec-Type-System-Extensions +type Extension struct { + Type NamedType + Directives DirectiveList + Loc errors.Location +} diff --git a/types/field.go b/types/field.go new file mode 100644 index 000000000..1a521ca5a --- /dev/null +++ b/types/field.go @@ -0,0 +1,39 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// FieldDefinition is a representation of a GraphQL FieldDefinition. +// +// http://spec.graphql.org/draft/#FieldDefinition +type FieldDefinition struct { + Name string + Arguments ArgumentsDefinition + Type Type + Directives DirectiveList + Desc string + Loc errors.Location +} + +// FieldsDefinition is a list of an ObjectTypeDefinition's Fields. +// +// https://spec.graphql.org/draft/#FieldsDefinition +type FieldsDefinition []*FieldDefinition + +// Get returns a FieldDefinition in a FieldsDefinition by name or nil if not found. +func (l FieldsDefinition) Get(name string) *FieldDefinition { + for _, f := range l { + if f.Name == name { + return f + } + } + return nil +} + +// Names returns a slice of FieldDefinition names. +func (l FieldsDefinition) Names() []string { + names := make([]string, len(l)) + for i, f := range l { + names[i] = f.Name + } + return names +} diff --git a/types/fragment.go b/types/fragment.go new file mode 100644 index 000000000..3dc13eca9 --- /dev/null +++ b/types/fragment.go @@ -0,0 +1,51 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +type Fragment struct { + On TypeName + Selections SelectionSet +} + +// InlineFragment is a representation of the GraphQL InlineFragment. +// +// http://spec.graphql.org/draft/#InlineFragment +type InlineFragment struct { + Fragment + Directives DirectiveList + Loc errors.Location +} + +// FragmentDefinition is a representation of the GraphQL FragmentDefinition. +// +// http://spec.graphql.org/draft/#FragmentDefinition +type FragmentDefinition struct { + Fragment + Name Ident + Directives DirectiveList + Loc errors.Location +} + +// FragmentSpread is a representation of the GraphQL FragmentSpread. +// +// http://spec.graphql.org/draft/#FragmentSpread +type FragmentSpread struct { + Name Ident + Directives DirectiveList + Loc errors.Location +} + +type FragmentList []*FragmentDefinition + +// Returns a FragmentDefinition by name or nil if not found. +func (l FragmentList) Get(name string) *FragmentDefinition { + for _, f := range l { + if f.Name.Name == name { + return f + } + } + return nil +} + +func (InlineFragment) isSelection() {} +func (FragmentSpread) isSelection() {} diff --git a/types/input.go b/types/input.go new file mode 100644 index 000000000..91510caad --- /dev/null +++ b/types/input.go @@ -0,0 +1,47 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// InputValueDefinition is a representation of the GraphQL InputValueDefinition. +// +// http://spec.graphql.org/draft/#InputValueDefinition +type InputValueDefinition struct { + Name Ident + Type Type + Default Value + Desc string + Directives DirectiveList + Loc errors.Location + TypeLoc errors.Location +} + +type InputValueDefinitionList []*InputValueDefinition + +// Returns an InputValueDefinition by name or nil if not found. +func (l InputValueDefinitionList) Get(name string) *InputValueDefinition { + for _, v := range l { + if v.Name.Name == name { + return v + } + } + return nil +} + +// InputObject types define a set of input fields; the input fields are either scalars, enums, or +// other input objects. +// +// This allows arguments to accept arbitrarily complex structs. +// +// http://spec.graphql.org/draft/#sec-Input-Objects +type InputObject struct { + Name string + Desc string + Values ArgumentsDefinition + Directives DirectiveList + Loc errors.Location +} + +func (*InputObject) Kind() string { return "INPUT_OBJECT" } +func (t *InputObject) String() string { return t.Name } +func (t *InputObject) TypeName() string { return t.Name } +func (t *InputObject) Description() string { return t.Desc } diff --git a/types/interface.go b/types/interface.go new file mode 100644 index 000000000..9fc3a7353 --- /dev/null +++ b/types/interface.go @@ -0,0 +1,25 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// InterfaceTypeDefinition recusrively defines list of named fields with their arguments via the +// implementation chain of interfaces. +// +// GraphQL objects can then implement these interfaces which requires that the object type will +// define all fields defined by those interfaces. +// +// http://spec.graphql.org/draft/#sec-Interfaces +type InterfaceTypeDefinition struct { + Name string + PossibleTypes []*ObjectTypeDefinition + Fields FieldsDefinition + Desc string + Directives DirectiveList + Loc errors.Location + Interfaces []*InterfaceTypeDefinition +} + +func (*InterfaceTypeDefinition) Kind() string { return "INTERFACE" } +func (t *InterfaceTypeDefinition) String() string { return t.Name } +func (t *InterfaceTypeDefinition) TypeName() string { return t.Name } +func (t *InterfaceTypeDefinition) Description() string { return t.Desc } diff --git a/types/object.go b/types/object.go new file mode 100644 index 000000000..5654c4bb2 --- /dev/null +++ b/types/object.go @@ -0,0 +1,25 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// ObjectTypeDefinition represents a GraphQL ObjectTypeDefinition. +// +// type FooObject { +// foo: String +// } +// +// https://spec.graphql.org/draft/#sec-Objects +type ObjectTypeDefinition struct { + Name string + Interfaces []*InterfaceTypeDefinition + Fields FieldsDefinition + Desc string + Directives DirectiveList + InterfaceNames []string + Loc errors.Location +} + +func (*ObjectTypeDefinition) Kind() string { return "OBJECT" } +func (t *ObjectTypeDefinition) String() string { return t.Name } +func (t *ObjectTypeDefinition) TypeName() string { return t.Name } +func (t *ObjectTypeDefinition) Description() string { return t.Desc } diff --git a/types/query.go b/types/query.go new file mode 100644 index 000000000..627315ccf --- /dev/null +++ b/types/query.go @@ -0,0 +1,62 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// ExecutableDefinition represents a set of operations or fragments that can be executed +// against a schema. +// +// http://spec.graphql.org/draft/#ExecutableDefinition +type ExecutableDefinition struct { + Operations OperationList + Fragments FragmentList +} + +// OperationDefinition represents a GraphQL Operation. +// +// https://spec.graphql.org/draft/#sec-Language.Operations +type OperationDefinition struct { + Type OperationType + Name Ident + Vars ArgumentsDefinition + Selections SelectionSet + Directives DirectiveList + Loc errors.Location +} + +type OperationType string + +// A Selection is a field requested in a GraphQL operation. +// +// http://spec.graphql.org/draft/#Selection +type Selection interface { + isSelection() +} + +// A SelectionSet represents a collection of Selections +// +// http://spec.graphql.org/draft/#sec-Selection-Sets +type SelectionSet []Selection + +// Field represents a field used in a query. +type Field struct { + Alias Ident + Name Ident + Arguments ArgumentList + Directives DirectiveList + SelectionSet SelectionSet + SelectionSetLoc errors.Location +} + +func (Field) isSelection() {} + +type OperationList []*OperationDefinition + +// Get returns an OperationDefinition by name or nil if not found. +func (l OperationList) Get(name string) *OperationDefinition { + for _, f := range l { + if f.Name.Name == name { + return f + } + } + return nil +} diff --git a/types/scalar.go b/types/scalar.go new file mode 100644 index 000000000..5db08c605 --- /dev/null +++ b/types/scalar.go @@ -0,0 +1,22 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// ScalarTypeDefinition types represent primitive leaf values (e.g. a string or an integer) in a GraphQL type +// system. +// +// GraphQL responses take the form of a hierarchical tree; the leaves on these trees are GraphQL +// scalars. +// +// http://spec.graphql.org/draft/#sec-Scalars +type ScalarTypeDefinition struct { + Name string + Desc string + Directives DirectiveList + Loc errors.Location +} + +func (*ScalarTypeDefinition) Kind() string { return "SCALAR" } +func (t *ScalarTypeDefinition) String() string { return t.Name } +func (t *ScalarTypeDefinition) TypeName() string { return t.Name } +func (t *ScalarTypeDefinition) Description() string { return t.Desc } diff --git a/types/schema.go b/types/schema.go new file mode 100644 index 000000000..349c112b0 --- /dev/null +++ b/types/schema.go @@ -0,0 +1,43 @@ +package types + +// Schema represents a GraphQL service's collective type system capabilities. +// A schema is defined in terms of the types and directives it supports as well as the root +// operation types for each kind of operation: `query`, `mutation`, and `subscription`. +// +// For a more formal definition, read the relevant section in the specification: +// +// http://spec.graphql.org/draft/#sec-Schema +type Schema struct { + // EntryPoints determines the place in the type system where `query`, `mutation`, and + // `subscription` operations begin. + // + // http://spec.graphql.org/draft/#sec-Root-Operation-Types + // + EntryPoints map[string]NamedType + + // Types are the fundamental unit of any GraphQL schema. + // There are six kinds of named types, and two wrapping types. + // + // http://spec.graphql.org/draft/#sec-Types + Types map[string]NamedType + + // Directives are used to annotate various parts of a GraphQL document as an indicator that they + // should be evaluated differently by a validator, executor, or client tool such as a code + // generator. + // + // http://spec.graphql.org/#sec-Type-System.Directives + Directives map[string]*DirectiveDefinition + + UseFieldResolvers bool + + EntryPointNames map[string]string + Objects []*ObjectTypeDefinition + Unions []*Union + Enums []*EnumTypeDefinition + Extensions []*Extension + SchemaString string +} + +func (s *Schema) Resolve(name string) Type { + return s.Types[name] +} diff --git a/types/types.go b/types/types.go new file mode 100644 index 000000000..34ecdb3d9 --- /dev/null +++ b/types/types.go @@ -0,0 +1,63 @@ +package types + +import ( + "github.com/tokopedia/graphql-go/errors" +) + +// TypeName is a base building block for GraphQL type references. +type TypeName struct { + Ident +} + +// NamedType represents a type with a name. +// +// http://spec.graphql.org/draft/#NamedType +type NamedType interface { + Type + TypeName() string + Description() string +} + +type Ident struct { + Name string + Loc errors.Location +} + +type Type interface { + // Kind returns one possible GraphQL type kind. A type kind must be + // valid as defined by the GraphQL spec. + // + // https://spec.graphql.org/draft/#sec-Type-Kinds + Kind() string + + // String serializes a Type into a GraphQL specification format type. + // + // http://spec.graphql.org/draft/#sec-Serialization-Format + String() string +} + +// List represents a GraphQL ListType. +// +// http://spec.graphql.org/draft/#ListType +type List struct { + // OfType represents the inner-type of a List type. + // For example, the List type `[Foo]` has an OfType of Foo. + OfType Type +} + +// NonNull represents a GraphQL NonNullType. +// +// https://spec.graphql.org/draft/#NonNullType +type NonNull struct { + // OfType represents the inner-type of a NonNull type. + // For example, the NonNull type `Foo!` has an OfType of Foo. + OfType Type +} + +func (*List) Kind() string { return "LIST" } +func (*NonNull) Kind() string { return "NON_NULL" } +func (*TypeName) Kind() string { panic("TypeName needs to be resolved to actual type") } + +func (t *List) String() string { return "[" + t.OfType.String() + "]" } +func (t *NonNull) String() string { return t.OfType.String() + "!" } +func (*TypeName) String() string { panic("TypeName needs to be resolved to actual type") } diff --git a/types/union.go b/types/union.go new file mode 100644 index 000000000..8dd0a7103 --- /dev/null +++ b/types/union.go @@ -0,0 +1,24 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// Union types represent objects that could be one of a list of GraphQL object types, but provides no +// guaranteed fields between those types. +// +// They also differ from interfaces in that object types declare what interfaces they implement, but +// are not aware of what unions contain them. +// +// http://spec.graphql.org/draft/#sec-Unions +type Union struct { + Name string + UnionMemberTypes []*ObjectTypeDefinition + Desc string + Directives DirectiveList + TypeNames []string + Loc errors.Location +} + +func (*Union) Kind() string { return "UNION" } +func (t *Union) String() string { return t.Name } +func (t *Union) TypeName() string { return t.Name } +func (t *Union) Description() string { return t.Desc } diff --git a/types/value.go b/types/value.go new file mode 100644 index 000000000..f5d39d856 --- /dev/null +++ b/types/value.go @@ -0,0 +1,151 @@ +package types + +import ( + "strconv" + "strings" + "text/scanner" + + "github.com/tokopedia/graphql-go/errors" +) + +// Value represents a literal input or literal default value in the GraphQL Specification. +// +// http://spec.graphql.org/draft/#sec-Input-Values +type Value interface { + // Deserialize transforms a GraphQL specification format literal into a Go type. + Deserialize(vars map[string]interface{}) interface{} + + // String serializes a Value into a GraphQL specification format literal. + String() string + Location() errors.Location +} + +// PrimitiveValue represents one of the following GraphQL scalars: Int, Float, +// String, or Boolean +type PrimitiveValue struct { + Type rune + Text string + Loc errors.Location +} + +func (val *PrimitiveValue) Deserialize(vars map[string]interface{}) interface{} { + switch val.Type { + case scanner.Int: + value, err := strconv.ParseInt(val.Text, 10, 32) + if err != nil { + // check if it is out of range error. + // which probably mean that the input use int64 data type + // as needed by scalar.Int64 data type of tokopedia/gqlserver + if strings.Contains(err.Error(), strconv.ErrRange.Error()) { + val64, err := strconv.ParseInt(val.Text, 10, 64) + if err != nil { + panic(err) + } + return int64(val64) + } + panic(err) + } + return int32(value) + + case scanner.Float: + value, err := strconv.ParseFloat(val.Text, 64) + if err != nil { + panic(err) + } + return value + + case scanner.String: + value, err := strconv.Unquote(val.Text) + if err != nil { + panic(err) + } + return value + + case scanner.Ident: + switch val.Text { + case "true": + return true + case "false": + return false + default: + return val.Text + } + + default: + panic("invalid literal value") + } +} + +func (val *PrimitiveValue) String() string { return val.Text } +func (val *PrimitiveValue) Location() errors.Location { return val.Loc } + +// ListValue represents a literal list Value in the GraphQL specification. +// +// http://spec.graphql.org/draft/#sec-List-Value +type ListValue struct { + Values []Value + Loc errors.Location +} + +func (val *ListValue) Deserialize(vars map[string]interface{}) interface{} { + entries := make([]interface{}, len(val.Values)) + for i, entry := range val.Values { + entries[i] = entry.Deserialize(vars) + } + return entries +} + +func (val *ListValue) String() string { + entries := make([]string, len(val.Values)) + for i, entry := range val.Values { + entries[i] = entry.String() + } + return "[" + strings.Join(entries, ", ") + "]" +} + +func (val *ListValue) Location() errors.Location { return val.Loc } + +// ObjectValue represents a literal object Value in the GraphQL specification. +// +// http://spec.graphql.org/draft/#sec-Object-Value +type ObjectValue struct { + Fields []*ObjectField + Loc errors.Location +} + +// ObjectField represents field/value pairs in a literal ObjectValue. +type ObjectField struct { + Name Ident + Value Value +} + +func (val *ObjectValue) Deserialize(vars map[string]interface{}) interface{} { + fields := make(map[string]interface{}, len(val.Fields)) + for _, f := range val.Fields { + fields[f.Name.Name] = f.Value.Deserialize(vars) + } + return fields +} + +func (val *ObjectValue) String() string { + entries := make([]string, 0, len(val.Fields)) + for _, f := range val.Fields { + entries = append(entries, f.Name.Name+": "+f.Value.String()) + } + return "{" + strings.Join(entries, ", ") + "}" +} + +func (val *ObjectValue) Location() errors.Location { + return val.Loc +} + +// NullValue represents a literal `null` Value in the GraphQL specification. +// +// http://spec.graphql.org/draft/#sec-Null-Value +type NullValue struct { + Loc errors.Location +} + +func (val *NullValue) Deserialize(vars map[string]interface{}) interface{} { return nil } +func (val *NullValue) String() string { return "null" } +func (val *NullValue) Location() errors.Location { return val.Loc } diff --git a/types/variable.go b/types/variable.go new file mode 100644 index 000000000..d5a8959bc --- /dev/null +++ b/types/variable.go @@ -0,0 +1,15 @@ +package types + +import "github.com/tokopedia/graphql-go/errors" + +// Variable is used in GraphQL operations to parameterize an input value. +// +// http://spec.graphql.org/draft/#Variable +type Variable struct { + Name string + Loc errors.Location +} + +func (v Variable) Deserialize(vars map[string]interface{}) interface{} { return vars[v.Name] } +func (v Variable) String() string { return "$" + v.Name } +func (v *Variable) Location() errors.Location { return v.Loc }