-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
extension.go
131 lines (117 loc) · 3.76 KB
/
extension.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0
package headerssetterextension // import "github.com/open-telemetry/opentelemetry-collector-contrib/extension/headerssetterextension"
import (
"context"
"errors"
"fmt"
"net/http"
"go.opentelemetry.io/collector/extension/auth"
"go.uber.org/zap"
"google.golang.org/grpc/credentials"
"github.com/open-telemetry/opentelemetry-collector-contrib/extension/headerssetterextension/internal/action"
"github.com/open-telemetry/opentelemetry-collector-contrib/extension/headerssetterextension/internal/source"
)
type Header struct {
action action.Action
source source.Source
}
func newHeadersSetterExtension(cfg *Config, logger *zap.Logger) (auth.Client, error) {
if cfg == nil {
return nil, errors.New("extension configuration is not provided")
}
headers := make([]Header, 0, len(cfg.HeadersConfig))
for _, header := range cfg.HeadersConfig {
var s source.Source
if header.Value != nil {
s = &source.StaticSource{
Value: *header.Value,
}
} else if header.FromContext != nil {
defaultValue := ""
if header.DefaultValue != nil {
defaultValue = *header.DefaultValue
}
s = &source.ContextSource{
Key: *header.FromContext,
DefaultValue: defaultValue,
}
}
var a action.Action
switch header.Action {
case INSERT:
a = action.Insert{Key: *header.Key}
case UPSERT:
a = action.Upsert{Key: *header.Key}
case UPDATE:
a = action.Update{Key: *header.Key}
case DELETE:
a = action.Delete{Key: *header.Key}
default:
a = action.Upsert{Key: *header.Key}
logger.Warn("The action was not provided, using 'upsert'." +
" In future versions, we'll require this to be explicitly set")
}
headers = append(headers, Header{action: a, source: s})
}
return auth.NewClient(
auth.WithClientRoundTripper(
func(base http.RoundTripper) (http.RoundTripper, error) {
return &headersRoundTripper{
base: base,
headers: headers,
}, nil
}),
auth.WithClientPerRPCCredentials(func() (credentials.PerRPCCredentials, error) {
return &headersPerRPC{headers: headers}, nil
}),
), nil
}
// headersPerRPC is a gRPC credentials.PerRPCCredentials implementation sets
// headers with values extracted from provided sources.
type headersPerRPC struct {
headers []Header
}
// GetRequestMetadata returns the request metadata to be used with the RPC.
func (h *headersPerRPC) GetRequestMetadata(
ctx context.Context,
_ ...string,
) (map[string]string, error) {
metadata := make(map[string]string, len(h.headers))
for _, header := range h.headers {
value, err := header.source.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to determine the source: %w", err)
}
header.action.ApplyOnMetadata(metadata, value)
}
return metadata, nil
}
// RequireTransportSecurity always returns false for this implementation.
// The header setter is not sending auth data, so it should not require
// a transport security.
func (h *headersPerRPC) RequireTransportSecurity() bool {
return false
}
// headersRoundTripper intercepts downstream requests and sets headers with
// values extracted from configured sources.
type headersRoundTripper struct {
base http.RoundTripper
headers []Header
}
// RoundTrip copies the original request and sets headers of the new requests
// with values extracted from configured sources.
func (h *headersRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := req.Clone(req.Context())
if req2.Header == nil {
req2.Header = make(http.Header)
}
for _, header := range h.headers {
value, err := header.source.Get(req.Context())
if err != nil {
return nil, fmt.Errorf("failed to determine the source: %w", err)
}
header.action.ApplyOnHeaders(req2.Header, value)
}
return h.base.RoundTrip(req2)
}