Skip to content

Commit

Permalink
Move websocket headers to opt function 'WithWebsocketHeaders' (#365)
Browse files Browse the repository at this point in the history
Follow-up up on the discussion in
#360 (review).

Move websocket headers to and opt function 'WithWebsocketHeaders'.

Note this is a breaking change for users using the main branch
(but not for users on tagged releases).

I have:
- [x] Written a clear PR title and description (above)
- [x] Signed the [Khan Academy CLA](https://www.khanacademy.org/r/cla)
- [x] Added tests covering my changes, if applicable
- [x] Included a link to the issue fixed, if applicable
- [x] Included documentation, for new features
- [x] Added an entry to the changelog
  • Loading branch information
HaraldNordgren authored Dec 2, 2024
1 parent 5913cd6 commit d3e516b
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 13 deletions.
21 changes: 13 additions & 8 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,10 @@ type WebSocketOption func(*webSocketClient)
//
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header, opts ...WebSocketOption) WebSocketClient {
if headers == nil {
headers = http.Header{}
}
if headers.Get("Sec-WebSocket-Protocol") == "" {
headers.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, opts ...WebSocketOption) WebSocketClient {
client := &webSocketClient{
Dialer: wsDialer,
Header: headers,
header: http.Header{},
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
Expand All @@ -152,6 +146,10 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head
opt(client)
}

if client.header.Get("Sec-WebSocket-Protocol") == "" {
client.header.Add("Sec-WebSocket-Protocol", "graphql-transport-ws")
}

return client
}

Expand All @@ -163,6 +161,13 @@ func WithConnectionParams(connParams map[string]interface{}) WebSocketOption {
}
}

// WithWebsocketHeader sets a header to be sent to the server.
func WithWebsocketHeader(header http.Header) WebSocketOption {
return func(ws *webSocketClient) {
ws.header = header
}
}

func newClient(endpoint string, httpClient Doer, method string) Client {
if httpClient == nil || httpClient == (*http.Client)(nil) {
httpClient = http.DefaultClient
Expand Down
4 changes: 2 additions & 2 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const (

type webSocketClient struct {
Dialer Dialer
Header http.Header
header http.Header
endpoint string
conn WSConn
connParams map[string]interface{}
Expand Down Expand Up @@ -169,7 +169,7 @@ func checkConnectionAckReceived(message []byte) (bool, error) {
}

func (w *webSocketClient) Start(ctx context.Context) (errChan chan error, err error) {
w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.Header)
w.conn, err = w.Dialer.DialContext(ctx, w.endpoint, w.header)
if err != nil {
return nil, err
}
Expand Down
10 changes: 9 additions & 1 deletion internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,21 @@ func TestSubscriptionConnectionParams(t *testing.T) {
opts []graphql.WebSocketOption
}{
{
name: "authorized_user_gets_counter",
name: "connection_params_authorized_user_gets_counter",
opts: []graphql.WebSocketOption{
graphql.WithConnectionParams(map[string]interface{}{
authKey: "authorized-user-token",
}),
},
},
{
name: "http_header_authorized_user_gets_counter",
opts: []graphql.WebSocketOption{
graphql.WithWebsocketHeader(http.Header{
authKey: []string{"authorized-user-token"},
}),
},
},
{
name: "unauthorized_user_gets_error",
expectedError: "input: countAuthorized unauthorized\n",
Expand Down
1 change: 0 additions & 1 deletion internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ func newRoundtripWebSocketClient(t *testing.T, endpoint string, opts ...graphql.
wsWrapped: graphql.NewClientUsingWebSocket(
endpoint,
&MyDialer{Dialer: dialer},
nil,
opts...,
),
t: t,
Expand Down
20 changes: 19 additions & 1 deletion internal/integration/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"time"
Expand Down Expand Up @@ -198,6 +199,20 @@ func getAuthToken(ctx context.Context) string {
return ""
}

func authHeaderMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

token := r.Header.Get(AuthKey)
if token != "" {
ctx = withAuthToken(ctx, token)
}

r = r.WithContext(ctx)
handler.ServeHTTP(w, r)
})
}

func RunServer() *httptest.Server {
gqlgenServer := handler.New(NewExecutableSchema(Config{Resolvers: &resolver{}}))
gqlgenServer.AddTransport(transport.POST{})
Expand All @@ -216,7 +231,10 @@ func RunServer() *httptest.Server {
graphql.RegisterExtension(ctx, "foobar", "test")
return next(ctx)
})
return httptest.NewServer(gqlgenServer)

server := authHeaderMiddleware(gqlgenServer)

return httptest.NewServer(server)
}

type (
Expand Down

0 comments on commit d3e516b

Please sign in to comment.