Skip to content

Commit

Permalink
[improvement] support two phases commit in structured streaming (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
gnehil authored Nov 7, 2023
1 parent 38c2718 commit df4f107
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 113 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ private[sql] class DorisSourceProvider extends DataSourceRegister
case _: SaveMode => // do nothing
}

// accumulator for transaction handling
val acc = sqlContext.sparkContext.collectionAccumulator[Long]("BatchTxnAcc")
// init stream loader
val writer = new DorisWriter(sparkSettings)
val writer = new DorisWriter(sparkSettings, acc)
writer.write(data)

new BaseRelation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.doris.spark.sql

import org.apache.doris.spark.cfg.SparkSettings
import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.doris.spark.txn.listener.DorisTxnStreamingQueryListener
import org.apache.doris.spark.writer.DorisWriter
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, SQLContext}
Expand All @@ -28,7 +30,12 @@ private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings: SparkSe
private val logger: Logger = LoggerFactory.getLogger(classOf[DorisStreamLoadSink].getName)
@volatile private var latestBatchId = -1L

private val writer = new DorisWriter(settings)
// accumulator for transaction handling
private val acc = sqlContext.sparkContext.collectionAccumulator[Long]("StreamTxnAcc")
private val writer = new DorisWriter(settings, acc)

// add listener for structured streaming
sqlContext.streams.addListener(new DorisTxnStreamingQueryListener(acc, settings))

override def addBatch(batchId: Long, data: DataFrame): Unit = {
if (batchId <= latestBatchId) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.txn

import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.doris.spark.sql.Utils
import org.apache.spark.internal.Logging

import java.time.Duration
import scala.collection.mutable
import scala.util.{Failure, Success}

/**
* Stream load transaction handler
*
* @param settings job settings
*/
class TransactionHandler(settings: SparkSettings) extends Logging {

private val sinkTxnIntervalMs: Int = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT)
private val sinkTxnRetries: Integer = settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT)
private val dorisStreamLoad: DorisStreamLoad = CachedDorisStreamLoadClient.getOrCreate(settings)

/**
* commit transactions
*
* @param txnIds transaction id list
*/
def commitTransactions(txnIds: List[Long]): Unit = {
log.debug(s"start to commit transactions, count ${txnIds.size}")
val (failedTxnIds, ex) = txnIds.map(commitTransaction).filter(_._1.nonEmpty)
.map(e => (e._1.get, e._2.get))
.aggregate((mutable.Buffer[Long](), new Exception))(
(z, r) => ((z._1 += r._1).asInstanceOf[mutable.Buffer[Long]], r._2), (r1, r2) => (r1._1 ++ r2._1, r2._2))
if (failedTxnIds.nonEmpty) {
log.error("uncommitted txn ids: {}", failedTxnIds.mkString("[", ",", "]"))
throw ex
}
}

/**
* commit single transaction
*
* @param txnId transaction id
* @return
*/
private def commitTransaction(txnId: Long): (Option[Long], Option[Exception]) = {
Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), log) {
dorisStreamLoad.commit(txnId)
}() match {
case Success(_) => (None, None)
case Failure(e: Exception) => (Option(txnId), Option(e))
}
}

/**
* abort transactions
*
* @param txnIds transaction id list
*/
def abortTransactions(txnIds: List[Long]): Unit = {
log.debug(s"start to abort transactions, count ${txnIds.size}")
var ex: Option[Exception] = None
val failedTxnIds = txnIds.map(txnId =>
Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), log) {
dorisStreamLoad.abortById(txnId)
}() match {
case Success(_) => None
case Failure(e: Exception) =>
ex = Option(e)
Option(txnId)
}).filter(_.nonEmpty).map(_.get)
if (failedTxnIds.nonEmpty) {
log.error("not aborted txn ids: {}", failedTxnIds.mkString("[", ",", "]"))
}
}

}

object TransactionHandler {
def apply(settings: SparkSettings): TransactionHandler = new TransactionHandler(settings)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.txn.listener

import org.apache.doris.spark.cfg.SparkSettings
import org.apache.doris.spark.txn.TransactionHandler
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler._
import org.apache.spark.util.CollectionAccumulator

import scala.collection.JavaConverters._
import scala.collection.mutable

class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Long], settings: SparkSettings)
extends SparkListener with Logging {

val txnHandler: TransactionHandler = TransactionHandler(settings)

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
jobEnd.jobResult match {
// if job succeed, commit all transactions
case JobSucceeded =>
if (txnIds.isEmpty) {
log.debug("job run succeed, but there is no pre-committed txn ids")
return
}
log.info("job run succeed, start committing transactions")
try txnHandler.commitTransactions(txnIds.toList)
catch {
case e: Exception => throw e
}
finally preCommittedTxnAcc.reset()
log.info("commit transaction success")
// if job failed, abort all pre committed transactions
case _ =>
if (txnIds.isEmpty) {
log.debug("job run failed, but there is no pre-committed txn ids")
return
}
log.info("job run failed, start aborting transactions")
try txnHandler.abortTransactions(txnIds.toList)
catch {
case e: Exception => throw e
}
finally preCommittedTxnAcc.reset()
log.info("abort transaction success")
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.spark.txn.listener

import org.apache.doris.spark.cfg.SparkSettings
import org.apache.doris.spark.txn.TransactionHandler
import org.apache.spark.internal.Logging
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.CollectionAccumulator

import scala.collection.JavaConverters._
import scala.collection.mutable

class DorisTxnStreamingQueryListener(preCommittedTxnAcc: CollectionAccumulator[Long], settings: SparkSettings)
extends StreamingQueryListener with Logging {

private val txnHandler = TransactionHandler(settings)

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
// do commit transaction when each batch ends
val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
if (txnIds.isEmpty) {
log.warn("job run succeed, but there is no pre-committed txn ids")
return
}
log.info(s"batch[${event.progress.batchId}] run succeed, start committing transactions")
try txnHandler.commitTransactions(txnIds.toList)
catch {
case e: Exception => throw e
} finally preCommittedTxnAcc.reset()
log.info(s"batch[${event.progress.batchId}] commit transaction success")
}


override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
// if job failed, abort all pre committed transactions
if (event.exception.nonEmpty) {
if (txnIds.isEmpty) {
log.warn("job run failed, but there is no pre-committed txn ids")
return
}
log.info("job run failed, start aborting transactions")
try txnHandler.abortTransactions(txnIds.toList)
catch {
case e: Exception => throw e
} finally preCommittedTxnAcc.reset()
log.info("abort transaction success")
}
}

}
Loading

0 comments on commit df4f107

Please sign in to comment.