Menü

Capturing data pipeline errors functionally with Writer Monads

As businesses grapple with vast quantities of data, emerging from batch-based and streaming sources, it’s truly exciting to see the dominant data processing frameworks embrace the Kappa Architecture, that unifies batch and stream processing. With Kappa Architecture, the batch processing is completely removed and is treated as a special case of streaming.  

As a data engineer, this means that I don't need to build two separate data pipelines — one for batch and one for streaming.  Now that it’s going to be a single unified pipeline, I also needn’t worry about how to reuse the code or where to draw modular boundaries.  

For Spark engineers, Spark Structured Streaming proves to be an invaluable tool to implement Kappa Architecture enabling engineers to focus on the business logic.

In this article, I’ll be covering one specific issue that arises when writing data pipelines using Spark — capturing transformation errors. The article also suggests two ways to solve the problem. This write up is structured in the following way:
  1. Capturing transformation errors on streaming applications isn’t easy
  2. Using Datasets for error collection
  3. Introducing Writer Monads in a rush
  4. Functionally collecting errors in Datasets using Writer monads

Capturing transformation errors on streaming applications isn’t easy

When I say "logging is painful," I’m not talking about the tracing or job failure logs that we use log4j or other log frameworks for. I’m talking about capturing the transformation/validation errors while preparing and structuring the source data.

If you’re reading this, it’s highly likely you’ve already written your fair share of data pipelines — possibly with Spark or other distributed processing frameworks. You may be loading or streaming from a bunch of sources, applying a variety of transformations and finally pushing the "refined" data to a few sinks.

Typically, we expect transformation/validation errors to be written to a different store — either as a table/structured data format in case of a batch or into a topic/error stream in case of a stream.
Historically, this has been accumulated in a variety of ways including:
  1. Using Accumulators or materialized collection that's collected at the Driver
  2. Appending an error column in the source Dataframe and collecting at the end of each transformation/transformation chain
  3. Making side-effecting IO calls to log the errors to an external datastore
The problem with the collection of errors with Accumulators or other non-distributed collection at the Driver program is that in cases where the size of the errors exceeds the memory of the Driver, the entire job fails.

With the case of an extra column at the Dataframe, comes the question of whether there's going to be a new column added for every single transformation stage and how we’re going to chain the transformation steps in a way that allows only the successful records of the previous stage to go to the next stage.

Also, collecting transformation errors per partition and pushing to a datastore isn’t that terrible compared to the other options, but then it’s neither idiomatic Scala nor Spark.

Using Datasets for error collection

What if we could use Spark Dataset for collection? Then we wouldn’t need to worry about lack of memory on the Driver, right? Surprisingly, it is not that difficult to achieve this. Let's take a first stab at this.
Let's assume that we have the following transformation/validation stages:

val pipelineStages = List(
 new AddRowKeyStage(EvergreenSchema),
 new WriteToHBaseForLanding(hBaseCatalog),
 new ReplaceCharDataStage(DoubleColsReplaceMap, EvergreenSchema, DoubleCols),
 new ReplaceCharDataStage(SpecialCharMap, EvergreenSchema, StringCols),
 new DataTypeValidatorStage(EvergreenSchema),
 new DataTypeCastStage(sourceRawDf.schema, EvergreenSchema)
)


Each of the DataStage needs to override an apply function that accepts two parameters:
  1. An error Dataset
  2. The Dataframe that needs to be applied this transformation
import com.thoughtworks.awayday.ingest.models.ErrorModels.DataError
import org.apache.spark.sql.{DataFrame, Dataset}

trait DataStage[T < : Dataset[_]] extends Serializable {
 def apply(errors: Dataset[DataError], dataRecords: T): (Dataset[DataError], DataFrame)
 def stage: String
}


The DataError itself is just a simple case class that wraps all the essential information that a transformation error must have.
case class DataError(rowKey: String, stage: String, fieldName: String, fieldValue: String, error: String, severity: String, addlInfo: String = "")


A typical implementation of a DataStage would look something like this:

import com.thoughtworks.awayday.ingest.DataFrameOps
import com.thoughtworks.awayday.ingest.UDFs.generateUUID
import com.thoughtworks.awayday.ingest.models.ErrorModels.DataError
import com.thoughtworks.awayday.ingest.stages.StageConstants.RowKey
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, SparkSession}


class AddRowKeyStage(schemaWithRowKey: StructType)(implicit spark: SparkSession, encoder: Encoder[DataError]) extends DataStage[DataFrame] {

  override val stage: String = getClass.getSimpleName

  def apply(errors: Dataset[DataError], dataRecords: DataFrame): (Dataset[DataError],DataFrame) = addRowKeys(errors, dataRecords)

  def addRowKeys(errors: Dataset[DataError], data: DataFrame): (Dataset[DataError],DataFrame) = {
    val colOrder = schemaWithRowKey.fields.map(_.name)
    val returnDf = data.withColumn(RowKey, lit(generateUUID())).select(colOrder.map(col): _*)
    (errors.union(DataFrameOps.emptyErrorStream(spark)), returnDf)
  }
}


Now that we have our fixtures ready, we could happily sequence the transformation with the stages.
val (initErr, initDf) = (DataFrameOps.emptyErrorStream(spark), sourceRawDf)

val validRecords = pipelineStages.foldLeft((initErr,initDf)) { case ((err, df), stage) =>
   stage(err, df)
}


Tuple result
Figure 1: Output through Method One
This seems alright, but there are two issues that bother me with this approach:
  1. The idea of passing the previous errors to a next transformation doesn't look good because a transformation is an independent unit of code and it needn’t be aware of the previous transformations
  2. I kept forgetting the union of the previous errors everytime I return the (error, value) tuple from a DataStage. I just care about the errors of this stage, and I shouldn’t need to worry about merging the errors of the current transformation with the errors of the previous transformations
Yes, there's a way to get away from the above two  issues that we have. Writer monads does exactly that.

Introducing Writer Monads in a rush

(If you are familiar with Monads and Writer Monads, you can skip right through to the next section)
Monads are the Functional Programming way to sequence computations within a context. To quote, Mastering Advanced Scala, "They are often used to build data processing pipelines."
Say, we have three functions:

def getCurrentTemperature():Future[Double] = ??? //1
def getTomorrowsTempFromPredictionAPI(curr: Double): Future[Double] = ??? //2
def publishItInOurWebsite(pred: Double):Future[Double] = ??? //3
  1. The getCurrentTemperature fetches the current temperature using a temperature service
  2. The getTomorrowsTempFromPredictionAPI function uses the current temperature to make a call to a Prediction service to predict tomorrow's temperature
  3. The publishItInOurWebsite function publishes this predicted temperature to our website
Each of these functions makes an async call and is wrapped inside a Future, a popular Monadic type. We’d want to chain these function calls, which could be done using the flatMap and the map functions but the code looks much cleaner with the for comprehension syntactic sugar like so:

 val published2:Future[Double] =
     for {
       curr < - getCurrentTemperature()
       pred < - getTomorrowsTempFromPredictionAPI(curr)
       pubw < - publishItInOurWebsite(pred)
     } yield pubw

     
Now, Writer Monads are just specialized monads. They’re used for sequencing computations but at the same time, these Monads can not only return the result of the computation but also give some "extra information." A typical use would be "logs," but they can be anything. To reiterate, the Writer monad has a "collection part" and a "value" part.

For this example and subsequent examples, I am using the functional programming library Cats, but this could be easily swapped with ScalaZLet's go ahead and modify the function to return a Writer instead of the Future.

For the sake of simplicity, I have the function calls just return Double directly instead of a Future[Double]. If you are keen to see the non-simplified version of the code with Future[Double] wrapped in Writer, please refer to this gist.

def getCurrentTemperatureW(): Writer[List[String], Double] = {
   Writer(List("Thermometer isn't broken yet"), 10.0)
}

def getTomorrowsTempFromPredictionAPIW(curr: Double): Writer[List[String], Double] = {
   Writer(List("Yay, the Prediction API works too"), 20.0)
}

def publishItInOurWebsiteW(pred: Double): Writer[List[String], Double] = {
   Writer(List("Published to our website"), 20.0)
}


Now, we can chain the functions just like before with Future

val publishedWriter: Writer[List[String], Double] =
 for {
   curr < - getCurrentTemperatureW()
   pred < - getTomorrowsTempFromPredictionAPIW(curr)
   pubw < - publishItInOurWebsiteW(pred)
 } yield pubw


Note that the publishedWriter is of type Writer[List[String], Double]. The way that we could get the final value and the logs is by calling the run function on the Writer.

val (logs, value) = publishedWriter.run

logs.foreach(println)
println (value)


As we see from the output below, at the end of the chain of computations, along with the value, the logs get accumulated as a collection.

Thermometer isn't broken yet
Yay, the Prediction API works too
Published to our website
20.0

Functionally collecting errors in Datasets using Writer Monads

Coming back to the original discussion,  we’d like to collect the transformation errors of the pipeline into a collection. With the Writer Monad example above, if we were to use Lists for the "collection part," the problem becomes similar to that of the Accumulator where we collect all the logs on the Driver.   


How do we do this?

If we open up the Writer Monad's flatMap function, the expectation of the "collection part" is that it must have evidence of a Semigroup instance. Now, Semigroup is a category that has the combine function — the one that takes two values and combines them into one.

def flatMap[U](f: V => WriterT[F, L, U])(implicit flatMapF: FlatMap[F], semigroupL: Semigroup[L]): WriterT[F, L, U] =
WriterT {
 flatMapF.flatMap(run) { lv =>
   flatMapF.map(f(lv._2).run) { lv2 =>
     (semigroupL.combine(lv._1, lv2._1), lv2._2)
   }
 }
}


Now, all we that we need to do is to write a Semigroup instance for the Dataset. And the implementation couldn't be easier too — we just delegate the call to the union function of the Dataset. That's it!

object DataFrameOps {
   ...
   ...
   implicit val dataFrameSemigroup: Semigroup[Dataset[_]] = new Semigroup[Dataset[_]] {
       override def combine(x: Dataset[_], y: Dataset[_]): Dataset[_] = x.union(y)
   }
}


Tying things together

There are some minor changes that we would need to make on the DataStages, but the pipeline would remain as is.
DataStage would now return a Writer instead of a tuple (DataSet[DataError],Dataframe).  I am type aliasing the Writer as DataSetWithErrors to look easy on the eyes:

type DataSetWithErrors[A] = Writer[Dataset[DataError], A]
trait DataStage[T < : Dataset[_]] extends Serializable {
 def apply(data: T): DataSetWithErrors[T]
 def stage: String
}


As for the DataStage, each implementation's apply function accepts a Dataframe and returns the DataSetWithErrors.

Here's an example implementation of the DataTypeValidatorStage that validates each cell against the expected datatype defined in the StructType. Note that there's neither a union call nor the errors accepted as a parameter in the DataStage.

import cats.data.Writer
import com.thoughtworks.awayday.ingest.UDFs._
import com.thoughtworks.awayday.ingest.models.ErrorModels.{DataError, DataSetWithErrors}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import StageConstants._

class DataTypeValidatorStage(schema: StructType)(implicit val spark: SparkSession) extends DataStage[DataFrame] {
 override val stage = getClass.getSimpleName
 def apply(dataRecords: DataFrame): DataSetWithErrors[DataFrame] = validateTypes(dataRecords)

 def validateTypes(data: DataFrame): DataSetWithErrors[DataFrame] = {
   val withErrorsDF = data.withColumn(RowLevelErrorListCol, validateRowUDF(schema, stage)(struct(data.columns.map(data(_)): _*)))

   import spark.implicits._

   val errorRecords =
     withErrorsDF
       .select(RowLevelErrorListCol)
       .select(explode(col(RowLevelErrorListCol)))
       .select("col.*")
       .map(row = > DataError(row))

   Writer(errorRecords, withErrorsDF.drop(RowLevelErrorListCol))
 }
}


Now, that we have the stages and the stage implementations are done, all we need to do is to chain them away like the awesome Monads that they are.

import DataFrameOps._

val initDf = Writer(DataFrameOps.emptyErrorStream(spark), sourceRawDf)

val validRecords = pipelineStages.foldLeft(initDf) { case (dfWithErrors, stage) =>
 for {
   df < - dfWithErrors
   applied < - stage.apply(df)
 } yield applied
}


With the pipeline set, we can just invoke the run of the Writer monad to get back the errors and the return value when the stream is started.

val (errors, processedDf) = validRecords.run

val query = processedDf
     .writeStream
     .format("console")
     .outputMode(OutputMode.Append())
     .start()

     
Final output
Figure 2: Output through Method Two

Empty error stream

In order to create an emptyErrorStream, unfortunately, I couldn't find a better way to achieve this in Spark other than using the MemoryStream as below. Using the emptyDataSet and attempting a join against a streaming Dataframe (isStreaming=true) throws an interesting error:

org.apache.spark.sql.AnalysisException: Union between streaming and batch DataFrames/Datasets is not supported;;


To get around this issue, for this example, I’ve used a memory stream, which I understand is to be used only in non-production environments because of "infinite in-memory collection of lines read and no fault recovery." But since we’re just using it to create an empty stream, I thought I could get away with this.

val emptyErrorStream = (spark:SparkSession) => {
   implicit val sqlC = spark.sqlContext
   MemoryStream[DataError].toDS()
}
Note: For Data type specific row-level error handling for CSV and JSON, you could optionally consider using the ["mode"] option of the DataFrameReader.


Writing transformation errors and warnings in a data pipeline has been predominantly done through side-effecting writes to a persistent datastore or by collecting them at the Driver.  This article proposes a functional and scalable way to approach the problem. While this article uses Spark and Spark Datasets to collect the errors, I believe this mechanism could be used for other distributed collection and frameworks.