Implement importing progress with SSE

This commit is contained in:
2026-01-03 17:50:44 +05:00
parent b6cd60a041
commit 2821635462
10 changed files with 478 additions and 0 deletions

View File

@ -6,3 +6,5 @@ S3_ENDPOINT=http://s3:9000
S3_ACCESS_KEY=minioadmin S3_ACCESS_KEY=minioadmin
S3_SECRET_KEY=minioadmin S3_SECRET_KEY=minioadmin
S3_BUCKET=composer-dev S3_BUCKET=composer-dev
REDIS_HOST_NAME=redis

View File

@ -90,6 +90,14 @@
<groupId>software.amazon.awssdk</groupId> <groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId> <artifactId>s3</artifactId>
</dependency> </dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-webflux</artifactId>
</dependency>
<dependency> <dependency>
<groupId>me.paulschwarz</groupId> <groupId>me.paulschwarz</groupId>
@ -123,6 +131,7 @@
<scope>runtime</scope> <scope>runtime</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId> <artifactId>spring-boot-devtools</artifactId>

View File

@ -0,0 +1,36 @@
package com.bivashy.backend.composer.config;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisStandaloneConfiguration;
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
@Configuration
public class RedisConfig {
@Value("${spring.redis.host-name}")
private String hostName;
@Bean
public LettuceConnectionFactory lettuceConnectionFactory() {
RedisStandaloneConfiguration config = new RedisStandaloneConfiguration(hostName);
return new LettuceConnectionFactory(config);
}
@Bean
public RedisTemplate<String, Object> redisTemplate() {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(lettuceConnectionFactory());
return template;
}
@Bean
public RedisMessageListenerContainer redisContainer() {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(lettuceConnectionFactory());
return container;
}
}

View File

@ -0,0 +1,116 @@
package com.bivashy.backend.composer.controller.importing;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.web.bind.annotation.*;
import com.bivashy.backend.composer.auth.CustomUserDetails;
import com.bivashy.backend.composer.dto.importing.TrackProgressDTO;
import com.bivashy.backend.composer.service.importing.RedisMessageSubscriber;
import com.bivashy.backend.composer.service.importing.RedisProgressService;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.http.HttpServletResponse;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@RestController
public class ProgressSSEController {
private final RedisProgressService redisProgressService;
private final RedisMessageSubscriber redisSubscriber;
private final Map<String, Sinks.Many<String>> sinks = new ConcurrentHashMap<>();
public ProgressSSEController(RedisProgressService redisProgressService,
RedisMessageSubscriber redisSubscriber) {
this.redisProgressService = redisProgressService;
this.redisSubscriber = redisSubscriber;
}
@GetMapping("/importing/test/{playlistId}")
public void test(@PathVariable String playlistId, @AuthenticationPrincipal CustomUserDetails user) {
var userId = user.getId();
redisProgressService.saveProgress(new TrackProgressDTO(
playlistId,
"test",
userId));
}
@GetMapping(value = "/importing/stream/{playlistId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> streamProgress(
@PathVariable String playlistId,
@AuthenticationPrincipal CustomUserDetails user,
HttpServletResponse response) {
var userId = user.getId();
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");
response.setCharacterEncoding("UTF-8");
String connectionKey = getConnectionKey(playlistId, userId);
Sinks.Many<String> sink = sinks.computeIfAbsent(connectionKey, k -> {
Sinks.Many<String> newSink = Sinks.many().replay().latest();
redisSubscriber.subscribeToPlaylist(playlistId, userId, message -> {
newSink.tryEmitNext(message);
});
return newSink;
});
redisProgressService.addActiveConnection(playlistId, userId);
return sink.asFlux()
.map(data -> ServerSentEvent.<String>builder()
.data(data)
.event("progress-update")
.build())
.doFirst(() -> {
try {
List<TrackProgressDTO> existingProgresses = redisProgressService.getPlaylistProgress(playlistId,
userId);
System.out.println(existingProgresses);
ObjectMapper mapper = new ObjectMapper();
for (TrackProgressDTO progress : existingProgresses) {
sink.tryEmitNext(mapper.writeValueAsString(progress));
}
} catch (Exception e) {
e.printStackTrace();
}
})
.doOnCancel(() -> {
cleanupConnection(playlistId, userId, sink, connectionKey);
})
.doOnTerminate(() -> {
cleanupConnection(playlistId, userId, sink, connectionKey);
})
.timeout(Duration.ofHours(2))
.onErrorResume(e -> {
cleanupConnection(playlistId, userId, sink, connectionKey);
return Flux.empty();
});
}
private void cleanupConnection(String playlistId, long userId,
Sinks.Many<String> sink, String connectionKey) {
try {
redisProgressService.removeActiveConnection(playlistId, userId);
redisSubscriber.unsubscribeFromPlaylist(playlistId, userId);
sinks.remove(connectionKey);
sink.tryEmitComplete();
} catch (Exception e) {
e.printStackTrace();
}
}
private String getConnectionKey(String playlistId, long userId) {
return String.format("%s:%s", Long.toString(userId), playlistId);
}
}

View File

@ -0,0 +1,27 @@
package com.bivashy.backend.composer.dto;
public enum SourceType {
AUDIO("AUDIO"),
PLAYLIST("PLAYLIST"),
FILE("FILE"),
URL("URL");
private final String value;
SourceType(String value) {
this.value = value;
}
public String getValue() {
return value;
}
public static SourceType fromValue(String value) {
for (SourceType type : values()) {
if (type.value.equalsIgnoreCase(value)) {
return type;
}
}
throw new IllegalArgumentException("Unknown source type: " + value);
}
}

View File

@ -0,0 +1,11 @@
package com.bivashy.backend.composer.dto.importing;
public class ImportTrackKey {
public static String progressKey(String playlistId, long userId) {
return String.format("progress:%s:%s", Long.toString(userId), playlistId);
}
public static String trackKey(String playlistId, String trackId, long userId) {
return String.format("track:%s:%s:%s", Long.toString(userId), playlistId, trackId);
}
}

View File

@ -0,0 +1,116 @@
package com.bivashy.backend.composer.dto.importing;
public class TrackProgressDTO {
private String playlistId;
private String trackId;
private String trackTitle;
private String format;
private String sourceType;
private int progress;
private String metadata;
private Long timestamp;
private long userId;
public TrackProgressDTO() {
}
public TrackProgressDTO(String playlistId, String trackId, long userId) {
this.playlistId = playlistId;
this.trackId = trackId;
this.userId = userId;
this.timestamp = System.currentTimeMillis();
}
public TrackProgressDTO(String playlistId,
String trackId,
String trackTitle,
String format,
String sourceType,
int progress,
String metadata,
Long timestamp,
long userId) {
this.playlistId = playlistId;
this.trackId = trackId;
this.trackTitle = trackTitle;
this.format = format;
this.sourceType = sourceType;
this.progress = progress;
this.metadata = metadata;
this.timestamp = timestamp;
this.userId = userId;
}
public String getPlaylistId() {
return playlistId;
}
public void setPlaylistId(String playlistId) {
this.playlistId = playlistId;
}
public String getTrackId() {
return trackId;
}
public void setTrackId(String trackId) {
this.trackId = trackId;
}
public String getTrackTitle() {
return trackTitle;
}
public void setTrackTitle(String trackTitle) {
this.trackTitle = trackTitle;
}
public String getFormat() {
return format;
}
public void setFormat(String format) {
this.format = format;
}
public String getSourceType() {
return sourceType;
}
public void setSourceType(String sourceType) {
this.sourceType = sourceType;
}
public int getProgress() {
return progress;
}
public void setProgress(int progress) {
this.progress = progress;
}
public String getMetadata() {
return metadata;
}
public void setMetadata(String metadata) {
this.metadata = metadata;
}
public Long getTimestamp() {
return timestamp;
}
public void setTimestamp(Long timestamp) {
this.timestamp = timestamp;
}
public long getUserId() {
return userId;
}
public void setUserId(long userId) {
this.userId = userId;
}
}

View File

@ -0,0 +1,50 @@
package com.bivashy.backend.composer.service.importing;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.stereotype.Component;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
@Component
public class RedisMessageSubscriber {
private final RedisMessageListenerContainer container;
private final Map<String, Consumer<String>> subscriptions = new ConcurrentHashMap<>();
public RedisMessageSubscriber(RedisMessageListenerContainer container) {
this.container = container;
}
public void subscribeToPlaylist(String playlistId, long userId, Consumer<String> messageHandler) {
String channel = String.format("progress_updates:%s:%s", userId, playlistId);
String subscriptionKey = getSubscriptionKey(playlistId, userId);
if (!subscriptions.containsKey(subscriptionKey)) {
container.addMessageListener(new MessageListener() {
@Override
public void onMessage(Message message, byte[] pattern) {
String receivedMessage = new String(message.getBody());
if (subscriptions.containsKey(subscriptionKey)) {
messageHandler.accept(receivedMessage);
}
}
}, new ChannelTopic(channel));
subscriptions.put(subscriptionKey, messageHandler);
}
}
public void unsubscribeFromPlaylist(String playlistId, long userId) {
String subscriptionKey = getSubscriptionKey(playlistId, userId);
subscriptions.remove(subscriptionKey);
}
private String getSubscriptionKey(String playlistId, long userId) {
return String.format("%s:%s", Long.toString(userId), playlistId);
}
}

View File

@ -0,0 +1,109 @@
package com.bivashy.backend.composer.service.importing;
import com.bivashy.backend.composer.dto.importing.ImportTrackKey;
import com.bivashy.backend.composer.dto.importing.TrackProgressDTO;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
@Service
public class RedisProgressService {
private final StringRedisTemplate redisTemplate;
private final ObjectMapper objectMapper;
private final Map<String, Set<String>> activeConnections = new ConcurrentHashMap<>();
public RedisProgressService(StringRedisTemplate redisTemplate,
ObjectMapper objectMapper) {
this.redisTemplate = redisTemplate;
this.objectMapper = objectMapper;
}
public void saveProgress(TrackProgressDTO progress) {
try {
String key = ImportTrackKey.progressKey(progress.getPlaylistId(), progress.getUserId());
String trackKey = ImportTrackKey.trackKey(
progress.getPlaylistId(),
progress.getTrackId(),
progress.getUserId());
String progressJson = objectMapper.writeValueAsString(progress);
redisTemplate.opsForHash().put(key, progress.getTrackId(), progressJson);
redisTemplate.opsForValue().set(trackKey, progressJson);
redisTemplate.expire(key, 24, java.util.concurrent.TimeUnit.HOURS);
redisTemplate.expire(trackKey, 24, java.util.concurrent.TimeUnit.HOURS);
publishProgressUpdate(progress);
} catch (Exception e) {
throw new RuntimeException("Failed to save progress to Redis", e);
}
}
public List<TrackProgressDTO> getPlaylistProgress(String playlistId, long userId) {
try {
String key = ImportTrackKey.progressKey(playlistId, userId);
Map<Object, Object> progressMap = redisTemplate.opsForHash().entries(key);
List<TrackProgressDTO> progressList = new ArrayList<>();
for (Object value : progressMap.values()) {
TrackProgressDTO progress = objectMapper.readValue(
(String) value,
TrackProgressDTO.class);
progressList.add(progress);
}
progressList.sort(Comparator.comparingLong(TrackProgressDTO::getTimestamp));
return progressList;
} catch (Exception e) {
throw new RuntimeException("Failed to get progress from Redis", e);
}
}
public TrackProgressDTO getTrackProgress(String playlistId, String trackId, long userId) {
try {
String key = ImportTrackKey.trackKey(playlistId, trackId, userId);
String progressJson = redisTemplate.opsForValue().get(key);
if (progressJson != null) {
return objectMapper.readValue(progressJson, TrackProgressDTO.class);
}
return null;
} catch (Exception e) {
throw new RuntimeException("Failed to get track progress", e);
}
}
private void publishProgressUpdate(TrackProgressDTO progress) {
try {
String channel = String.format("progress_updates:%s:%s",
progress.getUserId(),
progress.getPlaylistId());
String message = objectMapper.writeValueAsString(progress);
redisTemplate.convertAndSend(channel, message);
} catch (Exception e) {
e.printStackTrace();
}
}
public void addActiveConnection(String playlistId, long userId) {
String connectionKey = String.format("%s:%s", Long.toString(userId), playlistId);
activeConnections.computeIfAbsent(connectionKey, k -> ConcurrentHashMap.newKeySet()).add(connectionKey);
}
public void removeActiveConnection(String playlistId, long userId) {
String connectionKey = String.format("%s:%s", Long.toString(userId), playlistId);
activeConnections.remove(connectionKey);
}
public boolean hasActiveConnections(String playlistId, long userId) {
String connectionKey = String.format("%s:%s", Long.toString(userId), playlistId);
return activeConnections.containsKey(connectionKey);
}
}

View File

@ -11,6 +11,8 @@ spring:
access-key: ${S3_ACCESS_KEY} access-key: ${S3_ACCESS_KEY}
secret-key: ${S3_SECRET_KEY} secret-key: ${S3_SECRET_KEY}
bucket: ${S3_BUCKET} bucket: ${S3_BUCKET}
redis:
host-name: ${REDIS_HOST_NAME}
servlet: servlet:
multipart: multipart:
max-file-size: 8096MB max-file-size: 8096MB