diff --git a/.gitignore b/.gitignore index 9b753d1..1c8c754 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ build/ tst/certs/ credential-process-data/ +tst/softhsm/ +tst/softhsm2.conf + diff --git a/Makefile b/Makefile index 78d5aeb..2b3bf23 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,76 @@ VERSION=1.0.5 release: - go build -buildmode=pie -ldflags "-X 'main.Version=${VERSION}' -linkmode=external -w -s" -trimpath -o build/bin/aws_signing_helper cmd/aws_signing_helper/main.go + go build -buildmode=pie -ldflags "-X 'github.com/aws/rolesanywhere-credential-helper/cmd.Version=${VERSION}' -linkmode=external -w -s" -trimpath -o build/bin/aws_signing_helper main.go + +certsdir=tst/certs +curdir=$(shell pwd) + +RSAKEYS := $(foreach keylen, 1024 2048 4096, $(certsdir)/rsa-$(keylen)-key.pem) +ECKEYS := $(foreach curve, prime256v1 secp384r1, $(certsdir)/ec-$(curve)-key.pem) +PKCS8KEYS := $(patsubst %-key.pem,%-key-pkcs8.pem,$(RSAKEYS) $(ECKEYS)) +ECCERTS := $(foreach digest, sha1 sha256 sha384 sha512, $(patsubst %-key.pem, %-$(digest)-cert.pem, $(ECKEYS))) +RSACERTS := $(foreach digest, md5 sha1 sha256 sha384 sha512, $(patsubst %-key.pem, %-$(digest)-cert.pem, $(RSAKEYS))) +PKCS12CERTS := $(patsubst %-cert.pem, %.p12, $(RSACERTS) $(ECCERTS)) + +test: test-certs + go test -v ./... + +%-md5-cert.pem: %-key.pem + SUBJ=$$(echo "$@" | sed -r 's|.*/([^/]+)-cert.pem|\1|'); \ + openssl req -x509 -new -key $< -out $@ -days 10000 -subj "/CN=roles-anywhere-$${SUBJ}" -md5 +%-sha1-cert.pem: %-key.pem + SUBJ=$$(echo "$@" | sed -r 's|.*/([^/]+)-cert.pem|\1|'); \ + openssl req -x509 -new -key $< -out $@ -days 10000 -subj "/CN=roles-anywhere-$${SUBJ}" -sha1 +%-sha256-cert.pem: %-key.pem + SUBJ=$$(echo "$@" | sed -r 's|.*/([^/]+)-cert.pem|\1|'); \ + openssl req -x509 -new -key $< -out $@ -days 10000 -subj "/CN=roles-anywhere-$${SUBJ}" -sha256 +%-sha384-cert.pem: %-key.pem + SUBJ=$$(echo "$@" | sed -r 's|.*/([^/]+)-cert.pem|\1|'); \ + openssl req -x509 -new -key $< -out $@ -days 10000 -subj "/CN=roles-anywhere-$${SUBJ}" -sha384 +%-sha512-cert.pem: %-key.pem + SUBJ=$$(echo "$@" | sed -r 's|.*/([^/]+)-cert.pem|\1|'); \ + openssl req -x509 -new -key $< -out $@ -days 10000 -subj "/CN=roles-anywhere-$${SUBJ}" -sha512 + +# Go PKCS#12 only supports SHA1 and 3DES!! +%.p12: %-pass.p12 + echo Creating $@... + ls -l $< + KEY=$$(echo "$@" | sed 's/-[^-]*\.p12/-key.pem/'); \ + CERT=$$(echo "$@" | sed 's/.p12/-cert.pem/'); \ + openssl pkcs12 -export -passout pass: -macalg SHA1 \ + -certpbe pbeWithSHA1And3-KeyTripleDES-CBC \ + -keypbe pbeWithSHA1And3-KeyTripleDES-CBC \ + -inkey $${KEY} -out "$@" -in $${CERT} + +%-pass.p12: %-cert.pem + echo Creating $@... + ls -l $< + KEY=$$(echo "$@" | sed 's/-[^-]*\-pass.p12/-key.pem/'); \ + openssl pkcs12 -export -passout pass:test -macalg SHA1 \ + -certpbe pbeWithSHA1And3-KeyTripleDES-CBC \ + -keypbe pbeWithSHA1And3-KeyTripleDES-CBC \ + -inkey $${KEY} -out "$@" -in "$<" + +%-pkcs8.pem: %.pem + openssl pkcs8 -topk8 -inform PEM -outform PEM -in $< -out $@ -nocrypt + +$(RSAKEYS): + KEYLEN=$$(echo "$@" | sed 's/.*rsa-\([0-9]*\)-key.pem/\1/'); \ + openssl genrsa -out $@ $${KEYLEN} + +$(ECKEYS): + CURVE=$$(echo "$@" | sed 's/.*ec-\([^-]*\)-key.pem/\1/'); \ + openssl ecparam -name $${CURVE} -genkey -out $@ + +$(certsdir)/cert-bundle.pem: $(RSACERTS) $(ECCERTS) + cat $^ > $@ + +test-certs: $(PKCS8KEYS) $(RSAKEYS) $(ECKEYS) $(RSACERTS) $(ECCERTS) $(PKCS12CERTS) $(certsdir)/cert-bundle.pem + +test-clean: + rm -f $(RSAKEYS) $(ECKEYS) + rm -f $(PKCS8KEYS) + rm -f $(RSACERTS) $(ECCERTS) + rm -f $(PKCS12CERTS) + rm -f $(certsdir)/cert-bundle.pem diff --git a/README.md b/README.md index 95b8b73..8c6fa52 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,10 @@ ## AWS IAM Roles Anywhere Credential Helper - rolesanywhere-credential-helper implements the [signing process](https://docs.aws.amazon.com/rolesanywhere/latest/userguide/authentication-sign-process.html) for IAM Roles Anywhere's [CreateSession](https://docs.aws.amazon.com/rolesanywhere/latest/userguide/authentication-create-session.html) API and returns temporary credentials in a standard JSON format that is compatible with the `credential_process` feature available across the language SDKs. More information can be found [here](https://docs.aws.amazon.com/rolesanywhere/latest/userguide/credential-helper.html). It is released and licensed under the Apache License 2.0. ## Building ### Dependencies - -In order to build the source code, you will need to install git, gcc, make, and golang. +In order to build the source code, you will need to install git, gcc, make, and golang. #### Linux @@ -18,7 +16,7 @@ You can download Apple clang through the [following link](https://developer.appl #### Windows -In order to get gcc on Windows, one option is to use [MinGW-w64](https://www.mingw-w64.org/downloads/). After obtaining gcc, you can install golang through the [installer](https://go.dev/doc/install). Lastly, you can install git and make through `Chocolatey` with `choco install git` and `choco install make`, respectively. +In order to get gcc on Windows, one option is to use [MinGW-w64](https://www.mingw-w64.org/downloads/). After obtaining gcc, you can install golang through the [installer](https://go.dev/doc/install). Lastly, you can install git and make through `Chocolatey` with `choco install git` and `choco install make`, respectively. ### Build @@ -38,23 +36,104 @@ The project also comes with two bash scripts at its root, called `generate-certs ### read-certificate-data -Reads a certificate that is on disk. The path to the certificate must be provided with the `--certificate` parameter. +Reads a certificate that is on disk. Either the path to the certificate on disk is provided with the `--certificate` parameter, or the `--cert-selector` flag is provided to select a certificate within an OS certificate store. Further details about the flag are provided below. + +#### cert-selector flag + +If you use Windows or MacOS, the credential helper also supports leveraging private keys and certificates that are in their OS-specific secure stores. In Windows, both CNG and Cryptography are supported, while on MacOS, Keychain Access is supported. Through the `--cert-selector` flag, it is possible to specify which certificate (and associated private key) to use in calling `CreateSession`. The credential helper will then delegate signing operations to the keys within those secure stores, without those keys ever having to leave those stores. It is important to note that on Windows, only the user's "MY" certificate store will be searched by the credential helper, while for MacOS, Keychains on the search list will be searched. + +The `--cert-selector` flag allows one to search for a specific certificate (and associated private key) through the certificate Subject, Issuer, and Serial Number. The corresponding keys are `x509Subject`, `x509Issuer`, and `x509Serial`, respectively. These keys can be specified either through a JSON file format or through the command line. An example of both approaches can be found below. + +If you would like to use a JSON file, it should look something like this: + +``` +[ + { + "Key": "x509Subject", + "Value": "CN=Subject" + }, + { + "Key": "x509Issuer", + "Value": "CN=Issuer" + }, + { + "Key": "x509Serial", + "Value": "15D19632234BF759A32802C0DA88F9E8AFC8702D" + } +] +``` + +If the above is placed in a file called `selector.json`, it can be specified with the `--cert-selector` flag through `file://path/to/selector.json`. The very same certificate selector argument can be specified through the command line as follows: + +``` +--cert-selector Key=x509Subject,Value=CN=Subject Key=x509Issuer,Value=CN=Issuer Key=x509Serial,Value=15D19632234BF759A32802C0DA88F9E8AFC8702D +``` + +The example given here is quite simple (they each only contain a single RDN), so it may not be obvious, but the Subject and Issuer values roughly follow the [RFC 2253](https://www.rfc-editor.org/rfc/rfc2253.html) Distinguished Names syntax. ### sign-string -Signs a string from standard input. Useful for validating your on-disk private key and digest. The path to the private key must be provided with the `--private-key` parameter. Other parameters that can be used are `--digest`, which must be one of `SHA256 (*default*) | SHA384 | SHA512`, and `--format`, which must be one of `text (*default*) | json | bin`. +Signs a string from standard input. Useful for validating your on-disk private key and digest. The path to the private key must be provided with the `--private-key` parameter. Other parameters that can be used are `--digest`, which must be one of `SHA256 (*default*) | SHA384 | SHA512`, and `--format`, which must be one of `text (*default*) | json | bin`. ### credential-process -Vends temporary credentials by sending a `CreateSession` request to the Roles Anywhere service. The request is signed by the private key whose path must be provided with the `--private-key` parameter. Other required parameters include `--certificate` (the path to the end-entity certificate), `--role-arn` (the ARN of the role to obtain temporary credentials for), `--profile-arn` (the ARN of the profile that provides a mapping for the specified role), and `--trust-anchor-arn` (the ARN of the trust anchor used to authenticate). Optional parameters that can be used are `--debug` (to provide debugging output about the request sent), `--no-verify-ssl` (to skip verification of the SSL certificate on the endpoint called), `--intermediates` (the path to intermediate certificates), `--with-proxy` (to make the binary proxy aware), `--endpoint` (the endpoint to call), `--region` (the region to scope the request to), and `--session-duration` (the duration of the vended session). +Vends temporary credentials by sending a `CreateSession` request to the Roles Anywhere service. The request is signed by the private key whose path can be provided with the `--private-key` parameter. Other parameters include `--certificate` (the path to the end-entity certificate), `--role-arn` (the ARN of the role to obtain temporary credentials for), `--profile-arn` (the ARN of the profile that provides a mapping for the specified role), and `--trust-anchor-arn` (the ARN of the trust anchor used to authenticate). Optional parameters that can be used are `--debug` (to provide debugging output about the request sent), `--no-verify-ssl` (to skip verification of the SSL certificate on the endpoint called), `--intermediates` (the path to intermediate certificates), `--with-proxy` (to make the binary proxy aware), `--endpoint` (the endpoint to call), `--region` (the region to scope the request to), and `--session-duration` (the duration of the vended session). Instead of passing in paths to the plaintext private key on your file system, another option (depending on your OS) could be to use the `--cert-selector` flag. More details can be found below. + +Note that if more than one certificate matches the `--cert-selector` parameter within the OS-specific secure store, the `credential-process` command will fail. To find the list of certificates that match a given `--cert-selector` parameter, you can use the same flag with the `read-certificate-data` command. + +#### MacOS Keychain Guidance + +If you would like to secure keys through MacOS Keychain and use them with IAM Roles Anywhere, you may want to consider creating a new Keychain that only the credential helper can access and store your keys there. The steps to do this are listed below. Note that the commands should be executed in bash. + +First, create the new Keychain: + +``` +security create-keychain -p ${CREDENTIAL_HELPER_KEYCHAIN_PASSWORD} credential-helper.keychain +``` + +In the above command line, `${CREDENTIAL_HELPER_KEYCHAIN_PASSWORD}` should contain the password you want the new Keychain to have. Next, unlock the Keychain: + +``` +security unlock-keychain -p ${CREDENTIAL_HELPER_KEYCHAIN_PASSWORD} credential-helper.keychain +``` + +Once again, you will have to specify the password to the Keychain, but this time it will be used to unlock it. Next, modify the Keychain search list to include your newly created Keychain: + +``` +EXISTING_KEYCHAINS=$(security list-keychains | cut -d '"' -f2) security list-keychains -s credential-helper.keychain $(echo ${EXISTING_KEYCHAINS} | awk -v ORS=" " '{print $1}') +``` + +The above command line will extract existing Keychains in the search list and add the newly created Keychain to the top of it. Lastly, add your PFX file (that contains your client certificate and associated private key) to the Keychain: + +``` +security import /path/to/identity.pfx -T /path/to/aws_signing_helper -P ${UNWRAPPING_PASSWORD} -k credential-helper.keychain +``` + +The above command line will import your client certificate and private key that are in a PFX file (which will be unwrapped using the `UNWRAPPING_PASSWORD` environment variable) into the newly created Keychain and only allow for the credential helper to access it. It's important to note that since the credential helper isn't signed, it isn't trusted by MacOS. To get around this, you may have to specify your Keychain password whenever the credential helper wants to use the private key to perform a signing operation. If you don't want to have to specify the password each time, you can choose to always allow the credential helper to use the Keychain item. + +Also note that the above steps can be done through [MacOS Keychain APIs](https://developer.apple.com/documentation/security/keychain_services/keychains), as well as through the [Keychain Access application](https://support.apple.com/guide/keychain-access/welcome/mac). + +#### Windows CNG Guidance + +If you would like to secure keys through Windows CNG and use them with IAM Roles Anywhere, it should be sufficient to to import your certificate (and associated private key) into your user's "MY" certificate store. + +Add your certificate (and associated private key) to the certificate store by importing e.g. a PFX file through the below command line in Command Prompt: + +``` +certutil -user -p %UNWRAPPING_PASSWORD% -importPFX "MY" \path\to\identity.pfx +``` + +The above command will import the PFX file into the user's "MY" certificate store. The `UNWRAPPING_PASSWORD` environment variable should contain the password to unwrap the PFX file. + +Also note that the above step can be done through a [Powershell cmdlet](https://learn.microsoft.com/en-us/powershell/module/pki/import-pfxcertificate?view=windowsserver2022-ps) or through [Windows CNG/Cryptography APIs](https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-pfximportcertstore). ### update -Updates temporary credentials in the [credential file](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html). Parameters for this command include those for the `credential-process` command, as well as `--profile`, which specifies the named profile for which credentials should be updated (if the profile doesn't already exist, it will be created), and `--once`, which specifies that credentials should be updated only once. Both arguments are optional. If `--profile` isn't specified, the default profile will have its credentials updated, and if `--once` isn't specified, credentials will be continuously updated. In this case, credentials will be updated through a call to `CreateSession` five minutes before the previous set of credentials are set to expire. Please note that running the `update` command multiple times, creating multiple processes, may not work as intended. There may be issues with concurrent writes to the credentials file. +Updates temporary credentials in the [credential file](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html). Parameters for this command include those for the `credential-process` command, as well as `--profile`, which specifies the named profile for which credentials should be updated (if the profile doesn't already exist, it will be created), and `--once`, which specifies that credentials should be updated only once. Both arguments are optional. If `--profile` isn't specified, the default profile will have its credentials updated, and if `--once` isn't specified, credentials will be continuously updated. In this case, credentials will be updated through a call to `CreateSession` five minutes before the previous set of credentials are set to expire. Please note that running the `update` command multiple times, creating multiple processes, may not work as intended. There may be issues with concurrent writes to the credentials file. ### serve -Vends temporary credentials through an endpoint running on localhost. Parameters for this command include those for the `credential-process` command, as well as an optional `--port`, to specify the port on which the local endpoint will be exposed. By default, the port will be `9911`. Once again, credentials will be updated through a call to `CreateSession` five minutes before the previous set of credentials are set to expire. Note that the URIs and request headers are the same as those used in [IMDSv2](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html) (only the address of the endpoint changes from `169.254.169.254` to `127.0.0.1`). In order to make the credentials served from the local endpoint available to the SDK, set the `AWS_EC2_METADATA_SERVICE_ENDPOINT` environment variable appropriately. +Vends temporary credentials through an endpoint running on localhost. Parameters for this command include those for the `credential-process` command, as well as an optional `--port`, to specify the port on which the local endpoint will be exposed. By default, the port will be `9911`. Once again, credentials will be updated through a call to `CreateSession` five minutes before the previous set of credentials are set to expire. Note that the URIs and request headers are the same as those used in [IMDSv2](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html) (only the address of the endpoint changes from `169.254.169.254` to `127.0.0.1`). In order to make the credentials served from the local endpoint available to the SDK, set the `AWS_EC2_METADATA_SERVICE_ENDPOINT` environment variable appropriately. ### Scripts @@ -70,7 +149,7 @@ Used by unit tests and for manual testing of the credential-process command. Cre ### Example Usage ``` -/bin/bash generate-credential-process-data.sh +/bin/sh generate-credential-process-data.sh TA_ARN=$(aws rolesanywhere create-trust-anchor \ --name "Test TA" \ diff --git a/THIRD-PARTY-LICENSES.txt b/THIRD-PARTY-LICENSES.txt new file mode 100644 index 0000000..56b0638 --- /dev/null +++ b/THIRD-PARTY-LICENSES.txt @@ -0,0 +1,25 @@ +** smimesign; version v0.2.0-rc1 -- https://github.com/github/smimesign +Copyright (c) 2017 GitHub, Inc. + +MIT License + +Copyright (c) 2017 GitHub, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/aws_signing_helper/credentials.go b/aws_signing_helper/credentials.go index 030115c..be89516 100644 --- a/aws_signing_helper/credentials.go +++ b/aws_signing_helper/credentials.go @@ -2,7 +2,6 @@ package aws_signing_helper import ( "crypto/tls" - "crypto/x509" "encoding/base64" "errors" "net/http" @@ -19,6 +18,7 @@ type CredentialsOpts struct { PrivateKeyId string CertificateId string CertificateBundleId string + CertIdentifier CertIdentifier RoleArn string ProfileArnStr string TrustAnchorArnStr string @@ -29,10 +29,11 @@ type CredentialsOpts struct { WithProxy bool Debug bool Version string + LibPkcs11 string } // Function to create session and generate credentials -func GenerateCredentials(opts *CredentialsOpts) (CredentialProcessOutput, error) { +func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (CredentialProcessOutput, error) { // assign values to region and endpoint if they haven't already been assigned trustAnchorArn, err := arn.Parse(opts.TrustAnchorArnStr) if err != nil { @@ -51,37 +52,10 @@ func GenerateCredentials(opts *CredentialsOpts) (CredentialProcessOutput, error) opts.Region = trustAnchorArn.Region } - privateKey, err := ReadPrivateKeyData(opts.PrivateKeyId) - if err != nil { - return CredentialProcessOutput{}, err - } - certificateData, err := ReadCertificateData(opts.CertificateId) - if err != nil { - return CredentialProcessOutput{}, err - } - certificateDerData, err := base64.StdEncoding.DecodeString(certificateData.CertificateData) - if err != nil { - return CredentialProcessOutput{}, err - } - certificate, err := x509.ParseCertificate([]byte(certificateDerData)) - if err != nil { - return CredentialProcessOutput{}, err - } - var certificateChain []x509.Certificate - if opts.CertificateBundleId != "" { - certificateChainPointers, err := ReadCertificateBundleData(opts.CertificateBundleId) - if err != nil { - return CredentialProcessOutput{}, err - } - for _, certificate := range certificateChainPointers { - certificateChain = append(certificateChain, *certificate) - } - } - mySession := session.Must(session.NewSession()) var logLevel aws.LogLevelType - if opts.Debug { + if Debug { logLevel = aws.LogDebug } else { logLevel = aws.LogOff @@ -107,11 +81,20 @@ func GenerateCredentials(opts *CredentialsOpts) (CredentialProcessOutput, error) rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: request.MakeAddToUserAgentHandler("CredHelper", opts.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)}) rolesAnywhereClient.Handlers.Sign.Clear() - rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: CreateSignFunction(privateKey, *certificate, certificateChain)}) + certificate, err := signer.Certificate() + if err != nil { + return CredentialProcessOutput{}, errors.New("unable to find certificate") + } + certificateChain, err := signer.CertificateChain() + if err != nil { + return CredentialProcessOutput{}, errors.New("unable to find certificate chain") + } + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: CreateRequestSignFunction(signer, signatureAlgorithm, certificate, certificateChain)}) + certificateStr := base64.StdEncoding.EncodeToString(certificate.Raw) durationSeconds := int64(opts.SessionDuration) createSessionRequest := rolesanywhere.CreateSessionInput{ - Cert: &certificateData.CertificateData, + Cert: &certificateStr, ProfileArn: &opts.ProfileArnStr, TrustAnchorArn: &opts.TrustAnchorArnStr, DurationSeconds: &(durationSeconds), diff --git a/aws_signing_helper/darwin_cert_store_signer.go b/aws_signing_helper/darwin_cert_store_signer.go new file mode 100644 index 0000000..7c9641d --- /dev/null +++ b/aws_signing_helper/darwin_cert_store_signer.go @@ -0,0 +1,502 @@ +//go:build darwin + +package aws_signing_helper + +// This code is based on the smimesign repository at +// https://github.com/github/smimesign + +/* +#cgo CFLAGS: -x objective-c +#cgo LDFLAGS: -framework CoreFoundation -framework Security +#include +#include +*/ +import "C" +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "errors" + "fmt" + "io" + "os" + "unsafe" +) + +type DarwinCertStoreSigner struct { + identRef C.SecIdentityRef + keyRef C.SecKeyRef + certRef C.SecCertificateRef + cert *x509.Certificate + certChain []*x509.Certificate +} + +// osStatus wraps a C.OSStatus +type osStatus C.OSStatus + +const ( + errSecItemNotFound = osStatus(C.errSecItemNotFound) +) + +// Gets the matching identity and certificate for this CertIdentifier +// If there is more than one, only a list of the matching certificates is returned +func GetMatchingCertsAndIdentity(certIdentifier CertIdentifier) (C.SecIdentityRef, C.SecCertificateRef, []CertificateContainer, error) { + queryMap := map[C.CFTypeRef]C.CFTypeRef{ + C.CFTypeRef(C.kSecClass): C.CFTypeRef(C.kSecClassIdentity), + C.CFTypeRef(C.kSecReturnRef): C.CFTypeRef(C.kCFBooleanTrue), + C.CFTypeRef(C.kSecMatchLimit): C.CFTypeRef(C.kSecMatchLimitAll), + } + + query := mapToCFDictionary(queryMap) + if query == 0 { + return 0, 0, nil, errors.New("error creating CFDictionary") + } + defer C.CFRelease(C.CFTypeRef(query)) + + var absResult C.CFTypeRef + if err := osStatusError(C.SecItemCopyMatching(query, &absResult)); err != nil { + if err == errSecItemNotFound { + return 0, 0, nil, errors.New("unable to find matching identity in cert store") + } + return 0, 0, nil, err + } + defer C.CFRelease(C.CFTypeRef(absResult)) + aryResult := C.CFArrayRef(absResult) + + // identRefs aren't owned by us initially + numIdentRefs := C.CFArrayGetCount(aryResult) + identRefs := make([]C.CFTypeRef, numIdentRefs) + C.CFArrayGetValues(aryResult, C.CFRange{0, numIdentRefs}, (*unsafe.Pointer)(unsafe.Pointer(&identRefs[0]))) + var certContainers []CertificateContainer + var certRef C.SecCertificateRef + var identRef C.SecIdentityRef + for _, curIdentRef := range identRefs { + curCertRef, err := getCertRef(C.SecIdentityRef(curIdentRef)) + if err != nil { + return 0, 0, nil, errors.New("unable to get cert ref") + } + curCert, err := getCert(curCertRef) + if err != nil { + return 0, 0, nil, errors.New("unable to get cert") + } + + // Find whether there is a matching certificate + certMatches := certMatches(certIdentifier, *curCert) + if certMatches { + certContainers = append(certContainers, CertificateContainer{curCert, ""}) + // Assign to certRef and identRef at most once in the loop + // Both values are only useful if there is exactly one match in the certificate store + // When creating a signer, there has to be exactly one matching certificate + if certRef == 0 { + certRef = curCertRef + identRef = C.SecIdentityRef(curIdentRef) + } + } + } + + if Debug { + fmt.Fprintf(os.Stderr, "found %d matching identities\n", len(certContainers)) + } + + // Only retain the SecIdentityRef if it should be used later on + // Note that only the SecIdentityRef needs to be retained since it was neither created nor copied + if len(certContainers) == 1 { + C.CFRetain(C.CFTypeRef(identRef)) + return identRef, certRef, certContainers, nil + } else { + return 0, 0, certContainers, nil + } +} + +// Gets the certificates that match the CertIdentifier +func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, error) { + identRef, certRef, certContainers, err := GetMatchingCertsAndIdentity(certIdentifier) + if len(certContainers) == 1 { + C.CFRelease(C.CFTypeRef(identRef)) + C.CFRelease(C.CFTypeRef(certRef)) + } + return certContainers, err +} + +// Creates a DarwinCertStoreSigner based on the identifying certificate +func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAlgorithm string, err error) { + identRef, certRef, certContainers, err := GetMatchingCertsAndIdentity(certIdentifier) + if err != nil { + return nil, "", err + } + if len(certContainers) == 0 { + return nil, "", errors.New("no matching identities") + } + if len(certContainers) > 1 { + return nil, "", errors.New("multiple matching identities") + } + cert := certContainers[0].Cert + + // Find the signing algorithm + switch cert.PublicKey.(type) { + case *ecdsa.PublicKey: + signingAlgorithm = aws4_x509_ecdsa_sha256 + case *rsa.PublicKey: + signingAlgorithm = aws4_x509_rsa_sha256 + default: + return nil, "", errors.New("unsupported algorithm") + } + + keyRef, err := getKeyRef(identRef) + if err != nil { + return nil, "", errors.New("unable to get key reference") + } + + return &DarwinCertStoreSigner{identRef, keyRef, certRef, cert, nil}, signingAlgorithm, nil +} + +// Gets a pointer to the certificate from a certificate reference +func getCert(certRef C.SecCertificateRef) (*x509.Certificate, error) { + cert, err := exportCertRef(certRef) + if err != nil { + return nil, errors.New("unable to export certificate reference to x509.Certificate") + } + + return cert, nil +} + +// Gets the certificate associated with this DarwinCertStoreSigner +func (signer *DarwinCertStoreSigner) Certificate() (*x509.Certificate, error) { + if signer.cert != nil { + return signer.cert, nil + } + + certRef, err := signer.getCertRef() + if err != nil { + return nil, err + } + + cert, err := getCert(certRef) + if err != nil { + return nil, err + } + signer.cert = cert + + return signer.cert, nil +} + +// Gets the certificate chain associated with this DarwinCertStoreSigner +func (signer *DarwinCertStoreSigner) CertificateChain() ([]*x509.Certificate, error) { + if signer.certChain != nil { + return signer.certChain, nil + } + + certRef, err := signer.getCertRef() + if err != nil { + return nil, err + } + + policy := C.SecPolicyCreateSSL(0, 0) + + var trustRef C.SecTrustRef + if err := osStatusError(C.SecTrustCreateWithCertificates(C.CFTypeRef(certRef), C.CFTypeRef(policy), &trustRef)); err != nil { + return nil, err + } + defer C.CFRelease(C.CFTypeRef(trustRef)) + + // var status C.SecTrustResultType + var cfErrRef C.CFErrorRef + if C.SecTrustEvaluateWithError(trustRef, &cfErrRef) { + return nil, cfErrorError(cfErrRef) + } + + var ( + nChain = C.SecTrustGetCertificateCount(trustRef) + certChain = make([]*x509.Certificate, 0, int(nChain)) + ) + + certChainArr := C.SecTrustCopyCertificateChain(trustRef) + defer C.CFRelease(C.CFTypeRef(certChainArr)) + for i := C.CFIndex(0); i < nChain; i++ { + chainCertRef := C.SecCertificateRef(C.CFArrayGetValueAtIndex(certChainArr, i)) + if chainCertRef == 0 { + return nil, errors.New("nil certificate in chain") + } + + chainCert, err := exportCertRef(chainCertRef) + if err != nil { + return nil, err + } + + certChain = append(certChain, chainCert) + } + + certChain = certChain[1:] + signer.certChain = certChain + + return signer.certChain, nil +} + +// Public implements the crypto.Signer interface and returns the public key associated with the signer +func (signer *DarwinCertStoreSigner) Public() crypto.PublicKey { + cert, err := signer.Certificate() + if err != nil { + return nil + } + + return cert.PublicKey +} + +// Closes the DarwinCertStoreSigner +func (signer *DarwinCertStoreSigner) Close() { + if signer.identRef != 0 { + C.CFRelease(C.CFTypeRef(signer.identRef)) + signer.identRef = 0 + } + + if signer.keyRef != 0 { + C.CFRelease(C.CFTypeRef(signer.keyRef)) + signer.keyRef = 0 + } + + if signer.certRef != 0 { + C.CFRelease(C.CFTypeRef(signer.certRef)) + signer.certRef = 0 + } +} + +// Sign implements the crypto.Signer interface and signs the digest +func (signer *DarwinCertStoreSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + var hash []byte + switch opts.HashFunc() { + case crypto.SHA256: + sum := sha256.Sum256(digest) + hash = sum[:] + case crypto.SHA384: + sum := sha512.Sum384(digest) + hash = sum[:] + case crypto.SHA512: + sum := sha512.Sum512(digest) + hash = sum[:] + default: + return nil, ErrUnsupportedHash + } + + keyRef, err := signer.getKeyRef() + if err != nil { + return nil, err + } + + chash, err := bytesToCFData(hash) + if err != nil { + return nil, err + } + defer C.CFRelease(C.CFTypeRef(chash)) + + cert, err := signer.Certificate() + if err != nil { + return nil, err + } + + algo, err := getAlgo(cert, opts.HashFunc()) + if err != nil { + return nil, err + } + + // sign the digest + var cfErrRef C.CFErrorRef + cSig := C.SecKeyCreateSignature(keyRef, algo, chash, &cfErrRef) + + if err := cfErrorError(cfErrRef); err != nil { + C.CFRelease(C.CFTypeRef(cfErrRef)) + + return nil, err + } + + if cSig == 0 { + return nil, errors.New("nil signature from SecKeyCreateSignature") + } + defer C.CFRelease(C.CFTypeRef(cSig)) + + sig := cfDataToBytes(cSig) + + return sig, nil +} + +// getAlgo decides which algorithm to use with this key type for the given hash. +func getAlgo(cert *x509.Certificate, hash crypto.Hash) (algo C.SecKeyAlgorithm, err error) { + switch cert.PublicKey.(type) { + case *ecdsa.PublicKey: + switch hash { + case crypto.SHA1: + algo = C.kSecKeyAlgorithmECDSASignatureDigestX962SHA1 + case crypto.SHA256: + algo = C.kSecKeyAlgorithmECDSASignatureDigestX962SHA256 + case crypto.SHA384: + algo = C.kSecKeyAlgorithmECDSASignatureDigestX962SHA384 + case crypto.SHA512: + algo = C.kSecKeyAlgorithmECDSASignatureDigestX962SHA512 + default: + err = ErrUnsupportedHash + } + case *rsa.PublicKey: + switch hash { + case crypto.SHA1: + algo = C.kSecKeyAlgorithmRSASignatureDigestPKCS1v15SHA1 + case crypto.SHA256: + algo = C.kSecKeyAlgorithmRSASignatureDigestPKCS1v15SHA256 + case crypto.SHA384: + algo = C.kSecKeyAlgorithmRSASignatureDigestPKCS1v15SHA384 + case crypto.SHA512: + algo = C.kSecKeyAlgorithmRSASignatureDigestPKCS1v15SHA512 + default: + err = ErrUnsupportedHash + } + default: + err = errors.New("unsupported key type") + } + + return algo, err +} + +// exportCertRef gets a *x509.Certificate for the given SecCertificateRef. +func exportCertRef(certRef C.SecCertificateRef) (*x509.Certificate, error) { + derRef := C.SecCertificateCopyData(certRef) + if derRef == 0 { + return nil, errors.New("error getting certificate from identity") + } + defer C.CFRelease(C.CFTypeRef(derRef)) + + der := cfDataToBytes(derRef) + crt, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + + return crt, nil +} + +// getKeyRef gets the SecKeyRef for this identity's private key. +func getKeyRef(ref C.SecIdentityRef) (C.SecKeyRef, error) { + var keyRef C.SecKeyRef + if err := osStatusError(C.SecIdentityCopyPrivateKey(ref, &keyRef)); err != nil { + return 0, err + } + + return keyRef, nil +} + +// getKeyRef gets the SecKeyRef for this identity's private key. +func (signer DarwinCertStoreSigner) getKeyRef() (C.SecKeyRef, error) { + if signer.keyRef != 0 { + return signer.keyRef, nil + } + + keyRef, err := getKeyRef(signer.identRef) + signer.keyRef = keyRef + + return signer.keyRef, err +} + +// getCertRef gets the SecCertificateRef for this identity's certificate. +func getCertRef(ref C.SecIdentityRef) (C.SecCertificateRef, error) { + var certRef C.SecCertificateRef + if err := osStatusError(C.SecIdentityCopyCertificate(ref, &certRef)); err != nil { + return 0, err + } + + return certRef, nil +} + +// getCertRef gets the identity's certificate reference +func (signer *DarwinCertStoreSigner) getCertRef() (C.SecCertificateRef, error) { + if signer.certRef != 0 { + return signer.certRef, nil + } + + certRef, err := getCertRef(signer.identRef) + signer.certRef = certRef + + return signer.certRef, err +} + +// stringToCFData converts a Go string to a CFDataRef +func stringToCFData(str string) (C.CFDataRef, error) { + return bytesToCFData([]byte(str)) +} + +// cfDataToBytes converts a CFDataRef to a Go byte slice +func cfDataToBytes(cfdata C.CFDataRef) []byte { + nBytes := C.CFDataGetLength(cfdata) + bytesPtr := C.CFDataGetBytePtr(cfdata) + return C.GoBytes(unsafe.Pointer(bytesPtr), C.int(nBytes)) +} + +// bytesToCFData converts a Go byte slice to a CFDataRef +func bytesToCFData(gobytes []byte) (C.CFDataRef, error) { + var ( + cptr = (*C.UInt8)(nil) + clen = C.CFIndex(len(gobytes)) + ) + + if len(gobytes) > 0 { + cptr = (*C.UInt8)(&gobytes[0]) + } + + cdata := C.CFDataCreate(0, cptr, clen) + if cdata == 0 { + return 0, errors.New("error creating cdata") + } + + return cdata, nil +} + +// cfErrorError returns an error for a CFErrorRef unless it is nil +func cfErrorError(cerr C.CFErrorRef) error { + if cerr == 0 { + return nil + } + + code := int(C.CFErrorGetCode(cerr)) + + if cdescription := C.CFErrorCopyDescription(cerr); cdescription != 0 { + defer C.CFRelease(C.CFTypeRef(cdescription)) + + if cstr := C.CFStringGetCStringPtr(cdescription, C.kCFStringEncodingUTF8); cstr != nil { + str := C.GoString(cstr) + + return fmt.Errorf("CFError %d (%s)", code, str) + } + + } + + return fmt.Errorf("CFError %d", code) +} + +// mapToCFDictionary converts a Go map[C.CFTypeRef]C.CFTypeRef to a CFDictionaryRef +func mapToCFDictionary(gomap map[C.CFTypeRef]C.CFTypeRef) C.CFDictionaryRef { + var ( + n = len(gomap) + keys = make([]unsafe.Pointer, 0, n) + values = make([]unsafe.Pointer, 0, n) + ) + + for k, v := range gomap { + keys = append(keys, unsafe.Pointer(k)) + values = append(values, unsafe.Pointer(v)) + } + + return C.CFDictionaryCreate(0, &keys[0], &values[0], C.CFIndex(n), nil, nil) +} + +// osStatusError returns an error for an OSStatus unless it is errSecSuccess +func osStatusError(s C.OSStatus) error { + if s == C.errSecSuccess { + return nil + } + + return osStatus(s) +} + +// Error implements the error interface and returns a stringified error for this osStatus +func (s osStatus) Error() string { + return fmt.Sprintf("OSStatus %d", s) +} diff --git a/aws_signing_helper/file_system_signer.go b/aws_signing_helper/file_system_signer.go new file mode 100644 index 0000000..e793543 --- /dev/null +++ b/aws_signing_helper/file_system_signer.go @@ -0,0 +1,132 @@ +package aws_signing_helper + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "errors" + "golang.org/x/crypto/pkcs12" + "io" + "log" + "os" +) + +type FileSystemSigner struct { + PrivateKey crypto.PrivateKey + cert *x509.Certificate + certChain []*x509.Certificate +} + +func (fileSystemSigner FileSystemSigner) Public() crypto.PublicKey { + { + privateKey, ok := fileSystemSigner.PrivateKey.(ecdsa.PrivateKey) + if ok { + return &privateKey.PublicKey + } + } + { + privateKey, ok := fileSystemSigner.PrivateKey.(rsa.PrivateKey) + if ok { + return &privateKey.PublicKey + } + } + return nil +} + +func (fileSystemSigner FileSystemSigner) Close() { +} + +func (fileSystemSigner FileSystemSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + var hash []byte + switch opts.HashFunc() { + case crypto.SHA256: + sum := sha256.Sum256(digest) + hash = sum[:] + case crypto.SHA384: + sum := sha512.Sum384(digest) + hash = sum[:] + case crypto.SHA512: + sum := sha512.Sum512(digest) + hash = sum[:] + default: + return nil, ErrUnsupportedHash + } + + ecdsaPrivateKey, ok := fileSystemSigner.PrivateKey.(ecdsa.PrivateKey) + if ok { + sig, err := ecdsa.SignASN1(rand, &ecdsaPrivateKey, hash[:]) + if err == nil { + return sig, nil + } + } + + rsaPrivateKey, ok := fileSystemSigner.PrivateKey.(rsa.PrivateKey) + if ok { + sig, err := rsa.SignPKCS1v15(rand, &rsaPrivateKey, opts.HashFunc(), hash[:]) + if err == nil { + return sig, nil + } + } + + log.Println("unsupported algorithm") + return nil, errors.New("unsupported algorithm") +} + +func (fileSystemSigner FileSystemSigner) Certificate() (*x509.Certificate, error) { + return fileSystemSigner.cert, nil +} + +func (fileSystemSigner FileSystemSigner) CertificateChain() ([]*x509.Certificate, error) { + return fileSystemSigner.certChain, nil +} + +// Returns a FileSystemSigner, that signs a payload using the +// private key passed in +func GetFileSystemSigner(privateKey crypto.PrivateKey, certificate *x509.Certificate, certificateChain []*x509.Certificate) (signer Signer, signingAlgorithm string, err error) { + // Find the signing algorithm + _, isRsaKey := privateKey.(rsa.PrivateKey) + if isRsaKey { + signingAlgorithm = aws4_x509_rsa_sha256 + } + _, isEcKey := privateKey.(ecdsa.PrivateKey) + if isEcKey { + signingAlgorithm = aws4_x509_ecdsa_sha256 + } + if signingAlgorithm == "" { + log.Println("unsupported algorithm") + return nil, "", errors.New("unsupported algorithm") + } + + return FileSystemSigner{privateKey, certificate, certificateChain}, signingAlgorithm, nil +} + +func GetPKCS12Signer(certificateId string) (signer Signer, signingAlgorithm string, err error) { + bytes, err := os.ReadFile(certificateId) + if err != nil { + return nil, "", err + } + privateKey, certificate, err := pkcs12.Decode(bytes, "") + if err != nil { + return nil, "", err + } + if privateKey == nil { + return nil, "", errors.New("PKCS#12 has no private key") + } + + rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey) + if ok { + signingAlgorithm = aws4_x509_rsa_sha256 + return FileSystemSigner{*rsaPrivateKey, certificate, nil}, signingAlgorithm, nil + } + + ecPrivateKey, ok := privateKey.(*ecdsa.PrivateKey) + if ok { + signingAlgorithm = aws4_x509_ecdsa_sha256 + return FileSystemSigner{*ecPrivateKey, certificate, nil}, signingAlgorithm, nil + } + + return nil, "", errors.New("unsupported algorithm on PKCS#12 key") +} diff --git a/aws_signing_helper/linux_cert_store_signer.go b/aws_signing_helper/linux_cert_store_signer.go new file mode 100644 index 0000000..71595c9 --- /dev/null +++ b/aws_signing_helper/linux_cert_store_signer.go @@ -0,0 +1,15 @@ +//go:build linux + +package aws_signing_helper + +import ( + "errors" +) + +func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, error) { + return nil, errors.New("unable to use cert store signer on linux") +} + +func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAlgorithm string, err error) { + return nil, "", errors.New("unable to use cert store signer on linux") +} diff --git a/aws_signing_helper/serve.go b/aws_signing_helper/serve.go index b0b892f..343e54b 100644 --- a/aws_signing_helper/serve.go +++ b/aws_signing_helper/serve.go @@ -147,7 +147,7 @@ func FindTokenTTLSeconds(r *http.Request) (string, error) { } } -func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *CredentialsOpts) (http.HandlerFunc, http.HandlerFunc, http.HandlerFunc) { +func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (http.HandlerFunc, http.HandlerFunc, http.HandlerFunc) { // Handles PUT requests to /latest/api/token/ putTokenHandler := func(w http.ResponseWriter, r *http.Request) { if r.Method != "PUT" { @@ -224,7 +224,7 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials var nextRefreshTime = cred.Expiration.Add(-RefreshTime) if time.Until(nextRefreshTime) < RefreshTime { - credentialProcessOutput, _ := GenerateCredentials(opts) + credentialProcessOutput, _ := GenerateCredentials(opts, signer, signatureAlgorithm) cred.AccessKeyId = credentialProcessOutput.AccessKeyId cred.SecretAccessKey = credentialProcessOutput.SecretAccessKey cred.Token = credentialProcessOutput.SessionToken @@ -268,7 +268,14 @@ func Serve(port int, credentialsOptions CredentialsOpts) { os.Exit(1) } - credentialProcessOutput, _ := GenerateCredentials(&credentialsOptions) + signer, signatureAlgorithm, err := GetSigner(&credentialsOptions) + if err != nil { + log.Println(err) + os.Exit(1) + } + defer signer.Close() + + credentialProcessOutput, _ := GenerateCredentials(&credentialsOptions, signer, signatureAlgorithm) refreshableCred.AccessKeyId = credentialProcessOutput.AccessKeyId refreshableCred.SecretAccessKey = credentialProcessOutput.SecretAccessKey refreshableCred.Token = credentialProcessOutput.SessionToken @@ -280,7 +287,7 @@ func Serve(port int, credentialsOptions CredentialsOpts) { endpoint.Server = &http.Server{} roleResourceParts := strings.Split(roleArn.Resource, "/") roleName := roleResourceParts[len(roleResourceParts)-1] // Find role name without path - putTokenHandler, getRoleNameHandler, getCredentialsHandler := AllIssuesHandlers(&endpoint.TmpCred, roleName, &credentialsOptions) + putTokenHandler, getRoleNameHandler, getCredentialsHandler := AllIssuesHandlers(&endpoint.TmpCred, roleName, &credentialsOptions, signer, signatureAlgorithm) http.HandleFunc(TOKEN_RESOURCE_PATH, putTokenHandler) http.HandleFunc(SECURITY_CREDENTIALS_RESOURCE_PATH, getRoleNameHandler) diff --git a/aws_signing_helper/signer.go b/aws_signing_helper/signer.go index 2e792b5..dd6f30f 100644 --- a/aws_signing_helper/signer.go +++ b/aws_signing_helper/signer.go @@ -7,8 +7,8 @@ import ( "crypto/rand" "crypto/rsa" "crypto/sha256" - "crypto/sha512" "crypto/x509" + "encoding/asn1" "encoding/base64" "encoding/hex" "encoding/pem" @@ -16,6 +16,7 @@ import ( "fmt" "io" "log" + "math/big" "net/http" "os" "sort" @@ -26,39 +27,6 @@ import ( "github.com/aws/aws-sdk-go/aws/request" ) -type SigningOpts struct { - // Private key to use for the signing operation. - PrivateKey crypto.PrivateKey - // Digest to use in the signing operation. For example, SHA256 - Digest crypto.Hash -} - -// Container for data that will be sent in a request to CreateSession. -type RequestOpts struct { - // ARN of the Role to assume in the CreateSession call. - RoleArn string - // ARN of the Configuration to use in the CreateSession call. - ConfigurationArn string - // Certificate, as base64-encoded DER; used in the `x-amz-x509` - // header in the API request. - CertificateData string - // Duration of the session that will be returned by CreateSession. - DurationSeconds int -} - -type RequestHeaderOpts struct { - // Certificate, as base64-encoded DER; used in the `x-amz-x509` - // header in the API request. - CertificateData string -} - -type RequestQueryStringOpts struct { - // ARN of the Role to assume in the CreateSession call. - RoleArn string - // ARN of the Configuration to use in the CreateSession call. - ConfigurationArn string -} - type SignerParams struct { OverriddenDate time.Time RegionName string @@ -66,10 +34,26 @@ type SignerParams struct { SigningAlgorithm string } -// Container for data returned after performing a signing operation. -type SigningResult struct { - // Signature encoded in hex. - Signature string `json:"signature"` +type CertIdentifier struct { + Subject string + Issuer string + SerialNumber *big.Int +} + +var ( + // ErrUnsupportedHash is returned by Signer.Sign() when the provided hash + // algorithm isn't supported. + ErrUnsupportedHash = errors.New("unsupported hash algorithm") +) + +// Interface that all signers will have to implement +// (as a result, they will also implement crypto.Signer) +type Signer interface { + Public() crypto.PublicKey + Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) + Certificate() (certificate *x509.Certificate, err error) + CertificateChain() (certificateChain []*x509.Certificate, err error) + Close() } // Container for certificate data returned to the SDK as JSON. @@ -101,10 +85,11 @@ type CredentialProcessOutput struct { Expiration string `json:"Expiration"` } -type RolesAnywhereSigner struct { - PrivateKey crypto.PrivateKey - Certificate x509.Certificate - CertificateChain []x509.Certificate +type CertificateContainer struct { + // Certificate data + Cert *x509.Certificate + // Certificate URI (only populated in the case that the certificate is a PKCS#11 object) + Uri string } // Define constants used in signing @@ -129,6 +114,105 @@ var ignoredHeaderKeys = map[string]bool{ "X-Amzn-Trace-Id": true, } +var Debug bool = false + +// Find whether the current certificate matches the CertIdentifier +func certMatches(certIdentifier CertIdentifier, cert x509.Certificate) bool { + if certIdentifier.Subject != "" && certIdentifier.Subject != cert.Subject.String() { + return false + } + if certIdentifier.Issuer != "" && certIdentifier.Issuer != cert.Issuer.String() { + return false + } + if certIdentifier.SerialNumber != nil && certIdentifier.SerialNumber.Cmp(cert.SerialNumber) != 0 { + return false + } + + return true +} + +// Because of *course* we have to do this for ourselves. +// +// Create the DER-encoded SEQUENCE containing R and S: +// +// Ecdsa-Sig-Value ::= SEQUENCE { +// r INTEGER, +// s INTEGER +// } +// +// This is defined in RFC3279 §2.2.3 as well as SEC.1. +// I can't find anything which mandates DER but I've seen +// OpenSSL refusing to verify it with indeterminate length. +func encodeEcdsaSigValue(signature []byte) (out []byte, err error) { + sigLen := len(signature) / 2 + + return asn1.Marshal(struct { + R *big.Int + S *big.Int + }{ + big.NewInt(0).SetBytes(signature[:sigLen]), + big.NewInt(0).SetBytes(signature[sigLen:])}) +} + +// Gets the Signer based on the flags passed in by the user (from which the CredentialsOpts structure is derived) +func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string, err error) { + var certificate *x509.Certificate + var certificateChain []*x509.Certificate + + privateKeyId := opts.PrivateKeyId + if privateKeyId == "" { + if opts.CertificateId == "" { + if Debug { + fmt.Fprintln(os.Stderr, "attempting to use CertStoreSigner") + } + return GetCertStoreSigner(opts.CertIdentifier) + } + privateKeyId = opts.CertificateId + } + + if opts.CertificateId != "" { + certificateData, err := ReadCertificateData(opts.CertificateId) + if err == nil { + certificateDerData, err := base64.StdEncoding.DecodeString(certificateData.CertificateData) + if err != nil { + return nil, "", err + } + certificate, err = x509.ParseCertificate([]byte(certificateDerData)) + if err != nil { + return nil, "", err + } + } else if opts.PrivateKeyId == "" { + if Debug { + fmt.Fprintln(os.Stderr, "not a PEM certificate, so trying PKCS#12") + } + // Not a PEM certificate? Try PKCS#12 + return GetPKCS12Signer(opts.CertificateId) + } else { + return nil, "", err + } + } + + if opts.CertificateBundleId != "" { + certificateChainPointers, err := ReadCertificateBundleData(opts.CertificateBundleId) + if err != nil { + return nil, "", err + } + for _, certificate := range certificateChainPointers { + certificateChain = append(certificateChain, certificate) + } + } + + privateKey, err := ReadPrivateKeyData(privateKeyId) + if err != nil { + return nil, "", err + } + + if Debug { + fmt.Fprintln(os.Stderr, "attempting to use FileSystemSigner") + } + return GetFileSystemSigner(privateKey, certificate, certificateChain) +} + // Obtain the date-time, formatted as specified by SigV4 func (signerParams *SignerParams) GetFormattedSigningDateTime() string { return signerParams.OverriddenDate.UTC().Format(timeFormat) @@ -153,12 +237,12 @@ func (signerParams *SignerParams) GetScope() string { } // Convert certificate to string, so that it can be present in the HTTP request header -func certificateToString(certificate x509.Certificate) string { +func certificateToString(certificate *x509.Certificate) string { return base64.StdEncoding.EncodeToString(certificate.Raw) } // Convert certificate chain to string, so that it can be pressent in the HTTP request header -func certificateChainToString(certificateChain []x509.Certificate) string { +func certificateChainToString(certificateChain []*x509.Certificate) string { var x509ChainString strings.Builder for i, certificate := range certificateChain { x509ChainString.WriteString(certificateToString(certificate)) @@ -169,65 +253,46 @@ func certificateChainToString(certificateChain []x509.Certificate) string { return x509ChainString.String() } -// Create a function that will sign requests, given the signing certificate, optional certificate chain, and the private key -func CreateSignFunction(privateKey crypto.PrivateKey, certificate x509.Certificate, certificateChain []x509.Certificate) func(*request.Request) { - v4x509 := RolesAnywhereSigner{privateKey, certificate, certificateChain} - return func(r *request.Request) { - v4x509.SignWithCurrTime(r) - } -} - -// Sign the request using the current time -func (v4x509 RolesAnywhereSigner) SignWithCurrTime(req *request.Request) error { - // Find the signing algorithm - var signingAlgorithm string - _, isRsaKey := v4x509.PrivateKey.(rsa.PrivateKey) - if isRsaKey { - signingAlgorithm = aws4_x509_rsa_sha256 - } - _, isEcKey := v4x509.PrivateKey.(ecdsa.PrivateKey) - if isEcKey { - signingAlgorithm = aws4_x509_ecdsa_sha256 - } - if signingAlgorithm == "" { - log.Println("unsupported algorithm") - return errors.New("unsupported algorithm") - } - - region := req.ClientInfo.SigningRegion - if region == "" { - region = aws.StringValue(req.Config.Region) - } - - name := req.ClientInfo.SigningName - if name == "" { - name = req.ClientInfo.ServiceName - } +func CreateRequestSignFunction(signer crypto.Signer, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate) func(*request.Request) { + return func(req *request.Request) { + region := req.ClientInfo.SigningRegion + if region == "" { + region = aws.StringValue(req.Config.Region) + } - signerParams := SignerParams{time.Now(), region, name, signingAlgorithm} + name := req.ClientInfo.SigningName + if name == "" { + name = req.ClientInfo.ServiceName + } - // Set headers that are necessary for signing - req.HTTPRequest.Header.Set(host, req.HTTPRequest.URL.Host) - req.HTTPRequest.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime()) - req.HTTPRequest.Header.Set(x_amz_x509, certificateToString(v4x509.Certificate)) - if v4x509.CertificateChain != nil { - req.HTTPRequest.Header.Set(x_amz_x509_chain, certificateChainToString(v4x509.CertificateChain)) - } + signerParams := SignerParams{time.Now(), region, name, signingAlgorithm} - contentSha256 := calculateContentHash(req.HTTPRequest, req.Body) - if req.HTTPRequest.Header.Get(x_amz_content_sha256) == "required" { - req.HTTPRequest.Header.Set(x_amz_content_sha256, contentSha256) - } + // Set headers that are necessary for signing + req.HTTPRequest.Header.Set(host, req.HTTPRequest.URL.Host) + req.HTTPRequest.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime()) + req.HTTPRequest.Header.Set(x_amz_x509, certificateToString(certificate)) + if certificateChain != nil { + req.HTTPRequest.Header.Set(x_amz_x509_chain, certificateChainToString(certificateChain)) + } - canonicalRequest, signedHeadersString := createCanonicalRequest(req.HTTPRequest, req.Body, contentSha256) + contentSha256 := calculateContentHash(req.HTTPRequest, req.Body) + if req.HTTPRequest.Header.Get(x_amz_content_sha256) == "required" { + req.HTTPRequest.Header.Set(x_amz_content_sha256, contentSha256) + } - stringToSign := CreateStringToSign(canonicalRequest, signerParams) + canonicalRequest, signedHeadersString := createCanonicalRequest(req.HTTPRequest, req.Body, contentSha256) - signingResult, _ := Sign([]byte(stringToSign), SigningOpts{v4x509.PrivateKey, crypto.SHA256}) + stringToSign := CreateStringToSign(canonicalRequest, signerParams) + signatureBytes, err := signer.Sign(rand.Reader, []byte(stringToSign), crypto.SHA256) + if err != nil { + log.Println(err.Error()) + os.Exit(1) + } + signature := hex.EncodeToString(signatureBytes) - req.HTTPRequest.Header.Set(authorization, BuildAuthorizationHeader(req.HTTPRequest, req.Body, signedHeadersString, signingResult.Signature, v4x509.Certificate, signerParams)) - req.SignedHeaderVals = req.HTTPRequest.Header - return nil + req.HTTPRequest.Header.Set(authorization, BuildAuthorizationHeader(req.HTTPRequest, req.Body, signedHeadersString, signature, certificate, signerParams)) + req.SignedHeaderVals = req.HTTPRequest.Header + } } // Find the SHA256 hash of the provided request body as a io.ReadSeeker @@ -370,7 +435,7 @@ func CreateStringToSign(canonicalRequest string, signerParams SignerParams) stri } // Builds the complete authorization header -func BuildAuthorizationHeader(request *http.Request, body io.ReadSeeker, signedHeadersString string, signature string, certificate x509.Certificate, signerParams SignerParams) string { +func BuildAuthorizationHeader(request *http.Request, body io.ReadSeeker, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string { signingCredentials := certificate.SerialNumber.String() + "/" + signerParams.GetScope() credential := "Credential=" + signingCredentials signerHeaders := "SignedHeaders=" + signedHeadersString @@ -388,44 +453,6 @@ func BuildAuthorizationHeader(request *http.Request, body io.ReadSeeker, signedH return authHeaderString } -// Sign the provided payload with the specified options. -func Sign(payload []byte, opts SigningOpts) (SigningResult, error) { - var hash []byte - switch opts.Digest { - case crypto.SHA256: - sum := sha256.Sum256(payload) - hash = sum[:] - case crypto.SHA384: - sum := sha512.Sum384(payload) - hash = sum[:] - case crypto.SHA512: - sum := sha512.Sum512(payload) - hash = sum[:] - default: - log.Println("unsupported digest") - return SigningResult{}, errors.New("unsupported digest") - } - - ecdsaPrivateKey, ok := opts.PrivateKey.(ecdsa.PrivateKey) - if ok { - sig, err := ecdsa.SignASN1(rand.Reader, &ecdsaPrivateKey, hash[:]) - if err == nil { - return SigningResult{hex.EncodeToString(sig)}, nil - } - } - - rsaPrivateKey, ok := opts.PrivateKey.(rsa.PrivateKey) - if ok { - sig, err := rsa.SignPKCS1v15(rand.Reader, &rsaPrivateKey, opts.Digest, hash[:]) - if err == nil { - return SigningResult{hex.EncodeToString(sig)}, nil - } - } - - log.Println("unsupported algorithm") - return SigningResult{}, errors.New("unsupported algorithm") -} - func encodeDer(der []byte) (string, error) { var buf bytes.Buffer encoder := base64.NewEncoder(base64.StdEncoding, &buf) diff --git a/aws_signing_helper/signer_test.go b/aws_signing_helper/signer_test.go index cbcae40..97189ce 100644 --- a/aws_signing_helper/signer_test.go +++ b/aws_signing_helper/signer_test.go @@ -3,17 +3,15 @@ package aws_signing_helper import ( "crypto" "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/sha512" - "crypto/x509" - "encoding/base64" - "encoding/hex" "errors" + "fmt" "io/ioutil" "log" + "math/big" "net/http" "net/http/httptest" "os" @@ -29,14 +27,8 @@ import ( const TestCredentialsFilePath = "/tmp/credentials" func setup() error { - generateCertsScript := exec.Command("/bin/bash", "../generate-certs.sh") - _, err := generateCertsScript.Output() - if err != nil { - return err - } - - generateCredentialProcessDataScript := exec.Command("/bin/bash", "../generate-credential-process-data.sh") - _, err = generateCredentialProcessDataScript.Output() + generateCredentialProcessDataScript := exec.Command("/bin/sh", "../generate-credential-process-data.sh") + _, err := generateCredentialProcessDataScript.Output() return err } @@ -129,28 +121,35 @@ func TestBuildAuthorizationHeader(t *testing.T) { t.Fail() } + certificateList, _ := ReadCertificateBundleData("../tst/certs/rsa-2048-sha256-cert.pem") + certificate := certificateList[0] privateKey, _ := ReadPrivateKeyData("../tst/certs/rsa-2048-key.pem") - certificateData, _ := ReadCertificateData("../tst/certs/rsa-2048-sha256-cert.pem") - certificateDerData, _ := base64.StdEncoding.DecodeString(certificateData.CertificateData) - certificate, _ := x509.ParseCertificate([]byte(certificateDerData)) awsRequest := request.Request{HTTPRequest: testRequest} - v4x509 := RolesAnywhereSigner{ - PrivateKey: privateKey, - Certificate: *certificate, + signer, signingAlgorithm, err := GetFileSystemSigner(privateKey, certificate, nil) + if err != nil { + t.Log(err) + t.Fail() } - err = v4x509.SignWithCurrTime(&awsRequest) + certificate, err = signer.Certificate() if err != nil { t.Log(err) t.Fail() } + certificateChain, err := signer.CertificateChain() + if err != nil { + t.Log(err) + t.Fail() + } + requestSignFunction := CreateRequestSignFunction(signer, signingAlgorithm, certificate, certificateChain) + requestSignFunction(&awsRequest) } // Verify that the provided payload was signed correctly with the provided options. // This function is specifically used for unit testing. -func Verify(payload []byte, opts SigningOpts, sig []byte) (bool, error) { +func Verify(payload []byte, publicKey crypto.PublicKey, digest crypto.Hash, sig []byte) (bool, error) { var hash []byte - switch opts.Digest { + switch digest { case crypto.SHA256: sum := sha256.Sum256(payload) hash = sum[:] @@ -161,27 +160,23 @@ func Verify(payload []byte, opts SigningOpts, sig []byte) (bool, error) { sum := sha512.Sum512(payload) hash = sum[:] default: - log.Fatal("Unsupported digest") - return false, errors.New("Unsupported digest") + log.Fatal("unsupported digest") + return false, errors.New("unsupported digest") } { - privateKey, ok := opts.PrivateKey.(ecdsa.PrivateKey) + publicKey, ok := publicKey.(*ecdsa.PublicKey) if ok { - valid := ecdsa.VerifyASN1(&privateKey.PublicKey, hash, sig) - if valid { - return valid, nil - } + valid := ecdsa.VerifyASN1(publicKey, hash, sig) + return valid, nil } } { - privateKey, ok := opts.PrivateKey.(rsa.PrivateKey) + publicKey, ok := publicKey.(*rsa.PublicKey) if ok { - err := rsa.VerifyPKCS1v15(&privateKey.PublicKey, opts.Digest, hash, sig) - if err == nil { - return true, nil - } + err := rsa.VerifyPKCS1v15(publicKey, digest, hash, sig) + return err == nil, nil } } @@ -190,37 +185,110 @@ func Verify(payload []byte, opts SigningOpts, sig []byte) (bool, error) { func TestSign(t *testing.T) { msg := "test message" - - var privateKeyList [2]crypto.PrivateKey - { - privateKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - privateKeyList[0] = *privateKey + testTable := []CredentialsOpts{} + + ec_digests := []string{"sha1", "sha256", "sha384", "sha512"} + ec_curves := []string{"prime256v1", "secp384r1"} + + for _, digest := range ec_digests { + for _, curve := range ec_curves { + cert := fmt.Sprintf("../tst/certs/ec-%s-%s-cert.pem", + curve, digest) + key := fmt.Sprintf("../tst/certs/ec-%s-key.pem", curve) + testTable = append(testTable, CredentialsOpts{ + CertificateId: cert, + PrivateKeyId: key, + }) + + key = fmt.Sprintf("../tst/certs/ec-%s-key-pkcs8.pem", curve) + testTable = append(testTable, CredentialsOpts{ + CertificateId: cert, + PrivateKeyId: key, + }) + + cert = fmt.Sprintf("../tst/certs/ec-%s-%s.p12", + curve, digest) + testTable = append(testTable, CredentialsOpts{ + CertificateId: cert, + }) + } } - { - privateKey, _ := rsa.GenerateKey(rand.Reader, 2048) - privateKeyList[1] = *privateKey + + rsa_digests := []string{"md5", "sha1", "sha256", "sha384", "sha512"} + rsa_key_lengths := []string{"1024", "2048", "4096"} + + for _, digest := range rsa_digests { + for _, keylen := range rsa_key_lengths { + cert := fmt.Sprintf("../tst/certs/rsa-%s-%s-cert.pem", + keylen, digest) + key := fmt.Sprintf("../tst/certs/rsa-%s-key.pem", keylen) + testTable = append(testTable, CredentialsOpts{ + CertificateId: cert, + PrivateKeyId: key, + }) + + key = fmt.Sprintf("../tst/certs/rsa-%s-key-pkcs8.pem", keylen) + testTable = append(testTable, CredentialsOpts{ + CertificateId: cert, + PrivateKeyId: key, + }) + + cert = fmt.Sprintf("../tst/certs/rsa-%s-%s.p12", + keylen, digest) + testTable = append(testTable, CredentialsOpts{ + CertificateId: cert, + }) + + } } + digestList := []crypto.Hash{crypto.SHA256, crypto.SHA384, crypto.SHA512} - for _, privateKey := range privateKeyList { + for _, credOpts := range testTable { + signer, _, err := GetSigner(&credOpts) + if err != nil { + var logMsg string + if credOpts.CertificateId != "" || credOpts.PrivateKeyId != "" { + logMsg = fmt.Sprintf("Failed to get signer for '%s'/'%s'", + credOpts.CertificateId, credOpts.PrivateKeyId) + } else { + logMsg = fmt.Sprintf("Failed to get signer for '%s'", + credOpts.CertIdentifier.Subject) + } + t.Log(logMsg) + t.Fail() + return + } + defer signer.Close() + + pubKey := signer.Public() + if credOpts.CertificateId != "" && pubKey == nil { + t.Log(fmt.Sprintf("Signer didn't provide public key for '%s'/'%s'", + credOpts.CertificateId, credOpts.PrivateKeyId)) + t.Fail() + return + } + for _, digest := range digestList { - signingResult, err := Sign([]byte(msg), SigningOpts{privateKey, digest}) + signatureBytes, err := signer.Sign(rand.Reader, []byte(msg), digest) if err != nil { t.Log("Failed to sign the input message") t.Fail() + return } - sig, err := hex.DecodeString(signingResult.Signature) - if err != nil { - t.Log("Failed to decode the hex-encoded signature") - t.Fail() - } - valid, _ := Verify([]byte(msg), SigningOpts{privateKey, digest}, sig) - if !valid { - t.Log("Failed to verify the signature") - t.Fail() + if pubKey != nil { + valid, _ := Verify([]byte(msg), pubKey, digest, signatureBytes) + if !valid { + t.Log(fmt.Sprintf("Failed to verify the signature for '%s'/'%s'", + credOpts.CertificateId, credOpts.PrivateKeyId)) + t.Fail() + return + } } } + + signer.Close() } } @@ -246,7 +314,13 @@ func TestCredentialProcess(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { defer tc.server.Close() - resp, err := GenerateCredentials(&credentialsOpts) + signer, signatureAlgorithm, err := GetSigner(&credentialsOpts) + if err != nil { + t.Log("Failed to get signer") + t.Fail() + return + } + resp, err := GenerateCredentials(&credentialsOpts, signer, signatureAlgorithm) if err != nil { t.Log(err) @@ -270,6 +344,43 @@ func TestCredentialProcess(t *testing.T) { } } +func TestCertStoreSignerCreationFails(t *testing.T) { + testTable := []CredentialsOpts{} + + randomLargeSerial := new(big.Int) + randomLargeSerial.SetString("123456719012345678901234567890", 10) + + testTable = append(testTable, CredentialsOpts{ + CertIdentifier: CertIdentifier{ + Subject: "invalid-subject", + }, + }) + testTable = append(testTable, CredentialsOpts{ + CertIdentifier: CertIdentifier{ + Issuer: "invalid-issuer", + }, + }) + testTable = append(testTable, CredentialsOpts{ + CertIdentifier: CertIdentifier{ + SerialNumber: randomLargeSerial, + }, + }) + testTable = append(testTable, CredentialsOpts{ + CertIdentifier: CertIdentifier{ + Subject: "CN=roles-anywhere-rsa-2048-sha25", + SerialNumber: randomLargeSerial, + }, + }) + + for _, credOpts := range testTable { + _, _, err := GetSigner(&credOpts) + if err == nil { + t.Log("Expected failure when creating certificate store signer, but received none") + t.Fail() + } + } +} + func TestUpdate(t *testing.T) { testTable := []struct { name string diff --git a/aws_signing_helper/update.go b/aws_signing_helper/update.go index b1049c0..c09c855 100644 --- a/aws_signing_helper/update.go +++ b/aws_signing_helper/update.go @@ -25,8 +25,16 @@ type TemporaryCredential struct { func Update(credentialsOptions CredentialsOpts, profile string, once bool) { var refreshableCred = TemporaryCredential{} var nextRefreshTime time.Time + + signer, signatureAlgorithm, err := GetSigner(&credentialsOptions) + if err != nil { + log.Println(err) + os.Exit(1) + } + defer signer.Close() + for { - credentialProcessOutput, err := GenerateCredentials(&credentialsOptions) + credentialProcessOutput, err := GenerateCredentials(&credentialsOptions, signer, signatureAlgorithm) if err != nil { log.Fatal(err) } @@ -108,7 +116,7 @@ func GetWriteOnlyCredentialsFile() (*os.File, error) { return os.OpenFile(awsCredentialsPath, os.O_WRONLY|os.O_TRUNC, 0200) } -// Function that will get the new conents of the credentials file after a +// Function that will get the new conents of the credentials file after a // refresh has been done func GetNewCredentialsFileContents(profileName string, readLines []string, cred *TemporaryCredential) []string { var profileExist = false @@ -170,7 +178,7 @@ func GetNewCredentialsFileContents(profileName string, readLines []string, cred writeLines = append(writeLines[:], writeCredential+"\n") } - return writeLines + return writeLines } // Function to write existing credentials and newly-created credentials to a destination file @@ -182,7 +190,7 @@ func WriteTo(profileName string, readLines []string, cred *TemporaryCredential) } defer destFile.Close() - // Create buffered writer + // Create buffered writer destFileWriter := bufio.NewWriterSize(destFile, BufferSize) for _, line := range GetNewCredentialsFileContents(profileName, readLines, cred) { _, err := destFileWriter.WriteString(line) diff --git a/aws_signing_helper/windows_cert_store_signer.go b/aws_signing_helper/windows_cert_store_signer.go new file mode 100644 index 0000000..6fb24ee --- /dev/null +++ b/aws_signing_helper/windows_cert_store_signer.go @@ -0,0 +1,621 @@ +//go:build windows + +package aws_signing_helper + +// This code is based on the smimesign repository at +// https://github.com/github/smimesign + +/* +#cgo windows LDFLAGS: -lcrypt32 -lncrypt +#include +#include +#include +#include + +// +// Go complains about LPCWSTR constants and the MAKELANGID function not being +// defined, so we define methods for them. +// + +LPCWSTR GET_BCRYPT_SHA1_ALGORITHM() { return BCRYPT_SHA1_ALGORITHM; } +LPCWSTR GET_BCRYPT_SHA256_ALGORITHM() { return BCRYPT_SHA256_ALGORITHM; } +LPCWSTR GET_BCRYPT_SHA384_ALGORITHM() { return BCRYPT_SHA384_ALGORITHM; } +LPCWSTR GET_BCRYPT_SHA512_ALGORITHM() { return BCRYPT_SHA512_ALGORITHM; } + +int MAKE_LANG_ID(int p, int s) { + return MAKELANGID(p, s); +} + +*/ +import "C" + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "errors" + "fmt" + "golang.org/x/sys/windows" + "io" + "os" + "strconv" + "strings" + "unsafe" +) + +// winPrivateKey is a wrapper around a HCRYPTPROV_OR_NCRYPT_KEY_HANDLE. +type winPrivateKey struct { + publicKey crypto.PublicKey + mustFree bool + + // CryptoAPI fields + cspHandle windows.Handle + keySpec uint32 + + // CNG fields + cngKeyHandle windows.Handle +} + +type WindowsCertStoreSigner struct { + store windows.Handle + certCtx *windows.CertContext + cert *x509.Certificate + certChain []*x509.Certificate + privateKey *winPrivateKey +} + +const ( + WIN_FALSE C.WINBOOL = 0 + + // ERROR_SUCCESS — The call succeeded + ERROR_SUCCESS = 0x00000000 + + // NTE_BAD_ALGID — Invalid algorithm specified + NTE_BAD_ALGID = 0x80090008 + + // WIN_API_FLAG specifies the flags that should be passed to + // CryptAcquireCertificatePrivateKey. This impacts whether the CryptoAPI or CNG + // API will be used. + // + // Possible values are: + // + // 0x00000000 — — Only use CryptoAPI. + // 0x00010000 — CRYPT_ACQUIRE_ALLOW_NCRYPT_KEY_FLAG — Prefer CryptoAPI. + // 0x00020000 — CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG — Prefer CNG. + // 0x00040000 — CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG — Only use CNG. + WIN_API_FLAG = windows.CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG +) + +// Error codes for Windows APIs - implements the error interface +type errCode uint64 + +// Security status for Windows APIs - implements the error interface +// Go representation of the C SECURITY_STATUS +type securityStatus uint64 + +// Gets the certificates that match the given CertIdentifier within the user's "MY" certificate store. +// If there is only a single matching certificate, then its chain will be returned too +func GetMatchingCertsAndChain(certIdentifier CertIdentifier) (store windows.Handle, certCtx *windows.CertContext, certChain []*x509.Certificate, certContainers []CertificateContainer, err error) { + storeName, err := windows.UTF16PtrFromString("MY") + if err != nil { + return 0, nil, nil, nil, errors.New("unable to UTF-16 encode personal certificate store name") + } + + store, err = windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM_W, 0, 0, windows.CERT_SYSTEM_STORE_CURRENT_USER, uintptr(unsafe.Pointer(storeName))) + if err != nil { + return 0, nil, nil, nil, errors.New("failed to open system cert store") + } + + var ( + // CertFindChainInStore parameters + encoding = uint32(windows.X509_ASN_ENCODING) + flags = uint32(windows.CERT_CHAIN_FIND_BY_ISSUER_CACHE_ONLY_FLAG | windows.CERT_CHAIN_FIND_BY_ISSUER_CACHE_ONLY_URL_FLAG) + findType = uint32(windows.CERT_CHAIN_FIND_BY_ISSUER) + params windows.CertChainFindByIssuerPara + paramsPtr unsafe.Pointer + chainCtx *windows.CertChainContext = nil + ) + params.Size = uint32(unsafe.Sizeof(params)) + paramsPtr = unsafe.Pointer(¶ms) + + var curCertCtx *windows.CertContext + for { + // Previous chainCtx should be freed here if it isn't nil + chainCtx, err = windows.CertFindChainInStore(store, encoding, flags, findType, paramsPtr, chainCtx) + if err != nil { + if strings.Contains(err.Error(), "Cannot find object or property.") { + break + } + err = errors.New("unable to find certificate chain in store") + goto fail + } + + if chainCtx.ChainCount < 1 { + err = errors.New("bad chain") + goto fail + } + + // When multiple valid certification paths that are found for a given + // certificate, only the first one is considered + simpleChain := *chainCtx.Chains + if simpleChain.NumElements < 1 { + err = errors.New("bad chain") + goto fail + } + + // Convert the array into a pointer + chainElts := unsafe.Slice(simpleChain.Elements, simpleChain.NumElements) + + // Build chain of certificates from each element's certificate context. + x509CertChain := make([]*x509.Certificate, len(chainElts)) + for j := range chainElts { + curCertCtx = chainElts[j].CertContext + x509CertChain[j], err = exportCertContext(curCertCtx) + if err != nil { + goto fail + } + } + + curCert := x509CertChain[0] + if certMatches(certIdentifier, *curCert) { + certContainers = append(certContainers, CertificateContainer{curCert, ""}) + + // Assign to certChain and certCtx at most once in the loop. + // The value is only useful if there is exactly one match in the certificate store. + // When creating a signer, there has to be exactly one matching certificate. + if certChain == nil { + certChain = x509CertChain[:] + certCtx = chainElts[0].CertContext + // This is required later on when creating the WindowsCertStoreSigner + // If this method isn't being called in order to create a WindowsCertStoreSigner, + // this return value will have to be freed explicitly. + windows.CertDuplicateCertificateContext(certCtx) + } + } + } + + if Debug { + fmt.Fprintf(os.Stderr, "found %d matching identities\n", len(certContainers)) + } + + return store, certCtx, certChain, certContainers, nil + +fail: + if chainCtx != nil { + windows.CertFreeCertificateChain(chainCtx) + } + if certCtx != nil { + windows.CertFreeCertificateContext(certCtx) + } + windows.CertCloseStore(store, 0) + + return 0, nil, nil, nil, err +} + +// Gets the certificates that match a CertIdentifier +func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, error) { + store, certCtx, _, certContainers, err := GetMatchingCertsAndChain(certIdentifier) + if certCtx != nil { + windows.CertFreeCertificateContext(certCtx) + } + windows.CertCloseStore(store, 0) + + return certContainers, err +} + +// Gets a WindowsCertStoreSigner based on the CertIdentifier +func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAlgorithm string, err error) { + var privateKey *winPrivateKey + store, certCtx, certChain, certContainers, err := GetMatchingCertsAndChain(certIdentifier) + if err != nil { + goto fail + } + if len(certContainers) > 1 { + err = errors.New("more than one matching cert found in cert store") + goto fail + } + if len(certContainers) == 0 { + err = errors.New("no matching certs found in cert store") + goto fail + } + + signer = &WindowsCertStoreSigner{store: store, cert: certContainers[0].Cert, certCtx: certCtx, certChain: certChain} + + privateKey, err = signer.(*WindowsCertStoreSigner).getPrivateKey() + if err != nil { + goto fail + } + + // Find the signing algorithm + switch privateKey.publicKey.(type) { + case *ecdsa.PublicKey: + signingAlgorithm = aws4_x509_ecdsa_sha256 + case *rsa.PublicKey: + signingAlgorithm = aws4_x509_rsa_sha256 + default: + err = errors.New("unsupported algorithm") + goto fail + } + + return signer, signingAlgorithm, err + +fail: + if certCtx != nil { + windows.CertFreeCertificateContext(certCtx) + } + if signer != nil { + signer.Close() + } + if store != 0 { + windows.CertCloseStore(store, 0) + } + + return nil, "", err +} + +// Certificate implements the aws_signing_helper.Signer interface and returns a pointer +// to the x509.Certificate associated with this signer +func (signer *WindowsCertStoreSigner) Certificate() (cert *x509.Certificate, err error) { + return signer.cert, nil +} + +// CertificateChain implements the aws_signing_helper.Signer interface and returns +// the certificate chain associated with this signer +func (signer *WindowsCertStoreSigner) CertificateChain() ([]*x509.Certificate, error) { + return signer.certChain, nil +} + +// Close implements the aws_signing_helper.Signer interface and closes the signer +func (signer *WindowsCertStoreSigner) Close() { + if signer.privateKey != nil && signer.privateKey.mustFree { + if signer.privateKey.cngKeyHandle != 0 { + cngHandle := (*C.NCRYPT_KEY_HANDLE)(unsafe.Pointer(&signer.privateKey.cngKeyHandle)) + C.NCryptFreeObject(*cngHandle) + } + if signer.privateKey.cspHandle != 0 { + windows.CryptReleaseContext(signer.privateKey.cspHandle, 0) + } + } + signer.privateKey = nil + + // If signer.privateKey.mustFree is false, key handles will be released on the + // last free action of the certificate context. + // See https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-cryptacquirecertificateprivatekey + if signer.certCtx != nil { + windows.CertFreeCertificateContext(signer.certCtx) + signer.certCtx = nil + } + + windows.CertCloseStore(signer.store, 0) + signer.store = 0 +} + +// getPrivateKey gets this identity's private *winPrivateKey +func (signer *WindowsCertStoreSigner) getPrivateKey() (*winPrivateKey, error) { + if signer.privateKey != nil { + return signer.privateKey, nil + } + + cert, err := signer.Certificate() + if err != nil { + return nil, fmt.Errorf("failed to get identity certificate: %w", err) + } + + signer.privateKey, err = newWinPrivateKey(signer.certCtx, cert.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to load identity private key: %w", err) + } + + return signer.privateKey, nil +} + +// Gets a *winPrivateKey for the given certificate +func newWinPrivateKey(certCtx *windows.CertContext, publicKey crypto.PublicKey) (*winPrivateKey, error) { + var ( + cspHandleOrCngKey windows.Handle + keySpec uint32 + mustFree bool + ) + + if publicKey == nil { + return nil, errors.New("nil public key") + } + + // Get a handle for the found private key + if err := windows.CryptAcquireCertificatePrivateKey(certCtx, WIN_API_FLAG, nil, &cspHandleOrCngKey, &keySpec, &mustFree); err != nil { + return nil, err + } + + if keySpec == C.CERT_NCRYPT_KEY_SPEC { + return &winPrivateKey{ + publicKey: publicKey, + cngKeyHandle: cspHandleOrCngKey, + mustFree: mustFree, + }, nil + } else { + return &winPrivateKey{ + publicKey: publicKey, + cspHandle: cspHandleOrCngKey, + keySpec: keySpec, + mustFree: mustFree, + }, nil + } +} + +// Public implements the crypto.Signer interface. +func (signer *WindowsCertStoreSigner) Public() crypto.PublicKey { + privateKey, err := signer.getPrivateKey() + if err != nil { + return nil + } + + return privateKey.publicKey +} + +// Sign implements the crypto.Signer interface and signs the digest +func (signer *WindowsCertStoreSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + var hash []byte + switch opts.HashFunc() { + case crypto.SHA256: + sum := sha256.Sum256(digest) + hash = sum[:] + case crypto.SHA384: + sum := sha512.Sum384(digest) + hash = sum[:] + case crypto.SHA512: + sum := sha512.Sum512(digest) + hash = sum[:] + default: + return nil, ErrUnsupportedHash + } + + privateKey, err := signer.getPrivateKey() + if err != nil { + return nil, err + } + + if privateKey.cspHandle != 0 { + return signer.cryptoSignHash(hash, opts.HashFunc()) + } else if privateKey.cngKeyHandle != 0 { + return signer.cngSignHash(hash, opts.HashFunc()) + } else { + return nil, errors.New("bad private key") + } +} + +// cngSignHash signs a digest using CNG APIs +func (signer *WindowsCertStoreSigner) cngSignHash(digest []byte, hash crypto.Hash) ([]byte, error) { + if len(digest) != hash.Size() { + return nil, errors.New("bad digest for hash") + } + + var ( + // Input + padPtr = unsafe.Pointer(nil) + digestPtr = (*C.BYTE)(&digest[0]) + digestLen = C.DWORD(len(digest)) + flags = C.DWORD(0) + + // Output + sigLen = C.DWORD(0) + ) + + // Set up pkcs1v1.5 padding for RSA + privateKey, _ := signer.getPrivateKey() + if _, isRSA := privateKey.publicKey.(*rsa.PublicKey); isRSA { + flags |= C.BCRYPT_PAD_PKCS1 + padInfo := C.BCRYPT_PKCS1_PADDING_INFO{} + padPtr = unsafe.Pointer(&padInfo) + + switch hash { + case crypto.SHA1: + padInfo.pszAlgId = C.GET_BCRYPT_SHA1_ALGORITHM() + case crypto.SHA256: + padInfo.pszAlgId = C.GET_BCRYPT_SHA256_ALGORITHM() + case crypto.SHA384: + padInfo.pszAlgId = C.GET_BCRYPT_SHA384_ALGORITHM() + case crypto.SHA512: + padInfo.pszAlgId = C.GET_BCRYPT_SHA512_ALGORITHM() + default: + return nil, ErrUnsupportedHash + } + } + + // Get C.NCRYPT_KEY_HANDLE in order to do the signature + cngKeyHandle := (*C.NCRYPT_KEY_HANDLE)(unsafe.Pointer(&privateKey.cngKeyHandle)) + + // Get signature length + if err := checkStatus(C.NCryptSignHash(*cngKeyHandle, padPtr, digestPtr, digestLen, nil, 0, &sigLen, flags)); err != nil { + return nil, fmt.Errorf("failed to get signature length: %w", err) + } + + // Get signature + sig := make([]byte, sigLen) + sigPtr := (*C.BYTE)(&sig[0]) + if err := checkStatus(C.NCryptSignHash(*cngKeyHandle, padPtr, digestPtr, digestLen, sigPtr, sigLen, &sigLen, flags)); err != nil { + return nil, fmt.Errorf("failed to sign digest: %w", err) + } + + // CNG returns a raw ECDSA signature, but we want ASN.1 DER encoding + if _, isEC := privateKey.publicKey.(*ecdsa.PublicKey); isEC { + if len(sig)%2 != 0 { + return nil, errors.New("bad ecdsa signature from CNG") + } + + return encodeEcdsaSigValue(sig) + } + + return sig, nil +} + +// Signs a digest using CryptoAPI +func (signer *WindowsCertStoreSigner) cryptoSignHash(digest []byte, hash crypto.Hash) ([]byte, error) { + if len(digest) != hash.Size() { + return nil, errors.New("bad digest for hash") + } + + // Figure out which CryptoAPI hash algorithm we're using + var hash_alg C.ALG_ID + + switch hash { + case crypto.SHA1: + hash_alg = C.CALG_SHA1 + case crypto.SHA256: + hash_alg = C.CALG_SHA_256 + case crypto.SHA384: + hash_alg = C.CALG_SHA_384 + case crypto.SHA512: + hash_alg = C.CALG_SHA_512 + default: + return nil, ErrUnsupportedHash + } + + // Instantiate a CryptoAPI hash object + var cryptoHash C.HCRYPTHASH + + privateKey, _ := signer.getPrivateKey() + cspHandle := (*C.HCRYPTPROV)(unsafe.Pointer(&privateKey.cspHandle)) + if ok := C.CryptCreateHash(*cspHandle, hash_alg, 0, 0, &cryptoHash); ok == WIN_FALSE { + if err := lastError("failed to create hash"); errCause(err) == errCode(NTE_BAD_ALGID) { + return nil, ErrUnsupportedHash + } else { + return nil, err + } + } + defer C.CryptDestroyHash(cryptoHash) + + // Make sure the hash size matches + var ( + hashSize C.DWORD + hashSizePtr = (*C.BYTE)(unsafe.Pointer(&hashSize)) + hashSizeLen = C.DWORD(unsafe.Sizeof(hashSize)) + ) + + if ok := C.CryptGetHashParam(cryptoHash, C.HP_HASHSIZE, hashSizePtr, &hashSizeLen, 0); ok == WIN_FALSE { + return nil, lastError("failed to get hash size") + } + + if hash.Size() != int(hashSize) { + return nil, errors.New("invalid CryptoAPI hash") + } + + // Put our digest into the hash object + digestPtr := (*C.BYTE)(unsafe.Pointer(&digest[0])) + if ok := C.CryptSetHashParam(cryptoHash, C.HP_HASHVAL, digestPtr, 0); ok == WIN_FALSE { + return nil, lastError("failed to set hash digest") + } + + // Get signature length + var sigLen C.DWORD + + if ok := C.CryptSignHash(cryptoHash, C.ulong(privateKey.keySpec), nil, 0, nil, &sigLen); ok == WIN_FALSE { + return nil, lastError("failed to get signature length") + } + + // Get signature + var ( + sig = make([]byte, int(sigLen)) + sigPtr = (*C.BYTE)(unsafe.Pointer(&sig[0])) + ) + + if ok := C.CryptSignHash(cryptoHash, C.ulong(privateKey.keySpec), nil, 0, sigPtr, &sigLen); ok == WIN_FALSE { + return nil, lastError("failed to sign digest") + } + + // Reversing signature since it is little endian, but we want big endian + for i := len(sig)/2 - 1; i >= 0; i-- { + opp := len(sig) - 1 - i + sig[i], sig[opp] = sig[opp], sig[i] + } + + return sig, nil +} + +// Exports a windows.CertContext as an *x509.Certificate. +func exportCertContext(certCtx *windows.CertContext) (*x509.Certificate, error) { + // Technically, we should never throw here, since the exportCertContext function + // is only called when searching for certificates + if certCtx.EncodingType != windows.X509_ASN_ENCODING { + return nil, errors.New("unknown certificate encoding type") + } + + der := unsafe.Slice(certCtx.EncodedCert, certCtx.Length) + return x509.ParseCertificate(der) +} + +// Finds the error code for the given error +func errCause(err error) errCode { + msg := err.Error() + codeStr := msg[strings.LastIndex(msg, " ")+1:] + code, _ := strconv.ParseUint(codeStr, 16, 64) + return errCode(code) +} + +// Gets the last error from the current thread. If there isn't one, it +// returns a new error +func lastError(msg string) error { + if err := checkError(msg); err != nil { + return err + } + + return errors.New(msg) +} + +// checkError tries to get the last error from the current thread. If there +// isn't one, it returns nil +func checkError(msg string) error { + if code := errCode(C.GetLastError()); code != 0 { + return fmt.Errorf("%s: %w", msg, code) + } + + return nil +} + +// Implements the error interface for errCode and returns a string +// version of the errCode +func (c errCode) Error() string { + var cMsg C.LPSTR + ret := C.FormatMessage( + C.FORMAT_MESSAGE_ALLOCATE_BUFFER| + C.FORMAT_MESSAGE_FROM_SYSTEM| + C.FORMAT_MESSAGE_IGNORE_INSERTS, + nil, + C.DWORD(c), + C.ulong(C.MAKE_LANG_ID(C.LANG_NEUTRAL, C.SUBLANG_DEFAULT)), + cMsg, + 0, nil) + if ret == 0 { + return fmt.Sprintf("Error %X", int(c)) + } + + if cMsg == nil { + return fmt.Sprintf("Error %X", int(c)) + } + + goMsg := C.GoString(cMsg) + + return fmt.Sprintf("Error: %X %s", int(c), goMsg) +} + +// Converts a SECURITY_STATUS into a securityStatus +func checkStatus(s C.SECURITY_STATUS) error { + secStatus := securityStatus(s) + + if secStatus == ERROR_SUCCESS { + return nil + } + + if secStatus == NTE_BAD_ALGID { + return ErrUnsupportedHash + } + + return secStatus +} + +// Implements the error interface +func (secStatus securityStatus) Error() string { + return fmt.Sprintf("SECURITY_STATUS %d", int(secStatus)) +} diff --git a/cmd/aws_signing_helper/main.go b/cmd/aws_signing_helper/main.go deleted file mode 100644 index cbe95f4..0000000 --- a/cmd/aws_signing_helper/main.go +++ /dev/null @@ -1,280 +0,0 @@ -package main - -import ( - "bufio" - "crypto" - "encoding/binary" - "encoding/hex" - "encoding/json" - "flag" - "fmt" - "io/ioutil" - "log" - "os" - "strings" - - helper "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" -) - -// Common flags that must be contained in all flag sets -var ( - privateKeyId string - certificateId string - certificateBundleId string - digestArg string - roleArnStr string - profileArnStr string - trustAnchorArnStr string - sessionDuration int - - region string - endpoint string - noVerifySSL bool - withProxy bool - debug bool - format string - - profile string - once bool - - port int - - credentialProcessCmd = flag.NewFlagSet("credential-process", flag.ExitOnError) - signStringCmd = flag.NewFlagSet("sign-string", flag.ExitOnError) - readCertificateDataCmd = flag.NewFlagSet("read-certificate-data", flag.ExitOnError) - updateCmd = flag.NewFlagSet("update", flag.ExitOnError) - serveCmd = flag.NewFlagSet("serve", flag.ExitOnError) - versionCmd = flag.NewFlagSet("version", flag.ExitOnError) -) - -var Version string -var globalOptSet = map[string]bool{"--region": true, "--endpoint": true} -var credentialCommands = map[string]struct{}{"credential-process": {}, "update": {}, "serve": {}} - -// Maps each command name to a flagset -var commands = map[string]*flag.FlagSet{ - credentialProcessCmd.Name(): credentialProcessCmd, - signStringCmd.Name(): signStringCmd, - readCertificateDataCmd.Name(): readCertificateDataCmd, - updateCmd.Name(): updateCmd, - serveCmd.Name(): serveCmd, - versionCmd.Name(): versionCmd, -} - -// Finds global parameters that can appear in any position -// Return a map that maps the name of global parameter to its value -// and a list of remaining arguments -func findGlobalVar(argList []string) (map[string]string, []string) { - globalVars := make(map[string]string) - - parseList := []string{} - - for i := 0; i < len(argList); i++ { - - if globalOptSet[argList[i]] { - - if !strings.HasPrefix(argList[i+1], "--") { - globalVars[argList[i]] = argList[i+1] - i = i + 1 - } else { - log.Println("Invalid value for ", argList[i]) - os.Exit(1) - } - } else { - parseList = append(parseList, argList[i]) - } - } - - return globalVars, parseList -} - -// Assigns different flags to different commands -func setupFlags() { - for command, fs := range commands { - // Common flags for all credential-related commands - if _, ok := credentialCommands[command]; ok { - fs.StringVar(&certificateId, "certificate", "", "Path to certificate file") - fs.StringVar(&privateKeyId, "private-key", "", "Path to private key file") - fs.StringVar(&roleArnStr, "role-arn", "", "Target role to assume") - fs.StringVar(&profileArnStr, "profile-arn", "", "Profile to to pull policies from") - fs.StringVar(&trustAnchorArnStr, "trust-anchor-arn", "", "Trust anchor to to use for authentication") - fs.IntVar(&sessionDuration, "session-duration", 3600, "Duration, in seconds, for the resulting session") - fs.StringVar(®ion, "region", "", "Signing region") - fs.StringVar(&endpoint, "endpoint", "", "Endpoint to retrieve session from") - fs.StringVar(&certificateBundleId, "intermediates", "", "Path to intermediate certificate bundle") - fs.BoolVar(&noVerifySSL, "no-verify-ssl", false, "To disable SSL verification") - fs.BoolVar(&withProxy, "with-proxy", false, "To use credential-process with a proxy") - fs.BoolVar(&debug, "debug", false, "To print debug output when SDK calls are made") - } - - if command == "read-certificate-data" { - fs.StringVar(&certificateId, "certificate", "", "Path to certificate file") - } else if command == "sign-string" { - fs.StringVar(&privateKeyId, "private-key", "", "Path to private key file") - fs.StringVar(&format, "format", "json", "Output format. One of json, text, and bin") - fs.StringVar(&digestArg, "digest", "SHA256", "One of SHA256, SHA384 and SHA512") - } else if command == "update" { - fs.StringVar(&profile, "profile", "default", "The aws profile to use (default 'default')") - fs.BoolVar(&once, "once", false, "Update the credentials once") - } else if command == "serve" { - fs.IntVar(&port, "port", helper.DefaultPort, "The port used to run local server (default: 9911)") - } - } -} - -func main() { - setupFlags() - - // find and remove global variables - globalVars, parseList := findGlobalVar(os.Args[1:]) - tmpRegion, regionDetected := globalVars["--region"] - tmpEndpoint, endpointDetected := globalVars["--endpoint"] - if len(parseList) == 0 || strings.HasPrefix(parseList[0], "--") { - log.Println("No command provided") - os.Exit(1) - } - - command := parseList[0] - commandFs, valid := commands[command] - // if the command does not exist in the command list - if !valid { - log.Println("Unrecognized command") - os.Exit(1) - } - - commandFs.Parse(parseList[1:]) - - // assign global variables if they have been detected - if regionDetected { - region = tmpRegion - } - if endpointDetected { - endpoint = tmpEndpoint - } - credentialsOptions := helper.CredentialsOpts{ - PrivateKeyId: privateKeyId, - CertificateId: certificateId, - CertificateBundleId: certificateBundleId, - RoleArn: roleArnStr, - ProfileArnStr: profileArnStr, - TrustAnchorArnStr: trustAnchorArnStr, - SessionDuration: sessionDuration, - Region: region, - Endpoint: endpoint, - NoVerifySSL: noVerifySSL, - WithProxy: withProxy, - Debug: debug, - Version: Version, - } - - switch command { - case "credential-process": - // First check whether required arguments are present - if privateKeyId == "" || certificateId == "" || profileArnStr == "" || - trustAnchorArnStr == "" || roleArnStr == "" { - msg := `Usage: aws_signing_helper credential-process - --private-key - --certificate - --profile-arn - --trust-anchor-arn - --role-arn - [--endpoint ] - [--region ] - [--session-duration ] - [--with-proxy] - [--no-verify-ssl] - [--debug] - [--intermediates ]` - log.Println(msg) - os.Exit(1) - } - credentialProcessOutput, err := helper.GenerateCredentials(&credentialsOptions) - if err != nil { - log.Println(err) - os.Exit(1) - } - buf, _ := json.Marshal(credentialProcessOutput) - fmt.Print(string(buf[:])) - case "sign-string": - stringToSign, _ := ioutil.ReadAll(bufio.NewReader(os.Stdin)) - privateKey, _ := helper.ReadPrivateKeyData(privateKeyId) - var digest crypto.Hash - switch strings.ToUpper(digestArg) { - case "SHA256": - digest = crypto.SHA256 - case "SHA384": - digest = crypto.SHA384 - case "SHA512": - digest = crypto.SHA512 - default: - digest = crypto.SHA256 - } - signingResult, _ := helper.Sign(stringToSign, helper.SigningOpts{PrivateKey: privateKey, Digest: digest}) - switch strings.ToLower(format) { - case "text": - fmt.Print(signingResult.Signature) - case "json": - buf, _ := json.Marshal(signingResult) - fmt.Print(string(buf[:])) - case "bin": - buf, _ := hex.DecodeString(signingResult.Signature) - binary.Write(os.Stdout, binary.BigEndian, buf[:]) - default: - fmt.Print(signingResult.Signature) - } - case "read-certificate-data": - data, _ := helper.ReadCertificateData(certificateId) - buf, _ := json.Marshal(data) - fmt.Print(string(buf[:])) - case "version": - fmt.Println(Version) - case "update": - if privateKeyId == "" || certificateId == "" || - profileArnStr == "" || trustAnchorArnStr == "" || roleArnStr == "" { - msg := `Usage: aws_signing_helper update - --private-key - --certificate - --profile-arn - --trust-anchor-arn - --role-arn - [--endpoint ] - [--region ] - [--session-duration ] - [--with-proxy] - [--no-verify-ssl] - [--intermediates ] - [--profile ] - [--once]` - log.Println(msg) - os.Exit(1) - } - helper.Update(credentialsOptions, profile, once) - case "serve": - // First check whether required arguments are present - if privateKeyId == "" || certificateId == "" || profileArnStr == "" || - trustAnchorArnStr == "" || roleArnStr == "" { - msg := `Usage: aws_signing_helper serve - --private-key - --certificate - --profile-arn - --trust-anchor-arn - --role-arn - [--endpoint ] - [--region ] - [--session-duration ] - [--with-proxy] - [--no-verify-ssl] - [--debug] - [--intermediates ] - [--port ]` - log.Println(msg) - os.Exit(1) - } - helper.Serve(port, credentialsOptions) - case "": - log.Println("No command provided") - os.Exit(1) - default: - log.Fatalf("Unrecognized command %s", command) - } -} diff --git a/cmd/aws_signing_helper/main_test.go b/cmd/aws_signing_helper/main_test.go deleted file mode 100644 index 38417d2..0000000 --- a/cmd/aws_signing_helper/main_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package main - -import ( - "testing" -) - -func TestParseArgs(t *testing.T) { - args := []string{ - "read-certificate-data", - "--certificate", - "/path/to/cert.pem", - } - setupFlags() - var command = commands[args[0]] - command.Parse(args[1:]) - - if certificateId != "/path/to/cert.pem" { - t.Errorf("Expected %s, got %s", "/path/to/cert.pem", certificateId) - } -} diff --git a/cmd/credential_process.go b/cmd/credential_process.go new file mode 100644 index 0000000..13c99e7 --- /dev/null +++ b/cmd/credential_process.go @@ -0,0 +1,46 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "log" + "os" + + helper "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/spf13/cobra" +) + +func init() { + initCredentialsSubCommand(credentialProcessCmd) +} + +var credentialProcessCmd = &cobra.Command{ + Use: "credential-process [flags]", + Short: "Retrieve AWS credentials in the appropriate format for external credential processes", + Long: `To retrieve AWS credentials in the appropriate format for external +credential processes, as determined by the SDK/CLI. More information can be +found at: https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sourcing-external.html`, + Run: func(cmd *cobra.Command, args []string) { + err := PopulateCredentialsOptions() + if err != nil { + log.Println(err) + os.Exit(1) + } + + helper.Debug = credentialsOptions.Debug + + signer, signingAlgorithm, err := helper.GetSigner(&credentialsOptions) + if err != nil { + log.Println(err) + os.Exit(1) + } + defer signer.Close() + credentialProcessOutput, err := helper.GenerateCredentials(&credentialsOptions, signer, signingAlgorithm) + if err != nil { + log.Println(err) + os.Exit(1) + } + buf, _ := json.Marshal(credentialProcessOutput) + fmt.Print(string(buf[:])) + }, +} diff --git a/cmd/credentials.go b/cmd/credentials.go new file mode 100644 index 0000000..99749ca --- /dev/null +++ b/cmd/credentials.go @@ -0,0 +1,222 @@ +package cmd + +import ( + "encoding/json" + "errors" + "io/ioutil" + "math/big" + "strings" + + helper "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/spf13/cobra" +) + +var ( + roleArnStr string + profileArnStr string + trustAnchorArnStr string + sessionDuration int + region string + endpoint string + noVerifySSL bool + withProxy bool + debug bool + + certificateId string + privateKeyId string + certificateBundleId string + certSelector string + + libPkcs11 string + pinPkcs11 string + slotPkcs11 uint + checkPkcs11 bool + + credentialsOptions helper.CredentialsOpts + + X509_SUBJECT_KEY = "x509Subject" + X509_ISSUER_KEY = "x509Issuer" + X509_SERIAL_KEY = "x509Serial" + + validCertSelectorKeys = []string{ + X509_SUBJECT_KEY, + X509_ISSUER_KEY, + X509_SERIAL_KEY, + } +) + +type MapEntry struct { + Key string + Value string +} + +// Parses common flags for commands that vend credentials +func initCredentialsSubCommand(subCmd *cobra.Command) { + rootCmd.AddCommand(subCmd) + subCmd.PersistentFlags().StringVar(&roleArnStr, "role-arn", "", "Target role to assume") + subCmd.PersistentFlags().StringVar(&profileArnStr, "profile-arn", "", "Profile to pull policies from") + subCmd.PersistentFlags().StringVar(&trustAnchorArnStr, "trust-anchor-arn", "", "Trust anchor to use for authentication") + subCmd.PersistentFlags().IntVar(&sessionDuration, "session-duration", 3600, "Duration, in seconds, for the resulting session") + subCmd.PersistentFlags().StringVar(®ion, "region", "", "Signing region") + subCmd.PersistentFlags().StringVar(&endpoint, "endpoint", "", "Endpoint used to call CreateSession") + subCmd.PersistentFlags().BoolVar(&noVerifySSL, "no-verify-ssl", false, "To disable SSL verification") + subCmd.PersistentFlags().BoolVar(&withProxy, "with-proxy", false, "To make the CreateSession call with a proxy") + subCmd.PersistentFlags().BoolVar(&debug, "debug", false, "To print debug output") + subCmd.PersistentFlags().StringVar(&certificateId, "certificate", "", "Path to certificate file") + subCmd.PersistentFlags().StringVar(&privateKeyId, "private-key", "", "Path to private key file") + subCmd.PersistentFlags().StringVar(&certificateBundleId, "intermediates", "", "Path to intermediate certificate bundle file") + subCmd.PersistentFlags().StringVar(&certSelector, "cert-selector", "", "JSON structure to identify a certificate from a certificate store. "+ + "Can be passed in either as string or a file name (prefixed by \"file://\")") + + subCmd.MarkFlagsMutuallyExclusive("private-key", "cert-selector") +} + +// Parses a cert selector string to a map +func getStringMap(s string) (map[string]string, error) { + entries := strings.Split(s, " ") + + m := make(map[string]string) + for _, e := range entries { + tokens := strings.SplitN(e, ",", 2) + keyTokens := strings.Split(tokens[0], "=") + if keyTokens[0] != "Key" { + return nil, errors.New("invalid cert selector map key") + } + key := strings.TrimSpace(strings.Join(keyTokens[1:], "=")) + + isValidKey := false + for _, validKey := range validCertSelectorKeys { + if validKey == key { + isValidKey = true + break + } + } + if !isValidKey { + return nil, errors.New("cert selector contained invalid key") + } + + valueTokens := strings.Split(tokens[1], "=") + if valueTokens[0] != "Value" { + return nil, errors.New("invalid cert selector map value") + } + value := strings.TrimSpace(strings.Join(valueTokens[1:], "=")) + m[key] = value + } + + return m, nil +} + +// Parses a JSON cert selector string into a map +func getMapFromJsonEntries(jsonStr string) (map[string]string, error) { + m := make(map[string]string) + var mapEntries []MapEntry + err := json.Unmarshal([]byte(jsonStr), &mapEntries) + if err != nil { + return nil, errors.New("unable to parse JSON map entries") + } + for _, mapEntry := range mapEntries { + isValidKey := false + for _, validKey := range validCertSelectorKeys { + if validKey == mapEntry.Key { + isValidKey = true + break + } + } + if !isValidKey { + return nil, errors.New("cert selector contained invalid key") + } + m[mapEntry.Key] = mapEntry.Value + } + return m, nil +} + +func createCertSelectorFromMap(certSelectorMap map[string]string) helper.CertIdentifier { + var certIdentifier helper.CertIdentifier + + for key, value := range certSelectorMap { + switch key { + case X509_SUBJECT_KEY: + certIdentifier.Subject = value + case X509_ISSUER_KEY: + certIdentifier.Issuer = value + case X509_SERIAL_KEY: + certSerial := new(big.Int) + certSerial.SetString(value, 16) + certIdentifier.SerialNumber = certSerial + } + } + + return certIdentifier +} + +func PopulateCertIdentifierFromJsonStr(jsonStr string) (helper.CertIdentifier, error) { + certSelectorMap, err := getMapFromJsonEntries(jsonStr) + if err != nil { + return helper.CertIdentifier{}, err + } + return createCertSelectorFromMap(certSelectorMap), nil +} + +// Populates a CertIdentifier object using a cert selector string +func PopulateCertIdentifierFromCertSelectorStr(certSelectorStr string) (helper.CertIdentifier, error) { + certSelectorMap, err := getStringMap(certSelectorStr) + if err != nil { + return helper.CertIdentifier{}, err + } + + return createCertSelectorFromMap(certSelectorMap), nil +} + +// Populates a CertIdentifier using a cert selector +// Note that this method can take in a file name as a the cert selector +func PopulateCertIdentifier(certSelector string) (helper.CertIdentifier, error) { + var certIdentifier helper.CertIdentifier + var err error + if certSelector != "" { + if strings.HasPrefix(certSelector, "file://") { + certSelectorFile, err := ioutil.ReadFile(strings.TrimPrefix(certSelector, "file://")) + if err != nil { + return helper.CertIdentifier{}, errors.New("unable to read cert selector file") + } + certIdentifier, err = PopulateCertIdentifierFromJsonStr(string(certSelectorFile[:])) + if err != nil { + return helper.CertIdentifier{}, errors.New("unable to parse JSON cert selector") + } + } else { + certIdentifier, err = PopulateCertIdentifierFromCertSelectorStr(certSelector) + if err != nil { + return helper.CertIdentifier{}, errors.New("unable to parse cert selector string") + } + } + } + + return certIdentifier, err +} + +// Populate CredentialsOpts that is used to aggregate all the information required to call CreateSession +func PopulateCredentialsOptions() error { + certIdentifier, err := PopulateCertIdentifier(certSelector) + if err != nil { + return err + } + + credentialsOptions = helper.CredentialsOpts{ + PrivateKeyId: privateKeyId, + CertificateId: certificateId, + CertificateBundleId: certificateBundleId, + CertIdentifier: certIdentifier, + RoleArn: roleArnStr, + ProfileArnStr: profileArnStr, + TrustAnchorArnStr: trustAnchorArnStr, + SessionDuration: sessionDuration, + Region: region, + Endpoint: endpoint, + NoVerifySSL: noVerifySSL, + WithProxy: withProxy, + Debug: debug, + Version: Version, + LibPkcs11: libPkcs11, + } + + return nil +} diff --git a/cmd/credentials_test.go b/cmd/credentials_test.go new file mode 100644 index 0000000..b6b992b --- /dev/null +++ b/cmd/credentials_test.go @@ -0,0 +1,45 @@ +package cmd + +import ( + "os" + "testing" +) + +func TestMain(m *testing.M) { + code := m.Run() + os.Exit(code) +} + +func TestValidSelectorParsing(t *testing.T) { + fixtures := []string{ + "file://../tst/selectors/valid-all-attributes-selector.json", + "file://../tst/selectors/valid-some-attributes-selector.json", + "Key=x509Subject,Value=CN=Subject Key=x509Issuer,Value=CN=Issuer Key=x509Serial,Value=15D19632234BF759A32802C0DA88F9E8AFC8702D", + "Key=x509Issuer,Value=CN=Issuer", + } + for _, fixture := range fixtures { + _, err := PopulateCertIdentifier(fixture) + if err != nil { + t.Log("Unable to populate cert identifier from selector") + t.Fail() + } + } +} + +func TestInvalidSelectorParsing(t *testing.T) { + fixtures := []string{ + "file://../tst/selectors/invalid-selector.json", + "file://../tst/selectors/invalid-selector-2.json", + "file://../tst/selectors/invalid-selector-3.json", + "laksdjadf", + "Key=laksdjf,Valalsd", + "Key=aljsdf,Value=aljsdfadsf", + } + for _, fixture := range fixtures { + _, err := PopulateCertIdentifier(fixture) + if err == nil { + t.Log("Expected parsing failure, but received none") + t.Fail() + } + } +} diff --git a/cmd/read_certificate_data.go b/cmd/read_certificate_data.go new file mode 100644 index 0000000..f76e286 --- /dev/null +++ b/cmd/read_certificate_data.go @@ -0,0 +1,82 @@ +package cmd + +import ( + "crypto/sha1" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "os" + + helper "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(readCertificateDataCmd) + readCertificateDataCmd.PersistentFlags().StringVar(&certificateId, "certificate", "", "Path to certificate file") + readCertificateDataCmd.PersistentFlags().StringVar(&certSelector, "cert-selector", "", "JSON structure to identify a certificate from a certificate store."+ + " Can be passed in either as string or a file name (prefixed by \"file://\")") + readCertificateDataCmd.PersistentFlags().BoolVar(&debug, "debug", false, "To print debug output") +} + +type PrintCertificate func(int, helper.CertificateContainer) + +func DefaultPrintCertificate(index int, certContainer helper.CertificateContainer) { + cert := certContainer.Cert + + fingerprint := sha1.Sum(cert.Raw) // nosemgrep + fingerprintHex := hex.EncodeToString(fingerprint[:]) + fmt.Printf("%d) %s \"%s\"\n", index+1, fingerprintHex, cert.Subject.String()) + + // Only for PKCS#11 + if certContainer.Uri != "" { + fmt.Printf("\tURI: %s\n", certContainer.Uri) + } +} + +var readCertificateDataCmd = &cobra.Command{ + Use: "read-certificate-data [flags]", + Short: "Diagnostic command to read certificate data", + Long: `Diagnostic command to read certificate data, either from files or + from a certificate store`, + Run: func(cmd *cobra.Command, args []string) { + certIdentifier, err := PopulateCertIdentifier(certSelector) + if err != nil { + log.Println("unable to populate CertIdentifier") + os.Exit(1) + } + + var certContainers []helper.CertificateContainer + // In case there is information that needs to be conditionally printed + // based on the type of integration being used (which can't be taken + // from the CertificateContainer), a function that implements the + // PrintCertificate interface can be assigned to this variable. + var printFunction PrintCertificate = DefaultPrintCertificate + + if certificateId != "" && certIdentifier == (helper.CertIdentifier{}) { + data, err := helper.ReadCertificateData(certificateId) + if err != nil { + os.Exit(1) + } + buf, err := json.Marshal(data) + if err != nil { + os.Exit(1) + } + + fmt.Print(string(buf[:])) + // Exit after printing out the certificate data + os.Exit(0) + } else { + certContainers, err = helper.GetMatchingCerts(certIdentifier) + if err != nil { + log.Println(err) + os.Exit(1) + } + } + fmt.Printf("Matching identities\n") + for index, certContainer := range certContainers { + printFunction(index, certContainer) + } + }, +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..6c569b9 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "github.com/spf13/cobra" + "log" + "os" +) + +var rootCmd = &cobra.Command{ + Use: "aws_signing_helper [command]", + Short: "The credential helper is a tool to retrieve temporary AWS credentials", + Long: `A tool that utilizes certificates and their associated private keys to +sign requests to AWS IAM Roles Anywhere's CreateSession API and retrieve temporary +AWS security credentials. This tool exposes multiple commands to make credential +retrieval and rotation more convenient.`, + Run: func(cmd *cobra.Command, args []string) { + + }, +} + +func Execute() { + if err := rootCmd.Execute(); err != nil { + log.Println(err) + os.Exit(1) + } +} diff --git a/cmd/serve.go b/cmd/serve.go new file mode 100644 index 0000000..10cb4e6 --- /dev/null +++ b/cmd/serve.go @@ -0,0 +1,35 @@ +package cmd + +import ( + "log" + "os" + + helper "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/spf13/cobra" +) + +var ( + port int +) + +func init() { + initCredentialsSubCommand(serveCmd) + serveCmd.PersistentFlags().IntVar(&port, "port", helper.DefaultPort, "The port used to run the local server (default: 9911)") +} + +var serveCmd = &cobra.Command{ + Use: "serve [flags]", + Short: "Serve AWS credentials through a local endpoint", + Long: "Serve AWS credentials through a local endpoint that is compatible with IMDSv2", + Run: func(cmd *cobra.Command, args []string) { + err := PopulateCredentialsOptions() + if err != nil { + log.Println(err) + os.Exit(1) + } + + helper.Debug = credentialsOptions.Debug + + helper.Serve(port, credentialsOptions) + }, +} diff --git a/cmd/sign_string.go b/cmd/sign_string.go new file mode 100644 index 0000000..00401c9 --- /dev/null +++ b/cmd/sign_string.go @@ -0,0 +1,169 @@ +package cmd + +import ( + "bufio" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "os" + "strconv" + "strings" + + helper "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/spf13/cobra" +) + +var ( + format *enum + digestArg *enum +) + +var ( + SIGN_STRING_TEST_VERSION uint16 = 1 + signFixedString bool = true +) + +type enum struct { + Allowed []string + Value string +} + +func newEnum(allowed []string, d string) *enum { + return &enum{ + Allowed: allowed, + Value: d, + } +} + +func (e enum) String() string { + return e.Value +} + +func (a *enum) Set(p string) error { + isIncluded := func(opts []string, val string) bool { + for _, opt := range opts { + if val == opt { + return true + } + } + return false + } + if !isIncluded(a.Allowed, p) { + return fmt.Errorf("%s is not included in %s", p, strings.Join(a.Allowed, ",")) + } + a.Value = p + return nil +} + +func (a *enum) Type() string { + return "string" +} + +func init() { + rootCmd.AddCommand(signStringCmd) + format = newEnum([]string{"json", "text", "bin"}, "json") + digestArg = newEnum([]string{"SHA256", "SHA384", "SHA512"}, "SHA256") + signStringCmd.PersistentFlags().StringVar(&certificateId, "certificate", "", "Path to certificate file") + signStringCmd.PersistentFlags().StringVar(&privateKeyId, "private-key", "", "Path to private key file") + signStringCmd.PersistentFlags().BoolVar(&debug, "debug", false, "To print debug output") + signStringCmd.PersistentFlags().StringVar(&certSelector, "cert-selector", "", "JSON structure to identify a certificate from a certificate store. "+ + "Can be passed in either as string or a file name (prefixed by \"file://\")") + signStringCmd.PersistentFlags().Var(format, "format", "Output format. One of json, text, and bin") + signStringCmd.PersistentFlags().Var(digestArg, "digest", "One of SHA256, SHA384, and SHA512") +} + +func getFixedStringToSign(publicKey crypto.PublicKey) string { + var digestSuffix []byte + ecdsaPublicKey, isEcKey := publicKey.(*ecdsa.PublicKey) + if isEcKey { + digestSuffixArr := sha256.Sum256(append([]byte("IAM RA"), elliptic.Marshal(ecdsaPublicKey, ecdsaPublicKey.X, ecdsaPublicKey.Y)...)) + digestSuffix = digestSuffixArr[:] + } + + rsaPublicKey, isRsaKey := publicKey.(*rsa.PublicKey) + if isRsaKey { + digestSuffixArr := sha256.Sum256(append([]byte("IAM RA"), x509.MarshalPKCS1PublicKey(rsaPublicKey)...)) + digestSuffix = digestSuffixArr[:] + } + + // "AWS Roles Anywhere Credential Helper Signing Test" || SIGN_STRING_TEST_VERSION || + // SHA256("IAM RA" || PUBLIC_KEY_BYTE_ARRAY) + fixedStringToSign := "AWS Roles Anywhere Credential Helper Signing Test" + + strconv.Itoa(int(SIGN_STRING_TEST_VERSION)) + string(digestSuffix) + + return fixedStringToSign +} + +var signStringCmd = &cobra.Command{ + Use: "sign-string [flags]", + Short: "Signs a fixed string using the passed-in private key (or reference to private key)", + Run: func(cmd *cobra.Command, args []string) { + var digest crypto.Hash + switch strings.ToUpper(digestArg.String()) { + case "SHA256": + digest = crypto.SHA256 + case "SHA384": + digest = crypto.SHA384 + case "SHA512": + digest = crypto.SHA512 + default: + digest = crypto.SHA256 + } + err := PopulateCredentialsOptions() + if err != nil { + log.Println(err) + os.Exit(1) + } + + helper.Debug = credentialsOptions.Debug + + var signer helper.Signer + signer, _, err = helper.GetSigner(&credentialsOptions) + if err != nil { + log.Println(err) + os.Exit(1) + } + defer signer.Close() + + var stringToSignBytes []byte + if signFixedString { + stringToSign := getFixedStringToSign(signer.Public()) + stringToSignBytes = []byte(stringToSign) + + if credentialsOptions.Debug { + fmt.Fprintln(os.Stderr, "Signing fixed string of the form: \"AWS Roles Anywhere "+ + "Credential Helper Signing Test\" || SIGN_STRING_TEST_VERSION || SHA256(\"IAM RA\" || PUBLIC_KEY_BYTE_ARRAY)\"") + } + } else { + stringToSignBytes, _ = ioutil.ReadAll(bufio.NewReader(os.Stdin)) + } + + sigBytes, err := signer.Sign(rand.Reader, stringToSignBytes, digest) + if err != nil { + log.Println("unable to sign the digest") + os.Exit(1) + } + sigStr := hex.EncodeToString(sigBytes) + switch strings.ToLower(format.String()) { + case "text": + fmt.Print(sigStr) + case "json": + buf, _ := json.Marshal(sigStr) + fmt.Print(string(buf[:])) + case "bin": + binary.Write(os.Stdout, binary.BigEndian, sigBytes[:]) + default: + fmt.Print(sigStr) + } + }, +} diff --git a/cmd/update.go b/cmd/update.go new file mode 100644 index 0000000..e7a1c72 --- /dev/null +++ b/cmd/update.go @@ -0,0 +1,37 @@ +package cmd + +import ( + "log" + "os" + + helper "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/spf13/cobra" +) + +var ( + profile string + once bool +) + +func init() { + initCredentialsSubCommand(updateCmd) + updateCmd.PersistentFlags().StringVar(&profile, "profile", "default", "profile to update") + updateCmd.PersistentFlags().BoolVar(&once, "once", false, "to update the profile just once") +} + +var updateCmd = &cobra.Command{ + Use: "update [flags]", + Short: "Updates a profile in the AWS credentials file with new AWS credentials", + Long: "Updates a profile in the AWS credentials file with new AWS credentials", + Run: func(cmd *cobra.Command, args []string) { + err := PopulateCredentialsOptions() + if err != nil { + log.Println(err) + os.Exit(1) + } + + helper.Debug = credentialsOptions.Debug + + helper.Update(credentialsOptions, profile, once) + }, +} diff --git a/cmd/version.go b/cmd/version.go new file mode 100644 index 0000000..526e831 --- /dev/null +++ b/cmd/version.go @@ -0,0 +1,24 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +var ( + Version string +) + +func init() { + rootCmd.AddCommand(versionCmd) +} + +var versionCmd = &cobra.Command{ + Use: "version", + Short: "Prints the version number of the credential helper", + Long: "Prints the version number of the credential helper", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println(Version) + }, +} diff --git a/generate-certs.sh b/generate-certs.sh deleted file mode 100755 index d948d45..0000000 --- a/generate-certs.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash - -# Simple script to generate key/digest permutations for testing -# keys are shared across certificates with the same algorithm, -# but different digests - -ec_digests="sha1 sha256 sha384 sha512" -ec_curves="prime256v1 secp384r1" - -rsa_digests="md5 sha1 sha256 sha384 sha512" -rsa_key_lengths="1024 2048 4096" - -script=$(readlink -f "$0") -basedir=$(dirname "$script") - -for c in $ec_curves; do - key_file="${basedir}/tst/certs/ec-${c}-key.pem" - openssl ecparam -name $c -genkey -out $key_file - for d in $ec_digests; do - cert_file="${basedir}/tst/certs/ec-${c}-${d}-cert.pem" - openssl req -x509 -new \ - -key $key_file \ - -out $cert_file \ - -days 365 \ - -subj "/CN=roles-anywhere-${c}-${d}" \ - -${d} - openssl pkcs8 -topk8 -inform PEM -outform PEM \ - -in ${basedir}/tst/certs/ec-${c}-key.pem \ - -out ${basedir}/tst/certs/ec-${c}-key-pkcs8.pem \ - -nocrypt - done; -done; - -for l in $rsa_key_lengths; do - key_file="${basedir}/tst/certs/rsa-${l}-key.pem" - openssl genrsa -out $key_file $l - for d in $rsa_digests; do - cert_file="${basedir}/tst/certs/rsa-${l}-${d}-cert.pem" - openssl req -x509 -new \ - -key $key_file \ - -out $cert_file \ - -days 365 \ - -subj "/CN=roles-anywhere-rsa-${l}" - openssl pkcs8 -topk8 -inform PEM -outform PEM \ - -in ${basedir}/tst/certs/rsa-${l}-key.pem \ - -out ${basedir}/tst/certs/rsa-${l}-key-pkcs8.pem \ - -nocrypt - done; -done; - -# Create certificate bundle -cp ${basedir}/tst/certs/rsa-2048-sha256-cert.pem ${basedir}/tst/certs/cert-bundle.pem -cat ${basedir}/tst/certs/ec-prime256v1-sha256-cert.pem >> ${basedir}/tst/certs/cert-bundle.pem diff --git a/go.mod b/go.mod index 1d7409f..f1e738c 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,17 @@ module github.com/aws/rolesanywhere-credential-helper go 1.18 -require github.com/aws/aws-sdk-go v1.44.57 +require ( + github.com/aws/aws-sdk-go v1.44.57 + github.com/spf13/cobra v1.6.1 + golang.org/x/crypto v0.10.0 + golang.org/x/sys v0.9.0 +) -require github.com/jmespath/go-jmespath v0.4.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/inconshreveable/mousetrap v1.0.1 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect +) diff --git a/go.sum b/go.sum index 3217287..885c299 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,11 @@ github.com/aws/aws-sdk-go v1.44.57 h1:Dx1QD+cA89LE0fVQWSov22tpnTa0znq2Feyaa/myVjg= github.com/aws/aws-sdk-go v1.44.57/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc= +github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -9,13 +13,24 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= +github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..2613c34 --- /dev/null +++ b/main.go @@ -0,0 +1,9 @@ +package main + +import ( + "github.com/aws/rolesanywhere-credential-helper/cmd" +) + +func main() { + cmd.Execute() +} diff --git a/tst/selectors/invalid-selector-2.json b/tst/selectors/invalid-selector-2.json new file mode 100644 index 0000000..345acec --- /dev/null +++ b/tst/selectors/invalid-selector-2.json @@ -0,0 +1,3 @@ +{ + "lajsdf": "lajsdlfjadf" +} diff --git a/tst/selectors/invalid-selector-3.json b/tst/selectors/invalid-selector-3.json new file mode 100644 index 0000000..0386a80 --- /dev/null +++ b/tst/selectors/invalid-selector-3.json @@ -0,0 +1,4 @@ +{ + "Key": "x509Serial", + "Value": "alskjdfa" +} diff --git a/tst/selectors/invalid-selector.json b/tst/selectors/invalid-selector.json new file mode 100644 index 0000000..82a1da2 --- /dev/null +++ b/tst/selectors/invalid-selector.json @@ -0,0 +1 @@ +asldjf;laj;laweijncaljda diff --git a/tst/selectors/valid-all-attributes-selector.json b/tst/selectors/valid-all-attributes-selector.json new file mode 100644 index 0000000..82ecac7 --- /dev/null +++ b/tst/selectors/valid-all-attributes-selector.json @@ -0,0 +1,14 @@ +[ + { + "Key": "x509Subject", + "Value": "CN=Subject" + }, + { + "Key": "x509Issuer", + "Value": "CN=Issuer" + }, + { + "Key": "x509Serial", + "Value": "15D19632234BF759A32802C0DA88F9E8AFC8702D" + } +] diff --git a/tst/selectors/valid-some-attributes-selector.json b/tst/selectors/valid-some-attributes-selector.json new file mode 100644 index 0000000..1f13f72 --- /dev/null +++ b/tst/selectors/valid-some-attributes-selector.json @@ -0,0 +1,10 @@ +[ + { + "Key": "x509Subject", + "Value": "CN=Subject" + }, + { + "Key": "x509Serial", + "Value": "15D19632234BF759A32802C0DA88F9E8AFC8702D" + } +] diff --git a/tst/softhsm/.gitkeep b/tst/softhsm/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/tst/softhsm2.conf.template b/tst/softhsm2.conf.template new file mode 100644 index 0000000..24bd81a --- /dev/null +++ b/tst/softhsm2.conf.template @@ -0,0 +1,6 @@ +# Assumes that this template is only referenced from the root of the project +directories.tokendir = @top_srcdir@/tst/softhsm +objectstore.backend = file + +loglevel = INFO +