From 8866984a076390b7b53f8af3d9c7e9f216f87b0c Mon Sep 17 00:00:00 2001 From: Evan Anderson Date: Wed, 13 Mar 2024 02:55:22 -0700 Subject: [PATCH] Store expected GitHub userid in database during enroll (#2566) * Update Keycloak config to store and expose gh_id and gh_login * Switch from kcadm.sh commands to keycloak-config-cli for most setup * Store expected GitHub userid in database during enroll * Fix lint errors * Further cleanup keycloak config (optional) * Fix remaining lint errors * Add tests for handlers_oauth.go * Add a metric to track what fraction of tokens are tied to a GitHub userid * Apply changes from Ria's review * Fix non-compiling code * Address Ozz's feedback * Update migration number * Fix test failures brought on by merge * re-run `make gen` Signed-off-by: Juan Antonio Osorio * Fix test Signed-off-by: Juan Antonio Osorio * Enable variable substitution in keycloak CLI container Signed-off-by: Juan Antonio Osorio --------- Signed-off-by: Juan Antonio Osorio Co-authored-by: Juan Antonio Osorio --- .mk/identity.mk | 6 +- cmd/dev/app/container/cmd_verify.go | 3 +- cmd/dev/app/rule_type/rttst.go | 3 +- .../migrations/000028_session_userid.down.sql | 15 ++ .../migrations/000028_session_userid.up.sql | 15 ++ database/mock/store.go | 30 --- database/query/session_store.sql | 12 +- docker-compose.yaml | 29 +- identity/config/stacklok.yaml | 137 ++++++++++ identity/scripts/initialize.sh | 56 ---- identity/scripts/kc-setup.sh | 22 -- internal/auth/jwtauth.go | 28 +- internal/authz/authz_test.go | 13 +- internal/controlplane/handlers_authz_test.go | 8 +- internal/controlplane/handlers_oauth.go | 118 ++++++-- internal/controlplane/handlers_oauth_test.go | 252 +++++++++++++++++- .../controlplane/handlers_projects_test.go | 8 +- .../handlers_repositories_test.go | 5 +- internal/controlplane/handlers_token.go | 2 +- internal/controlplane/metrics/metrics.go | 25 ++ internal/controlplane/metrics/noop.go | 3 + internal/controlplane/server.go | 46 ++-- internal/db/models.go | 1 + internal/db/querier.go | 2 - internal/db/session_store.sql.go | 51 +--- .../gh_branch_protect_test.go | 3 +- .../pull_request/pull_request_test.go | 3 +- .../actions/remediate/remediate_test.go | 3 +- .../actions/remediate/rest/rest_test.go | 7 +- .../engine/ingester/artifact/artifact_test.go | 3 +- internal/engine/ingester/git/git_test.go | 15 +- internal/engine/ingester/ingester_test.go | 3 +- internal/engine/ingester/rest/rest_test.go | 15 +- internal/providers/providers.go | 19 +- pkg/api/protobuf/go/minder/v1/minder.pb.go | 2 +- 35 files changed, 695 insertions(+), 268 deletions(-) create mode 100644 database/migrations/000028_session_userid.down.sql create mode 100644 database/migrations/000028_session_userid.up.sql create mode 100644 identity/config/stacklok.yaml delete mode 100755 identity/scripts/initialize.sh delete mode 100755 identity/scripts/kc-setup.sh diff --git a/.mk/identity.mk b/.mk/identity.mk index 53ebb44b36..b89f768695 100644 --- a/.mk/identity.mk +++ b/.mk/identity.mk @@ -22,11 +22,13 @@ ifndef KC_GITHUB_CLIENT_SECRET $(error KC_GITHUB_CLIENT_SECRET is not set) endif @echo "Setting up GitHub login..." - @$(CONTAINER) exec -it keycloak_container /opt/keycloak/bin/kcadm.sh config credentials --server http://localhost:8080 --realm master --user admin --password admin +# Delete the existing GitHub identity provider, if it exists. Otherwise, ignore the error. + @$(CONTAINER) exec -it keycloak_container /opt/keycloak/bin/kcadm.sh delete identity-provider/instances/github -r stacklok || true @$(CONTAINER) exec -it keycloak_container /opt/keycloak/bin/kcadm.sh create identity-provider/instances -r stacklok -s alias=github -s providerId=github -s enabled=true -s 'config.useJwksUrl="true"' -s config.clientId=$$KC_GITHUB_CLIENT_ID -s config.clientSecret=$$KC_GITHUB_CLIENT_SECRET + @$(CONTAINER) exec -it keycloak_container /opt/keycloak/bin/kcadm.sh create identity-provider/instances/github/mappers -r stacklok -s name=gh_id -s identityProviderAlias=github -s identityProviderMapper=github-user-attribute-mapper -s config='{"syncMode":"FORCE", "jsonField":"id", "userAttribute":"gh_id"}' + @$(CONTAINER) exec -it keycloak_container /opt/keycloak/bin/kcadm.sh create identity-provider/instances/github/mappers -r stacklok -s name=gh_login -s identityProviderAlias=github -s identityProviderMapper=github-user-attribute-mapper -s config='{"syncMode":"FORCE", "jsonField":"login", "userAttribute":"gh_login"}' password-login: @echo "Setting up password login..." - @$(CONTAINER) exec -it keycloak_container /opt/keycloak/bin/kcadm.sh config credentials --server http://localhost:8080 --realm master --user admin --password admin @$(CONTAINER) exec -it keycloak_container /opt/keycloak/bin/kcadm.sh create users -r stacklok -s username=testuser -s enabled=true @$(CONTAINER) exec -it keycloak_container /opt/keycloak/bin/kcadm.sh set-password -r stacklok --username testuser --new-password tester \ No newline at end of file diff --git a/cmd/dev/app/container/cmd_verify.go b/cmd/dev/app/container/cmd_verify.go index 2d723204e2..3afc2c7d25 100644 --- a/cmd/dev/app/container/cmd_verify.go +++ b/cmd/dev/app/container/cmd_verify.go @@ -17,6 +17,7 @@ package container import ( "context" + "database/sql" "encoding/json" "fmt" "os" @@ -121,7 +122,7 @@ func buildGitHubClient(token string) (provifv1.GitHub, error) { "github": {} }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, token, ) diff --git a/cmd/dev/app/rule_type/rttst.go b/cmd/dev/app/rule_type/rttst.go index ce3665c356..aef60a8c0b 100644 --- a/cmd/dev/app/rule_type/rttst.go +++ b/cmd/dev/app/rule_type/rttst.go @@ -17,6 +17,7 @@ package rule_type import ( "bytes" "context" + "database/sql" "encoding/json" "fmt" "os" @@ -132,7 +133,7 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { "github": {} }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, token, )) inf := &entities.EntityInfoWrapper{ diff --git a/database/migrations/000028_session_userid.down.sql b/database/migrations/000028_session_userid.down.sql new file mode 100644 index 0000000000..46490dd7b7 --- /dev/null +++ b/database/migrations/000028_session_userid.down.sql @@ -0,0 +1,15 @@ +-- Copyright 2024 Stacklok, Inc +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +ALTER TABLE session_store DROP COLUMN remote_user; \ No newline at end of file diff --git a/database/migrations/000028_session_userid.up.sql b/database/migrations/000028_session_userid.up.sql new file mode 100644 index 0000000000..f9c73d2fec --- /dev/null +++ b/database/migrations/000028_session_userid.up.sql @@ -0,0 +1,15 @@ +-- Copyright 2024 Stacklok, Inc +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +ALTER TABLE session_store ADD COLUMN remote_user TEXT; \ No newline at end of file diff --git a/database/mock/store.go b/database/mock/store.go index a628979f51..fab1f81b21 100644 --- a/database/mock/store.go +++ b/database/mock/store.go @@ -1032,36 +1032,6 @@ func (mr *MockStoreMockRecorder) GetRuleTypeByName(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRuleTypeByName", reflect.TypeOf((*MockStore)(nil).GetRuleTypeByName), arg0, arg1) } -// GetSessionState mocks base method. -func (m *MockStore) GetSessionState(arg0 context.Context, arg1 int32) (db.SessionStore, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionState", arg0, arg1) - ret0, _ := ret[0].(db.SessionStore) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSessionState indicates an expected call of GetSessionState. -func (mr *MockStoreMockRecorder) GetSessionState(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionState", reflect.TypeOf((*MockStore)(nil).GetSessionState), arg0, arg1) -} - -// GetSessionStateByProjectID mocks base method. -func (m *MockStore) GetSessionStateByProjectID(arg0 context.Context, arg1 uuid.UUID) (db.SessionStore, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionStateByProjectID", arg0, arg1) - ret0, _ := ret[0].(db.SessionStore) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSessionStateByProjectID indicates an expected call of GetSessionStateByProjectID. -func (mr *MockStoreMockRecorder) GetSessionStateByProjectID(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionStateByProjectID", reflect.TypeOf((*MockStore)(nil).GetSessionStateByProjectID), arg0, arg1) -} - // GetUserByID mocks base method. func (m *MockStore) GetUserByID(arg0 context.Context, arg1 int32) (db.User, error) { m.ctrl.T.Helper() diff --git a/database/query/session_store.sql b/database/query/session_store.sql index 0affa1e7af..d7d216f58d 100644 --- a/database/query/session_store.sql +++ b/database/query/session_store.sql @@ -1,20 +1,14 @@ -- name: CreateSessionState :one -INSERT INTO session_store (provider, project_id, session_state, owner_filter, redirect_url) VALUES ($1, $2, $3, $4, $5) RETURNING *; - --- name: GetSessionState :one -SELECT * FROM session_store WHERE id = $1; - --- name: GetSessionStateByProjectID :one -SELECT * FROM session_store WHERE project_id = $1; +INSERT INTO session_store (provider, project_id, remote_user, session_state, owner_filter, redirect_url) VALUES ($1, $2, $3, $4, $5, $6) RETURNING *; -- name: GetProjectIDBySessionState :one -SELECT provider, project_id, owner_filter, redirect_url FROM session_store WHERE session_state = $1; +SELECT provider, project_id, remote_user, owner_filter, redirect_url FROM session_store WHERE session_state = $1; -- name: DeleteSessionState :exec DELETE FROM session_store WHERE id = $1; -- name: DeleteSessionStateByProjectID :exec -DELETE FROM session_store WHERE provider=$1 AND project_id = $2; +DELETE FROM session_store WHERE provider = $1 AND project_id = $2; -- name: DeleteExpiredSessionStates :exec DELETE FROM session_store WHERE created_at < NOW() - INTERVAL '1 day'; \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index c4347e6099..b68f746791 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -68,6 +68,10 @@ services: condition: service_healthy openfga: condition: service_healthy + migrate: + condition: service_completed_successfully + keycloak-config: + condition: service_completed_successfully migrate: container_name: minder_migrate_up build: @@ -126,10 +130,9 @@ services: environment: KEYCLOAK_ADMIN: admin KEYCLOAK_ADMIN_PASSWORD: admin - KC_MINDER_SERVER_SECRET: secret KC_HEALTH_ENABLED: "true" healthcheck: - test: ["CMD", "/opt/keycloak/bin/kcadm.sh", "get", "realms/stacklok", "--fields", "enabled"] + test: ["CMD", "/opt/keycloak/bin/kcadm.sh", "config", "credentials", "--server", "http://localhost:8080", "--realm", "master", "--user", "admin", "--password", "admin"] interval: 10s timeout: 5s retries: 10 @@ -137,10 +140,28 @@ services: - "8081:8080" volumes: - ./identity/themes:/opt/keycloak/themes:z - - ./identity/scripts:/opt/keycloak/scripts:z networks: - app_net - entrypoint: ["/opt/keycloak/scripts/kc-setup.sh"] + + keycloak-config: + container_name: keycloak_config + image: bitnami/keycloak-config-cli:5.10.0 + entrypoint: ["java", "-jar", "/opt/bitnami/keycloak-config-cli/keycloak-config-cli.jar"] + environment: + KEYCLOAK_URL: http://keycloak:8080 + KEYCLOAK_USER: admin + KEYCLOAK_PASSWORD: admin + KC_MINDER_SERVER_SECRET: secret + IMPORT_VARSUBSTITUTION_ENABLED: "true" + IMPORT_FILES_LOCATIONS: /config/*.yaml + volumes: + - ./identity/config:/config:z + networks: + - app_net + + depends_on: + keycloak: + condition: service_healthy openfga: container_name: openfga diff --git a/identity/config/stacklok.yaml b/identity/config/stacklok.yaml new file mode 100644 index 0000000000..c1bb9f075a --- /dev/null +++ b/identity/config/stacklok.yaml @@ -0,0 +1,137 @@ +# Copyright 2024 Stacklok, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# From: +# create realms -s realm=stacklok -s loginTheme=keycloak -s eventsEnabled=true -s 'enabledEventTypes=["DELETE_ACCOUNT"]' -s eventsExpiration=604800 -s enabled=true +realm: stacklok +enabled: true +loginTheme: keycloak +eventsEnabled: true +enabledEventTypes: + - DELETE_ACCOUNT +eventsExpiration: 604800 + +# From: +# Add account deletion capability to stacklok realm (see https://www.keycloak.org/docs/latest/server_admin/#authentication-operations) +# update "/authentication/required-actions/delete_account" -r stacklok -b '{ "alias" : "delete_account", "name" : "Delete Account", "providerId" : "delete_account", "enabled" : true, "defaultAction" : false, "priority" : 60, "config" : { }}' +requiredActions: + - alias: delete_account + name: Delete Account + providerId: delete_account + enabled: true + defaultAction: false + +# From: +# Give all users permission to delete their own account +# add-roles -r stacklok --rname default-roles-stacklok --rolename delete-account --cclientid account +roles: + realm: + - name: default-roles-stacklok + composites: + client: + account: + - delete-account + - view-profile + - manage-account + +# Collect gh_login and gh_id from GitHub and expose them in tokens +clientScopes: + - name: gh-data + description: "Add GitHub information to tokens" + protocol: openid-connect + attributes: + "include.in.token.scope": "true" + "display.on.consent.screen": "false" + protocolMappers: + - name: gh_id + protocol: openid-connect + protocolMapper: oidc-usermodel-attribute-mapper + consentRequired: false + config: + userinfo.token.claim: "true" + id.token.claim: "true" + access.token.claim: "true" + claim.name: "gh_id" + jsonType.label: "String" + user.attribute: "gh_id" + - name: gh_login + protocol: openid-connect + protocolMapper: oidc-usermodel-attribute-mapper + consentRequired: false + config: + userinfo.token.claim: "true" + id.token.claim: "true" + access.token.claim: "true" + claim.name: "gh_login" + jsonType.label: "String" + user.attribute: "gh_login" + + +clients: + # From: + # create clients -r stacklok -s clientId=minder-cli -s 'redirectUris=["http://localhost/*"]' -s publicClient=true -s enabled=true -s defaultClientScopes='["acr","email","profile","roles","web-origins","gh-data"]' -s optionalClientScopes='["microprofile-jwt","offline_access"]' + - clientId: minder-cli + enabled: true + redirectUris: + - "http://localhost/*" + publicClient: true + # If you set one of these, you seem to need to set both (per CLI experimentation) + defaultClientScopes: + - acr + - email + - profile + - roles + - web-origins + - gh-data + optionalClientScopes: + - microprofile-jwt + - offline_access + # From: + # create clients -r stacklok -s clientId=minder-ui -s 'redirectUris=["http://localhost/*"]' -s publicClient=true -s enabled=true -s defaultClientScopes='["acr","email","profile","roles","web-origins","gh-data"]' -s optionalClientScopes='["microprofile-jwt","offline_access"]' + - clientId: minder-ui + enabled: true + redirectUris: + - "http://localhost/*" + publicClient: true + # If you set one of these, you seem to need to set both (per CLI experimentation) + defaultClientScopes: + - acr + - email + - profile + - roles + - web-origins + - gh-data + optionalClientScopes: + - microprofile-jwt + - offline_access + # From: + # create clients -r stacklok -s clientId=minder-server -s serviceAccountsEnabled=true -s clientAuthenticatorType=client-secret -s secret="$KC_MINDER_SERVER_SECRET" -s enabled=true -s defaultClientScopes='["acr","email","profile","roles","web-origins","gh-data"]' -s optionalClientScopes='["microprofile-jwt","offline_access"]' + - clientId: minder-server + enabled: true + serviceAccountsEnabled: true + clientAuthenticatorType: client-secret + secret: "$(env:KC_MINDER_SERVER_SECRET)" + +users: + - username: service-account-minder-server + clientRoles: + realm-management: + # From: + # Give minder-server the capability to view events + # add-roles -r stacklok --uusername service-account-minder-server --cclientid realm-management --rolename view-events + - view-events + # From: + # Give minder-server the capability to delete users + # add-roles -r stacklok --uusername service-account-minder-server --cclientid realm-management --rolename manage-users + - manage-users \ No newline at end of file diff --git a/identity/scripts/initialize.sh b/identity/scripts/initialize.sh deleted file mode 100755 index 95ff9d66e4..0000000000 --- a/identity/scripts/initialize.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env bash - -# -# Copyright 2023 Stacklok, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -euo pipefail - -# Wait for Keycloak to start and authenticate with admin credentials -while ! /opt/keycloak/bin/kcadm.sh config credentials --server http://keycloak:8080 --realm master --user "$KEYCLOAK_ADMIN" --password "$KEYCLOAK_ADMIN_PASSWORD"; do - sleep 1 -done - -status=0 -# Create realm stacklok, which stores account deletion events for 7 days -/opt/keycloak/bin/kcadm.sh get realms/stacklok >/dev/null 2>&1 || status="$?"; if [ $status -eq 1 ]; then - /opt/keycloak/bin/kcadm.sh create realms -s realm=stacklok -s loginTheme=keycloak -s eventsEnabled=true -s 'enabledEventTypes=["DELETE_ACCOUNT"]' -s eventsExpiration=604800 -s enabled=true - - # Add account deletion capability to stacklok realm (see https://www.keycloak.org/docs/latest/server_admin/#authentication-operations) - /opt/keycloak/bin/kcadm.sh update "/authentication/required-actions/delete_account" -r stacklok -b '{ "alias" : "delete_account", "name" : "Delete Account", "providerId" : "delete_account", "enabled" : true, "defaultAction" : false, "priority" : 60, "config" : { }}' - - # Give all users permission to delete their own account - /opt/keycloak/bin/kcadm.sh add-roles -r stacklok --rname default-roles-stacklok --rolename delete-account --cclientid account -fi - -# Create client minder-cli -if ! /opt/keycloak/bin/kcadm.sh get clients -r stacklok --fields 'clientId' | grep -q "minder-cli"; then - /opt/keycloak/bin/kcadm.sh create clients -r stacklok -s clientId=minder-cli -s 'redirectUris=["http://localhost/*"]' -s publicClient=true -s enabled=true -fi - -# Create client minder-ui -if ! /opt/keycloak/bin/kcadm.sh get clients -r stacklok --fields 'clientId' | grep -q "minder-ui"; then - /opt/keycloak/bin/kcadm.sh create clients -r stacklok -s clientId=minder-ui -s 'redirectUris=["http://localhost/*"]' -s publicClient=true -s enabled=true -fi - -# Create client minder-server to receive account deletion events -if ! /opt/keycloak/bin/kcadm.sh get clients -r stacklok --fields 'clientId' | grep -q "minder-server"; then - /opt/keycloak/bin/kcadm.sh create clients -r stacklok -s clientId=minder-server -s serviceAccountsEnabled=true -s clientAuthenticatorType=client-secret -s secret="$KC_MINDER_SERVER_SECRET" -s enabled=true - - # Give minder-server the capability to view events - /opt/keycloak/bin/kcadm.sh add-roles -r stacklok --uusername service-account-minder-server --cclientid realm-management --rolename view-events - - # Give minder-server the capability to delete users - /opt/keycloak/bin/kcadm.sh add-roles -r stacklok --uusername service-account-minder-server --cclientid realm-management --rolename manage-users -fi \ No newline at end of file diff --git a/identity/scripts/kc-setup.sh b/identity/scripts/kc-setup.sh deleted file mode 100755 index 69cc2d32b1..0000000000 --- a/identity/scripts/kc-setup.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env bash - -# -# Copyright 2023 Stacklok, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -euo pipefail - -/opt/keycloak/scripts/initialize.sh & disown - -/opt/keycloak/bin/kc.sh "$@" \ No newline at end of file diff --git a/internal/auth/jwtauth.go b/internal/auth/jwtauth.go index 3fd8dd86e8..45ebedb841 100644 --- a/internal/auth/jwtauth.go +++ b/internal/auth/jwtauth.go @@ -101,18 +101,34 @@ func NewJwtValidator(ctx context.Context, jwksUrl string) (JwtValidator, error) }, nil } -var userSubjectContextKey struct{} +var userTokenContextKey struct{} // GetUserSubjectFromContext returns the user subject from the context, or nil func GetUserSubjectFromContext(ctx context.Context) string { - subject, ok := ctx.Value(userSubjectContextKey).(string) + token, ok := ctx.Value(userTokenContextKey).(openid.Token) if !ok { return "" } - return subject + return token.Subject() } -// WithUserSubjectContext stores the specified user subject in the context. -func WithUserSubjectContext(ctx context.Context, subject string) context.Context { - return context.WithValue(ctx, userSubjectContextKey, subject) +// GetUserClaimFromContext returns the specified claim from the user subject in +// the context if found and of the correct type +func GetUserClaimFromContext[T any](ctx context.Context, claim string) (T, bool) { + var ret T + token, ok := ctx.Value(userTokenContextKey).(openid.Token) + if !ok { + return ret, false + } + data, ok := token.Get(claim) + if !ok { + return ret, false + } + ret, ok = data.(T) + return ret, ok +} + +// WithAuthTokenContext stores the specified user-identifying token in the context. +func WithAuthTokenContext(ctx context.Context, token openid.Token) context.Context { + return context.WithValue(ctx, userTokenContextKey, token) } diff --git a/internal/authz/authz_test.go b/internal/authz/authz_test.go index 9c3eaf80a2..0179f613df 100644 --- a/internal/authz/authz_test.go +++ b/internal/authz/authz_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwt/openid" fgasdk "github.com/openfga/go-sdk" "github.com/openfga/openfga/cmd/run" "github.com/openfga/openfga/pkg/logger" @@ -102,7 +103,9 @@ func TestVerifyOneProject(t *testing.T) { prj := uuid.New() assert.NoError(t, c.Write(ctx, "user-1", authz.AuthzRoleAdmin, prj), "failed to write project") - userctx := auth.WithUserSubjectContext(ctx, "user-1") + userJWT := openid.New() + assert.NoError(t, userJWT.Set("sub", "user-1")) + userctx := auth.WithAuthTokenContext(ctx, userJWT) // verify the project assert.NoError(t, c.Check(userctx, "get", prj), "failed to check project") @@ -154,7 +157,9 @@ func TestVerifyMultipleProjects(t *testing.T) { prj1 := uuid.New() assert.NoError(t, c.Write(ctx, "user-1", authz.AuthzRoleAdmin, prj1), "failed to write project") - userctx := auth.WithUserSubjectContext(ctx, "user-1") + user1JWT := openid.New() + assert.NoError(t, user1JWT.Set("sub", "user-1")) + userctx := auth.WithAuthTokenContext(ctx, user1JWT) // verify the project assert.NoError(t, c.Check(userctx, "get", prj1), "failed to check project") @@ -171,7 +176,9 @@ func TestVerifyMultipleProjects(t *testing.T) { assert.NoError(t, c.Write(ctx, "user-2", authz.AuthzRoleAdmin, prj3), "failed to write project") // verify the project - assert.NoError(t, c.Check(auth.WithUserSubjectContext(ctx, "user-2"), "get", prj3), "failed to check project") + user2JWT := openid.New() + assert.NoError(t, user2JWT.Set("sub", "user-2")) + assert.NoError(t, c.Check(auth.WithAuthTokenContext(ctx, user2JWT), "get", prj3), "failed to check project") // verify user-1 cannot operate on project 3 assert.Error(t, c.Check(userctx, "get", prj3), "expected user-1 to not be able to operate on project 3") diff --git a/internal/controlplane/handlers_authz_test.go b/internal/controlplane/handlers_authz_test.go index 44150d5b1f..2aa3894f76 100644 --- a/internal/controlplane/handlers_authz_test.go +++ b/internal/controlplane/handlers_authz_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwt/openid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -58,7 +59,8 @@ func TestEntityContextProjectInterceptor(t *testing.T) { malformedProjectID := "malformed" //nolint:goconst provider := "github" - subject := "subject1" + userJWT := openid.New() + assert.NoError(t, userJWT.Set("sub", "subject1")) assert.NotEqual(t, projectID, defaultProjectID) @@ -118,7 +120,7 @@ func TestEntityContextProjectInterceptor(t *testing.T) { buildStubs: func(t *testing.T, store *mockdb.MockStore) { t.Helper() store.EXPECT(). - GetUserBySubject(gomock.Any(), subject). + GetUserBySubject(gomock.Any(), userJWT.Subject()). Return(db.User{ ID: 1, }, nil) @@ -174,7 +176,7 @@ func TestEntityContextProjectInterceptor(t *testing.T) { if tc.buildStubs != nil { tc.buildStubs(t, mockStore) } - ctx := auth.WithUserSubjectContext(withRpcOptions(context.Background(), rpcOptions), subject) + ctx := auth.WithAuthTokenContext(withRpcOptions(context.Background(), rpcOptions), userJWT) authzClient := &mock.SimpleClient{} diff --git a/internal/controlplane/handlers_oauth.go b/internal/controlplane/handlers_oauth.go index 4928b96ab7..7a0f160d86 100644 --- a/internal/controlplane/handlers_oauth.go +++ b/internal/controlplane/handlers_oauth.go @@ -24,9 +24,11 @@ import ( "net/http" "net/url" "slices" + "strconv" "github.com/google/uuid" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -40,6 +42,7 @@ import ( "github.com/stacklok/minder/internal/db" "github.com/stacklok/minder/internal/engine" "github.com/stacklok/minder/internal/logger" + "github.com/stacklok/minder/internal/providers" "github.com/stacklok/minder/internal/util" pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" ) @@ -70,11 +73,9 @@ func (s *Server) GetAuthorizationURL(ctx context.Context, span.SetAttributes(attribute.Key("provider").String(provider.Name)) defer span.End() - // Create a new OAuth2 config for the given provider - oauthConfig, err := auth.NewOAuthConfig(provider.Name, req.Cli) - if err != nil { - return nil, err - } + user, _ := auth.GetUserClaimFromContext[string](ctx, "gh_id") + // If the user's token doesn't have gh_id set yet, we'll pass it through for now. + s.mt.AddTokenOpCount(ctx, "issued", user != "") // Generate a random nonce based state state, err := mcrypto.GenerateNonce() @@ -90,20 +91,17 @@ func (s *Server) GetAuthorizationURL(ctx context.Context, return nil, status.Errorf(codes.Unknown, "error deleting session state: %s", err) } - var owner sql.NullString - if req.Owner == nil { - owner = sql.NullString{Valid: false} - } else { - owner = sql.NullString{Valid: true, String: *req.Owner} + owner := sql.NullString{ + Valid: req.GetOwner() != "", + String: req.GetOwner(), } var redirectUrl sql.NullString - if req.RedirectUrl == nil { - redirectUrl = sql.NullString{Valid: false} - } else { + // Empty redirect URL means null string (default condition) + if req.GetRedirectUrl() != "" { encryptedRedirectUrl, err := s.cryptoEngine.EncryptString(*req.RedirectUrl) if err != nil { - return nil, status.Errorf(codes.Unknown, "error encrypting redirect URL: %s", err) + return nil, status.Errorf(codes.Internal, "error encrypting redirect URL: %s", err) } redirectUrl = sql.NullString{Valid: true, String: encryptedRedirectUrl} } @@ -113,6 +111,7 @@ func (s *Server) GetAuthorizationURL(ctx context.Context, _, err = s.store.CreateSessionState(ctx, db.CreateSessionStateParams{ Provider: provider.Name, ProjectID: projectID, + RemoteUser: sql.NullString{Valid: user != "", String: user}, SessionState: state, OwnerFilter: owner, RedirectUrl: redirectUrl, @@ -125,6 +124,12 @@ func (s *Server) GetAuthorizationURL(ctx context.Context, logger.BusinessRecord(ctx).Provider = provider.Name logger.BusinessRecord(ctx).Project = projectID + // Create a new OAuth2 config for the given provider + oauthConfig, err := s.providerAuthFactory(provider.Name, req.Cli) + if err != nil { + return nil, err + } + // Return the authorization URL and state return &pb.GetAuthorizationURLResponse{ Url: oauthConfig.AuthCodeURL(state, oauth2.AccessTypeOffline), @@ -140,6 +145,10 @@ func (s *Server) HandleProviderCallback() runtime.HandlerFunc { ctx := r.Context() if err := s.processCallback(ctx, w, r, pathParams); err != nil { + if httpErr, ok := err.(*httpResponseError); ok { + httpErr.WriteError(w) + return + } log.Printf("error handling provider callback: %s", err) w.WriteHeader(http.StatusInternalServerError) return @@ -212,7 +221,7 @@ func (s *Server) processCallback(ctx context.Context, w http.ResponseWriter, r * func (s *Server) generateOAuthToken(ctx context.Context, provider string, code string, stateData db.GetProjectIDBySessionStateRow) error { // generate a new OAuth2 config for the given provider - oauthConfig, err := auth.NewOAuthConfig(provider, true) + oauthConfig, err := s.providerAuthFactory(provider, true) if err != nil { return fmt.Errorf("error creating OAuth config: %w", err) } @@ -225,6 +234,18 @@ func (s *Server) generateOAuthToken(ctx context.Context, provider string, code s return fmt.Errorf("error exchanging code for token: %w", err) } + // Older enrollments may not have a RemoteUser stored; these should age out fairly quickly. + s.mt.AddTokenOpCount(ctx, "check", stateData.RemoteUser.Valid) + if stateData.RemoteUser.Valid { + if err := s.verifyProviderTokenIdentity(ctx, stateData, provider, token.AccessToken); err != nil { + // TODO: make this prettier? + return newHttpError(http.StatusForbidden, "User token mismatch").SetContents( + "The provided login token was associated with a different GitHub user.") + } + } else { + zerolog.Ctx(ctx).Warn().Msg("RemoteUser not found in session state") + } + ftoken := &oauth2.Token{ AccessToken: token.AccessToken, TokenType: token.TokenType, @@ -245,18 +266,11 @@ func (s *Server) generateOAuthToken(ctx context.Context, provider string, code s encodedToken := base64.StdEncoding.EncodeToString(encryptedToken) - var owner sql.NullString - if stateData.OwnerFilter.Valid { - owner = sql.NullString{Valid: true, String: stateData.OwnerFilter.String} - } else { - owner = sql.NullString{Valid: false} - } - _, err = s.store.UpsertAccessToken(ctx, db.UpsertAccessTokenParams{ ProjectID: stateData.ProjectID, Provider: provider, EncryptedToken: encodedToken, - OwnerFilter: owner, + OwnerFilter: stateData.OwnerFilter, }) if err != nil { return fmt.Errorf("error inserting access token: %w", err) @@ -264,6 +278,36 @@ func (s *Server) generateOAuthToken(ctx context.Context, provider string, code s return nil } +func (s *Server) verifyProviderTokenIdentity( + ctx context.Context, stateData db.GetProjectIDBySessionStateRow, provider string, token string) error { + dbProvider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{ + Name: provider, + Projects: []uuid.UUID{stateData.ProjectID}, + }) + if err != nil { + return fmt.Errorf("error getting provider by name: %w", err) + } + pbOpts := []providers.ProviderBuilderOption{ + providers.WithProviderMetrics(s.provMt), + providers.WithRestClientCache(s.restClientCache), + } + builder := providers.NewProviderBuilder(&dbProvider, sql.NullString{}, token, pbOpts...) + // NOTE: this is github-specific at the moment. We probably need to generally + // re-think token enrollment when we add more providers. + ghClient, err := builder.GetGitHub() + if err != nil { + return fmt.Errorf("error creating GitHub client: %w", err) + } + userId, err := ghClient.GetUserId(ctx) + if err != nil { + return fmt.Errorf("error getting user ID: %w", err) + } + if strconv.FormatInt(userId, 10) != stateData.RemoteUser.String { + return fmt.Errorf("user ID mismatch: %d != %s", userId, stateData.RemoteUser.String) + } + return nil +} + // getProviderAccessToken returns the access token for providers func (s *Server) getProviderAccessToken(ctx context.Context, provider string, projectID uuid.UUID) (oauth2.Token, string, error) { @@ -376,3 +420,31 @@ func (s *Server) VerifyProviderTokenFrom(ctx context.Context, return &pb.VerifyProviderTokenFromResponse{Status: "OK"}, nil } + +type httpResponseError struct { + statusCode int + short string + pageContents string +} + +func newHttpError(statusCode int, short string) *httpResponseError { + return &httpResponseError{ + statusCode: statusCode, + short: short, + pageContents: "An unknown error occurred", + } +} + +func (e *httpResponseError) SetContents(contents string, args ...any) *httpResponseError { + e.pageContents = fmt.Sprintf(contents, args...) + return e +} + +// Error implements error +func (e *httpResponseError) Error() string { + return fmt.Sprintf("HTTP error: %d %s", e.statusCode, e.short) +} + +func (e *httpResponseError) WriteError(w http.ResponseWriter) { + http.Error(w, e.pageContents, e.statusCode) +} diff --git a/internal/controlplane/handlers_oauth_test.go b/internal/controlplane/handlers_oauth_test.go index 5d4322bfd3..5d9b5fce0a 100644 --- a/internal/controlplane/handlers_oauth_test.go +++ b/internal/controlplane/handlers_oauth_test.go @@ -15,12 +15,24 @@ package controlplane import ( + "bytes" "context" + "database/sql" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" "testing" "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwt/openid" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" + "golang.org/x/oauth2" "golang.org/x/oauth2/github" "golang.org/x/oauth2/google" "google.golang.org/grpc/codes" @@ -29,7 +41,9 @@ import ( "github.com/stacklok/minder/internal/auth" "github.com/stacklok/minder/internal/db" "github.com/stacklok/minder/internal/engine" + "github.com/stacklok/minder/internal/providers/ratecache" pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" + provinfv1 "github.com/stacklok/minder/pkg/providers/v1" ) func TestNewOAuthConfig(t *testing.T) { @@ -93,6 +107,7 @@ func TestGetAuthorizationURL(t *testing.T) { name string req *pb.GetAuthorizationURLRequest buildStubs func(store *mockdb.MockStore) + getToken func(openid.Token) openid.Token checkResponse func(t *testing.T, res *pb.GetAuthorizationURLResponse, err error) expectedStatusCode codes.Code }{ @@ -120,7 +135,10 @@ func TestGetAuthorizationURL(t *testing.T) { }, }}, nil) store.EXPECT(). - CreateSessionState(gomock.Any(), gomock.Any()). + CreateSessionState(gomock.Any(), partialDbParamsMatcher{db.CreateSessionStateParams{ + Provider: "github", + ProjectID: projectID, + }}). Return(db.SessionStore{}, nil) store.EXPECT(). DeleteSessionStateByProjectID(gomock.Any(), gomock.Any()). @@ -174,15 +192,75 @@ func TestGetAuthorizationURL(t *testing.T) { expectedStatusCode: codes.InvalidArgument, }, + { + name: "No GitHub id", + req: &pb.GetAuthorizationURLRequest{ + Context: &pb.Context{ + Provider: &providerName, + Project: &projectIdStr, + }, + Port: 8080, + Cli: true, + }, + buildStubs: func(store *mockdb.MockStore) { + store.EXPECT(). + GetParentProjects(gomock.Any(), projectID). + Return([]uuid.UUID{projectID}, nil) + store.EXPECT(). + ListProvidersByProjectID(gomock.Any(), []uuid.UUID{projectID}). + Return([]db.Provider{{ + ID: providerID, + Name: "github", + AuthFlows: []db.AuthorizationFlow{ + db.AuthorizationFlowOauth2AuthorizationCodeFlow, + }, + }}, nil) + store.EXPECT(). + CreateSessionState(gomock.Any(), partialDbParamsMatcher{db.CreateSessionStateParams{ + Provider: "github", + ProjectID: projectID, + RemoteUser: sql.NullString{Valid: true, String: "31337"}, + }}). + Return(db.SessionStore{}, nil) + store.EXPECT(). + DeleteSessionStateByProjectID(gomock.Any(), gomock.Any()). + Return(nil) + }, + getToken: func(tok openid.Token) openid.Token { + if err := tok.Set("gh_id", "31337"); err != nil { + t.Fatalf("Error setting gh_id: %v", err) + } + return tok + }, + checkResponse: func(t *testing.T, res *pb.GetAuthorizationURLResponse, err error) { + t.Helper() + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if res.Url == "" { + t.Errorf("Unexpected response from GetAuthorizationURL: %v", res) + } + }, + + expectedStatusCode: codes.OK, + }, } rpcOptions := &pb.RpcOptions{ TargetResource: pb.TargetResource_TARGET_RESOURCE_USER, } - ctx := withRpcOptions(context.Background(), rpcOptions) + baseCtx := withRpcOptions(context.Background(), rpcOptions) + + userJWT := openid.New() + if err := userJWT.Set("sub", "testuser"); err != nil { + t.Fatalf("Error setting sub: %v", err) + } + // Set the entity context - ctx = engine.WithEntityContext(ctx, &engine.EntityContext{ + baseCtx = engine.WithEntityContext(baseCtx, &engine.EntityContext{ Project: engine.Project{ ID: projectID, }, @@ -193,6 +271,16 @@ func TestGetAuthorizationURL(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() + tok, err := userJWT.Clone() + if err != nil { + t.Fatalf("Failed to clone token: %v", err) + } + token := tok.(openid.Token) + if tc.getToken != nil { + token = tc.getToken(token) + } + ctx := auth.WithAuthTokenContext(baseCtx, token) + ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -206,3 +294,161 @@ func TestGetAuthorizationURL(t *testing.T) { }) } } + +func TestProviderCallback(t *testing.T) { + t.Parallel() + + projectID := uuid.New() + code := "0xefbeadde" + + testCases := []struct { + name string + redirectUrl string + remoteUser sql.NullString + code int + err string + }{{ + name: "Success", + redirectUrl: "http://localhost:8080", + code: 307, + }, { + name: "Success with remote user", + redirectUrl: "http://localhost:8080", + remoteUser: sql.NullString{Valid: true, String: "31337"}, + code: 307, + }, { + name: "Wrong remote userid", + remoteUser: sql.NullString{Valid: true, String: "1234"}, + code: 403, + err: "The provided login token was associated with a different GitHub user.\n", + }} + + for _, tt := range testCases { + tc := tt + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + req := http.Request{} + + resp := httptest.ResponseRecorder{Body: new(bytes.Buffer)} + params := map[string]string{"provider": "github"} + + stateBinary := make([]byte, 8) + // Store a very large timestamp in the state to ensure it's not expired + binary.BigEndian.PutUint64(stateBinary, 0x0fffffffffffffff) + stateBinary = append(stateBinary, []byte(tc.name)...) + state := base64.RawURLEncoding.EncodeToString(stateBinary) + + req.URL = &url.URL{ + RawQuery: url.Values{"state": {state}, "code": {code}}.Encode(), + } + + oauthServer := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "anAccessToken", + }) + if err != nil { + t.Fatalf("Failed to write response: %v", err) + } + })) + defer oauthServer.Close() + + stubClient := StubGitHub{ + UserId: 31337, + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + store := mockdb.NewMockStore(ctrl) + s := newDefaultServer(t, store) + + encryptedUrl := sql.NullString{} + if tc.redirectUrl != "" { + var err error + encryptedUrl.String, err = s.cryptoEngine.EncryptString(tc.redirectUrl) + if err != nil { + t.Fatalf("Failed to encrypt redirect URL: %v", err) + } + encryptedUrl.Valid = true + } + + store.EXPECT().GetProjectIDBySessionState(gomock.Any(), state).Return( + db.GetProjectIDBySessionStateRow{ + ProjectID: projectID, + RedirectUrl: encryptedUrl, + RemoteUser: tc.remoteUser, + }, nil) + + if tc.remoteUser.String != "" { + // TODO: verfifyProviderTokenIdentity + store.EXPECT().GetProviderByName(gomock.Any(), gomock.Any()).Return( + db.Provider{ + Name: "github", + Implements: []db.ProviderType{db.ProviderTypeGithub}, + Version: provinfv1.V1, + }, nil) + cancelable, cancel := context.WithCancel(context.Background()) + defer cancel() + clientCache := ratecache.NewRestClientCache(cancelable) + clientCache.Set("", "anAccessToken", db.ProviderTypeGithub, &stubClient) + s.restClientCache = clientCache + } + if tc.code < http.StatusBadRequest { + store.EXPECT().UpsertAccessToken(gomock.Any(), gomock.Any()).Return( + db.ProviderAccessToken{}, nil) + } + + t.Logf("Request: %+v", req.URL) + s.providerAuthFactory = func(_ string, _ bool) (*oauth2.Config, error) { + return &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: oauthServer.URL, + }, + }, nil + } + s.HandleProviderCallback()(&resp, &req, params) + + t.Logf("Response: %v", resp.Code) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + t.Logf("Body: %s", string(body)) + + if resp.Code != tc.code { + t.Errorf("Unexpected status code: %v", resp.Code) + } + if tc.code >= http.StatusMovedPermanently && tc.code < http.StatusBadRequest { + if resp.Header().Get("Location") != tc.redirectUrl { + t.Errorf("Unexpected redirect URL: %v", resp.Header().Get("Location")) + } + } + if tc.err != "" { + if string(body) != tc.err { + t.Errorf("Unexpected error message: %q", string(body)) + } + } + }) + } +} + +type partialDbParamsMatcher struct { + value db.CreateSessionStateParams +} + +func (p partialDbParamsMatcher) Matches(x interface{}) bool { + typedX, ok := x.(db.CreateSessionStateParams) + if !ok { + return false + } + + typedX.SessionState = "" + + return typedX == p.value +} + +func (m partialDbParamsMatcher) String() string { + return fmt.Sprintf("matches %+v", m.value) +} diff --git a/internal/controlplane/handlers_projects_test.go b/internal/controlplane/handlers_projects_test.go index 13d268156a..4eaf222e05 100644 --- a/internal/controlplane/handlers_projects_test.go +++ b/internal/controlplane/handlers_projects_test.go @@ -19,6 +19,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwt/openid" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -32,7 +33,8 @@ import ( func TestListProjects(t *testing.T) { t.Parallel() - user := "testuser" + user := openid.New() + assert.NoError(t, user.Set("sub", "testuser")) authzClient := &mock.SimpleClient{ Allowed: []uuid.UUID{uuid.New()}, @@ -42,7 +44,7 @@ func TestListProjects(t *testing.T) { defer ctrl.Finish() mockStore := mockdb.NewMockStore(ctrl) - mockStore.EXPECT().GetUserBySubject(gomock.Any(), user).Return(db.User{ID: 1}, nil) + mockStore.EXPECT().GetUserBySubject(gomock.Any(), user.Subject()).Return(db.User{ID: 1}, nil) mockStore.EXPECT().GetProjectByID(gomock.Any(), authzClient.Allowed[0]).Return( db.Project{ID: authzClient.Allowed[0]}, nil) @@ -52,7 +54,7 @@ func TestListProjects(t *testing.T) { } ctx := context.Background() - ctx = auth.WithUserSubjectContext(ctx, user) + ctx = auth.WithAuthTokenContext(ctx, user) resp, err := server.ListProjects(ctx, &minder.ListProjectsRequest{}) assert.NoError(t, err) diff --git a/internal/controlplane/handlers_repositories_test.go b/internal/controlplane/handlers_repositories_test.go index f40ffcc4d5..c6316a0f81 100644 --- a/internal/controlplane/handlers_repositories_test.go +++ b/internal/controlplane/handlers_repositories_test.go @@ -240,6 +240,7 @@ type StubGitHub struct { T *testing.T Repo *github.Repository RepoErr error + UserId int64 ExistingHooks []*github.Hook DeletedHooks []int64 NewHooks []*github.Hook @@ -424,8 +425,8 @@ func (*StubGitHub) UpdateBranchProtection(context.Context, string, string, strin } // GetUserId implements v1.GitHub. -func (*StubGitHub) GetUserId(context.Context) (int64, error) { - panic("unimplemented") +func (s *StubGitHub) GetUserId(context.Context) (int64, error) { + return s.UserId, nil } // GetUsername implements v1.GitHub. diff --git a/internal/controlplane/handlers_token.go b/internal/controlplane/handlers_token.go index f2c00f589e..1530e8085c 100644 --- a/internal/controlplane/handlers_token.go +++ b/internal/controlplane/handlers_token.go @@ -71,7 +71,7 @@ func TokenValidationInterceptor(ctx context.Context, req interface{}, info *grpc return nil, status.Errorf(codes.Unauthenticated, "invalid auth token: %v", err) } - ctx = auth.WithUserSubjectContext(ctx, parsedToken.Subject()) + ctx = auth.WithAuthTokenContext(ctx, parsedToken) // Attach the login sha for telemetry usage (hash of the user subject from the JWT) loginSHA := sha256.Sum256([]byte(parsedToken.Subject())) diff --git a/internal/controlplane/metrics/metrics.go b/internal/controlplane/metrics/metrics.go index 94685d8e74..e37042c3b0 100644 --- a/internal/controlplane/metrics/metrics.go +++ b/internal/controlplane/metrics/metrics.go @@ -45,6 +45,10 @@ type Metrics interface { // AddWebhookEventTypeCount adds a count to the webhook event type counter AddWebhookEventTypeCount(context.Context, *WebhookEventState) + + // AddTokenOpCount records a token operation (issued, check) and whether the + // github ID was present at the time of check. + AddTokenOpCount(context.Context, string, bool) } type metricsImpl struct { @@ -55,6 +59,10 @@ type metricsImpl struct { webhookStatusCodeCounter metric.Int64Counter // webhook event type counter webhookEventTypeCounter metric.Int64Counter + + // Track how often users who register a token are correlated with the + // GitHub user from GetAuthorizationURL + tokenOpCounter metric.Int64Counter } // NewMetrics creates a new controlplane metrics instance. @@ -160,6 +168,13 @@ func (m *metricsImpl) initInstrumentsOnce(store db.Store) error { return fmt.Errorf("failed to create webhook event type counter: %w", err) } + m.tokenOpCounter, err = m.meter.Int64Counter("token-checks", + metric.WithDescription("Number of times token URLs are issued and consumed"), + metric.WithUnit("ops")) + if err != nil { + return fmt.Errorf("failed to create token operations counter: %w", err) + } + return nil } @@ -176,3 +191,13 @@ func (m *metricsImpl) AddWebhookEventTypeCount(ctx context.Context, state *Webho } m.webhookEventTypeCounter.Add(ctx, 1, metric.WithAttributes(labels...)) } + +func (m *metricsImpl) AddTokenOpCount(ctx context.Context, stage string, hasId bool) { + if m.tokenOpCounter == nil { + return + } + + m.tokenOpCounter.Add(ctx, 1, metric.WithAttributes( + attribute.String("stage", stage), + attribute.Bool("has-id", hasId))) +} diff --git a/internal/controlplane/metrics/noop.go b/internal/controlplane/metrics/noop.go index 7aed2ce088..4f08f487ed 100644 --- a/internal/controlplane/metrics/noop.go +++ b/internal/controlplane/metrics/noop.go @@ -36,3 +36,6 @@ func (_ *noopMetrics) Init(_ db.Store) error { // AddWebhookEventTypeCount implements Metrics.AddWebhookEventTypeCount func (_ *noopMetrics) AddWebhookEventTypeCount(_ context.Context, _ *WebhookEventState) {} + +// AddTokenOpCount implements Metrics.AddTokenOpCount +func (_ *noopMetrics) AddTokenOpCount(_ context.Context, _ string, _ bool) {} diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 21a7eedab2..0d5b4c667f 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -70,19 +70,20 @@ var ( // Server represents the controlplane server type Server struct { - store db.Store - cfg *serverconfig.Config - evt events.Interface - mt metrics.Metrics - provMt provtelemetry.ProviderMetrics - grpcServer *grpc.Server - vldtr auth.JwtValidator - OAuth2 *oauth2.Config - ClientID string - ClientSecret string - authzClient authz.Client - cryptoEngine crypto.Engine - restClientCache ratecache.RestClientCache + store db.Store + cfg *serverconfig.Config + evt events.Interface + mt metrics.Metrics + provMt provtelemetry.ProviderMetrics + grpcServer *grpc.Server + vldtr auth.JwtValidator + OAuth2 *oauth2.Config + providerAuthFactory func(string, bool) (*oauth2.Config, error) + ClientID string + ClientSecret string + authzClient authz.Client + cryptoEngine crypto.Engine + restClientCache ratecache.RestClientCache // We may want to start breaking up the server struct if we use it to // inject more entity-specific interfaces. For example, we may want to // consider having a struct per grpc service @@ -146,15 +147,16 @@ func NewServer( return nil, fmt.Errorf("failed to create crypto engine: %w", err) } s := &Server{ - store: store, - cfg: cfg, - evt: evt, - cryptoEngine: eng, - vldtr: vldtr, - mt: metrics.NewNoopMetrics(), - provMt: provtelemetry.NewNoopMetrics(), - profileValidator: profiles.NewValidator(store), - webhookManager: webhooks.NewWebhookManager(cfg.WebhookConfig), + store: store, + cfg: cfg, + evt: evt, + cryptoEngine: eng, + vldtr: vldtr, + providerAuthFactory: auth.NewOAuthConfig, + mt: metrics.NewNoopMetrics(), + provMt: provtelemetry.NewNoopMetrics(), + profileValidator: profiles.NewValidator(store), + webhookManager: webhooks.NewWebhookManager(cfg.WebhookConfig), // TODO: this currently always returns authorized as a transitionary measure. // When OpenFGA is fully rolled out, we may want to make this a hard error or set to false. authzClient: &mock.NoopClient{Authorized: true}, diff --git a/internal/db/models.go b/internal/db/models.go index 6f9fe00df0..7a33fe5488 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -573,6 +573,7 @@ type SessionStore struct { SessionState string `json:"session_state"` CreatedAt time.Time `json:"created_at"` RedirectUrl sql.NullString `json:"redirect_url"` + RemoteUser sql.NullString `json:"remote_user"` } type User struct { diff --git a/internal/db/querier.go b/internal/db/querier.go index 709b2cac90..66facdb651 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -83,8 +83,6 @@ type Querier interface { GetRepositoryByRepoName(ctx context.Context, arg GetRepositoryByRepoNameParams) (Repository, error) GetRuleTypeByID(ctx context.Context, id uuid.UUID) (RuleType, error) GetRuleTypeByName(ctx context.Context, arg GetRuleTypeByNameParams) (RuleType, error) - GetSessionState(ctx context.Context, id int32) (SessionStore, error) - GetSessionStateByProjectID(ctx context.Context, projectID uuid.UUID) (SessionStore, error) GetUserByID(ctx context.Context, id int32) (User, error) GetUserBySubject(ctx context.Context, identitySubject string) (User, error) GlobalListProviders(ctx context.Context) ([]Provider, error) diff --git a/internal/db/session_store.sql.go b/internal/db/session_store.sql.go index e012fe3a84..f34de3eb04 100644 --- a/internal/db/session_store.sql.go +++ b/internal/db/session_store.sql.go @@ -13,12 +13,13 @@ import ( ) const createSessionState = `-- name: CreateSessionState :one -INSERT INTO session_store (provider, project_id, session_state, owner_filter, redirect_url) VALUES ($1, $2, $3, $4, $5) RETURNING id, provider, project_id, port, owner_filter, session_state, created_at, redirect_url +INSERT INTO session_store (provider, project_id, remote_user, session_state, owner_filter, redirect_url) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, provider, project_id, port, owner_filter, session_state, created_at, redirect_url, remote_user ` type CreateSessionStateParams struct { Provider string `json:"provider"` ProjectID uuid.UUID `json:"project_id"` + RemoteUser sql.NullString `json:"remote_user"` SessionState string `json:"session_state"` OwnerFilter sql.NullString `json:"owner_filter"` RedirectUrl sql.NullString `json:"redirect_url"` @@ -28,6 +29,7 @@ func (q *Queries) CreateSessionState(ctx context.Context, arg CreateSessionState row := q.db.QueryRowContext(ctx, createSessionState, arg.Provider, arg.ProjectID, + arg.RemoteUser, arg.SessionState, arg.OwnerFilter, arg.RedirectUrl, @@ -42,6 +44,7 @@ func (q *Queries) CreateSessionState(ctx context.Context, arg CreateSessionState &i.SessionState, &i.CreatedAt, &i.RedirectUrl, + &i.RemoteUser, ) return i, err } @@ -65,7 +68,7 @@ func (q *Queries) DeleteSessionState(ctx context.Context, id int32) error { } const deleteSessionStateByProjectID = `-- name: DeleteSessionStateByProjectID :exec -DELETE FROM session_store WHERE provider=$1 AND project_id = $2 +DELETE FROM session_store WHERE provider = $1 AND project_id = $2 ` type DeleteSessionStateByProjectIDParams struct { @@ -79,12 +82,13 @@ func (q *Queries) DeleteSessionStateByProjectID(ctx context.Context, arg DeleteS } const getProjectIDBySessionState = `-- name: GetProjectIDBySessionState :one -SELECT provider, project_id, owner_filter, redirect_url FROM session_store WHERE session_state = $1 +SELECT provider, project_id, remote_user, owner_filter, redirect_url FROM session_store WHERE session_state = $1 ` type GetProjectIDBySessionStateRow struct { Provider string `json:"provider"` ProjectID uuid.UUID `json:"project_id"` + RemoteUser sql.NullString `json:"remote_user"` OwnerFilter sql.NullString `json:"owner_filter"` RedirectUrl sql.NullString `json:"redirect_url"` } @@ -95,48 +99,9 @@ func (q *Queries) GetProjectIDBySessionState(ctx context.Context, sessionState s err := row.Scan( &i.Provider, &i.ProjectID, + &i.RemoteUser, &i.OwnerFilter, &i.RedirectUrl, ) return i, err } - -const getSessionState = `-- name: GetSessionState :one -SELECT id, provider, project_id, port, owner_filter, session_state, created_at, redirect_url FROM session_store WHERE id = $1 -` - -func (q *Queries) GetSessionState(ctx context.Context, id int32) (SessionStore, error) { - row := q.db.QueryRowContext(ctx, getSessionState, id) - var i SessionStore - err := row.Scan( - &i.ID, - &i.Provider, - &i.ProjectID, - &i.Port, - &i.OwnerFilter, - &i.SessionState, - &i.CreatedAt, - &i.RedirectUrl, - ) - return i, err -} - -const getSessionStateByProjectID = `-- name: GetSessionStateByProjectID :one -SELECT id, provider, project_id, port, owner_filter, session_state, created_at, redirect_url FROM session_store WHERE project_id = $1 -` - -func (q *Queries) GetSessionStateByProjectID(ctx context.Context, projectID uuid.UUID) (SessionStore, error) { - row := q.db.QueryRowContext(ctx, getSessionStateByProjectID, projectID) - var i SessionStore - err := row.Scan( - &i.ID, - &i.Provider, - &i.ProjectID, - &i.Port, - &i.OwnerFilter, - &i.SessionState, - &i.CreatedAt, - &i.RedirectUrl, - ) - return i, err -} diff --git a/internal/engine/actions/remediate/gh_branch_protect/gh_branch_protect_test.go b/internal/engine/actions/remediate/gh_branch_protect/gh_branch_protect_test.go index 7c864ed106..7fe8a75bd2 100644 --- a/internal/engine/actions/remediate/gh_branch_protect/gh_branch_protect_test.go +++ b/internal/engine/actions/remediate/gh_branch_protect/gh_branch_protect_test.go @@ -17,6 +17,7 @@ package gh_branch_protect import ( "context" + "database/sql" "encoding/json" "fmt" "strings" @@ -64,7 +65,7 @@ func testGithubProviderBuilder(baseURL string) *providers.ProviderBuilder { Implements: []db.ProviderType{db.ProviderTypeGithub, db.ProviderTypeRest}, Definition: json.RawMessage(definitionJSON), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ) } diff --git a/internal/engine/actions/remediate/pull_request/pull_request_test.go b/internal/engine/actions/remediate/pull_request/pull_request_test.go index 70bb337b93..c14b06fb1f 100644 --- a/internal/engine/actions/remediate/pull_request/pull_request_test.go +++ b/internal/engine/actions/remediate/pull_request/pull_request_test.go @@ -18,6 +18,7 @@ package pull_request import ( "bytes" "context" + "database/sql" "encoding/json" "fmt" "io" @@ -107,7 +108,7 @@ func testGithubProviderBuilder() *providers.ProviderBuilder { Implements: []db.ProviderType{db.ProviderTypeGithub, db.ProviderTypeRest, db.ProviderTypeGit}, Definition: json.RawMessage(definitionJSON), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ) } diff --git a/internal/engine/actions/remediate/remediate_test.go b/internal/engine/actions/remediate/remediate_test.go index 763e098f9f..4791b7fc9a 100644 --- a/internal/engine/actions/remediate/remediate_test.go +++ b/internal/engine/actions/remediate/remediate_test.go @@ -17,6 +17,7 @@ package remediate_test import ( + "database/sql" "encoding/json" "testing" @@ -48,7 +49,7 @@ var ( } }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ) ) diff --git a/internal/engine/actions/remediate/rest/rest_test.go b/internal/engine/actions/remediate/rest/rest_test.go index 6e53853e27..a2e6655aae 100644 --- a/internal/engine/actions/remediate/rest/rest_test.go +++ b/internal/engine/actions/remediate/rest/rest_test.go @@ -18,6 +18,7 @@ package rest import ( "context" + "database/sql" "encoding/json" "fmt" "net/http" @@ -54,7 +55,7 @@ var ( } }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ) invalidProviderBuilder = providers.NewProviderBuilder( @@ -69,7 +70,7 @@ var ( "base_url": "https://api.github.com/" }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ) TestActionTypeValid interfaces.ActionType = "remediate-test" @@ -93,7 +94,7 @@ func testGithubProviderBuilder(baseURL string) *providers.ProviderBuilder { Implements: []db.ProviderType{db.ProviderTypeGithub, db.ProviderTypeRest}, Definition: json.RawMessage(definitionJSON), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ) } diff --git a/internal/engine/ingester/artifact/artifact_test.go b/internal/engine/ingester/artifact/artifact_test.go index 8c96f5a740..1f54b2d265 100644 --- a/internal/engine/ingester/artifact/artifact_test.go +++ b/internal/engine/ingester/artifact/artifact_test.go @@ -17,6 +17,7 @@ package artifact import ( "context" + "database/sql" "encoding/json" "testing" "time" @@ -55,7 +56,7 @@ func testGithubProviderBuilder() *providers.ProviderBuilder { Implements: []db.ProviderType{db.ProviderTypeGithub, db.ProviderTypeRest, db.ProviderTypeGit}, Definition: json.RawMessage(definitionJSON), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ) } diff --git a/internal/engine/ingester/git/git_test.go b/internal/engine/ingester/git/git_test.go index d2669c6032..5235b8fe2c 100644 --- a/internal/engine/ingester/git/git_test.go +++ b/internal/engine/ingester/git/git_test.go @@ -17,6 +17,7 @@ package git_test import ( "bytes" "context" + "database/sql" "testing" "github.com/stretchr/testify/require" @@ -42,7 +43,7 @@ func TestGitIngestWithCloneURLFromRepo(t *testing.T) { "git", }, }, - db.ProviderAccessToken{}, + sql.NullString{}, "", )) require.NoError(t, err, "expected no error") @@ -79,7 +80,7 @@ func TestGitIngestWithCloneURLFromParams(t *testing.T) { "git", }, }, - db.ProviderAccessToken{}, + sql.NullString{}, "", )) require.NoError(t, err, "expected no error") @@ -116,7 +117,7 @@ func TestGitIngestWithCustomBranchFromParams(t *testing.T) { "git", }, }, - db.ProviderAccessToken{}, + sql.NullString{}, "", )) require.NoError(t, err, "expected no error") @@ -153,7 +154,7 @@ func TestGitIngestWithBranchFromRepoEntity(t *testing.T) { "git", }, }, - db.ProviderAccessToken{}, + sql.NullString{}, "", )) require.NoError(t, err, "expected no error") @@ -192,7 +193,7 @@ func TestGitIngestWithUnexistentBranchFromParams(t *testing.T) { "git", }, }, - db.ProviderAccessToken{}, + sql.NullString{}, "", )) require.NoError(t, err, "expected no error") @@ -220,7 +221,7 @@ func TestGitIngestFailsBecauseOfAuthorization(t *testing.T) { "git", }, }, - db.ProviderAccessToken{}, + sql.NullString{}, "foobar", ), ) @@ -245,7 +246,7 @@ func TestGitIngestFailsBecauseOfUnexistentCloneUrl(t *testing.T) { "git", }, }, - db.ProviderAccessToken{}, + sql.NullString{}, // No authentication is the right thing in this case. "", )) diff --git a/internal/engine/ingester/ingester_test.go b/internal/engine/ingester/ingester_test.go index 9759dcf133..91f8419d99 100644 --- a/internal/engine/ingester/ingester_test.go +++ b/internal/engine/ingester/ingester_test.go @@ -18,6 +18,7 @@ package ingester import ( + "database/sql" "encoding/json" "testing" @@ -175,7 +176,7 @@ func TestNewRuleDataIngest(t *testing.T) { "github": {} }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", )) if tt.wantErr { diff --git a/internal/engine/ingester/rest/rest_test.go b/internal/engine/ingester/rest/rest_test.go index 14ab720465..47d291d429 100644 --- a/internal/engine/ingester/rest/rest_test.go +++ b/internal/engine/ingester/rest/rest_test.go @@ -17,6 +17,7 @@ package rest import ( "context" + "database/sql" "encoding/json" "net/http" "net/http/httptest" @@ -65,7 +66,7 @@ func TestNewRestRuleDataIngest(t *testing.T) { } }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ), }, @@ -90,7 +91,7 @@ func TestNewRestRuleDataIngest(t *testing.T) { } }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ), }, @@ -115,7 +116,7 @@ func TestNewRestRuleDataIngest(t *testing.T) { } }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ), }, @@ -135,7 +136,7 @@ func TestNewRestRuleDataIngest(t *testing.T) { "rest", }, }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ), }, @@ -160,7 +161,7 @@ func TestNewRestRuleDataIngest(t *testing.T) { } }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ), }, @@ -184,7 +185,7 @@ func TestNewRestRuleDataIngest(t *testing.T) { "base_url": "https://api.github.com/" }`), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ), }, @@ -228,7 +229,7 @@ func testGithubProviderBuilder(baseURL string) *providers.ProviderBuilder { Implements: []db.ProviderType{db.ProviderTypeGithub, db.ProviderTypeRest}, Definition: json.RawMessage(definitionJSON), }, - db.ProviderAccessToken{}, + sql.NullString{}, "token", ) } diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 058a857a14..2057a6b0c8 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -18,6 +18,7 @@ package providers import ( "context" + "database/sql" "fmt" "golang.org/x/exp/slices" @@ -53,14 +54,14 @@ func GetProviderBuilder( return nil, fmt.Errorf("error decrypting access token: %w", err) } - return NewProviderBuilder(&prov, encToken, decryptedToken.AccessToken, opts...), nil + return NewProviderBuilder(&prov, encToken.OwnerFilter, decryptedToken.AccessToken, opts...), nil } // ProviderBuilder is a utility struct which allows for the creation of // provider clients. type ProviderBuilder struct { p *db.Provider - tokenInf db.ProviderAccessToken + ownerFilter sql.NullString // NOTE: we don't seem to actually use the null-ness anywhere. restClientCache ratecache.RestClientCache tok string metrics telemetry.ProviderMetrics @@ -86,15 +87,15 @@ func WithRestClientCache(cache ratecache.RestClientCache) ProviderBuilderOption // NewProviderBuilder creates a new provider builder. func NewProviderBuilder( p *db.Provider, - tokenInf db.ProviderAccessToken, + ownerFilter sql.NullString, tok string, opts ...ProviderBuilderOption, ) *ProviderBuilder { pb := &ProviderBuilder{ - p: p, - tokenInf: tokenInf, - tok: tok, - metrics: telemetry.NewNoopMetrics(), + p: p, + ownerFilter: ownerFilter, + tok: tok, + metrics: telemetry.NewNoopMetrics(), } for _, opt := range opts { @@ -170,7 +171,7 @@ func (pb *ProviderBuilder) GetGitHub() (provinfv1.GitHub, error) { } if pb.restClientCache != nil { - client, ok := pb.restClientCache.Get(pb.tokenInf.OwnerFilter.String, pb.GetToken(), db.ProviderTypeGithub) + client, ok := pb.restClientCache.Get(pb.ownerFilter.String, pb.GetToken(), db.ProviderTypeGithub) if ok { return client.(provinfv1.GitHub), nil } @@ -182,7 +183,7 @@ func (pb *ProviderBuilder) GetGitHub() (provinfv1.GitHub, error) { return nil, fmt.Errorf("error parsing github config: %w", err) } - cli, err := ghclient.NewRestClient(cfg, pb.metrics, pb.restClientCache, pb.GetToken(), pb.tokenInf.OwnerFilter.String) + cli, err := ghclient.NewRestClient(cfg, pb.metrics, pb.restClientCache, pb.GetToken(), pb.ownerFilter.String) if err != nil { return nil, fmt.Errorf("error creating github client: %w", err) } diff --git a/pkg/api/protobuf/go/minder/v1/minder.pb.go b/pkg/api/protobuf/go/minder/v1/minder.pb.go index d8ed9b54d8..b9704f0c1c 100644 --- a/pkg/api/protobuf/go/minder/v1/minder.pb.go +++ b/pkg/api/protobuf/go/minder/v1/minder.pb.go @@ -15,7 +15,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.32.0 +// protoc-gen-go v1.33.0 // protoc (unknown) // source: minder/v1/minder.proto