Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS Batch - add support for custom logs group and tags #7219

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
*/
package cromwell.backend.impl.aws

import java.security.MessageDigest

import cats.data.ReaderT._
import cats.data.{Kleisli, ReaderT}
import cats.effect.{Async, Timer}
Expand All @@ -53,8 +51,9 @@ import software.amazon.awssdk.services.s3.S3Client
import software.amazon.awssdk.services.s3.model.{GetObjectRequest, HeadObjectRequest, NoSuchKeyException, PutObjectRequest}
import wdl4s.parser.MemoryUnit

import scala.jdk.CollectionConverters._
import java.security.MessageDigest
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.util.{Random, Try}

/**
Expand Down Expand Up @@ -256,7 +255,6 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
SubmitJobRequest.builder()
.jobName(sanitize(jobDescriptor.taskCall.fullyQualifiedName))
.parameters(parameters.collect({ case i: AwsBatchInput => i.toStringString }).toMap.asJava)

//provide job environment variables, vcpu and memory
.containerOverrides(
ContainerOverrides.builder
Expand All @@ -276,6 +274,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
)
.build()
)
.tags(runtimeAttributes.resourceTags.asJava)
.jobQueue(runtimeAttributes.queueArn)
.jobDefinition(definitionArn)
.build
Expand Down Expand Up @@ -462,7 +461,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
def output(detail: JobDetail): String = {
val events: Seq[OutputLogEvent] = cloudWatchLogsClient.getLogEvents(GetLogEventsRequest.builder
// http://aws-java-sdk-javadoc.s3-website-us-west-2.amazonaws.com/latest/software/amazon/awssdk/services/batch/model/ContainerDetail.html#logStreamName--
.logGroupName("/aws/batch/job")
.logGroupName(runtimeAttributes.logsGroup)
.logStreamName(detail.container.logStreamName)
.startFromHead(true)
.build).events.asScala.toList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ package cromwell.backend.impl.aws
import scala.collection.mutable.ListBuffer
import cromwell.backend.BackendJobDescriptor
import cromwell.backend.io.JobPaths
import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, Volume}
import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, LogConfiguration, MountPoint, ResourceRequirement, ResourceType, Volume}
import cromwell.backend.impl.aws.io.AwsBatchVolume

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -128,8 +128,8 @@ trait AwsBatchJobDefinitionBuilder {
)
}

def buildName(imageName: String, packedCommand: String, volumes: List[Volume], mountPoints: List[MountPoint], env: Seq[KeyValuePair]): String = {
val str = s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}"
def buildName(imageName: String, packedCommand: String, volumes: List[Volume], mountPoints: List[MountPoint], env: Seq[KeyValuePair], logsGroup: String): String = {
val str = s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}:$logsGroup"

val sha1 = MessageDigest.getInstance("SHA-1")
.digest( str.getBytes("UTF-8") )
Expand All @@ -148,26 +148,36 @@ trait AwsBatchJobDefinitionBuilder {
val packedCommand = packCommand("/bin/bash", "-c", cmdName)
val volumes = buildVolumes( context.runtimeAttributes.disks )
val mountPoints = buildMountPoints( context.runtimeAttributes.disks)
val logConfiguration = LogConfiguration.builder()
.logDriver("awslogs")
.options(
Map(
"awslogs-group" -> context.runtimeAttributes.logsGroup
).asJava
)
.build()
val jobDefinitionName = buildName(
context.runtimeAttributes.dockerImage,
packedCommand.mkString(","),
volumes,
mountPoints,
environment
environment,
context.runtimeAttributes.logsGroup
)

(builder
.command(packedCommand.asJava)
.resourceRequirements(
ResourceRequirement.builder()
.`type`(ResourceType.MEMORY)
.value(context.runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString)
.build(),
ResourceRequirement.builder()
.`type`(ResourceType.VCPU)
.value(context.runtimeAttributes.cpu.value.toString)
.build(),
)
.resourceRequirements(
ResourceRequirement.builder()
.`type`(ResourceType.MEMORY)
.value(context.runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString)
.build(),
ResourceRequirement.builder()
.`type`(ResourceType.VCPU)
.value(context.runtimeAttributes.cpu.value.toString)
.build(),
)
.logConfiguration(logConfiguration)
.volumes( volumes.asJava)
.mountPoints( mountPoints.asJava)
.environment(environment.asJava),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import wom.format.MemorySize
import wom.types._
import wom.values._

import scala.jdk.CollectionConverters._
import scala.util.matching.Regex

/**
Expand All @@ -60,6 +61,8 @@ import scala.util.matching.Regex
* @param noAddress is there no address
* @param scriptS3BucketName the s3 bucket where the execution command or script will be written and, from there, fetched into the container and executed
* @param fileSystem the filesystem type, default is "s3"
* @param logsGroup the CloudWatch log group name to write logs to
* @param resourceTags a map of tags to add to the AWS Batch job submission
*/
case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive,
zones: Vector[String],
Expand All @@ -71,7 +74,10 @@ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive,
continueOnReturnCode: ContinueOnReturnCode,
noAddress: Boolean,
scriptS3BucketName: String,
fileSystem:String= "s3")
logsGroup: String,
resourceTags: Map[String, String],
fileSystem: String = "s3") {
}

object AwsBatchRuntimeAttributes {

Expand All @@ -92,6 +98,12 @@ object AwsBatchRuntimeAttributes {

private val MemoryDefaultValue = "2 GB"

private val logsGroupKey = "logsGroup"
private val logsGroupValidationInstance = new StringRuntimeAttributesValidation(logsGroupKey)
private val LogsGroupDefaultValue = WomString("/aws/batch/job")

private val resourceTagsKey = "resourceTags"

private def cpuValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instance
.withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin)

Expand Down Expand Up @@ -123,6 +135,9 @@ object AwsBatchRuntimeAttributes {
private def noAddressValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Boolean] = noAddressValidationInstance
.withDefault(noAddressValidationInstance.configDefaultWomValue(runtimeConfig) getOrElse NoAddressDefaultValue)

private def logsGroupValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[String] = logsGroupValidationInstance
.withDefault(logsGroupValidationInstance.configDefaultWomValue(runtimeConfig) getOrElse LogsGroupDefaultValue)

private def scriptS3BucketNameValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[String] = {
ScriptS3BucketNameValidation(scriptS3BucketKey).withDefault(ScriptS3BucketNameValidation(scriptS3BucketKey)
.configDefaultWomValue(runtimeConfig).getOrElse( throw new RuntimeException( "scriptBucketName is required" )))
Expand All @@ -146,7 +161,8 @@ object AwsBatchRuntimeAttributes {
noAddressValidation(runtimeConfig),
dockerValidation,
queueArnValidation(runtimeConfig),
scriptS3BucketNameValidation(runtimeConfig)
scriptS3BucketNameValidation(runtimeConfig),
logsGroupValidation(runtimeConfig)
)
def validationsLocalBackend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
cpuValidation(runtimeConfig),
Expand All @@ -157,7 +173,8 @@ object AwsBatchRuntimeAttributes {
memoryMinValidation(runtimeConfig),
noAddressValidation(runtimeConfig),
dockerValidation,
queueArnValidation(runtimeConfig)
queueArnValidation(runtimeConfig),
logsGroupValidation(runtimeConfig)
)

configuration.fileSystem match {
Expand All @@ -181,7 +198,14 @@ object AwsBatchRuntimeAttributes {
case AWSBatchStorageSystems.s3 => RuntimeAttributesValidation.extract(scriptS3BucketNameValidation(runtimeAttrsConfig) , validatedRuntimeAttributes)
case _ => ""
}
val logsGroup: String = RuntimeAttributesValidation.extract(logsGroupValidation(runtimeAttrsConfig), validatedRuntimeAttributes)

val resourceTags: Map[String, String] = runtimeAttrsConfig.collect {
case config if config.hasPath(resourceTagsKey) =>
config.getObject(resourceTagsKey).entrySet().asScala
.map(e => e.getKey -> e.getValue.unwrapped().toString)
.toMap
}.getOrElse(Map.empty[String, String])

new AwsBatchRuntimeAttributes(
cpu,
Expand All @@ -194,6 +218,8 @@ object AwsBatchRuntimeAttributes {
continueOnReturnCode,
noAddress,
scriptS3BucketName,
logsGroup,
resourceTags,
fileSystem
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
continueOnReturnCode = ContinueOnReturnCodeFlag(false),
noAddress = false,
scriptS3BucketName = "script-bucket",
fileSystem = "s3")
fileSystem = "s3",
logsGroup = "/aws/batch/job",
resourceTags = Map("tag" -> "value"))

private def generateBasicJob: AwsBatchJob = {
val job = AwsBatchJob(null, runtimeAttributes, "commandLine", script,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,31 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout
)))
}

val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"),

val expectedDefaults = new AwsBatchRuntimeAttributes(
refineMV[Positive](1),
Vector("us-east-1a", "us-east-1b"),
MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()),
"ubuntu:latest",
"arn:aws:batch:us-east-1:111222333444:job-queue/job-queue",
false,
ContinueOnReturnCodeSet(Set(0)),
false,
"my-stuff")

val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"),
"my-stuff",
"/Cromwell/job/",
Map("tag1" -> "value1"))

val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(
refineMV[Positive](1),
Vector("us-east-1a", "us-east-1b"),
MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()),
"ubuntu:latest",
"arn:aws:batch:us-east-1:111222333444:job-queue/job-queue",
false,
ContinueOnReturnCodeSet(Set(0)),
false,
"",
"/Cromwell/job/",
Map(),
"local")

"AwsBatchRuntimeAttributes" should {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ object AwsBatchTestConfig {
| zones:["us-east-1a", "us-east-1b"]
| queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue"
| scriptBucketName: "my-bucket"
| logsGroup: "/Cromwell/job/"
| resourceTags {
| tag1: "value1"
| }
|}
|
|""".stripMargin
Expand Down Expand Up @@ -140,6 +144,7 @@ object AwsBatchTestConfigForLocalFS {
| zones:["us-east-1a", "us-east-1b"]
| queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue"
| scriptBucketName: ""
| logsGroup: "/Cromwell/job/"
|}
|
|""".stripMargin
Expand Down Expand Up @@ -190,4 +195,4 @@ object AwsBatchTestConfigForLocalFS {
val AwsBatchBackendNoDefaultConfig = ConfigFactory.parseString(NoDefaultsConfigString)
val AwsBatchBackendConfigurationDescriptor = BackendConfigurationDescriptor(AwsBatchBackendConfig, AwsBatchGlobalConfig)
val NoDefaultsConfigurationDescriptor = BackendConfigurationDescriptor(AwsBatchBackendNoDefaultConfig, AwsBatchGlobalConfig)
}
}