8.1 KiB
8.1 KiB
ACT 仓库适配内镜机器人(2-DOF + 图像/qpos/Text 指令)修改清单(仅训练版)
1. 目标与约束
目标
将当前标准 ACT 仓库改造成可用于你的内镜机器人离线训练,支持:
- 动作维度仅 2(2 个电机)
- 不依赖 Gym / 仿真环境
- 输入为 图像 + qpos + text instruction
- 以离线数据训练为主(本阶段不包含真实机器人在线接口)
约束
当前代码默认是 ALOHA 双臂(14 维状态/动作)和 sim/real ALOHA 环境接口,且没有 text 分支,存在大量硬编码,需要系统性改造。
2. 现有代码中的关键硬编码(必须改)
-
状态/动作维度硬编码为 14
imitate_episodes.py中state_dim = 14detr/models/detr_vae.py中多个nn.Linear(14, ...)record_sim_episodes.py的数据写入 shape 固定(T, 14)
-
训练/评估流程绑定 sim 或 aloha_scripts real_env
imitate_episodes.py的eval_bc()依赖sim_env.make_sim_env()或aloha_scripts.real_env
-
数据加载器默认字段是 qpos/qvel/action + images
utils.py的EpisodicDataset仅加载qpos、action、images,无 text
-
模型只融合图像 + 状态,无文本编码
policy.py与detr/models/detr_vae.py当前无 text 输入通道
3. 必改模块清单(按文件)
A. 配置与任务定义
文件
constants.py
修改点
- 新增内镜任务配置(建议新字典
ENDOSCOPE_TASK_CONFIGS):dataset_dirnum_episodesepisode_lencamera_namesstate_dim=2action_dim=2use_text_instruction=Trueinstruction_mode(episode-level / timestep-level)text_encoder_type="distilbert"text_feature_dim=768text_fusion_type="concat_transformer_input"
- 避免继续依赖
sim_前缀来判断任务类型。
目的
把任务参数从 ALOHA 默认值中解耦,作为后续训练与模型构建的统一入口。
B. 数据协议与数据集加载
文件
utils.py- (新增)
dataset_tools/下的数据转换脚本
修改点
-
数据协议统一定义(建议 HDF5)
/observations/images/<cam_name>:(T, H, W, C)/observations/qpos:(T, 2)/action:(T, 2)/instruction(字符串或 token)- 可选:
/instruction_timestep(若每步指令不同)
-
重构
EpisodicDataset- 保持
qpos命名,不做重命名 - 加载 text instruction
- 返回训练样本改为:
image_dataqpos_dataaction_datais_padtext_input_idstext_attention_mask
- 保持
-
归一化统计扩展
get_norm_stats()支持qpos/action任意维度(本任务均为 2)- text 采用在线 DistilBERT 编码(默认),可选缓存特征
-
兼容性策略
- 保持对旧字段
qpos的直接兼容 - 支持多相机或单相机
- 保持对旧字段
目的
构建与你真实数据一致的数据管线,彻底摆脱 14 维与仿真字段假设。
C. 训练入口与流程控制
文件
imitate_episodes.py
修改点
-
配置读取改造
- 使用新任务配置读取
state_dim=2/action_dim=2/camera_names/use_text_instruction
- 使用新任务配置读取
-
移除/隔离仿真耦合逻辑(训练范围内)
main()保留纯离线训练路径eval_bc()仅保留离线评估路径
-
前向输入变更
forward_pass()改为支持 text- dataloader batch 解包增加 text 分量
-
命令行参数补充
--task_config或--task_name对应新配置--text_encoder_type--freeze_text_encoder
目的
让训练脚本成为“真实机器人离线模仿学习”的统一入口,而不是 sim demo 入口。
D. 策略封装层(Policy API)
文件
policy.py
修改点
-
ACTPolicy.__call__()和CNNMLPPolicy.__call__()签名扩展:- 现有:
(qpos, image, actions=None, is_pad=None) - 目标:
(qpos, image, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)
- 现有:
-
图像归一化保留,但要确保支持任意相机数量。
-
loss 计算保持一致,同时确保 text 缺失时可降级运行(便于 ablation)。
目的
把 text 从数据层顺畅传递到模型层。
E. 模型构建参数与入口
文件
detr/main.py
修改点
- 新增模型参数:
state_dimaction_dimuse_texttext_encoder_type="distilbert"text_feature_dim=768text_fusion_type="concat_transformer_input"
- 删除或弱化与原脚本无关的占位参数。
目的
将所有硬编码维度下放为可配置项,便于后续迭代。
F. ACT 主干网络(核心改造)
文件
detr/models/detr_vae.py
修改点
-
去掉 14 维硬编码
input_proj_robot_state = nn.Linear(state_dim, hidden_dim)encoder_action_proj = nn.Linear(action_dim, hidden_dim)encoder_joint_proj = nn.Linear(state_dim, hidden_dim)action_head = nn.Linear(hidden_dim, action_dim)
-
加入 text 分支
- 使用 DistilBERT 输出特征(768 维)
- 新增 text 投影层:
nn.Linear(768, hidden_dim) - 融合策略固定为:将 text token/特征作为额外 token,直接 concat 到 Transformer 输入序列
-
前向函数增加 text 输入
forward(self, qpos, image, env_state, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)
-
保持训练/推理双模式一致
- 训练:动作序列 + text 条件 VAE
- 推理:先验采样 + text 条件生成
目的
把 ACT 从“图像+14 维关节”模型改造成“图像+2 维 qpos+文本条件”模型。
G. 真实机器人接口与采集(暂不纳入本轮)
本轮仅做训练侧改造,以下内容延期:
- 在线推理接口
- 真实机器人数据采集脚本
- 在线安全控制与频率控制
H. 文档与脚本
文件
README.md- (新增)
docs/endoscope_data_format.md - (新增)
docs/endoscope_train_eval.md
修改点
- 给出最小可运行流程:
- 准备数据
- 训练命令
- 离线评估
- 明确 text instruction 的格式规范。
4. 建议新增文件(清单)
configs/endoscope_task.yaml(或继续用 python dict)dataset_tools/convert_endoscope_to_act_hdf5.pydataset_tools/validate_endoscope_dataset.pymodels/text_encoder.py(DistilBERT 封装)docs/endoscope_data_format.md
5. 分阶段实施顺序(建议)
Phase 1:先跑通“无 text”2-DOF 版本
- 改
state_dim/action_dim - 跑通数据加载 + 训练 + 离线验证
Phase 2:加入 text instruction
- 数据协议加入 instruction
- 接入 DistilBERT(768)
- 按
concat_transformer_input完成 text 融合训练
6. 验收标准(Definition of Done)
- 可用你的 HDF5 数据直接训练,不依赖 sim/gym。
- 模型输入同时支持图像、2D qpos、text instruction。
- Text 编码器使用 DistilBERT,输出特征维度为 768。
- Text 融合方式为 Transformer 输入级 concat。
- README 有完整训练与离线评估命令示例,团队可复现。
7. 风险点与提前规避
-
text 与动作时序对齐问题
- 需明确 instruction 是 episode-level 还是 timestep-level。
-
小维度控制下的动作抖动
- 可在后处理中加入 low-pass / action smoothing。
-
多模态尺度不平衡
- 需关注图像/状态/text 融合后梯度主导问题(可加 modality dropout 或 loss 权重调节)。
-
文本编码开销导致训练变慢
- 可选缓存 DistilBERT 特征,或冻结 text encoder。
8. 你接下来只需提供的最小信息(进入代码改造前)
- 2 个电机各自的物理含义与取值范围(单位、上下限)。
- 你当前数据中
qpos和action的实际定义(是否相同)。 - text instruction 是每个 episode 一条,还是每个 timestep 一条。
- 相机数量、分辨率、帧率。
- 是否在训练时冻结 DistilBERT(
freeze_text_encoder=True/False)。
有了这 5 项,即可进入下一步代码改造。