用 aria2 加速 pip —— 记一次对 pip 的 monkey patch
xhz123456789 · · 科技·工程
用 aria2 加速 pip —— 记一次对 pip 的 monkey patch
0. 背景
蒟蒻在使用 pip 安装 torch 时,被 20h+ 的预计下载时间搞到破防。
蒟蒻对 aria2 的下载速度感到十分震撼。
要是能用 aria2 代替 pip 的下载过程……
(考虑到篇幅,文中代码有省略,完整代码请于文末查看,也可访问云剪贴板,或Gist)。
1.一个大胆的思路
我们先要弄清楚 pip 下载文件时的逻辑。观察 pip 的文件发现在 pip/_internal/network/download.py 中:
...
class Downloader:
def __init__(
self,
session: PipSession,
progress_bar: str,
resume_retries: int,
) -> None:
...
def __call__(self, link: Link, location: str) -> Tuple[str, str]:
"""Download the file given by link into location."""
...
...
class BatchDownloader:
def __init__(
self,
session: PipSession,
progress_bar: str,
resume_retries: int,
) -> None:
...
def __call__(
self, links: Iterable[Link], location: str
) -> Iterable[Tuple[Link, Tuple[str, str]]]:
"""Download the files given by links into location."""
...
...
...
正是 pip 的下载器。
于是我们便有一个简单的思路:在 pip 运行前把它们换掉。
2.第一次尝试
我们先搓一个“假的” Downloader:
...
class Aria2Downloader:
def __init__(self,session,progress_bar: str,resume_retries: int|None=None):
...
def __call__(self, link, location)->tuple[str,str]:
...
...
然后在 pip 启动前换掉 pip 的下载器:
from pip._internal.network import download as pip_download
pip_download.Downloader = Aria2Downloader
最后调用 pip:
from pip._internal.cli.main import main as pip_main
sys.exit(pip_main())
现在,激动人心的时刻到了,让我们运行它:
python aria2_pip.py install transformers
然后,屏幕上弹出了冷冰冰的报错:
ImportError: cannot import name 'get_install_progress_renderer' from partially initialized module 'pip._internal.cli.progress_bars' (most likely due to a circular import)
3.重新思考
在与 deepseek 深入交流后,我们得知这是一个“循环引用”的问题,因为我们直接导入 pip 内部模块,容易引发这类问题,并且建议我们避免直接导入 pip 内部模块
同样的模块 pip 用的好好的啊?!! (崩溃
凭什么它就不会循环引用?!! (崩溃
等等,pip 导入的时候不会循环引用!
不如我们看看 pip 怎么导入的。
# pip/__main__.py
...
if __name__ == "__main__":
from pip._internal.cli.main import main as _main
sys.exit(_main())
# pip/_internal/cli/main.py
...
from pip._internal.commands import create_command
...
def main(args: Optional[List[str]] = None) -> int:
...
command = create_command(cmd_name, isolated=("--isolated" in cmd_args))
return command.main(cmd_args)
# pip/_internal/commands/__init__.py
...
def create_command(name: str, **kwargs: Any) -> Command:
"""
Create an instance of the Command class with the given name.
"""
module_path, class_name, summary = commands_dict[name]
module = importlib.import_module(module_path)
#这里使用importlib加载了真正运行的命令,然后命令对应的程序加载所需的模块
command_class = getattr(module, class_name)
command = command_class(name=name, summary=summary, **kwargs)
return command
...
显然,pip 内部加载的逻辑有些不同。如果我们先让 pip 加载完,再进行修改呢?
4.再次尝试
我们先 patch 掉 pip._internal.commands.create_command,然后在 pip 加载完命令后 patch 掉 Downloader:
def patch_downloader()->bool:
from pip._internal.network import download as pip_download
pip_download.Downloader = Aria2Downloader
def patch_create_command()->bool:
from pip._internal import commands as pip_commands
old_create_commands=pip_commands.create_command
def patched_create_commands(name: str, **kwargs: Any):
command=old_create_commands(name,**kwargs)
patch_downloader()
return command
pip_commands.create_command=patched_create_commands
return True
5.成功了?
现在,我们的代码真的跑起来了,但是……
不是哥们,怎么还有 pip 下载的输出啊???
分别在我们的下载器和原有下载器的 __init__ 中打个断点,它们竟然都被调用了一次。
它们像这样被调用:
RequirementPreparer.__init -> Downloader.__init__
RequirementPreparer.__init__ -> BatchDownloader.__init__ -> Aria2Downloader.__init__
让我们来看看 pip._internal.operations.prepare.RequirementPreparer 到底干了什么:
...
from pip._internal.network.download import BatchDownloader, Downloader
...
class RequirementPreparer:
"""Prepares a Requirement"""
def __init__(
...
) -> None:
...
self._download = Downloader(session, progress_bar, resume_retries)
self._batch_download = BatchDownloader(session, progress_bar, resume_retries)
...
...
...
这里特别感谢 @t7424fd 在这篇帖子中的帮助。
罪魁祸首在这一句:
from pip._internal.network.download import BatchDownloader, Downloader
关键在于使用 from xxx import xxx 导入函数,导入的是“导入时的那个函数”,而不是“调用时从那个文件中找那个函数”,在当前命名空间中,这个函数名已经指向原来的函数。
而我们 patch 时,实际上将模块的命名空间中的函数名指向新的函数,所以 import 的函数名的指向并未改变,而模块中调用函数时,由于函数名指向已经被改变,所以会调用到新的函数。
这种行为更像是“复制”,而不是“引用”。
就像这个简单的例子:
# a.py
def fun1():
print("fun1")
def call_fun1():
fun1()#这里调用 a.py 中的 fun1
# b.py
from a import fun1,call_fun1
def fun2():
print("fun2")
def patch():
import a
a.fun1=fun2
#我们已经把 a.py 中的函数名指向了 fun2,但当前命名空间的 fun1 仍指向原来的函数
patch()
fun1() #这里无法成功替换
call_fun1() #这里成功执行了 fun2
所以,正确的办法应该是修改 pip._internal.operations.prepare.RequirementPreparer 中,已经导入的 BatchDownloader 和 Downloader,而不是修改 pip._internal.network.download 中的。
现在我们的代码终于能正常跑起来了。
完整代码
"""
@author: xu-haozhe
@date: 2025/07/28
aria2-pip
用 aria2 加速 pip 下载!!
环境:
ubuntu 24.04
python 3.12.3
pip 24.0 | 25.1.1
requests 2.32.4
"""
import requests
import os
import logging
import time
import importlib
import sys
from typing import Dict, Any, TYPE_CHECKING, Iterable, Tuple
# --- 配置区 ---
PATCH_COMMANDS=['install']
ARIA2_RPC_URL = os.environ.get("ARIA2_RPC_URL", "http://localhost:16800/jsonrpc")
ARIA2_RPC_SECRET = os.environ.get("ARIA2_RPC_SECRET", "")
# --- 配置区结束 ---
logging.basicConfig(level=logging.INFO, format='[aria2-pip %(levelname)s] %(message)s')
logger=logging.getLogger()
if sys.stdout.isatty():
COLOR_GREEN = '\033[92m'
COLOR_RED = '\033[91m'
COLOR_BLUE = '\033[94m'
COLOR_DIM = '\033[2m'
COLOR_RESET = '\033[0m'
def _format_size(size_in_bytes: float) -> str:
if size_in_bytes < 1024: return f"{int(size_in_bytes)} B"
if size_in_bytes < 1024**2: return f"{size_in_bytes/1024:.1f} KB"
if size_in_bytes < 1024**3: return f"{size_in_bytes/(1024**2):.1f} MB"
return f"{size_in_bytes/(1024**3):.1f} GB"
def _format_eta(seconds: float|None) -> str:
if seconds is None: return "--:--:--"
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
return f"{int(h):02d}:{int(m):02d}:{int(s):02d}"
def _progress_bar(status)->str:
total_length = int(status.get("totalLength", 0))
completed_length = int(status.get("completedLength", 0))
download_speed = int(status.get("downloadSpeed", 0))
eta_seconds = (total_length - completed_length) / download_speed if download_speed > 0 else None
bar_width=40
filled_len=int(40*completed_length/total_length) if total_length > 0 else 40
bar=(f"{COLOR_GREEN}{'━' * filled_len}{COLOR_RESET}"
f"{COLOR_DIM}{'━' * (bar_width - filled_len)}{COLOR_RESET}")
size_str = f"{COLOR_GREEN}{_format_size(completed_length)}/{_format_size(total_length)}{COLOR_RESET}"
speed_str = f"{COLOR_RED}{_format_size(download_speed)}/s{COLOR_RESET}"
eta_str = f"{COLOR_BLUE}{_format_eta(eta_seconds)}{COLOR_RESET}"
return f" {bar} {size_str.rjust(18)} {speed_str.rjust(12)} eta {eta_str}"
def _show_progress_bar(meta_info:Iterable[Tuple[str,str,str]],status:Iterable[Dict])->None:
if len(meta_info)==0:return
if len(meta_info)!=len(status):
raise ValueError("meta_info and status must have the same length")
sys.stdout.write(f"\033[{len(meta_info)*2}A")
for (name,_,_),status in zip(meta_info,status):
sys.stdout.write(name+"\033[K\n")
sys.stdout.write(_progress_bar(status[0])+"\033[K\n")
pass
else:
COLOR_GREEN, COLOR_RED, COLOR_BLUE, COLOR_DIM, COLOR_RESET = "", "", "", "", ""
def _show_progress_bar(meta_info:Iterable[Tuple[str,str,str]],status:Iterable[Dict])->None:
pass
def _aria2_rpc_call(method: str, params: list) -> Dict[str, Any]:
if ARIA2_RPC_SECRET:
params.insert(0, f"token:{ARIA2_RPC_SECRET}")
payload = {
"jsonrpc": "2.0",
"id": f"fastpip-{method}",
"method": method,
"params": params,
}
response = requests.post(ARIA2_RPC_URL, json=payload, timeout=10)
response.raise_for_status()
return response.json()
def _aria2_rpc_multicall(method:str,params:list[list])->list[Any]:
payload = {
"jsonrpc": "2.0",
"id": f"fastpip-multicall-{method}",
"method": "system.multicall",
"params": [[
{
"methodName": method,
"params": [f"token:{ARIA2_RPC_SECRET}"]+param if ARIA2_RPC_SECRET else param,
}for param in params
]],
}
response = requests.post(ARIA2_RPC_URL, json=payload, timeout=10)
response.raise_for_status()
return response.json()['result']
def _check_aria2c_connection():
logger.info(f"正在检查 aria2c RPC 服务于 {ARIA2_RPC_URL}...")
try:
result=_aria2_rpc_call("aria2.getVersion", [])
if "error" in result:
logger.error(f"连接 aria2c RPC 失败: {result['error']['message']}")
return False
version = result.get("result", {}).get("version", "未知")
logger.info(f"成功连接到 aria2c (版本: {version})。")
return True
except requests.exceptions.RequestException as e:
logger.error(f"无法连接到 aria2c RPC 服务: {e}")
logger.error("请确保 aria2c 正在后台以 RPC 模式运行。")
return False
def _get_http_response_filename(resp, link):
return link.filename
def _wait_aria2(meta_infos:Iterable[Tuple[str,str,str]],gids:Iterable[str]):
for name,_,_ in meta_infos:
sys.stdout.write(name+"\033[K\n\033[k\n")
while True:
try:
result=_aria2_rpc_multicall("aria2.tellStatus",[
[gid] for gid in gids
])
for status in result:
if 'code' in status[0]:
raise RuntimeError(f"aria2 报告错误: {status[0]['message']}")
_show_progress_bar(meta_infos,result)
if all(status[0].get("status")=="complete" for status in result):
break
except requests.exceptions.RequestException as e:
logger.error(f"轮询 aria2 状态时出错: {e}")
raise
time.sleep(1)
logger.info("下载完成")
class Aria2Downloader:
from pip._internal.models.link import Link
def __init__(self,session,progress_bar: str,resume_retries: int|None=None):
self._session = session
self._progress_bar = progress_bar
def __call__(self, links:Iterable|Link, location)->Iterable[Tuple[Link, Tuple[str, str]]]|tuple[str,str]:
is_batch = not isinstance(links,self.Link)
if not is_batch: links=[links]
meta_info=[self._get_meta_info(link) for link in links]
try:
result=_aria2_rpc_multicall("aria2.addUri",[
[[final_url], {"dir": location, "out": filename}]
for filename, content_type, final_url in meta_info
])
for gid in result:
if 'code' in gid[0]:
raise RuntimeError(f"aria2 添加任务失败: {gid[0]['message']}")
gids=[gid[0] for gid in result]
except requests.exceptions.RequestException|RuntimeError as e:
logger.error(f"无法将下载任务添加到 aria2: {e}")
raise
_wait_aria2(meta_info,gids)
_aria2_rpc_multicall("aria2.removeDownloadResult", [
[gid] for gid in gids
])
if is_batch:
return [
(link,(os.path.join(location,filename),content_type))
for link,(filename, content_type, url) in zip(links,meta_info)
]
else:
filename,content_type,url=meta_info[0]
return (os.path.join(location,filename),content_type)
def _get_meta_info(self,link)->tuple[str,str,str]:
"""
return : name,content_type,url
"""
logger.info(f"正在为 {link.filename} 获取下载元信息...")
try:
head_response = self._session.head(link.url, allow_redirects=True, stream=True)
head_response.raise_for_status()
return (
_get_http_response_filename(head_response, link),
head_response.headers.get("Content-Type", ""),
head_response.url
)
except Exception as e:
logger.error(f"获取元信息失败: {e}")
raise e
def patch_downloader()->bool:
if not _check_aria2c_connection():
logger.warning("aria2c 服务不可用,将回退到 pip 默认下载器。")
return False
try:
from pip._internal.operations import prepare as pip_prepare
from pip._internal.network.download import _get_http_response_filename as pip_get_http_response_filename
except ImportError as e:
logger.error(f"无法从 pip 导入必要模块: {e}。脚本可能与 pip 版本不兼容。")
raise
return False
global _get_http_response_filename
_get_http_response_filename=pip_get_http_response_filename
pip_prepare.BatchDownloader=Aria2Downloader
pip_prepare.Downloader=Aria2Downloader
logger.info("Patch 成功!下载将由 aria2c 处理。")
def patch_create_command()->bool:
try:
from pip._internal import commands as pip_commands
except ImportError as e:
logger.error(f"无法从 pip 导入必要模块: {e}。脚本可能与 pip 版本不兼容。")
raise
old_create_commands=pip_commands.create_command
def patched_create_commands(name: str, **kwargs: Any):
command=old_create_commands(name,**kwargs)
if name.strip() in PATCH_COMMANDS:
logger.info(f"检测到命令 {name},尝试进行 Patch...")
patch_downloader()
return command
pip_commands.create_command=patched_create_commands
return True
def main()->None:
patch_create_command()
try:
from pip._internal.cli.main import main as pip_main
sys.exit(pip_main())
except Exception as e:
logger.error(f"pip 运行时出现未知错误: {e}")
sys.exit(1)
if __name__ == '__main__':
main()