Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous Fixes #66

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions aws_signing_helper/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,25 @@ import (
)

type CredentialsOpts struct {
PrivateKeyId string
CertificateId string
CertificateBundleId string
CertIdentifier CertIdentifier
RoleArn string
ProfileArnStr string
TrustAnchorArnStr string
SessionDuration int
Region string
Endpoint string
NoVerifySSL bool
WithProxy bool
Debug bool
Version string
LibPkcs11 string
ReusePin bool
ServerTTL int
RoleSessionName string
PrivateKeyId string
CertificateId string
CertificateBundleId string
CertIdentifier CertIdentifier
UseLatestExpiringCertificate bool
RoleArn string
ProfileArnStr string
TrustAnchorArnStr string
SessionDuration int
Region string
Endpoint string
NoVerifySSL bool
WithProxy bool
Debug bool
Version string
LibPkcs11 string
ReusePin bool
ServerTTL int
RoleSessionName string
}

// Function to create session and generate credentials
Expand Down Expand Up @@ -99,12 +100,16 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: CreateRequestSignFunction(signer, signatureAlgorithm, certificate, certificateChain)})

certificateStr := base64.StdEncoding.EncodeToString(certificate.Raw)
durationSeconds := int64(opts.SessionDuration)
var durationSecondsPtr *int64 = nil
if opts.SessionDuration != -1 {
durationSeconds := int64(opts.SessionDuration)
durationSecondsPtr = &durationSeconds
}
createSessionRequest := rolesanywhere.CreateSessionInput{
Cert: &certificateStr,
ProfileArn: &opts.ProfileArnStr,
TrustAnchorArn: &opts.TrustAnchorArnStr,
DurationSeconds: &(durationSeconds),
DurationSeconds: durationSecondsPtr,
InstanceProperties: nil,
RoleArn: &opts.RoleArn,
SessionName: nil,
Expand Down
12 changes: 8 additions & 4 deletions aws_signing_helper/darwin_cert_store_signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,22 @@ func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, er
}

// Creates a DarwinCertStoreSigner based on the identifying certificate
func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAlgorithm string, err error) {
func GetCertStoreSigner(certIdentifier CertIdentifier, useLatestExpiringCert bool) (signer Signer, signingAlgorithm string, err error) {
identRef, certRef, certContainers, err := GetMatchingCertsAndIdentity(certIdentifier)
if err != nil {
return nil, "", err
}
if len(certContainers) == 0 {
return nil, "", errors.New("no matching identities")
}
if len(certContainers) > 1 {
return nil, "", errors.New("multiple matching identities")
if useLatestExpiringCert {
sort.Sort(CertificateContainerList(certContainers))
} else {
if len(certContainers) > 1 {
return nil, "", errors.New("multiple matching identities")
}
}
cert := certContainers[0].Cert
cert := certContainers[len(certContainers)-1].Cert

// Find the signing algorithm
switch cert.PublicKey.(type) {
Expand Down
15 changes: 15 additions & 0 deletions aws_signing_helper/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,21 @@ var ignoredHeaderKeys = map[string]bool{

var Debug bool = false

// CertificateContainerList implements the sort.Interface interface
type CertificateContainerList []CertificateContainer

func (certificateContainerList CertificateContainerList) Less(i, j int) bool {
return certificateContainerList[i].Cert.NotAfter.Before(certificateContainerList[j].Cert.NotAfter)
}

func (certificateContainerList CertificateContainerList) Swap(i, j int) {
certificateContainerList[i], certificateContainerList[j] = certificateContainerList[j], certificateContainerList[i]
}

func (certificateContainerList CertificateContainerList) Len() int {
return len(certificateContainerList)
}

// Find whether the current certificate matches the CertIdentifier
func certMatches(certIdentifier CertIdentifier, cert x509.Certificate) bool {
if certIdentifier.Subject != "" && certIdentifier.Subject != cert.Subject.String() {
Expand Down
24 changes: 18 additions & 6 deletions aws_signing_helper/signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func Verify(payload []byte, publicKey crypto.PublicKey, digest crypto.Hash, sig
sum := sha512.Sum512(payload)
hash = sum[:]
default:
log.Fatal("unsupported digest")
log.Println("unsupported digest")
return false, errors.New("unsupported digest")
}

Expand Down Expand Up @@ -383,12 +383,24 @@ func TestSign(t *testing.T) {

func TestCredentialProcess(t *testing.T) {
testTable := []struct {
name string
server *httptest.Server
name string
server *httptest.Server
durationSeconds int
}{
{
name: "create-session-server-response",
server: GetMockedCreateSessionResponseServer(),
name: "create-session-server-response",
server: GetMockedCreateSessionResponseServer(),
durationSeconds: -1,
},
{
name: "create-session-server-response",
server: GetMockedCreateSessionResponseServer(),
durationSeconds: 900,
},
{
name: "create-session-server-response",
server: GetMockedCreateSessionResponseServer(),
durationSeconds: 3600,
},
}
for _, tc := range testTable {
Expand All @@ -399,7 +411,7 @@ func TestCredentialProcess(t *testing.T) {
ProfileArnStr: "arn:aws:rolesanywhere:us-east-1:000000000000:profile/41cl0bae-6783-40d4-ab20-65dc5d922e45",
TrustAnchorArnStr: "arn:aws:rolesanywhere:us-east-1:000000000000:trust-anchor/41cl0bae-6783-40d4-ab20-65dc5d922e45",
Endpoint: tc.server.URL,
SessionDuration: 900,
SessionDuration: tc.durationSeconds,
}
t.Run(tc.name, func(t *testing.T) {
defer tc.server.Close()
Expand Down
3 changes: 2 additions & 1 deletion aws_signing_helper/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ func Update(credentialsOptions CredentialsOpts, profile string, once bool) {
for {
credentialProcessOutput, err := GenerateCredentials(&credentialsOptions, signer, signatureAlgorithm)
if err != nil {
log.Fatal(err)
log.Println(err)
os.Exit(1)
}

// Assign credential values
Expand Down
17 changes: 11 additions & 6 deletions aws_signing_helper/windows_cert_store_signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,22 +260,27 @@ func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, er
}

// Gets a WindowsCertStoreSigner based on the CertIdentifier
func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAlgorithm string, err error) {
func GetCertStoreSigner(certIdentifier CertIdentifier, useLatestExpiringCert bool) (signer Signer, signingAlgorithm string, err error) {
var privateKey *winPrivateKey
store, certCtx, certChain, certContainers, err := GetMatchingCertsAndChain(certIdentifier)
if err != nil {
goto fail
}
if len(certContainers) > 1 {
err = errors.New("more than one matching cert found in cert store")
goto fail
}
if len(certContainers) == 0 {
err = errors.New("no matching certs found in cert store")
goto fail
}

signer = &WindowsCertStoreSigner{store: store, cert: certContainers[0].Cert, certCtx: certCtx, certChain: certChain}
if useLatestExpiringCert {
sort.Sort(CertificateContainerList(certContainers))
} else {
if len(certContainers) > 1 {
err = errors.New("multiple matching identities")
goto fail
}
}

signer = &WindowsCertStoreSigner{store: store, cert: certContainers[len(certContainers)-1].Cert, certCtx: certCtx, certChain: certChain}

privateKey, err = signer.(*WindowsCertStoreSigner).getPrivateKey()
if err != nil {
Expand Down
118 changes: 74 additions & 44 deletions cmd/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"io/ioutil"
"math/big"
"regexp"
"strings"

helper "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper"
Expand All @@ -21,16 +22,18 @@ var (
noVerifySSL bool
withProxy bool
debug bool
reusePin bool
roleSessionName string

certificateId string
privateKeyId string
certificateBundleId string
certSelector string
systemStoreName string

certSelector string
systemStoreName string
useLatestExpiringCertificate bool

libPkcs11 string
reusePin bool

credentialsOptions helper.CredentialsOpts

Expand All @@ -43,6 +46,8 @@ var (
X509_ISSUER_KEY,
X509_SERIAL_KEY,
}

CERT_SELECTOR_KEY_VALUE_REGEX = `^\s*Key=(.+?),Value=(.+?)\s*(?:Key=|$)`
)

type MapEntry struct {
Expand All @@ -56,7 +61,7 @@ func initCredentialsSubCommand(subCmd *cobra.Command) {
subCmd.PersistentFlags().StringVar(&roleArnStr, "role-arn", "", "Target role to assume")
subCmd.PersistentFlags().StringVar(&profileArnStr, "profile-arn", "", "Profile to pull policies from")
subCmd.PersistentFlags().StringVar(&trustAnchorArnStr, "trust-anchor-arn", "", "Trust anchor to use for authentication")
subCmd.PersistentFlags().IntVar(&sessionDuration, "session-duration", 3600, "Duration, in seconds, for the resulting session")
subCmd.PersistentFlags().IntVar(&sessionDuration, "session-duration", -1, "Duration, in seconds, for the resulting session")
subCmd.PersistentFlags().StringVar(&region, "region", "", "Signing region")
subCmd.PersistentFlags().StringVar(&endpoint, "endpoint", "", "Endpoint used to call CreateSession")
subCmd.PersistentFlags().BoolVar(&noVerifySSL, "no-verify-ssl", false, "To disable SSL verification")
Expand All @@ -69,51 +74,72 @@ func initCredentialsSubCommand(subCmd *cobra.Command) {
"Can be passed in either as string or a file name (prefixed by \"file://\")")
subCmd.PersistentFlags().StringVar(&systemStoreName, "system-store-name", "MY", "Name of the system store to search for within the "+
"CERT_SYSTEM_STORE_CURRENT_USER context. Note that this flag is only relevant for Windows certificate stores and will be ignored otherwise")
subCmd.PersistentFlags().BoolVar(&useLatestExpiringCertificate, "use-latest-expiring-certificate", false, "If multiple certificates match "+
"a given certificate selector, the one that expires the latest will be chosen (if more than one still fits this criteria, an arbitrary "+
"one is chosen from those that meet the criteria)")
subCmd.PersistentFlags().StringVar(&libPkcs11, "pkcs11-lib", "", "Library for smart card / cryptographic device (OpenSC or vendor specific)")
subCmd.PersistentFlags().BoolVar(&reusePin, "reuse-pin", false, "Use the CKU_USER PIN as the CKU_CONTEXT_SPECIFIC PIN for "+
"private key objects, when they are first used to sign. If the CKU_USER PIN doesn't work as the CKU_CONTEXT_SPECIFIC PIN "+
"for a given private key object, fall back to prompting the user")
subCmd.PersistentFlags().StringVar(&roleSessionName, "role-session-name", "", "An identifier of a role session")
subCmd.PersistentFlags().StringVar(&roleSessionName, "role-session-name", "", "An identifier of a role session")

subCmd.MarkFlagsMutuallyExclusive("certificate", "cert-selector")
subCmd.MarkFlagsMutuallyExclusive("certificate", "system-store-name")
subCmd.MarkFlagsMutuallyExclusive("private-key", "cert-selector")
subCmd.MarkFlagsMutuallyExclusive("private-key", "system-store-name")
subCmd.MarkFlagsMutuallyExclusive("private-key", "use-latest-expiring-certificate")
subCmd.MarkFlagsMutuallyExclusive("use-latest-expiring-certificate", "intermediates")
subCmd.MarkFlagsMutuallyExclusive("use-latest-expiring-certificate", "reuse-pin")
subCmd.MarkFlagsMutuallyExclusive("cert-selector", "intermediates")
subCmd.MarkFlagsMutuallyExclusive("cert-selector", "reuse-pin")
subCmd.MarkFlagsMutuallyExclusive("system-store-name", "reuse-pin")
}

// Parses a cert selector string to a map
func getStringMap(s string) (map[string]string, error) {
entries := strings.Split(s, " ")
regex := regexp.MustCompile(CERT_SELECTOR_KEY_VALUE_REGEX)

m := make(map[string]string)
for _, e := range entries {
tokens := strings.SplitN(e, ",", 2)
keyTokens := strings.Split(tokens[0], "=")
if keyTokens[0] != "Key" {
return nil, errors.New("invalid cert selector map key")
}
key := strings.TrimSpace(strings.Join(keyTokens[1:], "="))
for {
match := regex.FindStringSubmatch(s)
if match == nil || len(match) == 0 {
break
} else {
if len(match) < 3 {
return nil, errors.New("unable to parse cert selector string")
}

isValidKey := false
for _, validKey := range validCertSelectorKeys {
if validKey == key {
isValidKey = true
break
key := match[1]
isValidKey := false
for _, validKey := range validCertSelectorKeys {
if validKey == key {
isValidKey = true
break
}
}
}
if !isValidKey {
return nil, errors.New("cert selector contained invalid key")
}
if !isValidKey {
return nil, errors.New("cert selector contained invalid key")
}
value := match[2]

if _, ok := m[key]; ok {
return nil, errors.New("cert selector contained duplicate key")
}
m[key] = value

valueTokens := strings.Split(tokens[1], "=")
if valueTokens[0] != "Value" {
return nil, errors.New("invalid cert selector map value")
// Remove the matching prefix from the input cert selector string
matchEnd := len(match[0])
if matchEnd != len(s) {
// Since the `Key=` part of the next key-value pair will have been matched, don't include it in the prefix to remove
matchEnd -= 4
}
s = s[matchEnd:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens to trailing spaces? Will they affect the certificate match?

Copy link
Contributor Author

@13ajay 13ajay Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they will. Is that acceptable, or should I change it?

Edit: I realize that my comment may not have answered the original question(s) well. Trailing spaces will be removed, which I suppose is a problem if customers have RDNs with trailing spaces (unless they have other RDNs corresponding to the same key (attribute, like Subject, for example) that don't have trailing spaces (in which case, they can use that RDN as the last one for a particular key when they input their cert selector in this form). But the caveat is that users don't have control over the ordering that RDNs should be specified in the selector - that ordering is fixed currently.

I think we should advertise this capability as a shorthand that doesn't allow for as much flexibility as using a JSON file. I think this is the same sort of thing when it comes to passing in tags in the AWS CLI. The AWS CLI is a bit different since there are a restricted set of characters that can be used in tag keys and values (certificate RDNs can contain arbitrary characters, OTOH), but the parsing of the shorthand syntax doesn't seem to be particularly robust.

}
value := strings.TrimSpace(strings.Join(valueTokens[1:], "="))
m[key] = value
}

// There is some part of the cert selector string that couldn't be parsed by the above loop
if len(s) != 0 {
return nil, errors.New("unable to parse cert selector string")
}

return m, nil
Expand All @@ -138,6 +164,9 @@ func getMapFromJsonEntries(jsonStr string) (map[string]string, error) {
if !isValidKey {
return nil, errors.New("cert selector contained invalid key")
}
if _, ok := m[mapEntry.Key]; ok {
return nil, errors.New("cert selector contained duplicate key")
}
m[mapEntry.Key] = mapEntry.Value
}
return m, nil
Expand Down Expand Up @@ -228,23 +257,24 @@ func PopulateCredentialsOptions() error {
}

credentialsOptions = helper.CredentialsOpts{
PrivateKeyId: privateKeyId,
CertificateId: certificateId,
CertificateBundleId: certificateBundleId,
CertIdentifier: certIdentifier,
RoleArn: roleArnStr,
ProfileArnStr: profileArnStr,
TrustAnchorArnStr: trustAnchorArnStr,
SessionDuration: sessionDuration,
Region: region,
Endpoint: endpoint,
NoVerifySSL: noVerifySSL,
WithProxy: withProxy,
Debug: debug,
Version: Version,
LibPkcs11: libPkcs11,
ReusePin: reusePin,
RoleSessionName: roleSessionName,
PrivateKeyId: privateKeyId,
CertificateId: certificateId,
CertificateBundleId: certificateBundleId,
CertIdentifier: certIdentifier,
UseLatestExpiringCertificate: useLatestExpiringCertificate,
RoleArn: roleArnStr,
ProfileArnStr: profileArnStr,
TrustAnchorArnStr: trustAnchorArnStr,
SessionDuration: sessionDuration,
Region: region,
Endpoint: endpoint,
NoVerifySSL: noVerifySSL,
WithProxy: withProxy,
Debug: debug,
Version: Version,
LibPkcs11: libPkcs11,
ReusePin: reusePin,
RoleSessionName: roleSessionName,
}

return nil
Expand Down
Loading
Loading