안녕하세요! GoseKose입니다.

 

스프링 배치에서 파티션 단위로 처리하는 과정은 대규모 데이터 처리를 병렬로 분할하여 성능을 최적화할 수 있습니다.

특히, 파티션을 독립적으로 처리할 수 있는데, 이는 partiton 혹은 step 단위로 실행 매개변수를 다룰 수 있음을 의미합니다.

 

파티션 구성 요소는 다음과 같습니다.

 

  • Partitioner: 데이터를 여러 파티션으로 나누는 역할을 합니다.
  • PartitionHandler: 파티션을 각 replica 스텝에 분배하고 병렬로 실행합니다.
  • Step: step 단위 실행 플로우를 정의합니다.

이 세 가지 구성요소를 바탕으로 토이 프로젝트에서 Partition 단위로 병렬 처리한 과정을 정리하도록 하겠습니다.

 

 

 

1. 목표 아키텍처

 

 

 

2. 목표 플로우 

순서 제목 설명
1 Job 실행 Batch Job을 시작합니다.
2 데이터 범위 계산 목표 타겟 데이터 id(pk)를 기반으로 min / max 값을 구합니다.
3 파티션 나누기 minId / maxId의 차이를 구한 후, 파티션 개수 (ex: 5)로 범위를 나눕니다.
4 Partiton 병렬 실행 partition 단위로 병렬로 step을 실행합니다.
5 Step 실행 각 step은 chunk 지향 처리를 수행합니다.
6 예외 처리 Reader 문제 발생 시, Step Listener 에서 chunk 단위로 조회 실패한 min / max id를 저장합니다.
7 Job 종료 Batch Job을 종료합니다.

 

 

 

3. 데이터 범위 계산 및 파티션 나누기

@Bean
fun rangePartitioner(): RangePartitioner {
    val (minId, maxId) = partitionResultJdbcQuery()
    return RangePartitioner(
        minId = minId,
        maxId = maxId,
    )
}

private fun partitionResultJdbcQuery(): Pair<Long?, Long?> {
    val minSQL =
        """
        SELECT MIN(id)
        FROM memory_marbles
        WHERE store_type = 'DAILY'
        AND created_at >= '$startTimeStamp' AND created_at < '$endTimeStamp'
        """.trimIndent()

    val maxSQL =
        """
        SELECT MAX(id)
        FROM memory_marbles
        WHERE store_type = 'DAILY'
        AND created_at >= '$startTimeStamp' AND created_at < '$endTimeStamp'
        """.trimIndent()

    val minId = jdbcTemplate.queryForObject(minSQL, Long::class.java)
    val maxId = jdbcTemplate.queryForObject(maxSQL, Long::class.java)

    return Pair(minId, maxId)
}

class RangePartitioner(
    private val minId: Long?,
    private val maxId: Long?,
) : Partitioner {
    override fun partition(gridSize: Int): MutableMap<String, ExecutionContext> {
        val result = mutableMapOf<String, ExecutionContext>()
        if (minId == null || maxId == null) {
            return result
        }

        val targetSize = (maxId - minId + 1) / gridSize

        var start = minId.toLong()
        var end = start + targetSize - 1

        for (i in 0 until gridSize - 1) {
            val context = ExecutionContext()
            context.putLong("minValue", start)
            context.putLong("maxValue", end)
            result["partition$i"] = context
            start += targetSize
            end = start + targetSize - 1
        }

        val context = ExecutionContext()
        context.putLong("minValue", start)
        context.putLong("maxValue", maxId)
        result["partition${gridSize - 1}"] = context

        return result
    }
}

 

배치 Job을 수행하고자 하는 최소 최대 범위를 구한 후, Partition 개수만큼 범위를 구분합니다.

 

만약 minId = 100, maxId =  200, partitonSize(gridSize) = 5 라면, 다음처럼 비교적 균등하게 파티션을 나눌 수 있습니다.

partition 조회할 where 범위 개수
1 100 <= id <= 119 20
2 120 <= id <= 139 20
3 140 <= id <= 159 20
4 160 <= id <= 179 20
5 180 <= id <= 200 21

 

 

 

4. 파티션 병렬 실행

@Bean
fun memoryMarbleDailyToPermanentUpdaterJob(): Job {
    return JobBuilder(batchProperties.job.name, jobRepository)
        .incrementer(RunIdIncrementer())
        .start(primaryMemoryMarbleDailyToPermanentUpdaterStep())
        .listener(batchJobExecutionListener)
        .preventRestart()
        .build()
}

@Bean
fun primaryMemoryMarbleDailyToPermanentUpdaterStep(): Step {
    return StepBuilder("primaryMemoryMarbleDailyToPermanentUpdaterStep", jobRepository)
        .partitioner("replicaMemoryMarbleDailyToPermanentUpdaterStep", rangePartitioner())
        .step(replicaMemoryMarbleDailyToPermanentUpdaterStep())
        .partitionHandler(partitionHandler())
        .build()
}

@Bean
fun replicaMemoryMarbleDailyToPermanentUpdaterStep(): Step {
    return StepBuilder("replicaMemoryMarbleDailyToPermanentUpdaterStep", jobRepository)
        .chunk<MemoryMarbleJpaEntity, MemoryMarbleJpaEntity>(CHUNK_SIZE, transactionManager)
        .reader(memoryMarbleReader(null, null))
        .processor(memoryMarbleProcessor())
        .writer(memoryMarbleWriter())
        .listener(batchStepExecutionListener())
        .transactionManager(transactionManager)
        .build()
}

@Bean
fun partitionHandler(): TaskExecutorPartitionHandler {
    val partitionHandler = TaskExecutorPartitionHandler()
    partitionHandler.setTaskExecutor(simpleAsyncTaskExecutor)
    partitionHandler.step = replicaMemoryMarbleDailyToPermanentUpdaterStep()

    val (minId, maxId) = partitionResultJdbcQuery()
    if ((minId == null || maxId == null) || (maxId - minId) < PARTITION_SIZE) {
        partitionHandler.gridSize = 1
    } else {
        partitionHandler.gridSize = PARTITION_SIZE
    }

    return partitionHandler
}

 

partition을 활용할 때, StepBuilder()의 partitioner, partitionerHandler를 정의해야 합니다.

각 함수의 역할은 다음과 같습니다.

 

함수 설명
partitioner 위에서 정의한 파티셔널 (RangePartitioner)을 바탕으로 데이터 범위를 나누는 역할을 합니다.
각 파티션은 별도의 ExecutionContext를 가지는데, 이를 바탕으로 실행 매개 변수를 독립적으로 관리할 수 있습니다.
partitionerHandler partitioner에 의해 나뉘어진 파티션을 병렬로 처리합니다.
각 핸들러는 파티션에 정의된 step을 실행합니다.

 

 

 

5-1. Step 정의하기: ItemReader의 Chunk 지향 처리

저는 Chunk 지향 처리 방법으로 Step을 구성하였습니다. Chunk 지향 처리는 각 청크마다 트랜잭션을 관리하므로, 데이터베이스 커넥션 시간을 효율적으로 관리할 수 있습니다.

 

Reader를 구현하는 방법은 ItemReader 구현, QueryDsl로 ItemReader 확장하기 등  다양한 방법이 존재합니다.

전 회사에서 QueryDsl로 AbstractPagingItemReader를 확장해서 배치 시스템을 개선했던 경험이 있어서, 

이번 토이 프로젝트는 JDBC를 활용하여 AbstractPagingItemReader를 확장하는 방법을 선택해 보았습니다.

 

Spring Batch에서 제공하는 JdbcPagingItemReader는 pageSize을 limit으로 설정하되,

내부적으로 커서 기반 페이지네이션으로 동작합니다.

 

 

JdbcPagingItemReader는 다음의 firstPageSql, startAfterValues라는 필드를 가지고 있습니다.

private String firstPageSql;

private Map<String, Object> startAfterValues;

 

firstPageSql이 동작한 후, startAfterValues가 업데이트되면, query 생성 시 cursor에 해당할 id를 (id >?)에 바인딩해줍니다.

이를 바탕으로, offset을 정의하지 않아도, limit와 cursor 기반으로 데이터를 빠르게 조회하고 처리할 수 있습니다.

 

 

5-1. Step 정의하기: CustomJdbcPagingItemReader 정의하기

청크 단위 Reader를 수행할 때, 저장된 데이터를 DAO/VO/DTO로 변환할 때 타입 에러가 발생하곤 합니다.

 

이 경우 스탭이 종료되거나, startAfterValues가 업데이트되지 않아 query가 정상 동작하지 않고 emptyList()를 출력하여
Batch Step은 더 이상 읽을 데이터가 없다고 판단하여 종료할 수 있습니다.

 

해당 문제를 해결하기 위해, AbstractPagingItemReader를 구현한 CustomJdbcPagingItemReader를 정의하였습니다.

기존 JdbcPagingItemReader의 로직을 그대로 가져오되,

doReadePage()와 실패한 페이지 정보를 ExecutionContext에 넘겨 StepListener에서 실패한 데이터의 범위를 저장할 수 있도록 하였습니다.

 

override fun doReadPage() {
    results = results?.apply { clear() } ?: CopyOnWriteArrayList()

    val rowCallback = PagingRowMapper()
    val query: List<T> = try {
        when {
            page == 0 -> {
                logger.info("SQL used for reading first page: [$firstPageSql]")
                executeQuery(firstPageSql, rowCallback, parameterValues)
            }

            startAfterValues != null -> {
                previousStartAfterValues = startAfterValues
                logger.info("SQL used for reading remaining pages: [$remainingPagesSql]")
                executeQuery(remainingPagesSql, rowCallback, startAfterValues)
            }

            else -> emptyList()
        }
    } catch (e: Exception) {
        failedValues["page_$page"] = firstPageSql // 페이지 정보 추가
        logger.error("Error occurred while reading page: ", e)
        retryWithPage(page, 1, 1, rowCallback) // 리트라이 할 수 있도록 로직 추가
    }
    results.addAll(query)
}

private fun executeQuery(sql: String, rowCallback: PagingRowMapper, parameters: Map<String, Any>?): List<T> {
    return if (!parameters.isNullOrEmpty()) {
        if (queryProvider.isUsingNamedParameters) {
            namedParameterJdbcTemplate?.query(sql, getParameterMap(parameters, null), rowCallback) ?: emptyList()
        } else {
            getJdbcTemplate().query(sql, rowCallback, *getParameterList(parameters, null).toTypedArray())
        }
    } else {
        getJdbcTemplate().query(sql, rowCallback)
    }
}

private fun retryWithPage(
    page: Int,
    offsetCount: Int,
    retryCount: Int,
    rowCallback: PagingRowMapper
): List<T> {
    val adjustedMinValue = pageSize * offsetCount // 페이지와 오프셋 기반으로 실패한 데이터는 건너뛰도록 조치
    val sqlWithOffset = "$firstPageSql OFFSET $adjustedMinValue"

    logger.info("Retry SQL used for reading page $page with adjusted offset: [$sqlWithOffset]")

    return try {
        executeQuery(sqlWithOffset, rowCallback, parameterValues)
    } catch (e: Exception) {
        logger.error("Error occurred while retrying page: ", e)
        if (retryCount < 5) { // 리트라이는 최대 5번
            failedValues["page_$page"] = firstPageSql
            retryWithPage(page + 1, offsetCount + 1, retryCount + 1, rowCallback)
        } else {
            throw e
        }
    }
}

@Throws(ItemStreamException::class)
override fun update(executionContext: ExecutionContext) {
    super.update(executionContext)
    if (isSaveState) {
        if (isAtEndOfPage() && startAfterValues != null) {
            executionContext.put(getExecutionContextKey(START_AFTER_VALUE), startAfterValues)
        } else if (previousStartAfterValues != null) {
            executionContext.put(getExecutionContextKey(START_AFTER_VALUE), previousStartAfterValues)
        }
    }
    if (failedValues.isNotEmpty()) { // 실패한 범위가 존재할 경우 executionContext로 넘기기
        executionContext.put(getExecutionContextKey(FAIL_VALUE), failedValues)
    }
}

 

 

만약 doPageReader()가 특정 예외가 발생할 경우,

retry를 하되 offset / limit으로 실패한 범위는 건너뛰고 데이터를 조회할 수 있도록 조치하였습니다.

그리고 실패한 failedValues는 executionContext에 저장하여 후처리를 할 수 있습니다.

 

 

6. 예외 처리

PartitionStepExecutionListener는 후처리 로직으로 두 가지 분기가 수행됩니다.

 

1. Step 전체가 실패한 경우 해당 Step의 파티션 범위를 failedPartition으로 저장

2. 부분적으로 Failed 된 경우, 각 chunk 범위를 failedPartition으로 저장

 

class PartitionStepExecutionListener(
    private val failedPartitionJdbcRepository: FailedPartitionJdbcRepository,
) : BatchStepExecutionListener() {
    override fun addAfterStep(stepExecution: StepExecution) {
        val minId = stepExecution.executionContext.getLong("minValue", -1L)
        val maxId = stepExecution.executionContext.getLong("maxValue", -1L)

        if (minId == -1L || maxId == -1L) return

        if (stepExecution.exitStatus.exitCode == ExitStatus.FAILED.exitCode) {
            failedPartitionJdbcRepository.save(
                FailedPartition.of(
                    minId = minId,
                    maxId = maxId,
                    jobExecutionId = stepExecution.jobExecutionId
                )
            )
        } else if (stepExecution.exitStatus.exitCode == ExitStatus.COMPLETED.exitCode) {
            val failedValue =
                stepExecution.executionContext["$MEMORY_MARBLE_JDBC_PAGING_ITEM_READER.$FAIL_VALUE"] as? Map<String, String>?

            if (failedValue != null) {
                failedValue.keys.filter { it.contains("page_") }
                    .map { key ->
                        val page = key.substring(5).toInt()
                        val newMinId = minId + page * MemoryMarbleDailyToPermanentUpdateJobConfig.PAGE_SIZE
                        FailedPartition.of(
                            minId = newMinId,
                            maxId = min(maxId, newMinId + MemoryMarbleDailyToPermanentUpdateJobConfig.PAGE_SIZE),
                            jobExecutionId = stepExecution.jobExecutionId,
                        )
                    }.let { failedPartitionJdbcRepository.saveAll(it) }
            }
        }
    }
}

interface FailedPartitionJdbcRepository {
    fun save(failedPartition: FailedPartition)
    fun saveAll(failedPartitions: List<FailedPartition>)
}

@Repository
class FailedPartitionJdbcRepositoryImpl(
    @Qualifier("batchSimpleJdbcInsert") private val simpleJdbcInsert: SimpleJdbcInsert,
) : FailedPartitionJdbcRepository {
    init {
        simpleJdbcInsert
            .withTableName("FAILED_PARTITIONS")
            .usingGeneratedKeyColumns("id")
            .usingColumns("min_id", "max_id", "step_execution_id", "created_at", "last_modified_at", "status")
    }

    override fun save(failedPartition: FailedPartition) {
        simpleJdbcInsert.executeBatch(generateMapSqlParameterSource(failedPartition))
    }

    override fun saveAll(failedPartitions: List<FailedPartition>) {
        simpleJdbcInsert.executeBatch(*generateMapSqlParameterSource(failedPartitions))
    }

    private fun generateMapSqlParameterSource(failedPartition: FailedPartition): SqlParameterSource {
        return failedPartition.let { DaoRowMapper.mapSqlParameterSourceWith(it) }
    }

    private fun generateMapSqlParameterSource(failedPartitions: List<FailedPartition>): Array<SqlParameterSource> {
        return failedPartitions.map { DaoRowMapper.mapSqlParameterSourceWith(it) }.toTypedArray()
    }
}

 

 

 

이 두 가지 저장 방식을 토대로, 실패한 데이터 범위를 최대한 줄여서 보정 배치를 수행할 수 있습니다.

 

 

 

CustomJdbcPagingItemReader는 아직 개선할 부분이 많이 있습니다!

추가로 개선되는 부분은 다음 블로그 글로 작성하도록 하겠습니다.!

잘못된 부분이나 개선할 부분 말씀 부탁드립니다!

이상으로 Spring Batch Partition 단위로 병렬 처리하기를 마치도록 하겠습니다!

 

감사합니다!

+ Recent posts