Files
lewm/AMD_SETUP.md
2026-05-14 03:52:50 +00:00

242 lines
5.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# AMD ROCm 环境配置说明
这份文档记录了在 AMD ROCm 环境下运行 LeWM 的可复现配置,重点是保留
`torch.compile` 时的 PyTorch 版本选择。
目标运行命令:
```bash
python eval.py --config-name=pusht.yaml policy=pusht/lewm
```
## 已验证环境
本次验证通过的环境:
- Ubuntu 24.04
- AMD Radeon PRO W7900D (`gfx1100`)
- 系统 ROCm 7.1.1
- Python 3.10
- `torch==2.10.0+rocm7.1`
- `torchvision==0.25.0+rocm7.1`
- `triton-rocm==3.6.0`
注意:`torch==2.12.0+rocm7.1` 可以正常导入,也能识别 GPU但在本项目里开启
`torch.compile` 后会崩溃,错误类似:
```text
HSA_STATUS_ERROR_INVALID_PACKET_FORMAT
CUDA error: unspecified launch failure
```
降级到 `torch==2.10.0+rocm7.1` 后,`torch.compile` 路径可以正常跑通。
## 检查系统 ROCm
在新 AMD 机器上,先确认系统能识别 GPU
```bash
rocminfo
amd-smi version
hipcc --version
```
`rocminfo` 里应该能看到 AMD GPU agent例如 `gfx1100`
## 创建 Python 环境
使用 `uv` 创建 Python 3.10 虚拟环境:
```bash
cd /path/to/lewm
uv venv --python 3.10 --allow-existing .venv
source .venv/bin/activate
```
给 uv 创建的 venv 补上 pip。ROCm 版 PyTorch wheel 很大,如果 uv 解析或下载卡住,
用 pip 安装大 wheel 更容易观察进度。
```bash
uv pip install pip
```
## 安装 ROCm 版 PyTorch
安装本项目已验证可用的 ROCm PyTorch 组合:
```bash
python -m pip install --force-reinstall \
--index-url https://download.pytorch.org/whl/rocm7.1 \
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"torch==2.10.0" \
"torchvision==0.25.0"
```
PyTorch wheel 有数 GB。如果网络慢不要频繁中断重试尽量等它下载完成。
## 安装项目依赖
普通 Python 包建议走国内 PyPI 镜像:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"gymnasium[all]==1.2.2" \
"stable-baselines3==2.8.0" \
"stable-worldmodel[train,env]"
```
然后修正两个容易被 pip 带偏的依赖版本:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"fsspec==2025.3.0" \
"pillow==11.3.0"
```
检查环境:
```bash
python -m pip check
python - <<'PY'
import torch
import torchvision
print("torch:", torch.__version__)
print("hip:", torch.version.hip)
print("cuda available:", torch.cuda.is_available())
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
print("torchvision:", torchvision.__version__)
PY
```
期望看到类似输出:
```text
torch: 2.10.0+rocm7.1
cuda available: True
torchvision: 0.25.0+rocm7.1
```
## 恢复本仓库里的 stable-worldmodel 修改
这个仓库把一些本地修改后的 `stable_worldmodel` 文件纳入了 git 管控,路径在:
```text
.venv/lib/python3.10/site-packages/stable_worldmodel/
```
从 PyPI 安装 `stable-worldmodel` 时可能会覆盖这些文件。安装依赖后执行:
```bash
git restore -- \
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
```
然后确认没有意外修改:
```bash
git status --short
```
## 数据和 checkpoint 路径
`eval.py` 会从 `$STABLEWM_HOME` 里找数据和 checkpoint。
PushT 评估至少需要:
```text
$STABLEWM_HOME/pusht_expert_train.h5
$STABLEWM_HOME/pusht/lewm_object.ckpt
```
例如本机使用:
```bash
export STABLEWM_HOME=/mnt/ASC1637/stablewm
```
如果没有正确设置,运行时会报找不到 `pusht_expert_train.h5`
## 运行评估
默认 PushT 评估,保留 `torch.compile`
```bash
export STABLEWM_HOME=/path/to/stablewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm
```
快速 smoke test
```bash
export STABLEWM_HOME=/path/to/stablewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
eval.num_eval=1 \
world.num_envs=1 \
output.filename=/tmp/lewm_smoke_test.txt
```
smoke test 应该能正常结束,并打印类似:
```text
{'success_rate': 100.0, ...}
```
## 常见问题
### `HSA_STATUS_ERROR_INVALID_PACKET_FORMAT`
如果开启 `torch.compile` 时出现这个错误,先检查 torch 版本:
```bash
python -c "import torch; print(torch.__version__, torch.version.hip)"
```
如果是 `2.12.0+rocm7.1`,建议降级到本项目验证通过的组合:
```bash
python -m pip install --force-reinstall \
--index-url https://download.pytorch.org/whl/rocm7.1 \
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"torch==2.10.0" \
"torchvision==0.25.0"
```
### 找不到 `pusht_expert_train.h5`
设置 `STABLEWM_HOME` 到包含数据和 checkpoint 的目录:
```bash
export STABLEWM_HOME=/path/to/stablewm
```
### pip 尝试构建旧版 `gym==0.21`
这是依赖解析回退导致的。先显式安装兼容版本:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"gymnasium[all]==1.2.2" \
"stable-baselines3==2.8.0"
```
### uv 或 pip 访问海外源很慢
普通 Python 包使用国内 PyPI 镜像:
```bash
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
```
PyTorch ROCm wheel 继续使用 PyTorch 官方 ROCm 源:
```bash
--index-url https://download.pytorch.org/whl/rocm7.1
```