Skip to content

Commit

Permalink
Actually check for auth flows in provider enrollment (#2601)
Browse files Browse the repository at this point in the history
Signed-off-by: Juan Antonio Osorio <[email protected]>
  • Loading branch information
JAORMX authored Mar 13, 2024
1 parent 8be9194 commit e174e74
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 12 deletions.
67 changes: 55 additions & 12 deletions cmd/cli/app/provider/provider_enroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ actions such as adding repositories.`,
// EnrollProviderCommand is the command for enrolling a provider
func EnrollProviderCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.ClientConn) error {
client := minderv1.NewOAuthServiceClient(conn)
provcli := minderv1.NewProvidersServiceClient(conn)

provider := viper.GetString("provider")
project := viper.GetString("project")
Expand Down Expand Up @@ -87,24 +88,66 @@ func EnrollProviderCommand(ctx context.Context, cmd *cobra.Command, conn *grpc.C
}
}

oAuthCallbackCtx, oAuthCancel := context.WithTimeout(context.Background(), MAX_WAIT+5*time.Second)
defer oAuthCancel()
prov, err := provcli.GetProvider(ctx, &minderv1.GetProviderRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
Name: provider,
})
if err != nil {
return cli.MessageAndError("Error getting provider", err)
}

if token != "" {
// use pat for enrollment
_, err := client.StoreProviderToken(context.Background(), &minderv1.StoreProviderTokenRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
AccessToken: token,
Owner: &owner,
})
if err != nil {
return cli.MessageAndError("Error storing token", err)
if !prov.Provider.SupportsAuthFlow(minderv1.AuthorizationFlow_AUTHORIZATION_FLOW_USER_INPUT) {
return fmt.Errorf("provider %s does not support token enrollment", provider)
}

cmd.Println("Provider enrolled successfully")
return nil
return enrollUsingToken(ctx, cmd, client, provider, project, token, owner)
}

if !prov.Provider.SupportsAuthFlow(
minderv1.AuthorizationFlow_AUTHORIZATION_FLOW_OAUTH2_AUTHORIZATION_CODE_FLOW) {
return fmt.Errorf("provider %s does not support OAuth2 enrollment", provider)
}

// This will have a different timeout
enrollemntCtx := cmd.Context()

return enrollUsingOAuth2Flow(enrollemntCtx, cmd, client, provider, project, owner)
}

func enrollUsingToken(
ctx context.Context,
cmd *cobra.Command,
client minderv1.OAuthServiceClient,
provider string,
project string,
token string,
owner string,
) error {
_, err := client.StoreProviderToken(ctx, &minderv1.StoreProviderTokenRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
AccessToken: token,
Owner: &owner,
})
if err != nil {
return cli.MessageAndError("Error storing token", err)
}

cmd.Println("Provider enrolled successfully")
return nil
}

func enrollUsingOAuth2Flow(
ctx context.Context,
cmd *cobra.Command,
client minderv1.OAuthServiceClient,
provider string,
project string,
owner string,
) error {
oAuthCallbackCtx, oAuthCancel := context.WithTimeout(ctx, MAX_WAIT+5*time.Second)
defer oAuthCancel()

// Get random port
port, err := rand.GetRandomPort()
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions internal/controlplane/handlers_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"net/http"
"net/url"
"slices"

"github.com/google/uuid"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
Expand All @@ -39,6 +40,7 @@ import (
"github.com/stacklok/minder/internal/db"
"github.com/stacklok/minder/internal/engine"
"github.com/stacklok/minder/internal/logger"
"github.com/stacklok/minder/internal/util"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

Expand All @@ -56,6 +58,11 @@ func (s *Server) GetAuthorizationURL(ctx context.Context,
return nil, providerError(err)
}

if !slices.Contains(provider.AuthFlows, db.AuthorizationFlowOauth2AuthorizationCodeFlow) {
return nil, util.UserVisibleError(codes.InvalidArgument,
"provider does not support authorization code flow")
}

// Configure tracing
// trace call to AuthCodeURL
span := trace.SpanFromContext(ctx)
Expand Down Expand Up @@ -288,6 +295,11 @@ func (s *Server) StoreProviderToken(ctx context.Context,
return nil, providerError(err)
}

if !slices.Contains(provider.AuthFlows, db.AuthorizationFlowUserInput) {
return nil, util.UserVisibleError(codes.InvalidArgument,
"provider does not support token enrollment")
}

// validate token
err = auth.ValidateProviderToken(ctx, provider.Name, in.AccessToken)
if err != nil {
Expand Down
37 changes: 37 additions & 0 deletions internal/controlplane/handlers_oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"golang.org/x/oauth2/github"
"golang.org/x/oauth2/google"
Expand Down Expand Up @@ -114,6 +115,9 @@ func TestGetAuthorizationURL(t *testing.T) {
Return([]db.Provider{{
ID: providerID,
Name: "github",
AuthFlows: []db.AuthorizationFlow{
db.AuthorizationFlowOauth2AuthorizationCodeFlow,
},
}}, nil)
store.EXPECT().
CreateSessionState(gomock.Any(), gomock.Any()).
Expand All @@ -137,6 +141,39 @@ func TestGetAuthorizationURL(t *testing.T) {

expectedStatusCode: codes.OK,
},
{
name: "Unsupported auth flow",
req: &pb.GetAuthorizationURLRequest{
Context: &pb.Context{
Provider: &providerName,
Project: &projectIdStr,
},
Port: 8080,
Cli: true,
},
buildStubs: func(store *mockdb.MockStore) {
store.EXPECT().
GetParentProjects(gomock.Any(), projectID).
Return([]uuid.UUID{projectID}, nil)
store.EXPECT().
ListProvidersByProjectID(gomock.Any(), []uuid.UUID{projectID}).
Return([]db.Provider{{
ID: providerID,
Name: "github",
AuthFlows: []db.AuthorizationFlow{
db.AuthorizationFlowNone,
},
}}, nil)
},

checkResponse: func(t *testing.T, _ *pb.GetAuthorizationURLResponse, err error) {
t.Helper()

assert.Error(t, err, "Expected error in GetAuthorizationURL")
},

expectedStatusCode: codes.InvalidArgument,
},
}

rpcOptions := &pb.RpcOptions{
Expand Down
7 changes: 7 additions & 0 deletions pkg/api/protobuf/go/minder/v1/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package v1

import "slices"

// ToString returns the string representation of the ProviderType
func (provt ProviderType) ToString() string {
return enumToStringViaDescriptor(provt.Descriptor(), provt.Number())
Expand All @@ -23,3 +25,8 @@ func (provt ProviderType) ToString() string {
func (a AuthorizationFlow) ToString() string {
return enumToStringViaDescriptor(a.Descriptor(), a.Number())
}

// SupportsAuthFlow returns true if the provider supports the given auth flow
func (p *Provider) SupportsAuthFlow(flow AuthorizationFlow) bool {
return slices.Contains(p.GetAuthFlows(), flow)
}

0 comments on commit e174e74

Please sign in to comment.