修改训练代码

This commit is contained in:
2026-02-17 19:16:09 +08:00
parent d4b4d554f8
commit ba006e14c4

View File

@@ -0,0 +1,277 @@
# ACT 仓库适配内镜机器人2-DOF + 图像/qpos/Text 指令)修改清单(仅训练版)
## 1. 目标与约束
### 目标
将当前标准 ACT 仓库改造成可用于你的内镜机器人离线训练,支持:
- **动作维度仅 2**2 个电机)
- **不依赖 Gym / 仿真环境**
- 输入为 **图像 + qpos + text instruction**
- 以离线数据训练为主(本阶段不包含真实机器人在线接口)
### 约束
当前代码默认是 ALOHA 双臂14 维状态/动作)和 sim/real ALOHA 环境接口,且**没有 text 分支**,存在大量硬编码,需要系统性改造。
---
## 2. 现有代码中的关键硬编码(必须改)
1. **状态/动作维度硬编码为 14**
- `imitate_episodes.py``state_dim = 14`
- `detr/models/detr_vae.py` 中多个 `nn.Linear(14, ...)`
- `record_sim_episodes.py` 的数据写入 shape 固定 `(T, 14)`
2. **训练/评估流程绑定 sim 或 aloha_scripts real_env**
- `imitate_episodes.py``eval_bc()` 依赖 `sim_env.make_sim_env()``aloha_scripts.real_env`
3. **数据加载器默认字段是 qpos/qvel/action + images**
- `utils.py``EpisodicDataset` 仅加载 `qpos``action``images`,无 text
4. **模型只融合图像 + 状态,无文本编码**
- `policy.py``detr/models/detr_vae.py` 当前无 text 输入通道
---
## 3. 必改模块清单(按文件)
## A. 配置与任务定义
### 文件
- `constants.py`
### 修改点
- 新增内镜任务配置(建议新字典 `ENDOSCOPE_TASK_CONFIGS`
- `dataset_dir`
- `num_episodes`
- `episode_len`
- `camera_names`
- `state_dim=2`
- `action_dim=2`
- `use_text_instruction=True`
- `instruction_mode`episode-level / timestep-level
- `text_encoder_type="distilbert"`
- `text_feature_dim=768`
- `text_fusion_type="concat_transformer_input"`
- 避免继续依赖 `sim_` 前缀来判断任务类型。
### 目的
把任务参数从 ALOHA 默认值中解耦,作为后续训练与模型构建的统一入口。
---
## B. 数据协议与数据集加载
### 文件
- `utils.py`
- (新增)`dataset_tools/` 下的数据转换脚本
### 修改点
1. **数据协议统一定义(建议 HDF5**
- `/observations/images/<cam_name>`: `(T, H, W, C)`
- `/observations/qpos`: `(T, 2)`
- `/action`: `(T, 2)`
- `/instruction`(字符串或 token
- 可选:`/instruction_timestep`(若每步指令不同)
2. **重构 `EpisodicDataset`**
- 保持 `qpos` 命名,不做重命名
- 加载 text instruction
- 返回训练样本改为:
- `image_data`
- `qpos_data`
- `action_data`
- `is_pad`
- `text_input_ids`
- `text_attention_mask`
3. **归一化统计扩展**
- `get_norm_stats()` 支持 `qpos/action` 任意维度(本任务均为 2
- text 采用在线 DistilBERT 编码(默认),可选缓存特征
4. **兼容性策略**
- 保持对旧字段 `qpos` 的直接兼容
- 支持多相机或单相机
### 目的
构建与你真实数据一致的数据管线,彻底摆脱 14 维与仿真字段假设。
---
## C. 训练入口与流程控制
### 文件
- `imitate_episodes.py`
### 修改点
1. **配置读取改造**
- 使用新任务配置读取 `state_dim=2/action_dim=2/camera_names/use_text_instruction`
2. **移除/隔离仿真耦合逻辑(训练范围内)**
- `main()` 保留纯离线训练路径
- `eval_bc()` 仅保留离线评估路径
3. **前向输入变更**
- `forward_pass()` 改为支持 text
- dataloader batch 解包增加 text 分量
4. **命令行参数补充**
- `--task_config``--task_name` 对应新配置
- `--text_encoder_type`
- `--freeze_text_encoder`
### 目的
让训练脚本成为“真实机器人离线模仿学习”的统一入口,而不是 sim demo 入口。
---
## D. 策略封装层Policy API
### 文件
- `policy.py`
### 修改点
1. `ACTPolicy.__call__()``CNNMLPPolicy.__call__()` 签名扩展:
- 现有:`(qpos, image, actions=None, is_pad=None)`
- 目标:`(qpos, image, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
2. 图像归一化保留,但要确保支持任意相机数量。
3. loss 计算保持一致,同时确保 text 缺失时可降级运行(便于 ablation
### 目的
把 text 从数据层顺畅传递到模型层。
---
## E. 模型构建参数与入口
### 文件
- `detr/main.py`
### 修改点
- 新增模型参数:
- `state_dim`
- `action_dim`
- `use_text`
- `text_encoder_type="distilbert"`
- `text_feature_dim=768`
- `text_fusion_type="concat_transformer_input"`
- 删除或弱化与原脚本无关的占位参数。
### 目的
将所有硬编码维度下放为可配置项,便于后续迭代。
---
## F. ACT 主干网络(核心改造)
### 文件
- `detr/models/detr_vae.py`
### 修改点
1. **去掉 14 维硬编码**
- `input_proj_robot_state = nn.Linear(state_dim, hidden_dim)`
- `encoder_action_proj = nn.Linear(action_dim, hidden_dim)`
- `encoder_joint_proj = nn.Linear(state_dim, hidden_dim)`
- `action_head = nn.Linear(hidden_dim, action_dim)`
2. **加入 text 分支**
- 使用 DistilBERT 输出特征768 维)
- 新增 text 投影层:`nn.Linear(768, hidden_dim)`
- 融合策略固定为:**将 text token/特征作为额外 token直接 concat 到 Transformer 输入序列**
3. **前向函数增加 text 输入**
- `forward(self, qpos, image, env_state, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
4. **保持训练/推理双模式一致**
- 训练:动作序列 + text 条件 VAE
- 推理:先验采样 + text 条件生成
### 目的
把 ACT 从“图像+14 维关节”模型改造成“图像+2 维 qpos+文本条件”模型。
---
## G. 真实机器人接口与采集(暂不纳入本轮)
本轮仅做训练侧改造,以下内容延期:
- 在线推理接口
- 真实机器人数据采集脚本
- 在线安全控制与频率控制
---
## H. 文档与脚本
### 文件
- `README.md`
- (新增)`docs/endoscope_data_format.md`
- (新增)`docs/endoscope_train_eval.md`
### 修改点
- 给出最小可运行流程:
1) 准备数据
2) 训练命令
3) 离线评估
- 明确 text instruction 的格式规范。
---
## 4. 建议新增文件(清单)
- `configs/endoscope_task.yaml`(或继续用 python dict
- `dataset_tools/convert_endoscope_to_act_hdf5.py`
- `dataset_tools/validate_endoscope_dataset.py`
- `models/text_encoder.py`DistilBERT 封装)
- `docs/endoscope_data_format.md`
---
## 5. 分阶段实施顺序(建议)
### Phase 1先跑通“无 text”2-DOF 版本
-`state_dim/action_dim`
- 跑通数据加载 + 训练 + 离线验证
### Phase 2加入 text instruction
- 数据协议加入 instruction
- 接入 DistilBERT768
-`concat_transformer_input` 完成 text 融合训练
---
## 6. 验收标准Definition of Done
1. 可用你的 HDF5 数据直接训练,不依赖 sim/gym。
2. 模型输入同时支持图像、2D qpos、text instruction。
3. Text 编码器使用 DistilBERT输出特征维度为 768。
4. Text 融合方式为 Transformer 输入级 concat。
5. README 有完整训练与离线评估命令示例,团队可复现。
---
## 7. 风险点与提前规避
1. **text 与动作时序对齐问题**
- 需明确 instruction 是 episode-level 还是 timestep-level。
2. **小维度控制下的动作抖动**
- 可在后处理中加入 low-pass / action smoothing。
3. **多模态尺度不平衡**
- 需关注图像/状态/text 融合后梯度主导问题(可加 modality dropout 或 loss 权重调节)。
4. **文本编码开销导致训练变慢**
- 可选缓存 DistilBERT 特征,或冻结 text encoder。
---
## 8. 你接下来只需提供的最小信息(进入代码改造前)
1. 2 个电机各自的物理含义与取值范围(单位、上下限)。
2. 你当前数据中 `qpos``action` 的实际定义(是否相同)。
3. text instruction 是每个 episode 一条,还是每个 timestep 一条。
4. 相机数量、分辨率、帧率。
5. 是否在训练时冻结 DistilBERT`freeze_text_encoder=True/False`)。
> 有了这 5 项,即可进入下一步代码改造。