修改训练代码
This commit is contained in:
277
ENDOSCOPE_ACT_ADAPTATION_PLAN.md
Normal file
277
ENDOSCOPE_ACT_ADAPTATION_PLAN.md
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
# ACT 仓库适配内镜机器人(2-DOF + 图像/qpos/Text 指令)修改清单(仅训练版)
|
||||||
|
|
||||||
|
## 1. 目标与约束
|
||||||
|
|
||||||
|
### 目标
|
||||||
|
将当前标准 ACT 仓库改造成可用于你的内镜机器人离线训练,支持:
|
||||||
|
- **动作维度仅 2**(2 个电机)
|
||||||
|
- **不依赖 Gym / 仿真环境**
|
||||||
|
- 输入为 **图像 + qpos + text instruction**
|
||||||
|
- 以离线数据训练为主(本阶段不包含真实机器人在线接口)
|
||||||
|
|
||||||
|
### 约束
|
||||||
|
当前代码默认是 ALOHA 双臂(14 维状态/动作)和 sim/real ALOHA 环境接口,且**没有 text 分支**,存在大量硬编码,需要系统性改造。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 现有代码中的关键硬编码(必须改)
|
||||||
|
|
||||||
|
1. **状态/动作维度硬编码为 14**
|
||||||
|
- `imitate_episodes.py` 中 `state_dim = 14`
|
||||||
|
- `detr/models/detr_vae.py` 中多个 `nn.Linear(14, ...)`
|
||||||
|
- `record_sim_episodes.py` 的数据写入 shape 固定 `(T, 14)`
|
||||||
|
|
||||||
|
2. **训练/评估流程绑定 sim 或 aloha_scripts real_env**
|
||||||
|
- `imitate_episodes.py` 的 `eval_bc()` 依赖 `sim_env.make_sim_env()` 或 `aloha_scripts.real_env`
|
||||||
|
|
||||||
|
3. **数据加载器默认字段是 qpos/qvel/action + images**
|
||||||
|
- `utils.py` 的 `EpisodicDataset` 仅加载 `qpos`、`action`、`images`,无 text
|
||||||
|
|
||||||
|
4. **模型只融合图像 + 状态,无文本编码**
|
||||||
|
- `policy.py` 与 `detr/models/detr_vae.py` 当前无 text 输入通道
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 必改模块清单(按文件)
|
||||||
|
|
||||||
|
## A. 配置与任务定义
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `constants.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
- 新增内镜任务配置(建议新字典 `ENDOSCOPE_TASK_CONFIGS`):
|
||||||
|
- `dataset_dir`
|
||||||
|
- `num_episodes`
|
||||||
|
- `episode_len`
|
||||||
|
- `camera_names`
|
||||||
|
- `state_dim=2`
|
||||||
|
- `action_dim=2`
|
||||||
|
- `use_text_instruction=True`
|
||||||
|
- `instruction_mode`(episode-level / timestep-level)
|
||||||
|
- `text_encoder_type="distilbert"`
|
||||||
|
- `text_feature_dim=768`
|
||||||
|
- `text_fusion_type="concat_transformer_input"`
|
||||||
|
- 避免继续依赖 `sim_` 前缀来判断任务类型。
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
把任务参数从 ALOHA 默认值中解耦,作为后续训练与模型构建的统一入口。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## B. 数据协议与数据集加载
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `utils.py`
|
||||||
|
- (新增)`dataset_tools/` 下的数据转换脚本
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
1. **数据协议统一定义(建议 HDF5)**
|
||||||
|
- `/observations/images/<cam_name>`: `(T, H, W, C)`
|
||||||
|
- `/observations/qpos`: `(T, 2)`
|
||||||
|
- `/action`: `(T, 2)`
|
||||||
|
- `/instruction`(字符串或 token)
|
||||||
|
- 可选:`/instruction_timestep`(若每步指令不同)
|
||||||
|
|
||||||
|
2. **重构 `EpisodicDataset`**
|
||||||
|
- 保持 `qpos` 命名,不做重命名
|
||||||
|
- 加载 text instruction
|
||||||
|
- 返回训练样本改为:
|
||||||
|
- `image_data`
|
||||||
|
- `qpos_data`
|
||||||
|
- `action_data`
|
||||||
|
- `is_pad`
|
||||||
|
- `text_input_ids`
|
||||||
|
- `text_attention_mask`
|
||||||
|
|
||||||
|
3. **归一化统计扩展**
|
||||||
|
- `get_norm_stats()` 支持 `qpos/action` 任意维度(本任务均为 2)
|
||||||
|
- text 采用在线 DistilBERT 编码(默认),可选缓存特征
|
||||||
|
|
||||||
|
4. **兼容性策略**
|
||||||
|
- 保持对旧字段 `qpos` 的直接兼容
|
||||||
|
- 支持多相机或单相机
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
构建与你真实数据一致的数据管线,彻底摆脱 14 维与仿真字段假设。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## C. 训练入口与流程控制
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `imitate_episodes.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
1. **配置读取改造**
|
||||||
|
- 使用新任务配置读取 `state_dim=2/action_dim=2/camera_names/use_text_instruction`
|
||||||
|
|
||||||
|
2. **移除/隔离仿真耦合逻辑(训练范围内)**
|
||||||
|
- `main()` 保留纯离线训练路径
|
||||||
|
- `eval_bc()` 仅保留离线评估路径
|
||||||
|
|
||||||
|
3. **前向输入变更**
|
||||||
|
- `forward_pass()` 改为支持 text
|
||||||
|
- dataloader batch 解包增加 text 分量
|
||||||
|
|
||||||
|
4. **命令行参数补充**
|
||||||
|
- `--task_config` 或 `--task_name` 对应新配置
|
||||||
|
- `--text_encoder_type`
|
||||||
|
- `--freeze_text_encoder`
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
让训练脚本成为“真实机器人离线模仿学习”的统一入口,而不是 sim demo 入口。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## D. 策略封装层(Policy API)
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `policy.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
1. `ACTPolicy.__call__()` 和 `CNNMLPPolicy.__call__()` 签名扩展:
|
||||||
|
- 现有:`(qpos, image, actions=None, is_pad=None)`
|
||||||
|
- 目标:`(qpos, image, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
|
||||||
|
|
||||||
|
2. 图像归一化保留,但要确保支持任意相机数量。
|
||||||
|
|
||||||
|
3. loss 计算保持一致,同时确保 text 缺失时可降级运行(便于 ablation)。
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
把 text 从数据层顺畅传递到模型层。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## E. 模型构建参数与入口
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `detr/main.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
- 新增模型参数:
|
||||||
|
- `state_dim`
|
||||||
|
- `action_dim`
|
||||||
|
- `use_text`
|
||||||
|
- `text_encoder_type="distilbert"`
|
||||||
|
- `text_feature_dim=768`
|
||||||
|
- `text_fusion_type="concat_transformer_input"`
|
||||||
|
- 删除或弱化与原脚本无关的占位参数。
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
将所有硬编码维度下放为可配置项,便于后续迭代。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## F. ACT 主干网络(核心改造)
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `detr/models/detr_vae.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
1. **去掉 14 维硬编码**
|
||||||
|
- `input_proj_robot_state = nn.Linear(state_dim, hidden_dim)`
|
||||||
|
- `encoder_action_proj = nn.Linear(action_dim, hidden_dim)`
|
||||||
|
- `encoder_joint_proj = nn.Linear(state_dim, hidden_dim)`
|
||||||
|
- `action_head = nn.Linear(hidden_dim, action_dim)`
|
||||||
|
|
||||||
|
2. **加入 text 分支**
|
||||||
|
- 使用 DistilBERT 输出特征(768 维)
|
||||||
|
- 新增 text 投影层:`nn.Linear(768, hidden_dim)`
|
||||||
|
- 融合策略固定为:**将 text token/特征作为额外 token,直接 concat 到 Transformer 输入序列**
|
||||||
|
|
||||||
|
3. **前向函数增加 text 输入**
|
||||||
|
- `forward(self, qpos, image, env_state, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
|
||||||
|
|
||||||
|
4. **保持训练/推理双模式一致**
|
||||||
|
- 训练:动作序列 + text 条件 VAE
|
||||||
|
- 推理:先验采样 + text 条件生成
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
把 ACT 从“图像+14 维关节”模型改造成“图像+2 维 qpos+文本条件”模型。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## G. 真实机器人接口与采集(暂不纳入本轮)
|
||||||
|
|
||||||
|
本轮仅做训练侧改造,以下内容延期:
|
||||||
|
- 在线推理接口
|
||||||
|
- 真实机器人数据采集脚本
|
||||||
|
- 在线安全控制与频率控制
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## H. 文档与脚本
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `README.md`
|
||||||
|
- (新增)`docs/endoscope_data_format.md`
|
||||||
|
- (新增)`docs/endoscope_train_eval.md`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
- 给出最小可运行流程:
|
||||||
|
1) 准备数据
|
||||||
|
2) 训练命令
|
||||||
|
3) 离线评估
|
||||||
|
- 明确 text instruction 的格式规范。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 建议新增文件(清单)
|
||||||
|
|
||||||
|
- `configs/endoscope_task.yaml`(或继续用 python dict)
|
||||||
|
- `dataset_tools/convert_endoscope_to_act_hdf5.py`
|
||||||
|
- `dataset_tools/validate_endoscope_dataset.py`
|
||||||
|
- `models/text_encoder.py`(DistilBERT 封装)
|
||||||
|
- `docs/endoscope_data_format.md`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 分阶段实施顺序(建议)
|
||||||
|
|
||||||
|
### Phase 1:先跑通“无 text”2-DOF 版本
|
||||||
|
- 改 `state_dim/action_dim`
|
||||||
|
- 跑通数据加载 + 训练 + 离线验证
|
||||||
|
|
||||||
|
### Phase 2:加入 text instruction
|
||||||
|
- 数据协议加入 instruction
|
||||||
|
- 接入 DistilBERT(768)
|
||||||
|
- 按 `concat_transformer_input` 完成 text 融合训练
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 验收标准(Definition of Done)
|
||||||
|
|
||||||
|
1. 可用你的 HDF5 数据直接训练,不依赖 sim/gym。
|
||||||
|
2. 模型输入同时支持图像、2D qpos、text instruction。
|
||||||
|
3. Text 编码器使用 DistilBERT,输出特征维度为 768。
|
||||||
|
4. Text 融合方式为 Transformer 输入级 concat。
|
||||||
|
5. README 有完整训练与离线评估命令示例,团队可复现。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 风险点与提前规避
|
||||||
|
|
||||||
|
1. **text 与动作时序对齐问题**
|
||||||
|
- 需明确 instruction 是 episode-level 还是 timestep-level。
|
||||||
|
|
||||||
|
2. **小维度控制下的动作抖动**
|
||||||
|
- 可在后处理中加入 low-pass / action smoothing。
|
||||||
|
|
||||||
|
3. **多模态尺度不平衡**
|
||||||
|
- 需关注图像/状态/text 融合后梯度主导问题(可加 modality dropout 或 loss 权重调节)。
|
||||||
|
|
||||||
|
4. **文本编码开销导致训练变慢**
|
||||||
|
- 可选缓存 DistilBERT 特征,或冻结 text encoder。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 你接下来只需提供的最小信息(进入代码改造前)
|
||||||
|
|
||||||
|
1. 2 个电机各自的物理含义与取值范围(单位、上下限)。
|
||||||
|
2. 你当前数据中 `qpos` 和 `action` 的实际定义(是否相同)。
|
||||||
|
3. text instruction 是每个 episode 一条,还是每个 timestep 一条。
|
||||||
|
4. 相机数量、分辨率、帧率。
|
||||||
|
5. 是否在训练时冻结 DistilBERT(`freeze_text_encoder=True/False`)。
|
||||||
|
|
||||||
|
> 有了这 5 项,即可进入下一步代码改造。
|
||||||
Reference in New Issue
Block a user