KaraokeVideoDownloader/karaoke_downloader/parallel_downloader.py

376 lines
13 KiB
Python

"""
Parallel download management for concurrent video downloads.
Handles thread-safe operations, progress tracking, and error handling.
"""
import concurrent.futures
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass, field
from karaoke_downloader.config_manager import AppConfig
from karaoke_downloader.download_pipeline import DownloadPipeline
from karaoke_downloader.tracking_manager import TrackingManager
@dataclass
class DownloadTask:
"""Represents a single download task with all necessary information."""
video_id: str
artist: str
title: str
channel_name: str
video_title: Optional[str] = None
priority: int = 0 # Higher number = higher priority
retry_count: int = 0
max_retries: int = 3
created_at: float = field(default_factory=time.time)
def __post_init__(self):
if self.created_at == 0:
self.created_at = time.time()
@dataclass
class DownloadResult:
"""Result of a download operation."""
task: DownloadTask
success: bool
error_message: Optional[str] = None
file_path: Optional[Path] = None
download_time: float = 0.0
file_size: Optional[int] = None
class ParallelDownloader:
"""
Manages parallel downloads with thread-safe operations and progress tracking.
"""
def __init__(
self,
yt_dlp_path: str,
config: Union[AppConfig, Dict[str, Any]],
downloads_dir: Path,
max_workers: int = 3,
songlist_tracking: Optional[Dict] = None,
tracker: Optional[TrackingManager] = None,
):
"""
Initialize the parallel downloader.
Args:
yt_dlp_path: Path to yt-dlp executable
config: Configuration object or dictionary
downloads_dir: Base downloads directory
max_workers: Maximum number of concurrent downloads
songlist_tracking: Optional songlist tracking data
tracker: Optional tracking manager
"""
self.yt_dlp_path = yt_dlp_path
self.config = config
self.downloads_dir = downloads_dir
self.max_workers = max_workers
self.songlist_tracking = songlist_tracking or {}
self.tracker = tracker
# Thread-safe state management
self._lock = threading.Lock()
self._active_downloads = 0
self._completed_downloads = 0
self._failed_downloads = 0
self._total_downloads = 0
self._start_time = None
# Progress tracking
self._progress_callbacks = []
self._download_queue = []
self._results = []
# Create download pipeline
self.pipeline = DownloadPipeline(
yt_dlp_path=yt_dlp_path,
config=config,
downloads_dir=downloads_dir,
songlist_tracking=songlist_tracking,
tracker=tracker,
)
def add_progress_callback(self, callback) -> None:
"""Add a progress callback function."""
with self._lock:
self._progress_callbacks.append(callback)
def _notify_progress(self, message: str, **kwargs) -> None:
"""Notify all progress callbacks."""
with self._lock:
for callback in self._progress_callbacks:
try:
callback(message, **kwargs)
except Exception as e:
print(f"⚠️ Progress callback error: {e}")
def add_download_task(self, task: DownloadTask) -> None:
"""Add a download task to the queue."""
with self._lock:
self._download_queue.append(task)
self._total_downloads += 1
def add_download_tasks(self, tasks: List[DownloadTask]) -> None:
"""Add multiple download tasks to the queue."""
with self._lock:
self._download_queue.extend(tasks)
self._total_downloads += len(tasks)
def _download_single_task(self, task: DownloadTask) -> DownloadResult:
"""Execute a single download task."""
start_time = time.time()
try:
with self._lock:
self._active_downloads += 1
self._notify_progress(
"Starting download",
task=task,
active_downloads=self._active_downloads,
total_downloads=self._total_downloads
)
# Execute the download pipeline
success = self.pipeline.execute_pipeline(
video_id=task.video_id,
artist=task.artist,
title=task.title,
channel_name=task.channel_name,
video_title=task.video_title,
)
download_time = time.time() - start_time
# Determine file path if successful
file_path = None
file_size = None
if success:
filename = f"{task.artist} - {task.title}.mp4"
file_path = self.downloads_dir / task.channel_name / filename
if file_path.exists():
file_size = file_path.stat().st_size
result = DownloadResult(
task=task,
success=success,
file_path=file_path,
download_time=download_time,
file_size=file_size,
)
with self._lock:
if success:
self._completed_downloads += 1
else:
self._failed_downloads += 1
self._active_downloads -= 1
self._notify_progress(
"Download completed" if success else "Download failed",
result=result,
active_downloads=self._active_downloads,
completed_downloads=self._completed_downloads,
failed_downloads=self._failed_downloads,
total_downloads=self._total_downloads
)
return result
except Exception as e:
download_time = time.time() - start_time
with self._lock:
self._failed_downloads += 1
self._active_downloads -= 1
result = DownloadResult(
task=task,
success=False,
error_message=str(e),
download_time=download_time,
)
self._notify_progress(
"Download error",
result=result,
active_downloads=self._active_downloads,
completed_downloads=self._completed_downloads,
failed_downloads=self._failed_downloads,
total_downloads=self._total_downloads
)
return result
def _retry_failed_downloads(self, failed_results: List[DownloadResult]) -> List[DownloadResult]:
"""Retry failed downloads up to their max retry count."""
retry_tasks = []
for result in failed_results:
if result.task.retry_count < result.task.max_retries:
result.task.retry_count += 1
retry_tasks.append(result.task)
if not retry_tasks:
return []
print(f"🔄 Retrying {len(retry_tasks)} failed downloads...")
# Execute retries with reduced concurrency to avoid overwhelming the system
retry_workers = max(1, self.max_workers // 2)
with concurrent.futures.ThreadPoolExecutor(max_workers=retry_workers) as executor:
future_to_task = {
executor.submit(self._download_single_task, task): task
for task in retry_tasks
}
retry_results = []
for future in concurrent.futures.as_completed(future_to_task):
result = future.result()
retry_results.append(result)
return retry_results
def execute_downloads(self, show_progress: bool = True) -> List[DownloadResult]:
"""
Execute all queued downloads in parallel.
Args:
show_progress: Whether to show progress information
Returns:
List of download results
"""
if not self._download_queue:
print("📭 No downloads queued.")
return []
# Sort tasks by priority (higher priority first)
with self._lock:
self._download_queue.sort(key=lambda x: x.priority, reverse=True)
tasks = self._download_queue.copy()
self._download_queue.clear()
self._start_time = time.time()
self._results = []
print(f"🚀 Starting parallel downloads with {self.max_workers} workers...")
print(f"📋 Total tasks: {len(tasks)}")
# Progress display thread
progress_thread = None
if show_progress:
progress_thread = threading.Thread(
target=self._progress_display_loop,
daemon=True
)
progress_thread.start()
try:
# Execute downloads in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_task = {
executor.submit(self._download_single_task, task): task
for task in tasks
}
for future in concurrent.futures.as_completed(future_to_task):
result = future.result()
self._results.append(result)
# Handle retries for failed downloads
failed_results = [r for r in self._results if not r.success]
if failed_results:
retry_results = self._retry_failed_downloads(failed_results)
self._results.extend(retry_results)
finally:
# Stop progress display
if progress_thread and progress_thread.is_alive():
self._stop_progress = True
progress_thread.join(timeout=1)
# Final summary
total_time = time.time() - self._start_time
successful = len([r for r in self._results if r.success])
failed = len([r for r in self._results if not r.success])
print(f"\n🎉 Parallel downloads completed!")
print(f" ✅ Successful: {successful}")
print(f" ❌ Failed: {failed}")
print(f" ⏱️ Total time: {total_time:.1f}s")
print(f" 📊 Average time per download: {total_time/len(tasks):.1f}s")
return self._results
def _progress_display_loop(self) -> None:
"""Display progress updates in a separate thread."""
self._stop_progress = False
while not self._stop_progress:
with self._lock:
active = self._active_downloads
completed = self._completed_downloads
failed = self._failed_downloads
total = self._total_downloads
if total > 0:
progress = (completed + failed) / total * 100
print(f"\r📊 Progress: {progress:.1f}% | Active: {active} | Completed: {completed} | Failed: {failed} | Total: {total}", end="", flush=True)
time.sleep(1)
print() # New line after progress display
def get_stats(self) -> Dict[str, Any]:
"""Get current download statistics."""
with self._lock:
return {
"active_downloads": self._active_downloads,
"completed_downloads": self._completed_downloads,
"failed_downloads": self._failed_downloads,
"total_downloads": self._total_downloads,
"queued_downloads": len(self._download_queue),
"elapsed_time": time.time() - self._start_time if self._start_time else 0,
}
def create_parallel_downloader(
yt_dlp_path: str,
config: Union[AppConfig, Dict[str, Any]],
downloads_dir: Path,
max_workers: int = 3,
songlist_tracking: Optional[Dict] = None,
tracker: Optional[TrackingManager] = None,
) -> ParallelDownloader:
"""
Factory function to create a parallel downloader instance.
Args:
yt_dlp_path: Path to yt-dlp executable
config: Configuration object or dictionary
downloads_dir: Base downloads directory
max_workers: Maximum number of concurrent downloads
songlist_tracking: Optional songlist tracking data
tracker: Optional tracking manager
Returns:
ParallelDownloader instance
"""
return ParallelDownloader(
yt_dlp_path=yt_dlp_path,
config=config,
downloads_dir=downloads_dir,
max_workers=max_workers,
songlist_tracking=songlist_tracking,
tracker=tracker,
)