376 lines
13 KiB
Python
376 lines
13 KiB
Python
"""
|
|
Parallel download management for concurrent video downloads.
|
|
Handles thread-safe operations, progress tracking, and error handling.
|
|
"""
|
|
|
|
import concurrent.futures
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
from dataclasses import dataclass, field
|
|
|
|
from karaoke_downloader.config_manager import AppConfig
|
|
from karaoke_downloader.download_pipeline import DownloadPipeline
|
|
from karaoke_downloader.tracking_manager import TrackingManager
|
|
|
|
|
|
@dataclass
|
|
class DownloadTask:
|
|
"""Represents a single download task with all necessary information."""
|
|
video_id: str
|
|
artist: str
|
|
title: str
|
|
channel_name: str
|
|
video_title: Optional[str] = None
|
|
priority: int = 0 # Higher number = higher priority
|
|
retry_count: int = 0
|
|
max_retries: int = 3
|
|
created_at: float = field(default_factory=time.time)
|
|
|
|
def __post_init__(self):
|
|
if self.created_at == 0:
|
|
self.created_at = time.time()
|
|
|
|
|
|
@dataclass
|
|
class DownloadResult:
|
|
"""Result of a download operation."""
|
|
task: DownloadTask
|
|
success: bool
|
|
error_message: Optional[str] = None
|
|
file_path: Optional[Path] = None
|
|
download_time: float = 0.0
|
|
file_size: Optional[int] = None
|
|
|
|
|
|
class ParallelDownloader:
|
|
"""
|
|
Manages parallel downloads with thread-safe operations and progress tracking.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
yt_dlp_path: str,
|
|
config: Union[AppConfig, Dict[str, Any]],
|
|
downloads_dir: Path,
|
|
max_workers: int = 3,
|
|
songlist_tracking: Optional[Dict] = None,
|
|
tracker: Optional[TrackingManager] = None,
|
|
):
|
|
"""
|
|
Initialize the parallel downloader.
|
|
|
|
Args:
|
|
yt_dlp_path: Path to yt-dlp executable
|
|
config: Configuration object or dictionary
|
|
downloads_dir: Base downloads directory
|
|
max_workers: Maximum number of concurrent downloads
|
|
songlist_tracking: Optional songlist tracking data
|
|
tracker: Optional tracking manager
|
|
"""
|
|
self.yt_dlp_path = yt_dlp_path
|
|
self.config = config
|
|
self.downloads_dir = downloads_dir
|
|
self.max_workers = max_workers
|
|
self.songlist_tracking = songlist_tracking or {}
|
|
self.tracker = tracker
|
|
|
|
# Thread-safe state management
|
|
self._lock = threading.Lock()
|
|
self._active_downloads = 0
|
|
self._completed_downloads = 0
|
|
self._failed_downloads = 0
|
|
self._total_downloads = 0
|
|
self._start_time = None
|
|
|
|
# Progress tracking
|
|
self._progress_callbacks = []
|
|
self._download_queue = []
|
|
self._results = []
|
|
|
|
# Create download pipeline
|
|
self.pipeline = DownloadPipeline(
|
|
yt_dlp_path=yt_dlp_path,
|
|
config=config,
|
|
downloads_dir=downloads_dir,
|
|
songlist_tracking=songlist_tracking,
|
|
tracker=tracker,
|
|
)
|
|
|
|
def add_progress_callback(self, callback) -> None:
|
|
"""Add a progress callback function."""
|
|
with self._lock:
|
|
self._progress_callbacks.append(callback)
|
|
|
|
def _notify_progress(self, message: str, **kwargs) -> None:
|
|
"""Notify all progress callbacks."""
|
|
with self._lock:
|
|
for callback in self._progress_callbacks:
|
|
try:
|
|
callback(message, **kwargs)
|
|
except Exception as e:
|
|
print(f"⚠️ Progress callback error: {e}")
|
|
|
|
def add_download_task(self, task: DownloadTask) -> None:
|
|
"""Add a download task to the queue."""
|
|
with self._lock:
|
|
self._download_queue.append(task)
|
|
self._total_downloads += 1
|
|
|
|
def add_download_tasks(self, tasks: List[DownloadTask]) -> None:
|
|
"""Add multiple download tasks to the queue."""
|
|
with self._lock:
|
|
self._download_queue.extend(tasks)
|
|
self._total_downloads += len(tasks)
|
|
|
|
def _download_single_task(self, task: DownloadTask) -> DownloadResult:
|
|
"""Execute a single download task."""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
with self._lock:
|
|
self._active_downloads += 1
|
|
|
|
self._notify_progress(
|
|
"Starting download",
|
|
task=task,
|
|
active_downloads=self._active_downloads,
|
|
total_downloads=self._total_downloads
|
|
)
|
|
|
|
# Execute the download pipeline
|
|
success = self.pipeline.execute_pipeline(
|
|
video_id=task.video_id,
|
|
artist=task.artist,
|
|
title=task.title,
|
|
channel_name=task.channel_name,
|
|
video_title=task.video_title,
|
|
)
|
|
|
|
download_time = time.time() - start_time
|
|
|
|
# Determine file path if successful
|
|
file_path = None
|
|
file_size = None
|
|
if success:
|
|
filename = f"{task.artist} - {task.title}.mp4"
|
|
file_path = self.downloads_dir / task.channel_name / filename
|
|
if file_path.exists():
|
|
file_size = file_path.stat().st_size
|
|
|
|
result = DownloadResult(
|
|
task=task,
|
|
success=success,
|
|
file_path=file_path,
|
|
download_time=download_time,
|
|
file_size=file_size,
|
|
)
|
|
|
|
with self._lock:
|
|
if success:
|
|
self._completed_downloads += 1
|
|
else:
|
|
self._failed_downloads += 1
|
|
self._active_downloads -= 1
|
|
|
|
self._notify_progress(
|
|
"Download completed" if success else "Download failed",
|
|
result=result,
|
|
active_downloads=self._active_downloads,
|
|
completed_downloads=self._completed_downloads,
|
|
failed_downloads=self._failed_downloads,
|
|
total_downloads=self._total_downloads
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
download_time = time.time() - start_time
|
|
|
|
with self._lock:
|
|
self._failed_downloads += 1
|
|
self._active_downloads -= 1
|
|
|
|
result = DownloadResult(
|
|
task=task,
|
|
success=False,
|
|
error_message=str(e),
|
|
download_time=download_time,
|
|
)
|
|
|
|
self._notify_progress(
|
|
"Download error",
|
|
result=result,
|
|
active_downloads=self._active_downloads,
|
|
completed_downloads=self._completed_downloads,
|
|
failed_downloads=self._failed_downloads,
|
|
total_downloads=self._total_downloads
|
|
)
|
|
|
|
return result
|
|
|
|
def _retry_failed_downloads(self, failed_results: List[DownloadResult]) -> List[DownloadResult]:
|
|
"""Retry failed downloads up to their max retry count."""
|
|
retry_tasks = []
|
|
|
|
for result in failed_results:
|
|
if result.task.retry_count < result.task.max_retries:
|
|
result.task.retry_count += 1
|
|
retry_tasks.append(result.task)
|
|
|
|
if not retry_tasks:
|
|
return []
|
|
|
|
print(f"🔄 Retrying {len(retry_tasks)} failed downloads...")
|
|
|
|
# Execute retries with reduced concurrency to avoid overwhelming the system
|
|
retry_workers = max(1, self.max_workers // 2)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=retry_workers) as executor:
|
|
future_to_task = {
|
|
executor.submit(self._download_single_task, task): task
|
|
for task in retry_tasks
|
|
}
|
|
|
|
retry_results = []
|
|
for future in concurrent.futures.as_completed(future_to_task):
|
|
result = future.result()
|
|
retry_results.append(result)
|
|
|
|
return retry_results
|
|
|
|
def execute_downloads(self, show_progress: bool = True) -> List[DownloadResult]:
|
|
"""
|
|
Execute all queued downloads in parallel.
|
|
|
|
Args:
|
|
show_progress: Whether to show progress information
|
|
|
|
Returns:
|
|
List of download results
|
|
"""
|
|
if not self._download_queue:
|
|
print("📭 No downloads queued.")
|
|
return []
|
|
|
|
# Sort tasks by priority (higher priority first)
|
|
with self._lock:
|
|
self._download_queue.sort(key=lambda x: x.priority, reverse=True)
|
|
tasks = self._download_queue.copy()
|
|
self._download_queue.clear()
|
|
|
|
self._start_time = time.time()
|
|
self._results = []
|
|
|
|
print(f"🚀 Starting parallel downloads with {self.max_workers} workers...")
|
|
print(f"📋 Total tasks: {len(tasks)}")
|
|
|
|
# Progress display thread
|
|
progress_thread = None
|
|
if show_progress:
|
|
progress_thread = threading.Thread(
|
|
target=self._progress_display_loop,
|
|
daemon=True
|
|
)
|
|
progress_thread.start()
|
|
|
|
try:
|
|
# Execute downloads in parallel
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
future_to_task = {
|
|
executor.submit(self._download_single_task, task): task
|
|
for task in tasks
|
|
}
|
|
|
|
for future in concurrent.futures.as_completed(future_to_task):
|
|
result = future.result()
|
|
self._results.append(result)
|
|
|
|
# Handle retries for failed downloads
|
|
failed_results = [r for r in self._results if not r.success]
|
|
if failed_results:
|
|
retry_results = self._retry_failed_downloads(failed_results)
|
|
self._results.extend(retry_results)
|
|
|
|
finally:
|
|
# Stop progress display
|
|
if progress_thread and progress_thread.is_alive():
|
|
self._stop_progress = True
|
|
progress_thread.join(timeout=1)
|
|
|
|
# Final summary
|
|
total_time = time.time() - self._start_time
|
|
successful = len([r for r in self._results if r.success])
|
|
failed = len([r for r in self._results if not r.success])
|
|
|
|
print(f"\n🎉 Parallel downloads completed!")
|
|
print(f" ✅ Successful: {successful}")
|
|
print(f" ❌ Failed: {failed}")
|
|
print(f" ⏱️ Total time: {total_time:.1f}s")
|
|
print(f" 📊 Average time per download: {total_time/len(tasks):.1f}s")
|
|
|
|
return self._results
|
|
|
|
def _progress_display_loop(self) -> None:
|
|
"""Display progress updates in a separate thread."""
|
|
self._stop_progress = False
|
|
|
|
while not self._stop_progress:
|
|
with self._lock:
|
|
active = self._active_downloads
|
|
completed = self._completed_downloads
|
|
failed = self._failed_downloads
|
|
total = self._total_downloads
|
|
|
|
if total > 0:
|
|
progress = (completed + failed) / total * 100
|
|
print(f"\r📊 Progress: {progress:.1f}% | Active: {active} | Completed: {completed} | Failed: {failed} | Total: {total}", end="", flush=True)
|
|
|
|
time.sleep(1)
|
|
|
|
print() # New line after progress display
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get current download statistics."""
|
|
with self._lock:
|
|
return {
|
|
"active_downloads": self._active_downloads,
|
|
"completed_downloads": self._completed_downloads,
|
|
"failed_downloads": self._failed_downloads,
|
|
"total_downloads": self._total_downloads,
|
|
"queued_downloads": len(self._download_queue),
|
|
"elapsed_time": time.time() - self._start_time if self._start_time else 0,
|
|
}
|
|
|
|
|
|
def create_parallel_downloader(
|
|
yt_dlp_path: str,
|
|
config: Union[AppConfig, Dict[str, Any]],
|
|
downloads_dir: Path,
|
|
max_workers: int = 3,
|
|
songlist_tracking: Optional[Dict] = None,
|
|
tracker: Optional[TrackingManager] = None,
|
|
) -> ParallelDownloader:
|
|
"""
|
|
Factory function to create a parallel downloader instance.
|
|
|
|
Args:
|
|
yt_dlp_path: Path to yt-dlp executable
|
|
config: Configuration object or dictionary
|
|
downloads_dir: Base downloads directory
|
|
max_workers: Maximum number of concurrent downloads
|
|
songlist_tracking: Optional songlist tracking data
|
|
tracker: Optional tracking manager
|
|
|
|
Returns:
|
|
ParallelDownloader instance
|
|
"""
|
|
return ParallelDownloader(
|
|
yt_dlp_path=yt_dlp_path,
|
|
config=config,
|
|
downloads_dir=downloads_dir,
|
|
max_workers=max_workers,
|
|
songlist_tracking=songlist_tracking,
|
|
tracker=tracker,
|
|
) |