mirror of
https://github.com/tuna/tunasync-scripts.git
synced 2025-12-25 16:32:47 +00:00
github-release: rewrite threading with ThreadPoolExecutor [ci skip]
Signed-off-by: Shengqi Chen <harry-chen@outlook.com>
This commit is contained in:
parent
894118f986
commit
c0a0cd617e
|
|
@ -1,10 +1,9 @@
|
|||
#!/usr/bin/env python3
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import tempfile
|
||||
import threading
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
|
@ -25,13 +24,13 @@ BASE_URL = os.getenv("TUNASYNC_UPSTREAM_URL", "https://api.github.com/repos/")
|
|||
WORKING_DIR = os.getenv("TUNASYNC_WORKING_DIR")
|
||||
CONFIG = os.getenv("GITHUB_RELEASE_CONFIG", "github-release.json")
|
||||
REPOS = []
|
||||
UA = "tuna-github-release-mirror/0.0 (+https://github.com/tuna/tunasync-scripts)"
|
||||
|
||||
# connect and read timeout value
|
||||
TIMEOUT_OPTION = (7, 10)
|
||||
total_size = 0
|
||||
|
||||
|
||||
def sizeof_fmt(num, suffix="iB"):
|
||||
def sizeof_fmt(num: float, suffix: str = "iB") -> str:
|
||||
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
|
||||
if abs(num) < 1024.0:
|
||||
return "%3.2f%s%s" % (num, unit, suffix)
|
||||
|
|
@ -40,15 +39,18 @@ def sizeof_fmt(num, suffix="iB"):
|
|||
|
||||
|
||||
# wrap around requests.get to use token if available
|
||||
def github_get(*args, **kwargs):
|
||||
def github_get(*args, **kwargs) -> requests.Response:
|
||||
headers = kwargs["headers"] if "headers" in kwargs else {}
|
||||
if "GITHUB_TOKEN" in os.environ:
|
||||
headers["Authorization"] = "token {}".format(os.environ["GITHUB_TOKEN"])
|
||||
headers["User-Agent"] = UA
|
||||
kwargs["headers"] = headers
|
||||
return requests.get(*args, **kwargs)
|
||||
|
||||
|
||||
def do_download(remote_url: str, dst_file: Path, remote_ts: float, remote_size: int):
|
||||
def do_download(
|
||||
remote_url: str, dst_file: Path, remote_ts: float, remote_size: int
|
||||
) -> None:
|
||||
# NOTE the stream=True parameter below
|
||||
with github_get(remote_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
|
|
@ -80,34 +82,7 @@ def do_download(remote_url: str, dst_file: Path, remote_ts: float, remote_size:
|
|||
tmp_dst_file.unlink()
|
||||
|
||||
|
||||
def downloading_worker(q):
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
|
||||
url, dst_file, working_dir, updated, remote_size = item
|
||||
|
||||
logger.info(f"downloading {url} to {dst_file.relative_to(working_dir)}")
|
||||
try:
|
||||
do_download(url, dst_file, updated, remote_size)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {url}: {e}", exc_info=True)
|
||||
if dst_file.is_file():
|
||||
dst_file.unlink()
|
||||
|
||||
q.task_done()
|
||||
|
||||
|
||||
def create_workers(n):
|
||||
task_queue = queue.Queue()
|
||||
for i in range(n):
|
||||
t = threading.Thread(target=downloading_worker, args=(task_queue,))
|
||||
t.start()
|
||||
return task_queue
|
||||
|
||||
|
||||
def ensure_safe_name(filename):
|
||||
def ensure_safe_name(filename: str) -> str:
|
||||
filename = filename.replace("\0", " ")
|
||||
if filename == ".":
|
||||
return " ."
|
||||
|
|
@ -138,15 +113,22 @@ def main():
|
|||
raise Exception("Working Directory is None")
|
||||
|
||||
working_dir = Path(args.working_dir)
|
||||
task_queue = create_workers(args.workers)
|
||||
remote_filelist = []
|
||||
cleaning = False
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
REPOS = json.load(f)
|
||||
|
||||
def download(release, release_dir, tarball=False):
|
||||
global total_size
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=args.workers)
|
||||
futures = []
|
||||
|
||||
def process_release(
|
||||
release: dict,
|
||||
release_dir: Path,
|
||||
tarball: bool,
|
||||
) -> int:
|
||||
|
||||
release_size = 0
|
||||
|
||||
if tarball:
|
||||
url = release["tarball_url"]
|
||||
|
|
@ -161,7 +143,11 @@ def main():
|
|||
else:
|
||||
dst_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
# tarball has no size information, use -1 to skip size check
|
||||
task_queue.put((url, dst_file, working_dir, updated, -1))
|
||||
futures.append(
|
||||
executor.submit(
|
||||
download_file, url, dst_file, working_dir, updated, -1
|
||||
)
|
||||
)
|
||||
|
||||
for asset in release["assets"]:
|
||||
url = asset["browser_download_url"]
|
||||
|
|
@ -171,7 +157,7 @@ def main():
|
|||
dst_file = release_dir / ensure_safe_name(asset["name"])
|
||||
remote_filelist.append(dst_file.relative_to(working_dir))
|
||||
remote_size = asset["size"]
|
||||
total_size += remote_size
|
||||
release_size += remote_size
|
||||
|
||||
if dst_file.is_file():
|
||||
if args.fast_skip:
|
||||
|
|
@ -191,9 +177,26 @@ def main():
|
|||
else:
|
||||
dst_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
task_queue.put((url, dst_file, working_dir, updated, remote_size))
|
||||
futures.append(
|
||||
executor.submit(
|
||||
download_file, url, dst_file, working_dir, updated, remote_size
|
||||
)
|
||||
)
|
||||
|
||||
def link_latest(name, repo_dir):
|
||||
return release_size
|
||||
|
||||
def download_file(
|
||||
url: str, dst_file: Path, working_dir: Path, updated: float, remote_size: int
|
||||
) -> None:
|
||||
logger.info(f"downloading {url} to {dst_file.relative_to(working_dir)}")
|
||||
try:
|
||||
do_download(url, dst_file, updated, remote_size)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {url}: {e}", exc_info=True)
|
||||
if dst_file.is_file():
|
||||
dst_file.unlink()
|
||||
|
||||
def link_latest(name: str, repo_dir: Path) -> None:
|
||||
try:
|
||||
os.unlink(repo_dir / "LatestRelease")
|
||||
except OSError:
|
||||
|
|
@ -203,6 +206,8 @@ def main():
|
|||
except OSError:
|
||||
pass
|
||||
|
||||
total_size = 0
|
||||
|
||||
for cfg in REPOS:
|
||||
flat = False # build a folder for each release
|
||||
versions = 1 # keep only one release
|
||||
|
|
@ -236,7 +241,7 @@ def main():
|
|||
releases = r.json()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error: cannot download metadata for {repo}: {e}\n{traceback.format_exc()}",
|
||||
f"Cannot download metadata for {repo}: {e}\n{traceback.format_exc()}",
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
|
@ -246,9 +251,13 @@ def main():
|
|||
if not release["draft"] and (prerelease or not release["prerelease"]):
|
||||
name = ensure_safe_name(release["name"] or release["tag_name"])
|
||||
if len(name) == 0:
|
||||
logger.error("Error: Unnamed release")
|
||||
logger.error("Unnamed release")
|
||||
continue
|
||||
download(release, (repo_dir if flat else repo_dir / name), tarball)
|
||||
total_size += process_release(
|
||||
release,
|
||||
(repo_dir if flat else repo_dir / name),
|
||||
tarball,
|
||||
)
|
||||
if n_downloaded == 0 and not flat:
|
||||
# create a symbolic link to the latest release folder
|
||||
link_latest(name, repo_dir)
|
||||
|
|
@ -256,20 +265,18 @@ def main():
|
|||
if versions > 0 and n_downloaded >= versions:
|
||||
break
|
||||
if n_downloaded == 0:
|
||||
logger.error(f"Error: No release version found for {repo}")
|
||||
logger.error(f"No release version found for {repo}")
|
||||
continue
|
||||
else:
|
||||
cleaning = True
|
||||
|
||||
# block until all tasks are done
|
||||
task_queue.join()
|
||||
# stop workers
|
||||
for i in range(args.workers):
|
||||
task_queue.put(None)
|
||||
# 等待所有下载任务完成
|
||||
concurrent.futures.wait(futures)
|
||||
executor.shutdown()
|
||||
|
||||
# XXX: this does not work because `cleaning` is always False when `REPO`` is not empty
|
||||
if cleaning:
|
||||
local_filelist = []
|
||||
local_filelist: list[Path] = []
|
||||
for local_file in working_dir.glob("**/*"):
|
||||
if local_file.is_file():
|
||||
local_filelist.append(local_file.relative_to(working_dir))
|
||||
|
|
|
|||
Loading…
Reference in New Issue