diff --git a/cloudSupport/src/main/scala/cromwell/cloudsupport/gcp/GoogleConfiguration.scala b/cloudSupport/src/main/scala/cromwell/cloudsupport/gcp/GoogleConfiguration.scala index 2b4a183c121..fd6527a548c 100644 --- a/cloudSupport/src/main/scala/cromwell/cloudsupport/gcp/GoogleConfiguration.scala +++ b/cloudSupport/src/main/scala/cromwell/cloudsupport/gcp/GoogleConfiguration.scala @@ -82,6 +82,11 @@ object GoogleConfiguration { UserServiceAccountMode(name) } + def userServiceAccountImpersonationAuth(authConfig: Config, name: String): ErrorOr[GoogleAuthMode] = validate { + val jsonFileOpt: Option[JsonFileFormat] = authConfig.getAs[String]("json-file").map(JsonFileFormat) + UserServiceAccountImpersonationMode(name, jsonFileOpt) + } + val name = authConfig.getString("name") val scheme = authConfig.getString("scheme") scheme match { @@ -89,6 +94,7 @@ object GoogleConfiguration { case "user_account" => userAccountAuth(authConfig, name) case "application_default" => applicationDefaultAuth(name) case "user_service_account" => userServiceAccountAuth(name) + case "user_service_account_impersonation" => userServiceAccountImpersonationAuth(authConfig, name) case "mock" => MockAuthMode(name).validNel case wut => s"Unsupported authentication scheme: $wut".invalidNel } diff --git a/cloudSupport/src/main/scala/cromwell/cloudsupport/gcp/auth/GoogleAuthMode.scala b/cloudSupport/src/main/scala/cromwell/cloudsupport/gcp/auth/GoogleAuthMode.scala index e850b53807a..99a4ac4b56d 100644 --- a/cloudSupport/src/main/scala/cromwell/cloudsupport/gcp/auth/GoogleAuthMode.scala +++ b/cloudSupport/src/main/scala/cromwell/cloudsupport/gcp/auth/GoogleAuthMode.scala @@ -3,14 +3,13 @@ package cromwell.cloudsupport.gcp.auth import java.io.{ByteArrayInputStream, FileNotFoundException, InputStream} import java.net.HttpURLConnection._ import java.nio.charset.StandardCharsets - import better.files.File import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport import com.google.api.client.http.{HttpResponseException, HttpTransport} import com.google.api.client.json.gson.GsonFactory import com.google.auth.Credentials import com.google.auth.http.HttpTransportFactory -import com.google.auth.oauth2.{GoogleCredentials, OAuth2Credentials, ServiceAccountCredentials, UserCredentials} +import com.google.auth.oauth2.{GoogleCredentials, ImpersonatedCredentials, OAuth2Credentials, ServiceAccountCredentials, UserCredentials} import com.google.cloud.NoCredentials import com.typesafe.scalalogging.LazyLogging import cromwell.cloudsupport.gcp.auth.ApplicationDefaultMode.applicationDefaultCredentials @@ -43,6 +42,7 @@ object GoogleAuthMode { lazy val HttpTransportFactory: HttpTransportFactory = () => httpTransport val UserServiceAccountKey = "user_service_account_json" + val UserServiceAccountEmailKey = "user_service_account_email" val DockerCredentialsEncryptionKeyNameKey = "docker_credentials_key_name" val DockerCredentialsTokenKey = "docker_credentials_token" @@ -73,6 +73,18 @@ object GoogleAuthMode { private def refreshCredentials(credentials: Credentials): Unit = { credentials.refresh() } + + def createServiceAccountCredentials(fileFormat: CredentialFileFormat): ServiceAccountCredentials = { + val credentialsFile = File(fileFormat.file) + checkReadable(credentialsFile) + + fileFormat match { + case PemFileFormat(accountId, _) => + ServiceAccountCredentials.fromPkcs8(accountId, accountId, credentialsFile.contentAsString, null, null) + case _: JsonFileFormat => ServiceAccountCredentials.fromStream(credentialsFile.newInputStream) + } + } + } sealed trait GoogleAuthMode extends LazyLogging { @@ -115,12 +127,18 @@ sealed trait GoogleAuthMode extends LazyLogging { */ private[auth] var credentialsValidation: CredentialsValidation = refreshCredentials - protected def validateCredentials[A <: GoogleCredentials](credential: A, - scopes: Iterable[String]): GoogleCredentials = { - val scopedCredentials = credential.createScoped(scopes.asJavaCollection) - Try(credentialsValidation(scopedCredentials)) match { - case Failure(ex) => throw new RuntimeException(s"Google credentials are invalid: ${ex.getMessage}", ex) - case Success(_) => scopedCredentials + protected def validateCredentials[A <: GoogleCredentials]( + credential: A, + scopes: Iterable[String] + ): GoogleCredentials = { + val credentialsToValidate = + if (scopes != null) credential.createScoped(scopes.asJavaCollection) + else credential + Try(credentialsValidation(credentialsToValidate)) match { + case Failure(ex) => + throw new RuntimeException(s"Google credentials are invalid: ${ex.getMessage}", ex) + case Success(_) => + credentialsToValidate } } } @@ -149,14 +167,10 @@ final case class ServiceAccountMode(override val name: String, private val credentialsFile = File(fileFormat.file) checkReadable(credentialsFile) - private lazy val serviceAccountCredentials: ServiceAccountCredentials = { - fileFormat match { - case PemFileFormat(accountId, _) => - logger.warn("The PEM file format will be deprecated in the upcoming Cromwell version. Please use JSON instead.") - ServiceAccountCredentials.fromPkcs8(accountId, accountId, credentialsFile.contentAsString, null, null) - case _: JsonFileFormat => ServiceAccountCredentials.fromStream(credentialsFile.newInputStream) - } + if (fileFormat.isInstanceOf[PemFileFormat]) { + logger.warn("The PEM file format will be deprecated in the upcoming Cromwell version. Please use JSON instead.") } + private lazy val serviceAccountCredentials: ServiceAccountCredentials = createServiceAccountCredentials(fileFormat) override def credentials(unusedOptions: OptionLookup, scopes: Iterable[String]): GoogleCredentials = { @@ -206,6 +220,37 @@ final case class ApplicationDefaultMode(name: String) extends GoogleAuthMode { } } +final case class UserServiceAccountImpersonationMode( + override val name: String, + jsonFileFormat: Option[JsonFileFormat] = None // Optional credential file format +) extends GoogleAuthMode { + + private def extractServiceAccount(options: OptionLookup): String = { + extract(options, UserServiceAccountEmailKey) + } + + override def credentials(options: OptionLookup, scopes: Iterable[String]): GoogleCredentials = { + // Credentials for the source service account that should have + // roles/iam.serviceAccountTokenCreator on the target service account + val credentials = jsonFileFormat match { + case Some(format) => createServiceAccountCredentials(format) + case None => GoogleCredentials.getApplicationDefault + } + + val impersonatedCredentials = ImpersonatedCredentials.create( + credentials, + extractServiceAccount(options), + null, + scopes.toList.asJava, + 3600 + ) + + // We don't pass in scopes because they are added to the credentials + // when we create ImpersonatedCredentials above. + validateCredentials(impersonatedCredentials, null) + } +} + sealed trait ClientSecrets { val clientId: String val clientSecret: String diff --git a/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/GoogleConfigurationSpec.scala b/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/GoogleConfigurationSpec.scala index 95a2380034a..124c811df3b 100644 --- a/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/GoogleConfigurationSpec.scala +++ b/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/GoogleConfigurationSpec.scala @@ -54,7 +54,16 @@ class GoogleConfigurationSpec extends AnyFlatSpec with CromwellTimeoutSpec with | { | name = "name-user-service-account" | scheme = "user_service_account" - | } + | }, + | { + | name = "name-user-service-account-impersonation" + | scheme = "user_service_account_impersonation" + | }, + | { + | name = "name-user-service-account-impersonation_json" + | scheme = "user_service_account_impersonation" + | json-file = "${jsonMockFile.pathAsString}" + | }, | ] |} | @@ -63,7 +72,7 @@ class GoogleConfigurationSpec extends AnyFlatSpec with CromwellTimeoutSpec with val gconf = GoogleConfiguration(ConfigFactory.parseString(righteousGoogleConfig)) gconf.applicationName shouldBe "cromwell" - gconf.authsByName should have size 5 + gconf.authsByName should have size 7 val auths = gconf.authsByName.values @@ -87,6 +96,17 @@ class GoogleConfigurationSpec extends AnyFlatSpec with CromwellTimeoutSpec with serviceJson.fileFormat.isInstanceOf[JsonFileFormat] shouldBe true serviceJson.fileFormat.file shouldBe jsonMockFile.pathAsString + val serviceImpersonation = (auths collectFirst { case a: UserServiceAccountImpersonationMode if a.name == "name-user-service-account-impersonation" => a }).get + serviceImpersonation.name shouldBe "name-user-service-account-impersonation" + + val serviceImpersonationJson = (auths collectFirst { case a: UserServiceAccountImpersonationMode if a.name == "name-user-service-account-impersonation_json" => a }).get + serviceImpersonationJson.name shouldBe "name-user-service-account-impersonation_json" + serviceImpersonationJson.jsonFileFormat.isInstanceOf[Option[JsonFileFormat]] shouldBe true + serviceImpersonationJson.jsonFileFormat match { + case Some(format) => format.file shouldBe jsonMockFile.pathAsString + case None => fail("jsonFileFormat should be Some") + } + pemMockFile.delete(swallowIOExceptions = true) jsonMockFile.delete(swallowIOExceptions = true) } diff --git a/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/auth/GoogleAuthModeSpec.scala b/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/auth/GoogleAuthModeSpec.scala index 368ce5d2472..5d078bdbdd8 100644 --- a/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/auth/GoogleAuthModeSpec.scala +++ b/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/auth/GoogleAuthModeSpec.scala @@ -66,4 +66,5 @@ object GoogleAuthModeSpec extends ServiceAccountTestSupport { lazy val refreshTokenOptions: OptionLookup = Map("refresh_token" -> "the_refresh_token") lazy val userServiceAccountOptions: OptionLookup = Map("user_service_account_json" -> serviceAccountJsonContents) + lazy val userServiceAccountImpersonationOptions: OptionLookup = Map("user_service_account_email" -> "the email") } diff --git a/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/auth/UserServiceAccountImpersonationModeSpec.scala b/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/auth/UserServiceAccountImpersonationModeSpec.scala new file mode 100644 index 00000000000..4d897d0db29 --- /dev/null +++ b/cloudSupport/src/test/scala/cromwell/cloudsupport/gcp/auth/UserServiceAccountImpersonationModeSpec.scala @@ -0,0 +1,24 @@ +package cromwell.cloudsupport.gcp.auth + +import common.assertion.CromwellTimeoutSpec +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class UserServiceAccountImpersonationModeSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matchers { + + behavior of "UserServiceAccountImpersonationMode" + + it should "generate a non-validated credential" in { + val impersonationMode = UserServiceAccountImpersonationMode("user-service-account-impersonation") + val workflowOptions = GoogleAuthModeSpec.userServiceAccountImpersonationOptions + impersonationMode.credentialsValidation = GoogleAuthMode.NoCredentialsValidation + val credentials = impersonationMode.credentials(workflowOptions) + credentials.getAuthenticationType should be("OAuth2") + } + + it should "fail to generate credentials without a user_service_account_email workflow option" in { + val impersonationMode = UserServiceAccountImpersonationMode("user-service-account-impersonation") + val exception = intercept[OptionLookupException](impersonationMode.credentials()) + exception.getMessage should be("user_service_account_email") + } +}