Skip to content

Commit

Permalink
#8 Make type defaults configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
Zejnilovic authored Aug 11, 2022
1 parent 0b94b6c commit ecab7cb
Show file tree
Hide file tree
Showing 41 changed files with 291 additions and 262 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package za.co.absa.standardization

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField}
import za.co.absa.standardization.types.{TypeDefaults, TypedStructField}
import za.co.absa.standardization.validation.field.FieldValidationIssue

import scala.collection.mutable.ListBuffer
Expand All @@ -27,15 +27,13 @@ import scala.collection.mutable.ListBuffer
* Object responsible for Spark schema validation against self inconsistencies (not against the actual data)
*/
object SchemaValidator {
private implicit val defaults: Defaults = GlobalDefaults

/**
* Validate a schema
*
* @param schema A Spark schema
* @return A list of ValidationErrors objects, each containing a column name and the list of errors and warnings
*/
def validateSchema(schema: StructType): List[FieldValidationIssue] = {
def validateSchema(schema: StructType)(implicit defaults: TypeDefaults): List[FieldValidationIssue] = {
var errorsAccumulator = new ListBuffer[FieldValidationIssue]
val flatSchema = flattenSchema(schema)
for {s <- flatSchema} {
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/za/co/absa/standardization/Standardization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@ package za.co.absa.standardization

import org.apache.hadoop.conf.Configuration
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructType, _}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.slf4j.{Logger, LoggerFactory}

import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements
import za.co.absa.standardization.config.{DefaultStandardizationConfig, StandardizationConfig}
import za.co.absa.standardization.stages.{SchemaChecker, TypeParser}
import za.co.absa.standardization.types.{Defaults, GlobalDefaults, ParseOutput}
import za.co.absa.standardization.types.{CommonTypeDefaults, ParseOutput, TypeDefaults}
import za.co.absa.standardization.udf.{UDFLibrary, UDFNames}

object Standardization {
private implicit val defaults: Defaults = GlobalDefaults
private val logger: Logger = LoggerFactory.getLogger(this.getClass)
final val DefaultColumnNameOfCorruptRecord = "_corrupt_record"

Expand All @@ -41,6 +40,7 @@ object Standardization {
(implicit sparkSession: SparkSession): DataFrame = {
implicit val udfLib: UDFLibrary = new UDFLibrary(standardizationConfig)
implicit val hadoopConf: Configuration = sparkSession.sparkContext.hadoopConfiguration
implicit val defaults: TypeDefaults = standardizationConfig.typeDefaults

logger.info(s"Step 1: Schema validation")
validateSchemaAgainstSelfInconsistencies(schema)
Expand All @@ -67,15 +67,15 @@ object Standardization {


private def validateSchemaAgainstSelfInconsistencies(expSchema: StructType)
(implicit spark: SparkSession): Unit = {
(implicit spark: SparkSession, defaults: TypeDefaults): Unit = {
val validationErrors = SchemaChecker.validateSchemaAndLog(expSchema)
if (validationErrors._1.nonEmpty) {
throw new ValidationException("A fatal schema validation error occurred.", validationErrors._1)
}
}

private def standardizeDataset(df: DataFrame, expSchema: StructType, stdConfig: StandardizationConfig)
(implicit spark: SparkSession, udfLib: UDFLibrary): DataFrame = {
(implicit spark: SparkSession, udfLib: UDFLibrary, defaults: TypeDefaults): DataFrame = {

val rowErrors: List[Column] = gatherRowErrors(df.schema)
val (stdCols, errorCols, oldErrorColumn) = expSchema.fields.foldLeft(List.empty[Column], rowErrors, None: Option[Column]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

package za.co.absa.standardization.config

import za.co.absa.standardization.types.{CommonTypeDefaults, TypeDefaults}

case class BasicStandardizationConfig(failOnInputNotPerSchema: Boolean,
errorCodes: ErrorCodesConfig,
metadataColumns: MetadataColumnsConfig,
typeDefaults: TypeDefaults,
errorColumn: String,
timezone: String) extends StandardizationConfig

Expand All @@ -28,6 +31,7 @@ object BasicStandardizationConfig {
DefaultStandardizationConfig.failOnInputNotPerSchema,
BasicErrorCodesConfig.fromDefault(),
BasicMetadataColumnsConfig.fromDefault(),
CommonTypeDefaults,
DefaultStandardizationConfig.errorColumn,
DefaultStandardizationConfig.timezone
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package za.co.absa.standardization.config

import za.co.absa.standardization.types.{CommonTypeDefaults, TypeDefaults}

object DefaultStandardizationConfig extends StandardizationConfig {
val errorCodes: ErrorCodesConfig = DefaultErrorCodesConfig
val metadataColumns: MetadataColumnsConfig = DefaultMetadataColumnsConfig
val failOnInputNotPerSchema: Boolean = false
val typeDefaults: TypeDefaults = CommonTypeDefaults
val errorColumn: String = "errCol"
val timezone: String = "UTC"
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

package za.co.absa.standardization.config

import za.co.absa.standardization.types.TypeDefaults

trait StandardizationConfig {
val failOnInputNotPerSchema: Boolean
val errorCodes: ErrorCodesConfig
val metadataColumns: MetadataColumnsConfig
val typeDefaults: TypeDefaults
val errorColumn: String
val timezone: String
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.log4j.{LogManager, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import za.co.absa.standardization.SchemaValidator.{validateErrorColumn, validateSchema}
import za.co.absa.standardization.types.TypeDefaults
import za.co.absa.standardization.{ValidationError, ValidationIssue, ValidationWarning}

object SchemaChecker {
Expand All @@ -32,7 +33,7 @@ object SchemaChecker {
* @param schema A Spark schema
*/
def validateSchemaAndLog(schema: StructType)
(implicit spark: SparkSession): (Seq[String], Seq[String]) = {
(implicit spark: SparkSession, defaults: TypeDefaults): (Seq[String], Seq[String]) = {
val failures = validateSchema(schema) ::: validateErrorColumn(schema)

type ColName = String
Expand Down
38 changes: 19 additions & 19 deletions src/main/scala/za/co/absa/standardization/stages/TypeParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import za.co.absa.standardization.schema.StdSchemaUtils.FieldWithSource
import za.co.absa.standardization.time.DateTimePattern
import za.co.absa.standardization.typeClasses.{DoubleLike, LongLike}
import za.co.absa.standardization.types.TypedStructField._
import za.co.absa.standardization.types.{Defaults, ParseOutput, TypedStructField}
import za.co.absa.standardization.types.{TypeDefaults, ParseOutput, TypedStructField}
import za.co.absa.standardization.udf.{UDFBuilder, UDFLibrary, UDFNames}

import scala.reflect.runtime.universe._
Expand Down Expand Up @@ -135,7 +135,7 @@ object TypeParser {
origSchema: StructType,
stdConfig: StandardizationConfig,
failOnInputNotPerSchema: Boolean = true)
(implicit udfLib: UDFLibrary, defaults: Defaults): ParseOutput = {
(implicit udfLib: UDFLibrary, defaults: TypeDefaults): ParseOutput = {
// udfLib implicit is present for error column UDF implementation
val sourceName = SchemaUtils.appendPath(path, field.sourceName)
val origField = origSchema.getField(sourceName)
Expand All @@ -162,7 +162,7 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean = false)
(implicit defaults: Defaults): TypeParser[_] = {
(implicit defaults: TypeDefaults): TypeParser[_] = {
val parserClass: (String, Column, DataType, Boolean, Boolean) => TypeParser[_] = field.dataType match {
case _: ArrayType => ArrayParser(TypedStructField.asArrayTypeStructField(field), _, _, _, _, _)
case _: StructType => StructParser(TypedStructField.asStructTypeStructField(field), _, _, _, _, _)
Expand Down Expand Up @@ -191,7 +191,7 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults) extends TypeParser[Any] {
(implicit defaults: TypeDefaults) extends TypeParser[Any] {

override def fieldType: ArrayType = {
field.dataType
Expand Down Expand Up @@ -226,7 +226,7 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults) extends TypeParser[Any] {
(implicit defaults: TypeDefaults) extends TypeParser[Any] {
override def fieldType: StructType = {
field.dataType
}
Expand Down Expand Up @@ -260,7 +260,7 @@ object TypeParser {
}
}

private abstract class PrimitiveParser[T](implicit defaults: Defaults) extends TypeParser[T] {
private abstract class PrimitiveParser[T](implicit defaults: TypeDefaults) extends TypeParser[T] {
override protected def standardizeAfterCheck(stdConfig: StandardizationConfig)(implicit logger: Logger): ParseOutput = {
val castedCol: Column = assemblePrimitiveCastLogic
val castHasError: Column = assemblePrimitiveCastErrorLogic(castedCol)
Expand Down Expand Up @@ -298,12 +298,12 @@ object TypeParser {
}
}

private abstract class ScalarParser[T](implicit defaults: Defaults) extends PrimitiveParser[T] {
private abstract class ScalarParser[T](implicit defaults: TypeDefaults) extends PrimitiveParser[T] {
override def assemblePrimitiveCastLogic: Column = column.cast(field.dataType)
}

private abstract class NumericParser[N: TypeTag](override val field: NumericTypeStructField[N])
(implicit defaults: Defaults) extends ScalarParser[N] {
(implicit defaults: TypeDefaults) extends ScalarParser[N] {
override protected def standardizeAfterCheck(stdConfig: StandardizationConfig)(implicit logger: Logger): ParseOutput = {
if (field.needsUdfParsing) {
standardizeUsingUdf(stdConfig)
Expand Down Expand Up @@ -355,7 +355,7 @@ object TypeParser {
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean,
overflowableTypes: Set[DataType])
(implicit defaults: Defaults) extends NumericParser[N](field) {
(implicit defaults: TypeDefaults) extends NumericParser[N](field) {
override protected def assemblePrimitiveCastErrorLogic(castedCol: Column): Column = {
val basicLogic: Column = super.assemblePrimitiveCastErrorLogic(castedCol)

Expand Down Expand Up @@ -385,7 +385,7 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults)
(implicit defaults: TypeDefaults)
extends NumericParser[BigDecimal](field)
// NB! loss of precision is not addressed for any DecimalType
// e.g. 3.141592 will be Standardized to Decimal(10,2) as 3.14
Expand All @@ -396,7 +396,7 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults)
(implicit defaults: TypeDefaults)
extends NumericParser[N](field) {
override protected def assemblePrimitiveCastErrorLogic(castedCol: Column): Column = {
//NB! loss of precision is not addressed for any fractional type
Expand All @@ -414,15 +414,15 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults) extends ScalarParser[String]
(implicit defaults: TypeDefaults) extends ScalarParser[String]

private final case class BinaryParser(field: BinaryTypeStructField,
path: String,
column: Column,
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults) extends PrimitiveParser[Array[Byte]] {
(implicit defaults: TypeDefaults) extends PrimitiveParser[Array[Byte]] {
override protected def assemblePrimitiveCastLogic: Column = {
origType match {
case BinaryType => column
Expand Down Expand Up @@ -450,7 +450,7 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults) extends ScalarParser[Boolean]
(implicit defaults: TypeDefaults) extends ScalarParser[Boolean]

/**
* Timestamp conversion logic
Expand All @@ -474,7 +474,7 @@ object TypeParser {
* Date | O | ->to_utc_timestamp->to_date
* Other | ->String->to_date | ->String->to_timestamp->to_utc_timestamp->to_date
*/
private abstract class DateTimeParser[T](implicit defaults: Defaults) extends PrimitiveParser[T] {
private abstract class DateTimeParser[T](implicit defaults: TypeDefaults) extends PrimitiveParser[T] {
override val field: DateTimeTypeStructField[T]
protected val pattern: DateTimePattern = field.pattern.get.get

Expand Down Expand Up @@ -551,11 +551,11 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults) extends DateTimeParser[Date] {
(implicit defaults: TypeDefaults) extends DateTimeParser[Date] {
private val defaultTimeZone: Option[String] = field
.defaultTimeZone
.map(Option(_))
.getOrElse(defaults.getDefaultDateTimeZone)
.getOrElse(defaults.defaultDateTimeZone)

private def applyPatternToStringColumn(column: Column, pattern: String): Column = {
defaultTimeZone.map(tz =>
Expand Down Expand Up @@ -605,12 +605,12 @@ object TypeParser {
origType: DataType,
failOnInputNotPerSchema: Boolean,
isArrayElement: Boolean)
(implicit defaults: Defaults) extends DateTimeParser[Timestamp] {
(implicit defaults: TypeDefaults) extends DateTimeParser[Timestamp] {

private val defaultTimeZone: Option[String] = field
.defaultTimeZone
.map(Option(_))
.getOrElse(defaults.getDefaultTimestampTimeZone)
.getOrElse(defaults.defaultTimestampTimeZone)

private def applyPatternToStringColumn(column: Column, pattern: String): Column = {
val interim: Column = to_timestamp(column, pattern)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright 2021 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.standardization.types

import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampType}
import za.co.absa.standardization.numeric.DecimalSymbols

import java.sql.{Date, Timestamp}
import java.util.Locale
import scala.util.{Success, Try}

class CommonTypeDefaults extends TypeDefaults {
val integerTypeDefault: Int = 0
val floatTypeDefault: Float = 0f
val byteTypeDefault: Byte = 0.toByte
val shortTypeDefault: Short = 0.toShort
val doubleTypeDefault: Double = 0.0d
val longTypeDefault: Long = 0L
val stringTypeDefault: String = ""
val binaryTypeDefault: Array[Byte] = Array.empty[Byte]
val dateTypeDefault: Date = new Date(0) // Linux epoch
val timestampTypeDefault: Timestamp = new Timestamp(0)
val booleanTypeDefault: Boolean = false
val decimalTypeDefault: (Int, Int) => BigDecimal = { (precision, scale) =>
val beforeFloatingPoint = "0" * (precision - scale)
val afterFloatingPoint = "0" * scale
BigDecimal(s"$beforeFloatingPoint.$afterFloatingPoint")
}

override def defaultTimestampTimeZone: Option[String] = None
override def defaultDateTimeZone: Option[String] = None

override def getDecimalSymbols: DecimalSymbols = DecimalSymbols(Locale.US)

override def getDataTypeDefaultValue(dt: DataType): Any =
dt match {
case _: IntegerType => integerTypeDefault
case _: FloatType => floatTypeDefault
case _: ByteType => byteTypeDefault
case _: ShortType => shortTypeDefault
case _: DoubleType => doubleTypeDefault
case _: LongType => longTypeDefault
case _: StringType => stringTypeDefault
case _: BinaryType => binaryTypeDefault
case _: DateType => dateTypeDefault
case _: TimestampType => timestampTypeDefault
case _: BooleanType => booleanTypeDefault
case t: DecimalType => decimalTypeDefault(t.precision, t.scale)
case _ => throw new IllegalStateException(s"No default value defined for data type ${dt.typeName}")
}

override def getDataTypeDefaultValueWithNull(dt: DataType, nullable: Boolean): Try[Option[Any]] = {
if (nullable) {
Success(None)
} else {
Try{
getDataTypeDefaultValue(dt)
}.map(Some(_))
}
}

override def getStringPattern(dt: DataType): String = dt match {
case DateType => "yyyy-MM-dd"
case TimestampType => "yyyy-MM-dd HH:mm:ss"
case _: IntegerType
| FloatType
| ByteType
| ShortType
| DoubleType
| LongType => ""
case _: DecimalType => ""
case _ => throw new IllegalStateException(s"No default format defined for data type ${dt.typeName}")
}
}

object CommonTypeDefaults extends CommonTypeDefaults
Loading

0 comments on commit ecab7cb

Please sign in to comment.