diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java index 7a2cb8981b78..3a6994cb7ee4 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebSocketConfigurationSupport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.web.socket.config.annotation; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; import org.springframework.context.annotation.Bean; import org.springframework.lang.Nullable; import org.springframework.scheduling.TaskScheduler; @@ -35,16 +37,13 @@ public class WebSocketConfigurationSupport { @Nullable private ServletWebSocketHandlerRegistry handlerRegistry; - @Nullable - private TaskScheduler scheduler; - @Bean - public HandlerMapping webSocketHandlerMapping(@Nullable TaskScheduler defaultSockJsTaskScheduler) { + public HandlerMapping webSocketHandlerMapping(DefaultSockJsSchedulerContainer schedulerContainer) { ServletWebSocketHandlerRegistry registry = initHandlerRegistry(); if (registry.requiresTaskScheduler()) { - TaskScheduler scheduler = defaultSockJsTaskScheduler; - Assert.notNull(scheduler, "Expected default TaskScheduler bean"); + TaskScheduler scheduler = schedulerContainer.getScheduler(); + Assert.notNull(scheduler, "TaskScheduler is required but not initialized"); registry.setTaskScheduler(scheduler); } return registry.getHandlerMapping(); @@ -62,8 +61,9 @@ protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { } /** - * The default TaskScheduler to use if none is registered explicitly via - * {@link SockJsServiceRegistration#setTaskScheduler}: + * A container of the default TaskScheduler to use if none was registered + * explicitly via {@link SockJsServiceRegistration#setTaskScheduler} as + * follows: *
 	 * @Configuration
 	 * @EnableWebSocket
@@ -80,16 +80,50 @@ protected void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
 	 * 
*/ @Bean - @Nullable - public TaskScheduler defaultSockJsTaskScheduler() { - if (initHandlerRegistry().requiresTaskScheduler()) { - ThreadPoolTaskScheduler threadPoolScheduler = new ThreadPoolTaskScheduler(); - threadPoolScheduler.setThreadNamePrefix("SockJS-"); - threadPoolScheduler.setPoolSize(Runtime.getRuntime().availableProcessors()); - threadPoolScheduler.setRemoveOnCancelPolicy(true); - this.scheduler = threadPoolScheduler; + public DefaultSockJsSchedulerContainer defaultSockJsSchedulerContainer() { + return (initHandlerRegistry().requiresTaskScheduler() ? + new DefaultSockJsSchedulerContainer(initDefaultSockJsScheduler()) : + new DefaultSockJsSchedulerContainer(null)); + } + + private ThreadPoolTaskScheduler initDefaultSockJsScheduler() { + ThreadPoolTaskScheduler scheduler = new ThreadPoolTaskScheduler(); + scheduler.setThreadNamePrefix("SockJS-"); + scheduler.setPoolSize(Runtime.getRuntime().availableProcessors()); + scheduler.setRemoveOnCancelPolicy(true); + return scheduler; + } + + + private static class DefaultSockJsSchedulerContainer implements InitializingBean, DisposableBean { + + @Nullable + private final ThreadPoolTaskScheduler scheduler; + + DefaultSockJsSchedulerContainer(@Nullable ThreadPoolTaskScheduler scheduler) { + this.scheduler = scheduler; + } + + @Nullable + public ThreadPoolTaskScheduler getScheduler() { + return this.scheduler; + } + + @Override + public void afterPropertiesSet() throws Exception { + if (this.scheduler != null) { + this.scheduler.afterPropertiesSet(); + } } - return this.scheduler; + + @Override + public void destroy() throws Exception { + if (this.scheduler != null) { + this.scheduler.destroy(); + } + } + } + }