Files
aloha/ENDOSCOPE_ACT_ADAPTATION_PLAN.md
2026-02-19 15:32:28 +08:00

297 lines
9.5 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ACT 仓库适配内镜机器人2-DOF + 图像/qpos/Text 指令)修改清单(仅训练版)
## 1. 目标与约束
### 目标
将当前标准 ACT 仓库改造成可用于你的内镜机器人离线训练,支持:
- **动作维度仅 2**2 个电机)
- **不依赖 Gym / 仿真环境**
- 输入为 **图像 + qpos + text instruction**
- 以离线数据训练为主(本阶段不包含真实机器人在线接口)
### 约束
当前代码默认是 ALOHA 双臂14 维状态/动作)和 sim/real ALOHA 环境接口,且**没有 text 分支**,存在大量硬编码,需要系统性改造。
---
## 2. 现有代码中的关键硬编码(必须改)
1. **状态/动作维度硬编码为 14**
- `imitate_episodes.py``state_dim = 14`
- `detr/models/detr_vae.py` 中多个 `nn.Linear(14, ...)`
- `record_sim_episodes.py` 的数据写入 shape 固定 `(T, 14)`
2. **训练/评估流程绑定 sim 或 aloha_scripts real_env**
- `imitate_episodes.py``eval_bc()` 依赖 `sim_env.make_sim_env()``aloha_scripts.real_env`
3. **数据加载器默认字段是 qpos/qvel/action + images**
- `utils.py``EpisodicDataset` 仅加载 `qpos``action``images`,无 text
4. **模型只融合图像 + 状态,无文本编码**
- `policy.py``detr/models/detr_vae.py` 当前无 text 输入通道
---
## 3. 必改模块清单(按文件)
## A. 配置与任务定义
### 文件
- `constants.py`
### 修改点
- 新增内镜任务配置(建议新字典 `ENDOSCOPE_TASK_CONFIGS`
- `dataset_dir`
- `num_episodes`
- `episode_len`
- `camera_names`
- `state_dim=2`
- `action_dim=2`
- `use_text_instruction=True`
- `instruction_mode`episode-level / timestep-level
- `text_encoder_type="distilbert"`
- `text_feature_dim=768`
- `text_fusion_type="concat_transformer_input"`
- 避免继续依赖 `sim_` 前缀来判断任务类型。
### 目的
把任务参数从 ALOHA 默认值中解耦,作为后续训练与模型构建的统一入口。
---
## B. 数据协议与数据集加载
### 文件
- `utils.py`
- (新增)`dataset_tools/` 下的数据转换脚本
### 修改点
1. **数据协议统一定义(建议 HDF5**
- `/observations/images/<cam_name>`: `(T, H, W, C)`
- `/observations/qpos`: `(T, 2)`
- `/action`: `(T, 2)`
- `/instruction`(字符串或 token
- 可选:`/instruction_timestep`(若每步指令不同)
2. **重构 `EpisodicDataset`**
- 保持 `qpos` 命名,不做重命名
- 加载 text instruction
- 返回训练样本改为:
- `image_data`
- `qpos_data`
- `action_data`
- `is_pad`
- `text_input_ids`
- `text_attention_mask`
3. **归一化统计扩展**
- `get_norm_stats()` 支持 `qpos/action` 任意维度(本任务均为 2
- text 采用在线 DistilBERT 编码(默认),可选缓存特征
4. **兼容性策略**
- 保持对旧字段 `qpos` 的直接兼容
- 支持多相机或单相机
### 目的
构建与你真实数据一致的数据管线,彻底摆脱 14 维与仿真字段假设。
---
## C. 训练入口与流程控制
### 文件
- `imitate_episodes.py`
### 修改点
1. **配置读取改造**
- 使用新任务配置读取 `state_dim=2/action_dim=2/camera_names/use_text_instruction`
2. **移除/隔离仿真耦合逻辑(训练范围内)**
- `main()` 保留纯离线训练路径
- `eval_bc()` 仅保留离线评估路径
3. **前向输入变更**
- `forward_pass()` 改为支持 text
- dataloader batch 解包增加 text 分量
4. **命令行参数补充**
- `--task_config``--task_name` 对应新配置
- `--text_encoder_type`
- `--freeze_text_encoder`
### 目的
让训练脚本成为“真实机器人离线模仿学习”的统一入口,而不是 sim demo 入口。
---
## D. 策略封装层Policy API
### 文件
- `policy.py`
### 修改点
1. `ACTPolicy.__call__()``CNNMLPPolicy.__call__()` 签名扩展:
- 现有:`(qpos, image, actions=None, is_pad=None)`
- 目标:`(qpos, image, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
2. 图像归一化保留,但要确保支持任意相机数量。
3. loss 计算保持一致,同时确保 text 缺失时可降级运行(便于 ablation
### 目的
把 text 从数据层顺畅传递到模型层。
---
## E. 模型构建参数与入口
### 文件
- `detr/main.py`
### 修改点
- 新增模型参数:
- `state_dim`
- `action_dim`
- `use_text`
- `text_encoder_type="distilbert"`
- `text_feature_dim=768`
- `text_fusion_type="concat_transformer_input"`
- 删除或弱化与原脚本无关的占位参数。
### 目的
将所有硬编码维度下放为可配置项,便于后续迭代。
---
## F. ACT 主干网络(核心改造)
### 文件
- `detr/models/detr_vae.py`
### 修改点
1. **去掉 14 维硬编码**
- `input_proj_robot_state = nn.Linear(state_dim, hidden_dim)`
- `encoder_action_proj = nn.Linear(action_dim, hidden_dim)`
- `encoder_joint_proj = nn.Linear(state_dim, hidden_dim)`
- `action_head = nn.Linear(hidden_dim, action_dim)`
2. **加入 text 分支**
- 使用 DistilBERT 输出特征768 维)
- 新增 text 投影层:`nn.Linear(768, hidden_dim)`
- 融合策略固定为:**将 text token/特征作为额外 token直接 concat 到 Transformer 输入序列**
3. **前向函数增加 text 输入**
- `forward(self, qpos, image, env_state, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
4. **保持训练/推理双模式一致**
- 训练:动作序列 + text 条件 VAE
- 推理:先验采样 + text 条件生成
### 目的
把 ACT 从“图像+14 维关节”模型改造成“图像+2 维 qpos+文本条件”模型。
---
## G. 真实机器人接口与采集(暂不纳入本轮)
本轮仅做训练侧改造,以下内容延期:
- 在线推理接口
- 真实机器人数据采集脚本
- 在线安全控制与频率控制
---
## H. 文档与脚本
### 文件
- `README.md`
- (新增)`docs/endoscope_data_format.md`
- (新增)`docs/endoscope_train_eval.md`
### 修改点
- 给出最小可运行流程:
1) 准备数据
2) 训练命令
3) 离线评估
- 明确 text instruction 的格式规范。
---
## 4. 建议新增文件(清单)
- `configs/endoscope_task.yaml`(或继续用 python dict
- `dataset_tools/convert_endoscope_to_act_hdf5.py`
- `dataset_tools/validate_endoscope_dataset.py`
- `models/text_encoder.py`DistilBERT 封装)
- `docs/endoscope_data_format.md`
---
## 5. 分阶段实施顺序(建议)
### Phase 1先跑通“无 text”2-DOF 版本
-`state_dim/action_dim`
- 跑通数据加载 + 训练 + 离线验证
### Phase 2加入 text instruction
- 数据协议加入 instruction
- 接入 DistilBERT768
-`concat_transformer_input` 完成 text 融合训练
---
## 6. 验收标准Definition of Done
1. 可用你的 HDF5 数据直接训练,不依赖 sim/gym。
2. 模型输入同时支持图像、2D qpos、text instruction。
3. Text 编码器使用 DistilBERT输出特征维度为 768。
4. Text 融合方式为 Transformer 输入级 concat。
5. README 有完整训练与离线评估命令示例,团队可复现。
---
## 7. 风险点与提前规避
1. **text 与动作时序对齐问题**
- 需明确 instruction 是 episode-level 还是 timestep-level。
2. **小维度控制下的动作抖动**
- 可在后处理中加入 low-pass / action smoothing。
3. **多模态尺度不平衡**
- 需关注图像/状态/text 融合后梯度主导问题(可加 modality dropout 或 loss 权重调节)。
4. **文本编码开销导致训练变慢**
- 可选缓存 DistilBERT 特征,或冻结 text encoder。
---
## 8. 你接下来只需提供的最小信息(进入代码改造前)
1. 2 个电机各自的物理含义与取值范围(单位、上下限):电机分别为 motor_x 和 motor_yx 的范围为 7000-17384y 的范围为 8000-18884。对应的 action_x 和 action_y 都为 0~65535 之间
2. 你当前数据中 `qpos``action` 的实际定义是否相同action 和 qpos 定义接近,只不过是将 0~65535 分别映射到电机磁编码器数值上。
3. text instruction 是每个 episode 一条,还是每个 timestep 一条text instruction 是每个 timestep 一条。
4. 相机数量、分辨率、帧率:相机数量为 1分辨率为 224*224帧率为 30Hz对应的电机控制频率也为 30Hz
5. 是否在训练时冻结 DistilBERT`freeze_text_encoder=True/False`DistilBERT 完全冻结。构建训练集时,先将每一个 frame 的 text instruction 用 DistilBERT 编码以后再保存。这样训练过程中不需要调用 DistilBERT。
> 有了这 5 项,即可进入下一步代码改造。
text_input_ids、text_attention_mask 什么意思;'instruction_mode': 'episode-level'没有用到;
```python
instruction = ''
if self.use_text_instruction:
if '/instruction_timestep' in root:
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
elif '/instruction' in root:
instruction_node = root['/instruction']
if getattr(instruction_node, 'shape', ()) == ():
instruction = self._decode_instruction(instruction_node[()])
else:
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
instruction = self._decode_instruction(instruction_node[start_ts])
else:
instruction = self._decode_instruction(instruction_node[0])
```
为什么修改了 Transformer 的定义?这里是否会生效?