Files
lewm/AMD_SETUP.md
2026-05-14 03:52:50 +00:00

5.3 KiB
Raw Permalink Blame History

AMD ROCm 环境配置说明

这份文档记录了在 AMD ROCm 环境下运行 LeWM 的可复现配置,重点是保留 torch.compile 时的 PyTorch 版本选择。

目标运行命令:

python eval.py --config-name=pusht.yaml policy=pusht/lewm

已验证环境

本次验证通过的环境:

  • Ubuntu 24.04
  • AMD Radeon PRO W7900D (gfx1100)
  • 系统 ROCm 7.1.1
  • Python 3.10
  • torch==2.10.0+rocm7.1
  • torchvision==0.25.0+rocm7.1
  • triton-rocm==3.6.0

注意:torch==2.12.0+rocm7.1 可以正常导入,也能识别 GPU但在本项目里开启 torch.compile 后会崩溃,错误类似:

HSA_STATUS_ERROR_INVALID_PACKET_FORMAT
CUDA error: unspecified launch failure

降级到 torch==2.10.0+rocm7.1 后,torch.compile 路径可以正常跑通。

检查系统 ROCm

在新 AMD 机器上,先确认系统能识别 GPU

rocminfo
amd-smi version
hipcc --version

rocminfo 里应该能看到 AMD GPU agent例如 gfx1100

创建 Python 环境

使用 uv 创建 Python 3.10 虚拟环境:

cd /path/to/lewm
uv venv --python 3.10 --allow-existing .venv
source .venv/bin/activate

给 uv 创建的 venv 补上 pip。ROCm 版 PyTorch wheel 很大,如果 uv 解析或下载卡住, 用 pip 安装大 wheel 更容易观察进度。

uv pip install pip

安装 ROCm 版 PyTorch

安装本项目已验证可用的 ROCm PyTorch 组合:

python -m pip install --force-reinstall \
  --index-url https://download.pytorch.org/whl/rocm7.1 \
  --extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
  "torch==2.10.0" \
  "torchvision==0.25.0"

PyTorch wheel 有数 GB。如果网络慢不要频繁中断重试尽量等它下载完成。

安装项目依赖

普通 Python 包建议走国内 PyPI 镜像:

python -m pip install \
  --index-url https://pypi.tuna.tsinghua.edu.cn/simple \
  "gymnasium[all]==1.2.2" \
  "stable-baselines3==2.8.0" \
  "stable-worldmodel[train,env]"

然后修正两个容易被 pip 带偏的依赖版本:

python -m pip install \
  --index-url https://pypi.tuna.tsinghua.edu.cn/simple \
  "fsspec==2025.3.0" \
  "pillow==11.3.0"

检查环境:

python -m pip check
python - <<'PY'
import torch
import torchvision

print("torch:", torch.__version__)
print("hip:", torch.version.hip)
print("cuda available:", torch.cuda.is_available())
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
print("torchvision:", torchvision.__version__)
PY

期望看到类似输出:

torch: 2.10.0+rocm7.1
cuda available: True
torchvision: 0.25.0+rocm7.1

恢复本仓库里的 stable-worldmodel 修改

这个仓库把一些本地修改后的 stable_worldmodel 文件纳入了 git 管控,路径在:

.venv/lib/python3.10/site-packages/stable_worldmodel/

从 PyPI 安装 stable-worldmodel 时可能会覆盖这些文件。安装依赖后执行:

git restore -- \
  .venv/lib/python3.10/site-packages/stable_worldmodel/policy.py \
  .venv/lib/python3.10/site-packages/stable_worldmodel/world.py \
  .venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py \
  .venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py

然后确认没有意外修改:

git status --short

数据和 checkpoint 路径

eval.py 会从 $STABLEWM_HOME 里找数据和 checkpoint。

PushT 评估至少需要:

$STABLEWM_HOME/pusht_expert_train.h5
$STABLEWM_HOME/pusht/lewm_object.ckpt

例如本机使用:

export STABLEWM_HOME=/mnt/ASC1637/stablewm

如果没有正确设置,运行时会报找不到 pusht_expert_train.h5

运行评估

默认 PushT 评估,保留 torch.compile

export STABLEWM_HOME=/path/to/stablewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm

快速 smoke test

export STABLEWM_HOME=/path/to/stablewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
  eval.num_eval=1 \
  world.num_envs=1 \
  output.filename=/tmp/lewm_smoke_test.txt

smoke test 应该能正常结束,并打印类似:

{'success_rate': 100.0, ...}

常见问题

HSA_STATUS_ERROR_INVALID_PACKET_FORMAT

如果开启 torch.compile 时出现这个错误,先检查 torch 版本:

python -c "import torch; print(torch.__version__, torch.version.hip)"

如果是 2.12.0+rocm7.1,建议降级到本项目验证通过的组合:

python -m pip install --force-reinstall \
  --index-url https://download.pytorch.org/whl/rocm7.1 \
  --extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
  "torch==2.10.0" \
  "torchvision==0.25.0"

找不到 pusht_expert_train.h5

设置 STABLEWM_HOME 到包含数据和 checkpoint 的目录:

export STABLEWM_HOME=/path/to/stablewm

pip 尝试构建旧版 gym==0.21

这是依赖解析回退导致的。先显式安装兼容版本:

python -m pip install \
  --index-url https://pypi.tuna.tsinghua.edu.cn/simple \
  "gymnasium[all]==1.2.2" \
  "stable-baselines3==2.8.0"

uv 或 pip 访问海外源很慢

普通 Python 包使用国内 PyPI 镜像:

--index-url https://pypi.tuna.tsinghua.edu.cn/simple

PyTorch ROCm wheel 继续使用 PyTorch 官方 ROCm 源:

--index-url https://download.pytorch.org/whl/rocm7.1