Skip to content

Commit

Permalink
Make MessageChannelPartitionHandler extend AbstractPartitionHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
fmbenhassine committed Aug 3, 2022
1 parent 2fdb68d commit 9211c2d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2006-2021 the original author or authors.
* Copyright 2006-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,7 @@
*/
public abstract class AbstractPartitionHandler implements PartitionHandler {

private int gridSize = 1;
protected int gridSize = 1;

/**
* Executes the specified {@link StepExecution} instances and returns an updated view
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2009-2021 the original author or authors.
* Copyright 2009-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,14 +15,14 @@
*/
package org.springframework.batch.integration.partition;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import javax.sql.DataSource;

Expand All @@ -35,6 +35,7 @@
import org.springframework.batch.core.explore.support.JobExplorerFactoryBean;
import org.springframework.batch.core.partition.PartitionHandler;
import org.springframework.batch.core.partition.StepExecutionSplitter;
import org.springframework.batch.core.partition.support.AbstractPartitionHandler;
import org.springframework.batch.core.repository.JobRepository;
import org.springframework.batch.poller.DirectPoller;
import org.springframework.batch.poller.Poller;
Expand Down Expand Up @@ -85,12 +86,10 @@
*
*/
@MessageEndpoint
public class MessageChannelPartitionHandler implements PartitionHandler, InitializingBean {
public class MessageChannelPartitionHandler extends AbstractPartitionHandler implements InitializingBean {

private static Log logger = LogFactory.getLog(MessageChannelPartitionHandler.class);

private int gridSize = 1;

private MessagingTemplate messagingGateway;

private String stepName;
Expand Down Expand Up @@ -187,18 +186,6 @@ public void setMessagingOperations(MessagingTemplate messagingGateway) {
this.messagingGateway = messagingGateway;
}

/**
* Passed to the {@link StepExecutionSplitter} in the
* {@link #handle(StepExecutionSplitter, StepExecution)} method, instructing it how
* many {@link StepExecution} instances are required, ideally. The
* {@link StepExecutionSplitter} is allowed to ignore the grid size in the case of a
* restart, since the input data partitions must be preserved.
* @param gridSize the number of step executions that will be created
*/
public void setGridSize(int gridSize) {
this.gridSize = gridSize;
}

/**
* The name of the {@link Step} that will be used to execute the partitioned
* {@link StepExecution}. This is a regular Spring Batch step, with all the business
Expand Down Expand Up @@ -234,19 +221,17 @@ public void setReplyChannel(PollableChannel replyChannel) {
*
* @see PartitionHandler#handle(StepExecutionSplitter, StepExecution)
*/
public Collection<StepExecution> handle(StepExecutionSplitter stepExecutionSplitter,
final StepExecution managerStepExecution) throws Exception {

final Set<StepExecution> split = stepExecutionSplitter.split(managerStepExecution, gridSize);
@Override
protected Set<StepExecution> doHandle(StepExecution managerStepExecution, Set<StepExecution> partitionStepExecutions) throws Exception {

if (CollectionUtils.isEmpty(split)) {
return split;
if (CollectionUtils.isEmpty(partitionStepExecutions)) {
return partitionStepExecutions;
}

int count = 0;

for (StepExecution stepExecution : split) {
Message<StepExecutionRequest> request = createMessage(count++, split.size(),
for (StepExecution stepExecution : partitionStepExecutions) {
Message<StepExecutionRequest> request = createMessage(count++, partitionStepExecutions.size(),
new StepExecutionRequest(stepName, stepExecution.getJobExecutionId(), stepExecution.getId()),
replyChannel);
if (logger.isDebugEnabled()) {
Expand All @@ -259,17 +244,17 @@ public Collection<StepExecution> handle(StepExecutionSplitter stepExecutionSplit
return receiveReplies(replyChannel);
}
else {
return pollReplies(managerStepExecution, split);
return pollReplies(managerStepExecution, partitionStepExecutions);
}
}

private Collection<StepExecution> pollReplies(final StepExecution managerStepExecution,
private Set<StepExecution> pollReplies(final StepExecution managerStepExecution,
final Set<StepExecution> split) throws Exception {
final Collection<StepExecution> result = new ArrayList<>(split.size());
final Set<StepExecution> result = new HashSet<>(split.size());

Callable<Collection<StepExecution>> callback = new Callable<Collection<StepExecution>>() {
Callable<Set<StepExecution>> callback = new Callable<Set<StepExecution>>() {
@Override
public Collection<StepExecution> call() throws Exception {
public Set<StepExecution> call() throws Exception {

for (Iterator<StepExecution> stepExecutionIterator = split.iterator(); stepExecutionIterator
.hasNext();) {
Expand Down Expand Up @@ -298,8 +283,8 @@ public Collection<StepExecution> call() throws Exception {
}
};

Poller<Collection<StepExecution>> poller = new DirectPoller<>(pollInterval);
Future<Collection<StepExecution>> resultsFuture = poller.poll(callback);
Poller<Set<StepExecution>> poller = new DirectPoller<>(pollInterval);
Future<Set<StepExecution>> resultsFuture = poller.poll(callback);

if (timeout >= 0) {
return resultsFuture.get(timeout, TimeUnit.MILLISECONDS);
Expand All @@ -309,9 +294,8 @@ public Collection<StepExecution> call() throws Exception {
}
}

private Collection<StepExecution> receiveReplies(PollableChannel currentReplyChannel) {
@SuppressWarnings("unchecked")
Message<Collection<StepExecution>> message = (Message<Collection<StepExecution>>) messagingGateway
private Set<StepExecution> receiveReplies(PollableChannel currentReplyChannel) {
Message<Set<StepExecution>> message = (Message<Set<StepExecution>>) messagingGateway
.receive(currentReplyChannel);

if (message == null) {
Expand All @@ -321,7 +305,7 @@ else if (logger.isDebugEnabled()) {
logger.debug("Received replies: " + message);
}

return message.getPayload();
return new HashSet<>(message.getPayload());
}

private Message<StepExecutionRequest> createMessage(int sequenceNumber, int sequenceSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void testHandleNoReply() throws Exception {
HashSet<StepExecution> stepExecutions = new HashSet<>();
stepExecutions.add(new StepExecution("step1", new JobExecution(5L)));
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
when(message.getPayload()).thenReturn(Collections.emptyList());
when(message.getPayload()).thenReturn(Collections.emptySet());
when(operations.receive((PollableChannel) any())).thenReturn(message);
// set
messageChannelPartitionHandler.setMessagingOperations(operations);
Expand Down Expand Up @@ -112,7 +112,7 @@ void testHandleWithReplyChannel() throws Exception {
HashSet<StepExecution> stepExecutions = new HashSet<>();
stepExecutions.add(new StepExecution("step1", new JobExecution(5L)));
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
when(message.getPayload()).thenReturn(Collections.emptyList());
when(message.getPayload()).thenReturn(Collections.emptySet());
when(operations.receive(replyChannel)).thenReturn(message);
// set
messageChannelPartitionHandler.setMessagingOperations(operations);
Expand Down

0 comments on commit 9211c2d

Please sign in to comment.