Improve ProgressSSEController OpenAPI, and progress

This commit is contained in:
2026-01-08 01:09:41 +05:00
parent 2baa79c3a9
commit aef8980786
11 changed files with 159 additions and 55 deletions

View File

@ -35,7 +35,7 @@
<apache-tika.version>3.2.3</apache-tika.version>
<springdoc-openapi.version>2.8.5</springdoc-openapi.version>
<jaffree.version>2024.08.29</jaffree.version>
<yt-dlp-java.version>2.0.6</yt-dlp-java.version>
<yt-dlp-java.version>2.0.7</yt-dlp-java.version>
<record-builder.version>51</record-builder.version>
</properties>
<repositories>
@ -121,6 +121,11 @@
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
<version>${springdoc-openapi.version}</version>
</dependency>
<dependency>
<groupId>org.springdoc</groupId>
<artifactId>springdoc-openapi-starter-webflux-ui</artifactId>
<version>${springdoc-openapi.version}</version>
</dependency>
<dependency>
<groupId>com.github.kokorin.jaffree</groupId>
<artifactId>jaffree</artifactId>

View File

@ -4,8 +4,9 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
@ -16,7 +17,6 @@ import com.bivashy.backend.composer.dto.importing.BaseTrackProgress;
import com.bivashy.backend.composer.dto.importing.ImportTrackKey;
import com.bivashy.backend.composer.service.importing.RedisMessageSubscriber;
import com.bivashy.backend.composer.service.importing.RedisProgressService;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.http.HttpServletResponse;
import reactor.core.publisher.Flux;
@ -24,10 +24,11 @@ import reactor.core.publisher.Sinks;
@RestController
public class ProgressSSEController {
private static final Logger logger = LoggerFactory.getLogger(ProgressSSEController.class);
private final RedisProgressService redisProgressService;
private final RedisMessageSubscriber redisSubscriber;
private final Map<String, Sinks.Many<String>> sinks = new ConcurrentHashMap<>();
private final Map<String, Sinks.Many<BaseTrackProgress>> sinks = new ConcurrentHashMap<>();
public ProgressSSEController(RedisProgressService redisProgressService,
RedisMessageSubscriber redisSubscriber) {
@ -36,7 +37,7 @@ public class ProgressSSEController {
}
@GetMapping(value = "/importing/stream/{playlistId}", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> streamProgress(
public Flux<BaseTrackProgress> streamProgress(
@PathVariable long playlistId,
@AuthenticationPrincipal CustomUserDetails user,
HttpServletResponse response) {
@ -48,8 +49,8 @@ public class ProgressSSEController {
String connectionKey = ImportTrackKey.subscriptionKey(playlistId, userId);
Sinks.Many<String> sink = sinks.computeIfAbsent(connectionKey, k -> {
Sinks.Many<String> newSink = Sinks.many().replay().latest();
Sinks.Many<BaseTrackProgress> sink = sinks.computeIfAbsent(connectionKey, k -> {
Sinks.Many<BaseTrackProgress> newSink = Sinks.many().replay().latest();
redisSubscriber.subscribeToPlaylist(playlistId, userId, message -> {
newSink.tryEmitNext(message);
@ -61,19 +62,14 @@ public class ProgressSSEController {
redisProgressService.addActiveConnection(playlistId, userId);
return sink.asFlux()
.map(data -> ServerSentEvent.<String>builder()
.data(data)
.event("progress-update")
.build())
.doFirst(() -> {
try {
List<BaseTrackProgress> existingProgresses = redisProgressService.getPlaylistProgress(
playlistId,
userId);
ObjectMapper mapper = new ObjectMapper();
for (BaseTrackProgress progress : existingProgresses) {
sink.tryEmitNext(mapper.writeValueAsString(progress));
sink.tryEmitNext(progress);
}
} catch (Exception e) {
e.printStackTrace();
@ -92,7 +88,7 @@ public class ProgressSSEController {
}
private void cleanupConnection(Long playlistId, long userId,
Sinks.Many<String> sink, String connectionKey) {
Sinks.Many<BaseTrackProgress> sink, String connectionKey) {
try {
redisProgressService.removeActiveConnection(playlistId, userId);
redisSubscriber.unsubscribeFromPlaylist(playlistId, userId);

View File

@ -10,15 +10,15 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
})
public abstract class BaseTrackProgress {
protected long playlistId;
protected long trackId;
protected long trackSourceId;
protected long userId;
protected long timestamp;
private String type;
public BaseTrackProgress(long playlistId, long trackId, long userId) {
public BaseTrackProgress(long playlistId, long trackSourceId, long userId) {
this.playlistId = playlistId;
this.trackId = trackId;
this.trackSourceId = trackSourceId;
this.userId = userId;
this.timestamp = System.currentTimeMillis();
}
@ -39,8 +39,8 @@ public abstract class BaseTrackProgress {
return playlistId;
}
public long getTrackId() {
return trackId;
public long getTrackSourceId() {
return trackSourceId;
}
protected void setType(ProgressEntryType type) {

View File

@ -5,8 +5,8 @@ public class ImportTrackKey {
return String.format("progress:%d:%d", userId, playlistId);
}
public static String trackKey(long playlistId, long trackId, long userId) {
return String.format("track:%d:%d:%d", userId, playlistId, trackId);
public static String trackKey(long playlistId, long trackSourceId, long userId) {
return String.format("track:%d:%d:%d", userId, playlistId, trackSourceId);
}
public static String redisChannelKey(long playlistId, long userId) {

View File

@ -3,12 +3,18 @@ package com.bivashy.backend.composer.dto.importing;
public class PlaylistProgress extends BaseTrackProgress {
private String ytdlnStdout;
private int overallProgress;
private String status;
private int trackCount;
private ProgressStatus status;
public PlaylistProgress(long playlistId, long trackId, long userId) {
super(playlistId, trackId, userId);
PlaylistProgress() {
super(0, 0, 0);
}
public PlaylistProgress(long playlistId, long trackSourceId, long userId, int trackCount) {
super(playlistId, trackSourceId, userId);
this.setType(ProgressEntryType.PLAYLIST);
this.status = "LOADING";
this.status = ProgressStatus.LOADING;
this.trackCount = trackCount;
}
public String getYtdlnStdout() {
@ -27,11 +33,16 @@ public class PlaylistProgress extends BaseTrackProgress {
this.overallProgress = overallProgress;
}
public String getStatus() {
public ProgressStatus getStatus() {
return status;
}
public void setStatus(String status) {
public void setStatus(ProgressStatus status) {
this.status = status;
}
public int getTrackCount() {
return trackCount;
}
}

View File

@ -0,0 +1,5 @@
package com.bivashy.backend.composer.dto.importing;
public enum ProgressStatus {
LOADING, FINISHED
}

View File

@ -4,8 +4,12 @@ public class SingleTrackProgress extends BaseTrackProgress {
private String title;
private String format;
public SingleTrackProgress(long playlistId, long trackId, long userId, String title, String format) {
super(playlistId, trackId, userId);
SingleTrackProgress() {
super(0, 0, 0);
}
public SingleTrackProgress(long playlistId, long trackSourceId, long userId, String title, String format) {
super(playlistId, trackSourceId, userId);
this.setType(ProgressEntryType.TRACK);
this.title = title;
this.format = format;

View File

@ -104,7 +104,8 @@ public class TrackService {
if (params.includeProgressHistory()) {
redisProgressService
.saveProgress(new SingleTrackProgress(playlistId, track.getId(), user.getId(), title, fileFormat));
.saveProgress(
new SingleTrackProgress(playlistId, trackSource.getId(), user.getId(), title, fileFormat));
}
return new TrackResponse(
@ -120,7 +121,7 @@ public class TrackService {
@Transactional
public List<TrackResponse> refreshYoutubePlaylist(CustomUserDetails user, long playlistId, long sourceId)
throws ImportTrackException {
return youtubeTrackService.refreshYoutubePlaylist(playlistId, sourceId);
return youtubeTrackService.refreshYoutubePlaylist(user, playlistId, sourceId);
}
@Transactional
@ -163,7 +164,8 @@ public class TrackService {
TrackSource trackSource = trackSourceService.createYoutubeTrackSource(SourceType.PLAYLIST,
request.youtubeUrl());
return youtubeTrackService.refreshYoutubePlaylist(playlistId, trackSource, videoInfos, request.youtubeUrl());
return youtubeTrackService.refreshYoutubePlaylist(user.getId(), playlistId, trackSource, videoInfos,
request.youtubeUrl());
}
public List<PlaylistTrackResponse> getPlaylistTracks(CustomUserDetails user, Long playlistId) {

View File

@ -15,6 +15,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import com.bivashy.backend.composer.auth.CustomUserDetails;
import com.bivashy.backend.composer.dto.importing.PlaylistProgress;
import com.bivashy.backend.composer.dto.importing.ProgressStatus;
import com.bivashy.backend.composer.dto.track.TrackResponse;
import com.bivashy.backend.composer.dto.track.service.AddLocalTrackParams;
import com.bivashy.backend.composer.dto.track.service.AddLocalTrackParamsBuilder;
@ -25,6 +28,7 @@ import com.bivashy.backend.composer.model.TrackSource;
import com.bivashy.backend.composer.model.TrackSourceMetadata;
import com.bivashy.backend.composer.repository.TrackRepository;
import com.bivashy.backend.composer.service.MetadataParseService.Metadata;
import com.bivashy.backend.composer.service.importing.RedisProgressService;
import com.bivashy.backend.composer.util.SimpleBlob.PathBlob;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.jfposton.ytdlp.YtDlp;
@ -46,16 +50,19 @@ public class YoutubeTrackService {
private final TrackMetadataService trackMetadataService;
private final TrackPlaylistService trackPlaylistService;
private final TrackSourceService trackSourceService;
private final RedisProgressService redisProgressService;
public YoutubeTrackService(AudioS3StorageService s3StorageService, MetadataParseService metadataParseService,
TrackRepository trackRepository, TrackMetadataService trackMetadataService,
TrackPlaylistService trackPlaylistService, TrackSourceService trackSourceService) {
TrackPlaylistService trackPlaylistService, TrackSourceService trackSourceService,
RedisProgressService redisProgressService) {
this.s3StorageService = s3StorageService;
this.metadataParseService = metadataParseService;
this.trackRepository = trackRepository;
this.trackMetadataService = trackMetadataService;
this.trackPlaylistService = trackPlaylistService;
this.trackSourceService = trackSourceService;
this.redisProgressService = redisProgressService;
}
public AddLocalTrackParams downloadYoutubeTrack(Path temporaryFolder, VideoInfo videoInfo, String youtubeUrl)
@ -85,7 +92,8 @@ public class YoutubeTrackService {
throw new ImportTrackException("cannot download any youtube track");
}
public List<TrackResponse> refreshYoutubePlaylist(long playlistId, long sourceId) throws ImportTrackException {
public List<TrackResponse> refreshYoutubePlaylist(CustomUserDetails user, long playlistId, long sourceId)
throws ImportTrackException {
Optional<TrackSourceMetadata> trackSourceMetadataOpt = trackSourceService.findWithMetadata(sourceId);
if (trackSourceMetadataOpt.isEmpty())
throw new ImportTrackException("cannot find track source with metadata with id " + sourceId);
@ -98,10 +106,11 @@ public class YoutubeTrackService {
} catch (YtDlpException e) {
throw new ImportTrackException("cannot `yt-dlp --dump-json` from " + youtubeUrl, e);
}
return refreshYoutubePlaylist(playlistId, trackSourceMetadata.getSource(), videoInfos, youtubeUrl);
return refreshYoutubePlaylist(user.getId(), playlistId, trackSourceMetadata.getSource(), videoInfos,
youtubeUrl);
}
public List<TrackResponse> refreshYoutubePlaylist(long playlistId, TrackSource trackSource,
public List<TrackResponse> refreshYoutubePlaylist(long userId, long playlistId, TrackSource trackSource,
List<VideoInfo> videoInfos,
String youtubeUrl) throws ImportTrackException {
List<TrackResponse> result = new ArrayList<>();
@ -126,10 +135,29 @@ public class YoutubeTrackService {
ytDlpRequest.setOption("audio-quality", 0);
ytDlpRequest.setOption("audio-format", "best");
ytDlpRequest.setOption("no-overwrites");
var response = YtDlp.execute(ytDlpRequest);
logger.info("yt dlp response {}", response);
// TODO: write to RedisProgressService
PlaylistProgress playlistProgress = new PlaylistProgress(playlistId, trackSource.getId(), userId,
videoInfos.size());
redisProgressService.saveProgress(playlistProgress);
var response = YtDlp.execute(ytDlpRequest, (downloadProgress, ignored) -> {
redisProgressService.<PlaylistProgress>updateTrackProgressField(playlistId, trackSource.getId(), userId,
progress -> {
progress.setOverallProgress((int) downloadProgress);
});
}, stdoutLine -> {
redisProgressService.<PlaylistProgress>updateTrackProgressField(playlistId, trackSource.getId(), userId,
progress -> {
progress.setYtdlnStdout(String.join("\n", progress.getYtdlnStdout(), stdoutLine));
});
}, null);
redisProgressService.<PlaylistProgress>updateTrackProgressField(playlistId, trackSource.getId(), userId,
progress -> {
progress.setOverallProgress(100);
progress.setStatus(ProgressStatus.FINISHED);
});
logger.info("yt dlp response {}", response);
try (Stream<Path> pathStream = Files.walk(temporaryFolder)) {
List<Path> downloadedFiles = Files.walk(temporaryFolder).toList();

View File

@ -1,28 +1,35 @@
package com.bivashy.backend.composer.service.importing;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.stereotype.Component;
import com.bivashy.backend.composer.dto.importing.BaseTrackProgress;
import com.bivashy.backend.composer.dto.importing.ImportTrackKey;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
@Component
public class RedisMessageSubscriber {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final Logger logger = LoggerFactory.getLogger(Logger.class);
private final RedisMessageListenerContainer container;
private final Map<String, Consumer<String>> subscriptions = new ConcurrentHashMap<>();
private final Map<String, Consumer<BaseTrackProgress>> subscriptions = new ConcurrentHashMap<>();
public RedisMessageSubscriber(RedisMessageListenerContainer container) {
this.container = container;
}
public void subscribeToPlaylist(long playlistId, long userId, Consumer<String> messageHandler) {
public void subscribeToPlaylist(long playlistId, long userId, Consumer<BaseTrackProgress> messageHandler) {
String channel = ImportTrackKey.redisChannelKey(playlistId, userId);
String subscriptionKey = ImportTrackKey.subscriptionKey(playlistId, userId);
@ -32,7 +39,13 @@ public class RedisMessageSubscriber {
public void onMessage(Message message, byte[] pattern) {
String receivedMessage = new String(message.getBody());
if (subscriptions.containsKey(subscriptionKey)) {
messageHandler.accept(receivedMessage);
try {
BaseTrackProgress progress = OBJECT_MAPPER.readValue(receivedMessage,
BaseTrackProgress.class);
messageHandler.accept(progress);
} catch (JsonProcessingException e) {
logger.error("cannot deserialize message into BaseTrackProgress.class", e);
}
}
}
}, new ChannelTopic(channel));

View File

@ -1,14 +1,22 @@
package com.bivashy.backend.composer.service.importing;
import com.bivashy.backend.composer.dto.importing.BaseTrackProgress;
import com.bivashy.backend.composer.dto.importing.ImportTrackKey;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import com.bivashy.backend.composer.dto.importing.BaseTrackProgress;
import com.bivashy.backend.composer.dto.importing.ImportTrackKey;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
@Service
public class RedisProgressService {
@ -27,11 +35,11 @@ public class RedisProgressService {
String key = ImportTrackKey.progressKey(progress.getPlaylistId(), progress.getUserId());
String trackKey = ImportTrackKey.trackKey(
progress.getPlaylistId(),
progress.getTrackId(),
progress.getTrackSourceId(),
progress.getUserId());
String progressJson = objectMapper.writeValueAsString(progress);
redisTemplate.opsForHash().put(key, Long.toString(progress.getTrackId()), progressJson);
redisTemplate.opsForHash().put(key, Long.toString(progress.getTrackSourceId()), progressJson);
redisTemplate.opsForValue().set(trackKey, progressJson);
@ -44,6 +52,38 @@ public class RedisProgressService {
}
}
public <T extends BaseTrackProgress> void updateTrackProgressField(long playlistId, long trackSourceId, long userId,
Consumer<T> updater) {
try {
String trackKey = ImportTrackKey.trackKey(playlistId, trackSourceId, userId);
String hashKey = ImportTrackKey.progressKey(playlistId, userId);
String existingJson = redisTemplate.opsForValue().get(trackKey);
if (existingJson == null) {
throw new RuntimeException("Track progress not found");
}
JavaType progressType = objectMapper.getTypeFactory()
.constructType(BaseTrackProgress.class);
T progress = objectMapper.readValue(existingJson, progressType);
updater.accept(progress);
String updatedJson = objectMapper.writeValueAsString(progress);
redisTemplate.opsForHash().put(hashKey, Long.toString(trackSourceId), updatedJson);
redisTemplate.opsForValue().set(trackKey, updatedJson);
redisTemplate.expire(hashKey, 24, TimeUnit.HOURS);
redisTemplate.expire(trackKey, 24, TimeUnit.HOURS);
publishProgressUpdate(progress);
} catch (Exception e) {
throw new RuntimeException("Failed to update track progress", e);
}
}
public List<BaseTrackProgress> getPlaylistProgress(long playlistId, long userId) {
try {
String key = ImportTrackKey.progressKey(playlistId, userId);
@ -65,9 +105,9 @@ public class RedisProgressService {
}
}
public BaseTrackProgress getTrackProgress(long playlistId, long trackId, long userId) {
public BaseTrackProgress getTrackProgress(long playlistId, long trackSourceId, long userId) {
try {
String key = ImportTrackKey.trackKey(playlistId, trackId, userId);
String key = ImportTrackKey.trackKey(playlistId, trackSourceId, userId);
String progressJson = redisTemplate.opsForValue().get(key);
if (progressJson != null) {