diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala index de4617ad786..3bb4c4f7acb 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJob.scala @@ -30,8 +30,6 @@ */ package cromwell.backend.impl.aws -import java.security.MessageDigest - import cats.data.ReaderT._ import cats.data.{Kleisli, ReaderT} import cats.effect.{Async, Timer} @@ -53,8 +51,9 @@ import software.amazon.awssdk.services.s3.S3Client import software.amazon.awssdk.services.s3.model.{GetObjectRequest, HeadObjectRequest, NoSuchKeyException, PutObjectRequest} import wdl4s.parser.MemoryUnit -import scala.jdk.CollectionConverters._ +import java.security.MessageDigest import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ import scala.util.{Random, Try} /** @@ -256,7 +255,6 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL SubmitJobRequest.builder() .jobName(sanitize(jobDescriptor.taskCall.fullyQualifiedName)) .parameters(parameters.collect({ case i: AwsBatchInput => i.toStringString }).toMap.asJava) - //provide job environment variables, vcpu and memory .containerOverrides( ContainerOverrides.builder @@ -276,6 +274,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL ) .build() ) + .tags(runtimeAttributes.resourceTags.asJava) .jobQueue(runtimeAttributes.queueArn) .jobDefinition(definitionArn) .build @@ -462,7 +461,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL def output(detail: JobDetail): String = { val events: Seq[OutputLogEvent] = cloudWatchLogsClient.getLogEvents(GetLogEventsRequest.builder // http://aws-java-sdk-javadoc.s3-website-us-west-2.amazonaws.com/latest/software/amazon/awssdk/services/batch/model/ContainerDetail.html#logStreamName-- - .logGroupName("/aws/batch/job") + .logGroupName(runtimeAttributes.logsGroup) .logStreamName(detail.container.logStreamName) .startFromHead(true) .build).events.asScala.toList diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala index 549d3c65185..757dc382156 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchJobDefinition.scala @@ -34,7 +34,7 @@ package cromwell.backend.impl.aws import scala.collection.mutable.ListBuffer import cromwell.backend.BackendJobDescriptor import cromwell.backend.io.JobPaths -import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, Volume} +import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, LogConfiguration, MountPoint, ResourceRequirement, ResourceType, Volume} import cromwell.backend.impl.aws.io.AwsBatchVolume import scala.jdk.CollectionConverters._ @@ -128,8 +128,8 @@ trait AwsBatchJobDefinitionBuilder { ) } - def buildName(imageName: String, packedCommand: String, volumes: List[Volume], mountPoints: List[MountPoint], env: Seq[KeyValuePair]): String = { - val str = s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}" + def buildName(imageName: String, packedCommand: String, volumes: List[Volume], mountPoints: List[MountPoint], env: Seq[KeyValuePair], logsGroup: String): String = { + val str = s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}:$logsGroup" val sha1 = MessageDigest.getInstance("SHA-1") .digest( str.getBytes("UTF-8") ) @@ -148,26 +148,36 @@ trait AwsBatchJobDefinitionBuilder { val packedCommand = packCommand("/bin/bash", "-c", cmdName) val volumes = buildVolumes( context.runtimeAttributes.disks ) val mountPoints = buildMountPoints( context.runtimeAttributes.disks) + val logConfiguration = LogConfiguration.builder() + .logDriver("awslogs") + .options( + Map( + "awslogs-group" -> context.runtimeAttributes.logsGroup + ).asJava + ) + .build() val jobDefinitionName = buildName( context.runtimeAttributes.dockerImage, packedCommand.mkString(","), volumes, mountPoints, - environment + environment, + context.runtimeAttributes.logsGroup ) (builder .command(packedCommand.asJava) - .resourceRequirements( - ResourceRequirement.builder() - .`type`(ResourceType.MEMORY) - .value(context.runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString) - .build(), - ResourceRequirement.builder() - .`type`(ResourceType.VCPU) - .value(context.runtimeAttributes.cpu.value.toString) - .build(), - ) + .resourceRequirements( + ResourceRequirement.builder() + .`type`(ResourceType.MEMORY) + .value(context.runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString) + .build(), + ResourceRequirement.builder() + .`type`(ResourceType.VCPU) + .value(context.runtimeAttributes.cpu.value.toString) + .build(), + ) + .logConfiguration(logConfiguration) .volumes( volumes.asJava) .mountPoints( mountPoints.asJava) .environment(environment.asJava), diff --git a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala index ca0eeb9b10c..9578a093769 100755 --- a/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala +++ b/supportedBackends/aws/src/main/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributes.scala @@ -45,6 +45,7 @@ import wom.format.MemorySize import wom.types._ import wom.values._ +import scala.jdk.CollectionConverters._ import scala.util.matching.Regex /** @@ -60,6 +61,8 @@ import scala.util.matching.Regex * @param noAddress is there no address * @param scriptS3BucketName the s3 bucket where the execution command or script will be written and, from there, fetched into the container and executed * @param fileSystem the filesystem type, default is "s3" + * @param logsGroup the CloudWatch log group name to write logs to + * @param resourceTags a map of tags to add to the AWS Batch job submission */ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive, zones: Vector[String], @@ -71,7 +74,10 @@ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive, continueOnReturnCode: ContinueOnReturnCode, noAddress: Boolean, scriptS3BucketName: String, - fileSystem:String= "s3") + logsGroup: String, + resourceTags: Map[String, String], + fileSystem: String = "s3") { +} object AwsBatchRuntimeAttributes { @@ -92,6 +98,12 @@ object AwsBatchRuntimeAttributes { private val MemoryDefaultValue = "2 GB" + private val logsGroupKey = "logsGroup" + private val logsGroupValidationInstance = new StringRuntimeAttributesValidation(logsGroupKey) + private val LogsGroupDefaultValue = WomString("/aws/batch/job") + + private val resourceTagsKey = "resourceTags" + private def cpuValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instance .withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin) @@ -123,6 +135,9 @@ object AwsBatchRuntimeAttributes { private def noAddressValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Boolean] = noAddressValidationInstance .withDefault(noAddressValidationInstance.configDefaultWomValue(runtimeConfig) getOrElse NoAddressDefaultValue) + private def logsGroupValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[String] = logsGroupValidationInstance + .withDefault(logsGroupValidationInstance.configDefaultWomValue(runtimeConfig) getOrElse LogsGroupDefaultValue) + private def scriptS3BucketNameValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[String] = { ScriptS3BucketNameValidation(scriptS3BucketKey).withDefault(ScriptS3BucketNameValidation(scriptS3BucketKey) .configDefaultWomValue(runtimeConfig).getOrElse( throw new RuntimeException( "scriptBucketName is required" ))) @@ -146,7 +161,8 @@ object AwsBatchRuntimeAttributes { noAddressValidation(runtimeConfig), dockerValidation, queueArnValidation(runtimeConfig), - scriptS3BucketNameValidation(runtimeConfig) + scriptS3BucketNameValidation(runtimeConfig), + logsGroupValidation(runtimeConfig) ) def validationsLocalBackend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation( cpuValidation(runtimeConfig), @@ -157,7 +173,8 @@ object AwsBatchRuntimeAttributes { memoryMinValidation(runtimeConfig), noAddressValidation(runtimeConfig), dockerValidation, - queueArnValidation(runtimeConfig) + queueArnValidation(runtimeConfig), + logsGroupValidation(runtimeConfig) ) configuration.fileSystem match { @@ -181,7 +198,14 @@ object AwsBatchRuntimeAttributes { case AWSBatchStorageSystems.s3 => RuntimeAttributesValidation.extract(scriptS3BucketNameValidation(runtimeAttrsConfig) , validatedRuntimeAttributes) case _ => "" } + val logsGroup: String = RuntimeAttributesValidation.extract(logsGroupValidation(runtimeAttrsConfig), validatedRuntimeAttributes) + val resourceTags: Map[String, String] = runtimeAttrsConfig.collect { + case config if config.hasPath(resourceTagsKey) => + config.getObject(resourceTagsKey).entrySet().asScala + .map(e => e.getKey -> e.getValue.unwrapped().toString) + .toMap + }.getOrElse(Map.empty[String, String]) new AwsBatchRuntimeAttributes( cpu, @@ -194,6 +218,8 @@ object AwsBatchRuntimeAttributes { continueOnReturnCode, noAddress, scriptS3BucketName, + logsGroup, + resourceTags, fileSystem ) } diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala index e9a5f1f86d9..a45b3fbda74 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchJobSpec.scala @@ -112,7 +112,9 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi continueOnReturnCode = ContinueOnReturnCodeFlag(false), noAddress = false, scriptS3BucketName = "script-bucket", - fileSystem = "s3") + fileSystem = "s3", + logsGroup = "/aws/batch/job", + resourceTags = Map("tag" -> "value")) private def generateBasicJob: AwsBatchJob = { val job = AwsBatchJob(null, runtimeAttributes, "commandLine", script, diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala index fbab5a81a23..5988dda3e7e 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchRuntimeAttributesSpec.scala @@ -56,18 +56,22 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout ))) } - val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"), - + val expectedDefaults = new AwsBatchRuntimeAttributes( + refineMV[Positive](1), + Vector("us-east-1a", "us-east-1b"), MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()), "ubuntu:latest", "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue", false, ContinueOnReturnCodeSet(Set(0)), false, - "my-stuff") - - val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"), + "my-stuff", + "/Cromwell/job/", + Map("tag1" -> "value1")) + val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes( + refineMV[Positive](1), + Vector("us-east-1a", "us-east-1b"), MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()), "ubuntu:latest", "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue", @@ -75,6 +79,8 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout ContinueOnReturnCodeSet(Set(0)), false, "", + "/Cromwell/job/", + Map(), "local") "AwsBatchRuntimeAttributes" should { diff --git a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala index 38545c7e472..45974ef6c90 100644 --- a/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala +++ b/supportedBackends/aws/src/test/scala/cromwell/backend/impl/aws/AwsBatchTestConfig.scala @@ -61,6 +61,10 @@ object AwsBatchTestConfig { | zones:["us-east-1a", "us-east-1b"] | queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue" | scriptBucketName: "my-bucket" + | logsGroup: "/Cromwell/job/" + | resourceTags { + | tag1: "value1" + | } |} | |""".stripMargin @@ -140,6 +144,7 @@ object AwsBatchTestConfigForLocalFS { | zones:["us-east-1a", "us-east-1b"] | queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue" | scriptBucketName: "" + | logsGroup: "/Cromwell/job/" |} | |""".stripMargin @@ -190,4 +195,4 @@ object AwsBatchTestConfigForLocalFS { val AwsBatchBackendNoDefaultConfig = ConfigFactory.parseString(NoDefaultsConfigString) val AwsBatchBackendConfigurationDescriptor = BackendConfigurationDescriptor(AwsBatchBackendConfig, AwsBatchGlobalConfig) val NoDefaultsConfigurationDescriptor = BackendConfigurationDescriptor(AwsBatchBackendNoDefaultConfig, AwsBatchGlobalConfig) -} \ No newline at end of file +}