From 7f68a942d934c1e8dec8b1588e5e6a958c5ec4bb Mon Sep 17 00:00:00 2001
From: james58899 <james59988@gmail.com>
Date: Tue, 3 Jan 2023 06:16:29 +0000
Subject: [PATCH] Parallel download gallery

---
 src/gallery_downloader.rs | 122 +++++++++++++++++++++-----------------
 src/rpc.rs                |   4 +-
 2 files changed, 70 insertions(+), 56 deletions(-)

diff --git a/src/gallery_downloader.rs b/src/gallery_downloader.rs
index 5e96d2c..2fe2524 100644
--- a/src/gallery_downloader.rs
+++ b/src/gallery_downloader.rs
@@ -9,11 +9,13 @@ use futures::StreamExt;
 use hex::FromHex;
 use log::{debug, error, info, warn};
 use openssl::sha::Sha1;
+use parking_lot::Mutex;
 use regex::Regex;
 use reqwest::Url;
 use tokio::{
     fs::{self, create_dir_all},
     io::{AsyncReadExt, AsyncWriteExt},
+    sync::Semaphore,
     time::{sleep, sleep_until, Instant},
 };
 
@@ -21,6 +23,8 @@ use crate::{error::Error, rpc::RPCClient, util};
 
 type BoxError = Box<dyn std::error::Error + Send + Sync>;
 
+static MAX_DOWNLOAD_TASK: u32 = 4;
+
 pub struct GalleryDownloader {
     client: Arc<RPCClient>,
     reqwest: reqwest::Client,
@@ -44,8 +48,8 @@ impl GalleryDownloader {
                 break;
             }
 
-            let mut meta = match meta {
-                Ok(meta) => meta,
+            let meta = match meta {
+                Ok(meta) => Arc::new(meta),
                 Err(e) => {
                     warn!("Failed to parse metadata for new gallery. {}", e);
 
@@ -74,20 +78,22 @@ impl GalleryDownloader {
                 }
             }
 
-            let mut downloaded_files = HashSet::new();
+            let downloaded_files = Arc::new(Mutex::new(HashSet::new()));
             'retry: for retry in 0..10 {
+                let semaphore = Arc::new(Semaphore::new(MAX_DOWNLOAD_TASK as usize));
                 for info in &meta.gallery_files {
+                    let info = info.clone();
                     if !self.client.is_running() {
                         break 'task;
                     }
 
                     // Check if file already downloaded
-                    if downloaded_files.contains(info) {
+                    if downloaded_files.lock().contains(&info) {
                         continue;
                     }
                     let path = dir.join(format!("{}.{}", info.filename, info.filetype));
                     if info.check_hash(&path).await {
-                        downloaded_files.insert(info);
+                        downloaded_files.lock().insert(info);
                         continue;
                     }
 
@@ -101,24 +107,31 @@ impl GalleryDownloader {
                     {
                         Some(url) => {
                             start_time = Instant::now();
-                            match self.download(url.clone(), &path, info.expected_sha1_hash).await {
-                                Ok(_) => {
-                                    info!(
-                                        "Finished downloading gid={} page={}: {}.{}",
-                                        meta.gid, info.page, info.filename, info.filetype
-                                    );
-                                    downloaded_files.insert(info);
-                                }
-                                Err(err) => {
-                                    warn!("Gallery file download error: {}", err);
-
-                                    if err.is::<reqwest::Error>() || err.is::<Error>() {
-                                        if let Some(host) = url.host_str() {
-                                            meta.failures.push(format!("{}-{}-{}", host, info.fileindex, info.xres))
+                            let permit = semaphore.clone().acquire_owned().await.expect("Semaphore closed");
+                            let meta = meta.clone();
+                            let downloaded_files = downloaded_files.clone();
+                            let reqwest = self.reqwest.clone();
+                            tokio::spawn(async move {
+                                for retry in 0..3 {
+                                    if let Err(err) = download(reqwest.clone(), url.clone(), &path, info.expected_sha1_hash).await {
+                                        warn!("Gallery file download error: {}", err);
+
+                                        if retry == 2 && (err.is::<reqwest::Error>() || err.is::<Error>()) {
+                                            if let Some(host) = url.host_str() {
+                                                meta.failures.lock().push(format!("{}-{}-{}", host, info.fileindex, info.xres))
+                                            }
                                         }
+                                    } else {
+                                        info!(
+                                            "Finished downloading gid={} page={}: {}.{}",
+                                            meta.gid, info.page, info.filename, info.filetype
+                                        );
+                                        downloaded_files.lock().insert(info);
+                                        break;
                                     }
                                 }
-                            }
+                                drop(permit);
+                            });
                         }
                         None => {
                             warn!(
@@ -128,27 +141,29 @@ impl GalleryDownloader {
                         }
                     };
 
-                    // Wait 1s before next download, or 5s if download not success
-                    sleep_until(start_time + Duration::from_secs(if downloaded_files.contains(info) { 1 } else { 5 })).await;
+                    sleep_until(start_time + Duration::from_secs(1)).await;
                 }
 
-                if downloaded_files.len() == meta.filecount {
+                drop(semaphore.acquire_many(MAX_DOWNLOAD_TASK).await); // Wait all task done
+
+                if downloaded_files.lock().len() == meta.filecount {
                     info!("Finished download of gallery: {}", meta.title);
 
                     if let Err(e) = fs::write(&dir.join("galleryinfo.txt"), &meta.information).await {
                         error!("Could not write galleryinfo.txt: {}", e);
                     }
 
-                    self.client.dl_fails(meta.failures.iter().collect()).await;
+                    let failures = meta.failures.lock().clone();
+                    self.client.dl_fails(&failures).await;
                     break 'retry;
                 }
             }
 
-            if downloaded_files.len() != meta.filecount {
+            if downloaded_files.lock().len() != meta.filecount {
                 warn!("Permanently failed downloading gallery: {}", meta.title);
             }
 
-            task = self.client.fetch_queue(Some(meta)).await.map(GalleryDownloader::parser);
+            task = self.client.fetch_queue(Some(&meta)).await.map(GalleryDownloader::parser);
         }
     }
 
@@ -161,7 +176,7 @@ impl GalleryDownloader {
         let mut xres_title = String::new();
         let mut title = String::new();
         let mut information = String::new();
-        let mut gallery_files: Vec<GalleryFile> = Vec::new();
+        let mut gallery_files: Vec<Arc<GalleryFile>> = Vec::new();
 
         let mut parse_state = 0;
         for s in raw_gallery {
@@ -233,14 +248,14 @@ impl GalleryDownloader {
                 let filetype = split[4].to_string();
                 let filename = split[5].to_string();
 
-                gallery_files.push(GalleryFile {
+                gallery_files.push(Arc::new(GalleryFile {
                     page,
                     fileindex,
                     xres,
                     expected_sha1_hash: sha1hash,
                     filetype,
                     filename,
-                })
+                }))
             } else {
                 // Gallery info
                 information += &(s + "\n");
@@ -255,35 +270,34 @@ impl GalleryDownloader {
             title,
             information,
             gallery_files,
-            failures: vec![],
+            failures: Mutex::new(vec![]),
         })
     }
+}
 
-    async fn download<P: AsRef<Path>>(&self, url: Url, path: P, hash: Option<[u8; 20]>) -> Result<(), BoxError> {
-        let mut file = fs::File::create(&path).await?;
-        let mut stream = self
-            .reqwest
-            .get(url)
-            .send()
-            .await
-            .and_then(|r| r.error_for_status())
-            .map(|r| r.bytes_stream())?;
-        let mut hasher = Sha1::new();
-        while let Some(bytes) = stream.next().await {
-            let bytes = &bytes?;
-            file.write_all(bytes).await?;
-            hasher.update(bytes);
-        }
+async fn download<P: AsRef<Path>>(reqwest: reqwest::Client, url: Url, path: P, hash: Option<[u8; 20]>) -> Result<(), BoxError> {
+    let mut file = fs::File::create(&path).await?;
+    let mut stream = reqwest
+        .get(url)
+        .send()
+        .await
+        .and_then(|r| r.error_for_status())
+        .map(|r| r.bytes_stream())?;
+    let mut hasher = Sha1::new();
+    while let Some(bytes) = stream.next().await {
+        let bytes = &bytes?;
+        file.write_all(bytes).await?;
+        hasher.update(bytes);
+    }
 
-        if let Some(expected) = hash {
-            let hash = hasher.finish();
-            if hash != expected {
-                return Err(Box::new(Error::HashMismatch { expected, actual: hash }));
-            }
+    if let Some(expected) = hash {
+        let hash = hasher.finish();
+        if hash != expected {
+            return Err(Box::new(Error::HashMismatch { expected, actual: hash }));
         }
-
-        Ok(())
     }
+
+    Ok(())
 }
 
 #[derive(Hash, Eq, PartialEq)]
@@ -335,8 +349,8 @@ pub struct GalleryMeta {
     xres_title: String,
     title: String,
     information: String,
-    gallery_files: Vec<GalleryFile>,
-    failures: Vec<String>,
+    gallery_files: Vec<Arc<GalleryFile>>,
+    failures: Mutex<Vec<String>>,
 }
 
 impl GalleryMeta {
diff --git a/src/rpc.rs b/src/rpc.rs
index cd50727..d9462e8 100644
--- a/src/rpc.rs
+++ b/src/rpc.rs
@@ -273,7 +273,7 @@ The program will now terminate.
         None
     }
 
-    pub async fn dl_fails<T: AsRef<str>>(&self, failures: Vec<T>) {
+    pub async fn dl_fails<T: AsRef<str>>(&self, failures: &Vec<T>) {
         if failures.is_empty() {
             return;
         }
@@ -295,7 +295,7 @@ The program will now terminate.
         );
     }
 
-    pub async fn fetch_queue(&self, gallery: Option<GalleryMeta>) -> Option<Vec<String>> {
+    pub async fn fetch_queue(&self, gallery: Option<&GalleryMeta>) -> Option<Vec<String>> {
         let additional = &gallery.map(|s| format!("{};{}", s.gid(), s.minxres())).unwrap_or_default();
         let url = self.build_url("fetchqueue", additional, Some("dl"));
         if let Ok(res) = self.send_request(url).await {
-- 
GitLab