diff --git a/.env.example b/.env.example index 2d5ea2d..daf58a3 100644 --- a/.env.example +++ b/.env.example @@ -6,3 +6,5 @@ S3_ENDPOINT=http://s3:9000 S3_ACCESS_KEY=minioadmin S3_SECRET_KEY=minioadmin S3_BUCKET=composer-dev + +REDIS_HOST_NAME=redis diff --git a/pom.xml b/pom.xml index c0ec893..047fca2 100644 --- a/pom.xml +++ b/pom.xml @@ -90,6 +90,14 @@ software.amazon.awssdk s3 + + org.springframework.boot + spring-boot-starter-data-redis + + + org.springframework.boot + spring-boot-starter-webflux + me.paulschwarz @@ -123,6 +131,7 @@ runtime + org.springframework.boot spring-boot-devtools diff --git a/src/main/java/com/bivashy/backend/composer/config/RedisConfig.java b/src/main/java/com/bivashy/backend/composer/config/RedisConfig.java new file mode 100644 index 0000000..363c86e --- /dev/null +++ b/src/main/java/com/bivashy/backend/composer/config/RedisConfig.java @@ -0,0 +1,36 @@ +package com.bivashy.backend.composer.config; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.redis.connection.RedisStandaloneConfiguration; +import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.listener.RedisMessageListenerContainer; + +@Configuration +public class RedisConfig { + @Value("${spring.redis.host-name}") + private String hostName; + + @Bean + public LettuceConnectionFactory lettuceConnectionFactory() { + RedisStandaloneConfiguration config = new RedisStandaloneConfiguration(hostName); + return new LettuceConnectionFactory(config); + } + + @Bean + public RedisTemplate redisTemplate() { + RedisTemplate template = new RedisTemplate<>(); + template.setConnectionFactory(lettuceConnectionFactory()); + return template; + } + + @Bean + public RedisMessageListenerContainer redisContainer() { + RedisMessageListenerContainer container = new RedisMessageListenerContainer(); + container.setConnectionFactory(lettuceConnectionFactory()); + return container; + } + +} diff --git a/src/main/java/com/bivashy/backend/composer/controller/importing/ProgressSSEController.java b/src/main/java/com/bivashy/backend/composer/controller/importing/ProgressSSEController.java new file mode 100644 index 0000000..6805f8c --- /dev/null +++ b/src/main/java/com/bivashy/backend/composer/controller/importing/ProgressSSEController.java @@ -0,0 +1,116 @@ +package com.bivashy.backend.composer.controller.importing; + +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.security.core.annotation.AuthenticationPrincipal; +import org.springframework.web.bind.annotation.*; + +import com.bivashy.backend.composer.auth.CustomUserDetails; +import com.bivashy.backend.composer.dto.importing.TrackProgressDTO; +import com.bivashy.backend.composer.service.importing.RedisMessageSubscriber; +import com.bivashy.backend.composer.service.importing.RedisProgressService; +import com.fasterxml.jackson.databind.ObjectMapper; + +import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +@RestController +public class ProgressSSEController { + + private final RedisProgressService redisProgressService; + private final RedisMessageSubscriber redisSubscriber; + private final Map> sinks = new ConcurrentHashMap<>(); + + public ProgressSSEController(RedisProgressService redisProgressService, + RedisMessageSubscriber redisSubscriber) { + this.redisProgressService = redisProgressService; + this.redisSubscriber = redisSubscriber; + } + + @GetMapping("/importing/test/{playlistId}") + public void test(@PathVariable String playlistId, @AuthenticationPrincipal CustomUserDetails user) { + var userId = user.getId(); + redisProgressService.saveProgress(new TrackProgressDTO( + playlistId, + "test", + userId)); + } + + @GetMapping(value = "/importing/stream/{playlistId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux> streamProgress( + @PathVariable String playlistId, + @AuthenticationPrincipal CustomUserDetails user, + HttpServletResponse response) { + var userId = user.getId(); + + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setCharacterEncoding("UTF-8"); + + String connectionKey = getConnectionKey(playlistId, userId); + + Sinks.Many sink = sinks.computeIfAbsent(connectionKey, k -> { + Sinks.Many newSink = Sinks.many().replay().latest(); + + redisSubscriber.subscribeToPlaylist(playlistId, userId, message -> { + newSink.tryEmitNext(message); + }); + + return newSink; + }); + + redisProgressService.addActiveConnection(playlistId, userId); + + return sink.asFlux() + .map(data -> ServerSentEvent.builder() + .data(data) + .event("progress-update") + .build()) + .doFirst(() -> { + try { + List existingProgresses = redisProgressService.getPlaylistProgress(playlistId, + userId); + System.out.println(existingProgresses); + + ObjectMapper mapper = new ObjectMapper(); + for (TrackProgressDTO progress : existingProgresses) { + sink.tryEmitNext(mapper.writeValueAsString(progress)); + } + } catch (Exception e) { + e.printStackTrace(); + } + }) + .doOnCancel(() -> { + cleanupConnection(playlistId, userId, sink, connectionKey); + }) + .doOnTerminate(() -> { + cleanupConnection(playlistId, userId, sink, connectionKey); + }) + .timeout(Duration.ofHours(2)) + .onErrorResume(e -> { + cleanupConnection(playlistId, userId, sink, connectionKey); + return Flux.empty(); + }); + } + + private void cleanupConnection(String playlistId, long userId, + Sinks.Many sink, String connectionKey) { + try { + redisProgressService.removeActiveConnection(playlistId, userId); + redisSubscriber.unsubscribeFromPlaylist(playlistId, userId); + sinks.remove(connectionKey); + sink.tryEmitComplete(); + } catch (Exception e) { + e.printStackTrace(); + } + } + + private String getConnectionKey(String playlistId, long userId) { + return String.format("%s:%s", Long.toString(userId), playlistId); + } +} diff --git a/src/main/java/com/bivashy/backend/composer/dto/SourceType.java b/src/main/java/com/bivashy/backend/composer/dto/SourceType.java new file mode 100644 index 0000000..d0409c7 --- /dev/null +++ b/src/main/java/com/bivashy/backend/composer/dto/SourceType.java @@ -0,0 +1,27 @@ +package com.bivashy.backend.composer.dto; + +public enum SourceType { + AUDIO("AUDIO"), + PLAYLIST("PLAYLIST"), + FILE("FILE"), + URL("URL"); + + private final String value; + + SourceType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static SourceType fromValue(String value) { + for (SourceType type : values()) { + if (type.value.equalsIgnoreCase(value)) { + return type; + } + } + throw new IllegalArgumentException("Unknown source type: " + value); + } +} diff --git a/src/main/java/com/bivashy/backend/composer/dto/importing/ImportTrackKey.java b/src/main/java/com/bivashy/backend/composer/dto/importing/ImportTrackKey.java new file mode 100644 index 0000000..1ec87ca --- /dev/null +++ b/src/main/java/com/bivashy/backend/composer/dto/importing/ImportTrackKey.java @@ -0,0 +1,11 @@ +package com.bivashy.backend.composer.dto.importing; + +public class ImportTrackKey { + public static String progressKey(String playlistId, long userId) { + return String.format("progress:%s:%s", Long.toString(userId), playlistId); + } + + public static String trackKey(String playlistId, String trackId, long userId) { + return String.format("track:%s:%s:%s", Long.toString(userId), playlistId, trackId); + } +} diff --git a/src/main/java/com/bivashy/backend/composer/dto/importing/TrackProgressDTO.java b/src/main/java/com/bivashy/backend/composer/dto/importing/TrackProgressDTO.java new file mode 100644 index 0000000..91b922d --- /dev/null +++ b/src/main/java/com/bivashy/backend/composer/dto/importing/TrackProgressDTO.java @@ -0,0 +1,116 @@ +package com.bivashy.backend.composer.dto.importing; + +public class TrackProgressDTO { + private String playlistId; + private String trackId; + private String trackTitle; + private String format; + private String sourceType; + private int progress; + private String metadata; + private Long timestamp; + private long userId; + + public TrackProgressDTO() { + } + + public TrackProgressDTO(String playlistId, String trackId, long userId) { + this.playlistId = playlistId; + this.trackId = trackId; + this.userId = userId; + this.timestamp = System.currentTimeMillis(); + } + + public TrackProgressDTO(String playlistId, + String trackId, + String trackTitle, + String format, + String sourceType, + int progress, + String metadata, + Long timestamp, + long userId) { + this.playlistId = playlistId; + this.trackId = trackId; + this.trackTitle = trackTitle; + this.format = format; + this.sourceType = sourceType; + this.progress = progress; + this.metadata = metadata; + this.timestamp = timestamp; + this.userId = userId; + } + + public String getPlaylistId() { + return playlistId; + } + + public void setPlaylistId(String playlistId) { + this.playlistId = playlistId; + } + + public String getTrackId() { + return trackId; + } + + public void setTrackId(String trackId) { + this.trackId = trackId; + } + + public String getTrackTitle() { + return trackTitle; + } + + public void setTrackTitle(String trackTitle) { + this.trackTitle = trackTitle; + } + + public String getFormat() { + return format; + } + + public void setFormat(String format) { + this.format = format; + } + + public String getSourceType() { + return sourceType; + } + + public void setSourceType(String sourceType) { + this.sourceType = sourceType; + } + + public int getProgress() { + return progress; + } + + public void setProgress(int progress) { + this.progress = progress; + } + + public String getMetadata() { + return metadata; + } + + public void setMetadata(String metadata) { + this.metadata = metadata; + } + + public Long getTimestamp() { + return timestamp; + } + + public void setTimestamp(Long timestamp) { + this.timestamp = timestamp; + } + + public long getUserId() { + return userId; + } + + public void setUserId(long userId) { + this.userId = userId; + } + +} diff --git a/src/main/java/com/bivashy/backend/composer/service/importing/RedisMessageSubscriber.java b/src/main/java/com/bivashy/backend/composer/service/importing/RedisMessageSubscriber.java new file mode 100644 index 0000000..b90101d --- /dev/null +++ b/src/main/java/com/bivashy/backend/composer/service/importing/RedisMessageSubscriber.java @@ -0,0 +1,50 @@ +package com.bivashy.backend.composer.service.importing; + +import org.springframework.data.redis.connection.Message; +import org.springframework.data.redis.connection.MessageListener; +import org.springframework.data.redis.listener.ChannelTopic; +import org.springframework.data.redis.listener.RedisMessageListenerContainer; +import org.springframework.stereotype.Component; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +@Component +public class RedisMessageSubscriber { + + private final RedisMessageListenerContainer container; + private final Map> subscriptions = new ConcurrentHashMap<>(); + + public RedisMessageSubscriber(RedisMessageListenerContainer container) { + this.container = container; + } + + public void subscribeToPlaylist(String playlistId, long userId, Consumer messageHandler) { + String channel = String.format("progress_updates:%s:%s", userId, playlistId); + String subscriptionKey = getSubscriptionKey(playlistId, userId); + + if (!subscriptions.containsKey(subscriptionKey)) { + container.addMessageListener(new MessageListener() { + @Override + public void onMessage(Message message, byte[] pattern) { + String receivedMessage = new String(message.getBody()); + if (subscriptions.containsKey(subscriptionKey)) { + messageHandler.accept(receivedMessage); + } + } + }, new ChannelTopic(channel)); + + subscriptions.put(subscriptionKey, messageHandler); + } + } + + public void unsubscribeFromPlaylist(String playlistId, long userId) { + String subscriptionKey = getSubscriptionKey(playlistId, userId); + subscriptions.remove(subscriptionKey); + } + + private String getSubscriptionKey(String playlistId, long userId) { + return String.format("%s:%s", Long.toString(userId), playlistId); + } +} diff --git a/src/main/java/com/bivashy/backend/composer/service/importing/RedisProgressService.java b/src/main/java/com/bivashy/backend/composer/service/importing/RedisProgressService.java new file mode 100644 index 0000000..81ce971 --- /dev/null +++ b/src/main/java/com/bivashy/backend/composer/service/importing/RedisProgressService.java @@ -0,0 +1,109 @@ +package com.bivashy.backend.composer.service.importing; + +import com.bivashy.backend.composer.dto.importing.ImportTrackKey; +import com.bivashy.backend.composer.dto.importing.TrackProgressDTO; +import com.fasterxml.jackson.databind.ObjectMapper; + +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.stereotype.Service; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; + +@Service +public class RedisProgressService { + private final StringRedisTemplate redisTemplate; + private final ObjectMapper objectMapper; + private final Map> activeConnections = new ConcurrentHashMap<>(); + + public RedisProgressService(StringRedisTemplate redisTemplate, + ObjectMapper objectMapper) { + this.redisTemplate = redisTemplate; + this.objectMapper = objectMapper; + } + + public void saveProgress(TrackProgressDTO progress) { + try { + String key = ImportTrackKey.progressKey(progress.getPlaylistId(), progress.getUserId()); + String trackKey = ImportTrackKey.trackKey( + progress.getPlaylistId(), + progress.getTrackId(), + progress.getUserId()); + + String progressJson = objectMapper.writeValueAsString(progress); + redisTemplate.opsForHash().put(key, progress.getTrackId(), progressJson); + + redisTemplate.opsForValue().set(trackKey, progressJson); + + redisTemplate.expire(key, 24, java.util.concurrent.TimeUnit.HOURS); + redisTemplate.expire(trackKey, 24, java.util.concurrent.TimeUnit.HOURS); + + publishProgressUpdate(progress); + } catch (Exception e) { + throw new RuntimeException("Failed to save progress to Redis", e); + } + } + + public List getPlaylistProgress(String playlistId, long userId) { + try { + String key = ImportTrackKey.progressKey(playlistId, userId); + Map progressMap = redisTemplate.opsForHash().entries(key); + + List progressList = new ArrayList<>(); + for (Object value : progressMap.values()) { + TrackProgressDTO progress = objectMapper.readValue( + (String) value, + TrackProgressDTO.class); + progressList.add(progress); + } + + progressList.sort(Comparator.comparingLong(TrackProgressDTO::getTimestamp)); + + return progressList; + } catch (Exception e) { + throw new RuntimeException("Failed to get progress from Redis", e); + } + } + + public TrackProgressDTO getTrackProgress(String playlistId, String trackId, long userId) { + try { + String key = ImportTrackKey.trackKey(playlistId, trackId, userId); + String progressJson = redisTemplate.opsForValue().get(key); + + if (progressJson != null) { + return objectMapper.readValue(progressJson, TrackProgressDTO.class); + } + return null; + } catch (Exception e) { + throw new RuntimeException("Failed to get track progress", e); + } + } + + private void publishProgressUpdate(TrackProgressDTO progress) { + try { + String channel = String.format("progress_updates:%s:%s", + progress.getUserId(), + progress.getPlaylistId()); + + String message = objectMapper.writeValueAsString(progress); + redisTemplate.convertAndSend(channel, message); + } catch (Exception e) { + e.printStackTrace(); + } + } + + public void addActiveConnection(String playlistId, long userId) { + String connectionKey = String.format("%s:%s", Long.toString(userId), playlistId); + activeConnections.computeIfAbsent(connectionKey, k -> ConcurrentHashMap.newKeySet()).add(connectionKey); + } + + public void removeActiveConnection(String playlistId, long userId) { + String connectionKey = String.format("%s:%s", Long.toString(userId), playlistId); + activeConnections.remove(connectionKey); + } + + public boolean hasActiveConnections(String playlistId, long userId) { + String connectionKey = String.format("%s:%s", Long.toString(userId), playlistId); + return activeConnections.containsKey(connectionKey); + } +} diff --git a/src/main/resources/application.yaml b/src/main/resources/application.yaml index cee3873..8260a01 100644 --- a/src/main/resources/application.yaml +++ b/src/main/resources/application.yaml @@ -11,6 +11,8 @@ spring: access-key: ${S3_ACCESS_KEY} secret-key: ${S3_SECRET_KEY} bucket: ${S3_BUCKET} + redis: + host-name: ${REDIS_HOST_NAME} servlet: multipart: max-file-size: 8096MB