105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
import asyncio
|
|
import logging
|
|
import math
|
|
import secrets
|
|
from pathlib import Path
|
|
|
|
import aiohttp
|
|
|
|
_logger = logging.getLogger("worthless.helper")
|
|
|
|
|
|
async def download_file(
|
|
file_url: str,
|
|
file_name: str,
|
|
file_path: Path | str,
|
|
file_len: int = None,
|
|
overwrite: bool = False,
|
|
chunks: int = None,
|
|
threads_num: int = None,
|
|
) -> Path:
|
|
"""
|
|
Download file name to file_path.
|
|
|
|
You should implement your own download method instead of using this.
|
|
|
|
Args:
|
|
file_url: Url to download the file from
|
|
file_name: File name to download into
|
|
file_path: Path to download file into
|
|
file_len: File length to support threaded downloading
|
|
overwrite: Whether overwrite existing file or not
|
|
chunks: Chunks to write file into memory before writing to disk
|
|
threads_num: Number of download threads
|
|
Return:
|
|
Downloaded file as a Path object
|
|
"""
|
|
logger = _logger.getChild("download_file")
|
|
if not chunks:
|
|
chunks = 8192
|
|
if not threads_num:
|
|
threads_num = 8
|
|
logger.debug("Downloading chunks {} with {} thread".format(chunks, threads_num))
|
|
file_path = Path(file_path).joinpath(file_name)
|
|
|
|
async def _download(
|
|
session: aiohttp.ClientSession,
|
|
from_bytes: int,
|
|
to_bytes: int,
|
|
threaded: bool = None,
|
|
) -> Path:
|
|
headers = {"Range": f"bytes={from_bytes}-{to_bytes if to_bytes else ''}"}
|
|
if threaded:
|
|
p = file_path.parent.joinpath(secrets.token_urlsafe(16))
|
|
else:
|
|
p = file_path
|
|
p.touch(exist_ok=True)
|
|
rsp = await session.get(file_url, headers=headers, timeout=None)
|
|
if rsp.status == 416:
|
|
# Not an error, so yeah.
|
|
return p
|
|
rsp.raise_for_status()
|
|
with p.open("ab") as file:
|
|
async for chunk in rsp.content.iter_chunked(chunks):
|
|
await asyncio.to_thread(file.write, chunk)
|
|
|
|
return p
|
|
|
|
if overwrite:
|
|
file_path.unlink(missing_ok=True)
|
|
if file_path.exists():
|
|
cur_len = file_path.stat().st_size
|
|
else:
|
|
file_path.touch()
|
|
cur_len = 0
|
|
if not file_len or threads_num == 1:
|
|
async with aiohttp.ClientSession() as s:
|
|
return await _download(session=s, from_bytes=cur_len, to_bytes=file_len)
|
|
|
|
download_bytes = file_len - cur_len
|
|
# if bytes * threads is smaller than file_len then we will not get the full file.
|
|
download_bytes_t = math.ceil(download_bytes / threads_num)
|
|
download_jobs = []
|
|
current_bytes = cur_len
|
|
async with aiohttp.ClientSession() as s:
|
|
for thread in range(threads_num):
|
|
next_bytes = current_bytes + download_bytes_t
|
|
if next_bytes > file_len:
|
|
next_bytes = file_len
|
|
download_jobs.append(
|
|
_download(
|
|
session=s,
|
|
from_bytes=current_bytes,
|
|
to_bytes=next_bytes,
|
|
threaded=True,
|
|
)
|
|
)
|
|
# Move to next bytes
|
|
current_bytes = next_bytes
|
|
all_bytes = await asyncio.gather(*download_jobs)
|
|
# Merge bytes into the file
|
|
with file_path.open("ab") as f:
|
|
for bytes_path in all_bytes:
|
|
f.write(bytes_path.read_bytes())
|
|
bytes_path.unlink()
|