Compare commits
10 Commits
2133db326e
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 96d19c0ffc | |||
| 81e1bf8838 | |||
| 88d0cc5ca2 | |||
| d85cce8a52 | |||
| ee257bcb6c | |||
| 7023d5dde4 | |||
| 88d14221ae | |||
| b701d939c2 | |||
| ba006e14c4 | |||
|
|
d4b4d554f8 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -4,6 +4,10 @@ wandb
|
||||
outputs
|
||||
data
|
||||
data_local
|
||||
ckpt
|
||||
*.ckpt
|
||||
*.pt
|
||||
*.pth
|
||||
.vscode
|
||||
_wandb
|
||||
|
||||
|
||||
297
ENDOSCOPE_ACT_ADAPTATION_PLAN.md
Normal file
297
ENDOSCOPE_ACT_ADAPTATION_PLAN.md
Normal file
@@ -0,0 +1,297 @@
|
||||
# ACT 仓库适配内镜机器人(2-DOF + 图像/qpos/Text 指令)修改清单(仅训练版)
|
||||
|
||||
## 1. 目标与约束
|
||||
|
||||
### 目标
|
||||
将当前标准 ACT 仓库改造成可用于你的内镜机器人离线训练,支持:
|
||||
- **动作维度仅 2**(2 个电机)
|
||||
- **不依赖 Gym / 仿真环境**
|
||||
- 输入为 **图像 + qpos + text instruction**
|
||||
- 以离线数据训练为主(本阶段不包含真实机器人在线接口)
|
||||
|
||||
### 约束
|
||||
当前代码默认是 ALOHA 双臂(14 维状态/动作)和 sim/real ALOHA 环境接口,且**没有 text 分支**,存在大量硬编码,需要系统性改造。
|
||||
|
||||
---
|
||||
|
||||
## 2. 现有代码中的关键硬编码(必须改)
|
||||
|
||||
1. **状态/动作维度硬编码为 14**
|
||||
- `imitate_episodes.py` 中 `state_dim = 14`
|
||||
- `detr/models/detr_vae.py` 中多个 `nn.Linear(14, ...)`
|
||||
- `record_sim_episodes.py` 的数据写入 shape 固定 `(T, 14)`
|
||||
|
||||
2. **训练/评估流程绑定 sim 或 aloha_scripts real_env**
|
||||
- `imitate_episodes.py` 的 `eval_bc()` 依赖 `sim_env.make_sim_env()` 或 `aloha_scripts.real_env`
|
||||
|
||||
3. **数据加载器默认字段是 qpos/qvel/action + images**
|
||||
- `utils.py` 的 `EpisodicDataset` 仅加载 `qpos`、`action`、`images`,无 text
|
||||
|
||||
4. **模型只融合图像 + 状态,无文本编码**
|
||||
- `policy.py` 与 `detr/models/detr_vae.py` 当前无 text 输入通道
|
||||
|
||||
---
|
||||
|
||||
## 3. 必改模块清单(按文件)
|
||||
|
||||
## A. 配置与任务定义
|
||||
|
||||
### 文件
|
||||
- `constants.py`
|
||||
|
||||
### 修改点
|
||||
- 新增内镜任务配置(建议新字典 `ENDOSCOPE_TASK_CONFIGS`):
|
||||
- `dataset_dir`
|
||||
- `num_episodes`
|
||||
- `episode_len`
|
||||
- `camera_names`
|
||||
- `state_dim=2`
|
||||
- `action_dim=2`
|
||||
- `use_text_instruction=True`
|
||||
- `instruction_mode`(episode-level / timestep-level)
|
||||
- `text_encoder_type="distilbert"`
|
||||
- `text_feature_dim=768`
|
||||
- `text_fusion_type="concat_transformer_input"`
|
||||
- 避免继续依赖 `sim_` 前缀来判断任务类型。
|
||||
|
||||
### 目的
|
||||
把任务参数从 ALOHA 默认值中解耦,作为后续训练与模型构建的统一入口。
|
||||
|
||||
---
|
||||
|
||||
## B. 数据协议与数据集加载
|
||||
|
||||
### 文件
|
||||
- `utils.py`
|
||||
- (新增)`dataset_tools/` 下的数据转换脚本
|
||||
|
||||
### 修改点
|
||||
1. **数据协议统一定义(建议 HDF5)**
|
||||
- `/observations/images/<cam_name>`: `(T, H, W, C)`
|
||||
- `/observations/qpos`: `(T, 2)`
|
||||
- `/action`: `(T, 2)`
|
||||
- `/instruction`(字符串或 token)
|
||||
- 可选:`/instruction_timestep`(若每步指令不同)
|
||||
|
||||
2. **重构 `EpisodicDataset`**
|
||||
- 保持 `qpos` 命名,不做重命名
|
||||
- 加载 text instruction
|
||||
- 返回训练样本改为:
|
||||
- `image_data`
|
||||
- `qpos_data`
|
||||
- `action_data`
|
||||
- `is_pad`
|
||||
- `text_input_ids`
|
||||
- `text_attention_mask`
|
||||
|
||||
3. **归一化统计扩展**
|
||||
- `get_norm_stats()` 支持 `qpos/action` 任意维度(本任务均为 2)
|
||||
- text 采用在线 DistilBERT 编码(默认),可选缓存特征
|
||||
|
||||
4. **兼容性策略**
|
||||
- 保持对旧字段 `qpos` 的直接兼容
|
||||
- 支持多相机或单相机
|
||||
|
||||
### 目的
|
||||
构建与你真实数据一致的数据管线,彻底摆脱 14 维与仿真字段假设。
|
||||
|
||||
---
|
||||
|
||||
## C. 训练入口与流程控制
|
||||
|
||||
### 文件
|
||||
- `imitate_episodes.py`
|
||||
|
||||
### 修改点
|
||||
1. **配置读取改造**
|
||||
- 使用新任务配置读取 `state_dim=2/action_dim=2/camera_names/use_text_instruction`
|
||||
|
||||
2. **移除/隔离仿真耦合逻辑(训练范围内)**
|
||||
- `main()` 保留纯离线训练路径
|
||||
- `eval_bc()` 仅保留离线评估路径
|
||||
|
||||
3. **前向输入变更**
|
||||
- `forward_pass()` 改为支持 text
|
||||
- dataloader batch 解包增加 text 分量
|
||||
|
||||
4. **命令行参数补充**
|
||||
- `--task_config` 或 `--task_name` 对应新配置
|
||||
- `--text_encoder_type`
|
||||
- `--freeze_text_encoder`
|
||||
|
||||
### 目的
|
||||
让训练脚本成为“真实机器人离线模仿学习”的统一入口,而不是 sim demo 入口。
|
||||
|
||||
---
|
||||
|
||||
## D. 策略封装层(Policy API)
|
||||
|
||||
### 文件
|
||||
- `policy.py`
|
||||
|
||||
### 修改点
|
||||
1. `ACTPolicy.__call__()` 和 `CNNMLPPolicy.__call__()` 签名扩展:
|
||||
- 现有:`(qpos, image, actions=None, is_pad=None)`
|
||||
- 目标:`(qpos, image, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
|
||||
|
||||
2. 图像归一化保留,但要确保支持任意相机数量。
|
||||
|
||||
3. loss 计算保持一致,同时确保 text 缺失时可降级运行(便于 ablation)。
|
||||
|
||||
### 目的
|
||||
把 text 从数据层顺畅传递到模型层。
|
||||
|
||||
---
|
||||
|
||||
## E. 模型构建参数与入口
|
||||
|
||||
### 文件
|
||||
- `detr/main.py`
|
||||
|
||||
### 修改点
|
||||
- 新增模型参数:
|
||||
- `state_dim`
|
||||
- `action_dim`
|
||||
- `use_text`
|
||||
- `text_encoder_type="distilbert"`
|
||||
- `text_feature_dim=768`
|
||||
- `text_fusion_type="concat_transformer_input"`
|
||||
- 删除或弱化与原脚本无关的占位参数。
|
||||
|
||||
### 目的
|
||||
将所有硬编码维度下放为可配置项,便于后续迭代。
|
||||
|
||||
---
|
||||
|
||||
## F. ACT 主干网络(核心改造)
|
||||
|
||||
### 文件
|
||||
- `detr/models/detr_vae.py`
|
||||
|
||||
### 修改点
|
||||
1. **去掉 14 维硬编码**
|
||||
- `input_proj_robot_state = nn.Linear(state_dim, hidden_dim)`
|
||||
- `encoder_action_proj = nn.Linear(action_dim, hidden_dim)`
|
||||
- `encoder_joint_proj = nn.Linear(state_dim, hidden_dim)`
|
||||
- `action_head = nn.Linear(hidden_dim, action_dim)`
|
||||
|
||||
2. **加入 text 分支**
|
||||
- 使用 DistilBERT 输出特征(768 维)
|
||||
- 新增 text 投影层:`nn.Linear(768, hidden_dim)`
|
||||
- 融合策略固定为:**将 text token/特征作为额外 token,直接 concat 到 Transformer 输入序列**
|
||||
|
||||
3. **前向函数增加 text 输入**
|
||||
- `forward(self, qpos, image, env_state, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
|
||||
|
||||
4. **保持训练/推理双模式一致**
|
||||
- 训练:动作序列 + text 条件 VAE
|
||||
- 推理:先验采样 + text 条件生成
|
||||
|
||||
### 目的
|
||||
把 ACT 从“图像+14 维关节”模型改造成“图像+2 维 qpos+文本条件”模型。
|
||||
|
||||
---
|
||||
|
||||
## G. 真实机器人接口与采集(暂不纳入本轮)
|
||||
|
||||
本轮仅做训练侧改造,以下内容延期:
|
||||
- 在线推理接口
|
||||
- 真实机器人数据采集脚本
|
||||
- 在线安全控制与频率控制
|
||||
|
||||
---
|
||||
|
||||
## H. 文档与脚本
|
||||
|
||||
### 文件
|
||||
- `README.md`
|
||||
- (新增)`docs/endoscope_data_format.md`
|
||||
- (新增)`docs/endoscope_train_eval.md`
|
||||
|
||||
### 修改点
|
||||
- 给出最小可运行流程:
|
||||
1) 准备数据
|
||||
2) 训练命令
|
||||
3) 离线评估
|
||||
- 明确 text instruction 的格式规范。
|
||||
|
||||
---
|
||||
|
||||
## 4. 建议新增文件(清单)
|
||||
|
||||
- `configs/endoscope_task.yaml`(或继续用 python dict)
|
||||
- `dataset_tools/convert_endoscope_to_act_hdf5.py`
|
||||
- `dataset_tools/validate_endoscope_dataset.py`
|
||||
- `models/text_encoder.py`(DistilBERT 封装)
|
||||
- `docs/endoscope_data_format.md`
|
||||
|
||||
---
|
||||
|
||||
## 5. 分阶段实施顺序(建议)
|
||||
|
||||
### Phase 1:先跑通“无 text”2-DOF 版本
|
||||
- 改 `state_dim/action_dim`
|
||||
- 跑通数据加载 + 训练 + 离线验证
|
||||
|
||||
### Phase 2:加入 text instruction
|
||||
- 数据协议加入 instruction
|
||||
- 接入 DistilBERT(768)
|
||||
- 按 `concat_transformer_input` 完成 text 融合训练
|
||||
|
||||
---
|
||||
|
||||
## 6. 验收标准(Definition of Done)
|
||||
|
||||
1. 可用你的 HDF5 数据直接训练,不依赖 sim/gym。
|
||||
2. 模型输入同时支持图像、2D qpos、text instruction。
|
||||
3. Text 编码器使用 DistilBERT,输出特征维度为 768。
|
||||
4. Text 融合方式为 Transformer 输入级 concat。
|
||||
5. README 有完整训练与离线评估命令示例,团队可复现。
|
||||
|
||||
---
|
||||
|
||||
## 7. 风险点与提前规避
|
||||
|
||||
1. **text 与动作时序对齐问题**
|
||||
- 需明确 instruction 是 episode-level 还是 timestep-level。
|
||||
|
||||
2. **小维度控制下的动作抖动**
|
||||
- 可在后处理中加入 low-pass / action smoothing。
|
||||
|
||||
3. **多模态尺度不平衡**
|
||||
- 需关注图像/状态/text 融合后梯度主导问题(可加 modality dropout 或 loss 权重调节)。
|
||||
|
||||
4. **文本编码开销导致训练变慢**
|
||||
- 可选缓存 DistilBERT 特征,或冻结 text encoder。
|
||||
|
||||
---
|
||||
|
||||
## 8. 你接下来只需提供的最小信息(进入代码改造前)
|
||||
|
||||
1. 2 个电机各自的物理含义与取值范围(单位、上下限):电机分别为 motor_x 和 motor_y,x 的范围为 7000-17384,y 的范围为 8000-18884。对应的 action_x 和 action_y 都为 0~65535 之间
|
||||
2. 你当前数据中 `qpos` 和 `action` 的实际定义(是否相同):action 和 qpos 定义接近,只不过是将 0~65535 分别映射到电机磁编码器数值上。
|
||||
3. text instruction 是每个 episode 一条,还是每个 timestep 一条:text instruction 是每个 timestep 一条。
|
||||
4. 相机数量、分辨率、帧率:相机数量为 1,分辨率为 224*224,帧率为 30Hz;对应的电机控制频率也为 30Hz
|
||||
5. 是否在训练时冻结 DistilBERT(`freeze_text_encoder=True/False`):DistilBERT 完全冻结。构建训练集时,先将每一个 frame 的 text instruction 用 DistilBERT 编码以后再保存。这样训练过程中不需要调用 DistilBERT。
|
||||
|
||||
> 有了这 5 项,即可进入下一步代码改造。
|
||||
|
||||
text_input_ids、text_attention_mask 什么意思;'instruction_mode': 'episode-level'没有用到;
|
||||
|
||||
```python
|
||||
instruction = ''
|
||||
if self.use_text_instruction:
|
||||
if '/instruction_timestep' in root:
|
||||
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
|
||||
elif '/instruction' in root:
|
||||
instruction_node = root['/instruction']
|
||||
if getattr(instruction_node, 'shape', ()) == ():
|
||||
instruction = self._decode_instruction(instruction_node[()])
|
||||
else:
|
||||
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
|
||||
instruction = self._decode_instruction(instruction_node[start_ts])
|
||||
else:
|
||||
instruction = self._decode_instruction(instruction_node[0])
|
||||
```
|
||||
|
||||
为什么修改了 Transformer 的定义?这里是否会生效?
|
||||
@@ -35,8 +35,8 @@ You can find all scripted/human demo for simulated environments [here](https://d
|
||||
pip install pyyaml
|
||||
pip install rospkg
|
||||
pip install pexpect
|
||||
pip install mujoco
|
||||
pip install dm_control
|
||||
pip install mujoco==2.3.7
|
||||
pip install dm_control==1.0.14
|
||||
pip install opencv-python
|
||||
pip install matplotlib
|
||||
pip install einops
|
||||
|
||||
680
build_endoscope_act_dataset.py
Normal file
680
build_endoscope_act_dataset.py
Normal file
@@ -0,0 +1,680 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CropBox:
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
@property
|
||||
def w(self) -> int:
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def h(self) -> int:
|
||||
return self.y2 - self.y1
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--segment_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to one raw segment, e.g. data/follow_seg_001",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output dir for episode_*.hdf5",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode_idx",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Output episode index (default: 0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_frames",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--camera_name",
|
||||
type=str,
|
||||
default="top",
|
||||
help="Camera name written to /observations/images/<camera_name>",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop",
|
||||
type=int,
|
||||
nargs=4,
|
||||
default=[733, 30, 1754, 1051],
|
||||
metavar=("X1", "Y1", "X2", "Y2"),
|
||||
help="Crop box in original image coordinates",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resize",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[224, 224],
|
||||
metavar=("W", "H"),
|
||||
help="Output image size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruction_template",
|
||||
type=str,
|
||||
default="Move toward the {label} at {region}.",
|
||||
help="Template for per-frame instruction",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instruction_empty",
|
||||
type=str,
|
||||
default="No target visible.",
|
||||
help="Instruction when no valid target after crop",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop_instruction",
|
||||
type=str,
|
||||
default="Stop move.",
|
||||
help="Instruction used for stationary head/tail frames",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state_norm",
|
||||
choices=["minus1_1", "0_1", "raw"],
|
||||
default="minus1_1",
|
||||
help="Normalization for qpos (motor_pos_y, motor_pos_x)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_norm",
|
||||
choices=["minus1_1", "0_1", "raw"],
|
||||
default="minus1_1",
|
||||
help="Normalization for action (motor_command_0, motor_command_1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encode_text_features",
|
||||
action="store_true",
|
||||
help="Encode per-frame instruction into 768-dim DistilBERT features",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_model_name",
|
||||
type=str,
|
||||
default="distilbert-base-uncased",
|
||||
help="HuggingFace model name for DistilBERT",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_batch_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for text feature extraction",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--motion_window",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Sliding window size used for stationary detection at beginning/end",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--motion_threshold",
|
||||
type=float,
|
||||
default=0.002,
|
||||
help=(
|
||||
"Motion threshold in normalized delta space (0~1). "
|
||||
"Smaller value means stricter stationary detection"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_stop_override",
|
||||
action="store_true",
|
||||
help="Disable head/tail stationary instruction override",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trim_stationary_edges",
|
||||
action="store_true",
|
||||
help="Trim stationary head/tail segments and keep only the middle moving segment",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_text_instruction",
|
||||
action="store_true",
|
||||
help="Do not save instruction/instruction_timestep (and disable text feature encoding)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def sorted_frame_jsons(frames_dir: Path) -> List[Path]:
|
||||
json_files = list(frames_dir.glob("*.json"))
|
||||
|
||||
def key_fn(p: Path) -> Tuple[int, str]:
|
||||
m = re.search(r"frame_(\d+)", p.name)
|
||||
idx = int(m.group(1)) if m else 10**9
|
||||
return idx, p.name
|
||||
|
||||
json_files.sort(key=key_fn)
|
||||
return json_files
|
||||
|
||||
|
||||
def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]:
|
||||
with csv_path.open("r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
return list(reader)
|
||||
|
||||
|
||||
def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray:
|
||||
if mode == "raw":
|
||||
return x.astype(np.float32)
|
||||
x01 = (x - min_v) / (max_v - min_v)
|
||||
if mode == "0_1":
|
||||
return x01.astype(np.float32)
|
||||
# minus1_1
|
||||
return (x01 * 2.0 - 1.0).astype(np.float32)
|
||||
|
||||
|
||||
def clip_bbox_to_crop(
|
||||
x_min: float,
|
||||
y_min: float,
|
||||
x_max: float,
|
||||
y_max: float,
|
||||
crop: CropBox,
|
||||
) -> Optional[Tuple[float, float, float, float]]:
|
||||
nx1 = max(x_min - crop.x1, 0.0)
|
||||
ny1 = max(y_min - crop.y1, 0.0)
|
||||
nx2 = min(x_max - crop.x1, float(crop.w - 1))
|
||||
ny2 = min(y_max - crop.y1, float(crop.h - 1))
|
||||
if nx2 <= nx1 or ny2 <= ny1:
|
||||
return None
|
||||
return nx1, ny1, nx2, ny2
|
||||
|
||||
|
||||
def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]:
|
||||
x1, y1, x2, y2 = box
|
||||
return (x1 + x2) * 0.5, (y1 + y2) * 0.5
|
||||
|
||||
|
||||
def region_3x3(cx: float, cy: float, w: int, h: int) -> str:
|
||||
x_bin = min(2, max(0, int(cx / (w / 3.0))))
|
||||
y_bin = min(2, max(0, int(cy / (h / 3.0))))
|
||||
xs = ["left", "center", "right"]
|
||||
ys = ["top", "middle", "bottom"]
|
||||
return f"{ys[y_bin]}-{xs[x_bin]}"
|
||||
|
||||
|
||||
def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]:
|
||||
points = shape.get("points", None)
|
||||
label = shape.get("label", "target")
|
||||
if not points or len(points) < 2:
|
||||
return None
|
||||
pts = np.array(points, dtype=np.float32)
|
||||
x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min())
|
||||
x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max())
|
||||
area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min)
|
||||
return label, x_min, y_min, x_max, y_max, area
|
||||
|
||||
|
||||
def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]:
|
||||
shapes = annotation.get("shapes", [])
|
||||
best = None
|
||||
for shape in shapes:
|
||||
parsed = read_shape_bbox(shape)
|
||||
if parsed is None:
|
||||
continue
|
||||
label, x1, y1, x2, y2, area = parsed
|
||||
clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop)
|
||||
if clipped is None:
|
||||
continue
|
||||
c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1])
|
||||
if best is None or c_area > best[2]:
|
||||
best = (label, clipped, c_area)
|
||||
if best is None:
|
||||
return None
|
||||
return best[0], best[1]
|
||||
|
||||
|
||||
def instruction_from_annotation(
|
||||
annotation: Dict,
|
||||
crop: CropBox,
|
||||
template: str,
|
||||
empty_instruction: str,
|
||||
) -> str:
|
||||
picked = select_target_box(annotation, crop)
|
||||
if picked is None:
|
||||
return empty_instruction
|
||||
label, box = picked
|
||||
cx, cy = bbox_center(box)
|
||||
region = region_3x3(cx, cy, crop.w, crop.h)
|
||||
return template.format(label=label, region=region)
|
||||
|
||||
|
||||
def extract_text_features(
|
||||
instructions: Sequence[str],
|
||||
model_name: str,
|
||||
batch_size: int = 32,
|
||||
) -> np.ndarray:
|
||||
try:
|
||||
import torch
|
||||
from transformers import DistilBertTokenizerFast
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Text feature encoding requires transformers. Please install: pip install transformers"
|
||||
) from exc
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
from models.text_encoder import DistilBERTTextEncoder
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
||||
model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device)
|
||||
model.eval()
|
||||
|
||||
feats: List[np.ndarray] = []
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(instructions), batch_size):
|
||||
batch = list(instructions[i:i + batch_size])
|
||||
tok = tokenizer(
|
||||
batch,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=32,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = tok["input_ids"].to(device)
|
||||
attention_mask = tok["attention_mask"].to(device)
|
||||
cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32)
|
||||
feats.append(cls)
|
||||
return np.concatenate(feats, axis=0)
|
||||
|
||||
|
||||
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
|
||||
scale = max_v - min_v
|
||||
if scale <= 0:
|
||||
return np.zeros_like(v, dtype=np.float32)
|
||||
return ((v - min_v) / scale).astype(np.float32)
|
||||
|
||||
|
||||
def override_stationary_edge_instructions(
|
||||
instructions: List[str],
|
||||
motor_pos_y: np.ndarray,
|
||||
motor_pos_x: np.ndarray,
|
||||
motor_cmd_0: np.ndarray,
|
||||
motor_cmd_1: np.ndarray,
|
||||
stop_instruction: str,
|
||||
motion_window: int,
|
||||
motion_threshold: float,
|
||||
) -> Tuple[List[str], int, int]:
|
||||
"""
|
||||
Override instruction text at head/tail using qpos velocity.
|
||||
Start side: once qpos speed is above threshold for consecutive frames,
|
||||
stop applying stop_instruction from that point onward.
|
||||
End side: similarly scan backward from the end.
|
||||
"""
|
||||
num = len(instructions)
|
||||
if num == 0:
|
||||
return instructions, 0, 0
|
||||
|
||||
start_count, end_count = detect_stationary_edge_counts_from_qpos(
|
||||
motor_pos_y=motor_pos_y,
|
||||
motor_pos_x=motor_pos_x,
|
||||
motion_window=motion_window,
|
||||
motion_threshold=motion_threshold,
|
||||
)
|
||||
|
||||
updated = list(instructions)
|
||||
for i in range(start_count):
|
||||
updated[i] = stop_instruction
|
||||
for i in range(num - end_count, num):
|
||||
updated[i] = stop_instruction
|
||||
|
||||
return updated, start_count, end_count
|
||||
|
||||
|
||||
def detect_stationary_edge_counts_from_qpos(
|
||||
motor_pos_y: np.ndarray,
|
||||
motor_pos_x: np.ndarray,
|
||||
motion_window: int,
|
||||
motion_threshold: float,
|
||||
) -> Tuple[int, int]:
|
||||
"""Return stationary frame counts on head and tail using qpos velocity rule."""
|
||||
num = int(len(motor_pos_y))
|
||||
if num == 0:
|
||||
return 0, 0
|
||||
if num == 1:
|
||||
return 1, 1
|
||||
|
||||
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
||||
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
||||
|
||||
consecutive = max(1, int(motion_window))
|
||||
dt = 1.0 / 30.0
|
||||
|
||||
frame_speed = np.zeros((num,), dtype=np.float32)
|
||||
dy = np.abs(np.diff(py)) / dt
|
||||
dx = np.abs(np.diff(px)) / dt
|
||||
frame_speed[1:] = np.maximum(dy, dx)
|
||||
|
||||
high_run = 0
|
||||
start_count = num
|
||||
for i in range(1, num):
|
||||
if frame_speed[i] > motion_threshold:
|
||||
high_run += 1
|
||||
if high_run >= consecutive:
|
||||
start_count = i - consecutive + 1
|
||||
break
|
||||
else:
|
||||
high_run = 0
|
||||
|
||||
high_run = 0
|
||||
end_count = num
|
||||
for i in range(num - 1, 0, -1):
|
||||
if frame_speed[i] > motion_threshold:
|
||||
high_run += 1
|
||||
if high_run >= consecutive:
|
||||
tail_start = i + consecutive
|
||||
end_count = max(0, num - tail_start)
|
||||
break
|
||||
else:
|
||||
high_run = 0
|
||||
|
||||
return start_count, end_count
|
||||
|
||||
|
||||
def find_segment_csv(segment_dir: Path) -> Path:
|
||||
csvs = sorted(segment_dir.glob("*.csv"))
|
||||
if not csvs:
|
||||
raise FileNotFoundError(f"No csv file found in {segment_dir}")
|
||||
return csvs[0]
|
||||
|
||||
|
||||
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
|
||||
"""Convert boolean mask to closed-open segments [start, end)."""
|
||||
segments: List[Tuple[int, int]] = []
|
||||
if mask.size == 0:
|
||||
return segments
|
||||
in_seg = False
|
||||
start = 0
|
||||
for i, v in enumerate(mask.tolist()):
|
||||
if v and not in_seg:
|
||||
in_seg = True
|
||||
start = i
|
||||
elif (not v) and in_seg:
|
||||
in_seg = False
|
||||
segments.append((start, i))
|
||||
if in_seg:
|
||||
segments.append((start, int(mask.size)))
|
||||
return segments
|
||||
|
||||
|
||||
def save_episode_plot_with_stop_segments(
|
||||
qpos: np.ndarray,
|
||||
action: np.ndarray,
|
||||
instructions: Sequence[str],
|
||||
stop_instruction: str,
|
||||
plot_path: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
|
||||
"""
|
||||
qpos = np.asarray(qpos)
|
||||
action = np.asarray(action)
|
||||
if qpos.ndim != 2 or action.ndim != 2:
|
||||
return
|
||||
|
||||
num_ts, num_dim = qpos.shape
|
||||
if num_ts == 0 or num_dim == 0:
|
||||
return
|
||||
|
||||
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
|
||||
stop_segments = _mask_to_segments(stop_mask)
|
||||
|
||||
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
|
||||
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
|
||||
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs_list[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
|
||||
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
|
||||
for seg_idx, (st, ed) in enumerate(stop_segments):
|
||||
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
|
||||
label='stop instruction' if seg_idx == 0 else None)
|
||||
ax.set_ylabel(f'dim {dim_idx}')
|
||||
ax.legend(loc='best')
|
||||
ax.grid(alpha=0.25, linestyle='--')
|
||||
|
||||
axs_list[-1].set_xlabel('timestep')
|
||||
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
|
||||
fig.tight_layout()
|
||||
fig.savefig(str(plot_path), dpi=140)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if args.no_text_instruction and args.encode_text_features:
|
||||
raise ValueError('--no_text_instruction and --encode_text_features cannot be used together.')
|
||||
|
||||
segment_dir = Path(args.segment_dir).resolve()
|
||||
output_dir = Path(args.output_dir).resolve()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
frames_dir = segment_dir / "frames"
|
||||
if not frames_dir.exists():
|
||||
raise FileNotFoundError(f"frames dir not found: {frames_dir}")
|
||||
|
||||
csv_path = find_segment_csv(segment_dir)
|
||||
csv_rows = load_csv_rows(csv_path)
|
||||
if len(csv_rows) == 0:
|
||||
raise ValueError(f"CSV has no rows: {csv_path}")
|
||||
|
||||
crop = CropBox(*args.crop)
|
||||
resize_w, resize_h = int(args.resize[0]), int(args.resize[1])
|
||||
|
||||
json_files = sorted_frame_jsons(frames_dir)
|
||||
if not json_files:
|
||||
raise FileNotFoundError(f"No frame json found in: {frames_dir}")
|
||||
|
||||
max_aligned = min(len(json_files), len(csv_rows))
|
||||
num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned)
|
||||
if num <= 0:
|
||||
raise ValueError("No aligned frames available.")
|
||||
|
||||
images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8)
|
||||
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
|
||||
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
|
||||
instructions: List[str] = []
|
||||
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
|
||||
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
|
||||
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
|
||||
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
|
||||
|
||||
y_min, y_max = 8000.0, 18884.0
|
||||
x_min, x_max = 7000.0, 17384.0
|
||||
cmd_min, cmd_max = 0.0, 65535.0
|
||||
|
||||
for i in range(num):
|
||||
json_path = json_files[i]
|
||||
with json_path.open("r", encoding="utf-8") as f:
|
||||
ann = json.load(f)
|
||||
|
||||
image_path = frames_dir / ann["imagePath"]
|
||||
if not image_path.exists():
|
||||
alt = json_path.with_suffix(".jpg")
|
||||
if alt.exists():
|
||||
image_path = alt
|
||||
else:
|
||||
raise FileNotFoundError(f"Image not found for {json_path.name}")
|
||||
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2))
|
||||
img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR)
|
||||
images[i] = np.asarray(img_resize, dtype=np.uint8)
|
||||
|
||||
row = csv_rows[i]
|
||||
motor_pos_y = float(row["motor_pos_y"])
|
||||
motor_pos_x = float(row["motor_pos_x"])
|
||||
motor_cmd_0 = float(row["motor_command_0"])
|
||||
motor_cmd_1 = float(row["motor_command_1"])
|
||||
|
||||
motor_pos_y_series[i] = motor_pos_y
|
||||
motor_pos_x_series[i] = motor_pos_x
|
||||
motor_cmd_0_series[i] = motor_cmd_0
|
||||
motor_cmd_1_series[i] = motor_cmd_1
|
||||
|
||||
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
|
||||
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
|
||||
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||
|
||||
if not args.no_text_instruction:
|
||||
ins = instruction_from_annotation(
|
||||
ann,
|
||||
crop,
|
||||
args.instruction_template,
|
||||
args.instruction_empty,
|
||||
)
|
||||
instructions.append(ins)
|
||||
|
||||
start_stop_count, end_stop_count = detect_stationary_edge_counts_from_qpos(
|
||||
motor_pos_y=motor_pos_y_series,
|
||||
motor_pos_x=motor_pos_x_series,
|
||||
motion_window=args.motion_window,
|
||||
motion_threshold=args.motion_threshold,
|
||||
)
|
||||
|
||||
if args.trim_stationary_edges:
|
||||
keep_start = int(start_stop_count)
|
||||
keep_end = int(num - end_stop_count)
|
||||
if keep_end <= keep_start:
|
||||
raise ValueError(
|
||||
f'No moving segment left after trim: start={start_stop_count}, end={end_stop_count}, num={num}. '
|
||||
f'Consider lowering --motion_threshold or --motion_window.'
|
||||
)
|
||||
images = images[keep_start:keep_end]
|
||||
qpos = qpos[keep_start:keep_end]
|
||||
action = action[keep_start:keep_end]
|
||||
motor_pos_y_series = motor_pos_y_series[keep_start:keep_end]
|
||||
motor_pos_x_series = motor_pos_x_series[keep_start:keep_end]
|
||||
motor_cmd_0_series = motor_cmd_0_series[keep_start:keep_end]
|
||||
motor_cmd_1_series = motor_cmd_1_series[keep_start:keep_end]
|
||||
if not args.no_text_instruction:
|
||||
instructions = instructions[keep_start:keep_end]
|
||||
|
||||
print(
|
||||
f'Trim stationary edges: removed head={start_stop_count}, tail={end_stop_count}, '
|
||||
f'kept={keep_end - keep_start}'
|
||||
)
|
||||
# After trimming, full kept segment is the moving region.
|
||||
start_stop_count, end_stop_count = 0, 0
|
||||
|
||||
if (not args.disable_stop_override) and (not args.no_text_instruction):
|
||||
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
|
||||
instructions=instructions,
|
||||
motor_pos_y=motor_pos_y_series,
|
||||
motor_pos_x=motor_pos_x_series,
|
||||
motor_cmd_0=motor_cmd_0_series,
|
||||
motor_cmd_1=motor_cmd_1_series,
|
||||
stop_instruction=args.stop_instruction,
|
||||
motion_window=args.motion_window,
|
||||
motion_threshold=args.motion_threshold,
|
||||
)
|
||||
|
||||
text_features = None
|
||||
if args.encode_text_features:
|
||||
text_features = extract_text_features(
|
||||
instructions,
|
||||
model_name=args.text_model_name,
|
||||
batch_size=args.text_batch_size,
|
||||
)
|
||||
|
||||
out_path = output_dir / f"episode_{args.episode_idx}.hdf5"
|
||||
dt = 1.0 / 30.0
|
||||
|
||||
with h5py.File(out_path, "w") as root:
|
||||
root.attrs["sim"] = False
|
||||
root.attrs["source_segment"] = str(segment_dir)
|
||||
root.attrs["frame_rate"] = 30
|
||||
root.attrs["dt"] = dt
|
||||
root.attrs["state_norm_mode"] = args.state_norm
|
||||
root.attrs["action_norm_mode"] = args.action_norm
|
||||
root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]"
|
||||
root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]"
|
||||
root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32)
|
||||
|
||||
obs = root.create_group("observations")
|
||||
obs.create_dataset("qpos", data=qpos, dtype=np.float32)
|
||||
images_group = obs.create_group("images")
|
||||
images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8)
|
||||
|
||||
root.create_dataset("action", data=action, dtype=np.float32)
|
||||
|
||||
if not args.no_text_instruction:
|
||||
str_dtype = h5py.string_dtype(encoding="utf-8")
|
||||
root.create_dataset(
|
||||
"instruction_timestep",
|
||||
shape=(len(instructions),),
|
||||
dtype=str_dtype,
|
||||
data=np.asarray(instructions, dtype=object),
|
||||
)
|
||||
root.create_dataset(
|
||||
"instruction",
|
||||
shape=(),
|
||||
dtype=str_dtype,
|
||||
data=instructions[0] if len(instructions) > 0 else "",
|
||||
)
|
||||
|
||||
if text_features is not None:
|
||||
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
|
||||
root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32)
|
||||
|
||||
print(f"Saved: {out_path}")
|
||||
print(f"Frames used: {num}")
|
||||
print(f"Image shape: {images.shape}")
|
||||
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
|
||||
if not args.disable_stop_override:
|
||||
print(
|
||||
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
|
||||
f"mode=qpos_velocity_consecutive, consecutive={args.motion_window}, "
|
||||
f"threshold={args.motion_threshold}, "
|
||||
f"instruction='{args.stop_instruction}'"
|
||||
)
|
||||
if text_features is not None:
|
||||
print(f"instruction_features_timestep shape: {text_features.shape}")
|
||||
|
||||
# Save a same-basename plot next to the generated hdf5
|
||||
plot_path = out_path.with_suffix('.png')
|
||||
save_episode_plot_with_stop_segments(
|
||||
qpos=qpos,
|
||||
action=action,
|
||||
instructions=instructions,
|
||||
stop_instruction=args.stop_instruction,
|
||||
plot_path=plot_path,
|
||||
)
|
||||
print(f"Saved episode plot to: {plot_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
build_no_text_dataset.sh
Executable file
25
build_no_text_dataset.sh
Executable file
@@ -0,0 +1,25 @@
|
||||
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/00-follow"
|
||||
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/follow-no-text"
|
||||
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
|
||||
|
||||
mkdir -p "$OUT_DIR"
|
||||
i=50
|
||||
for d in "$SEG_ROOT"/follow_seg_*; do
|
||||
[ -d "$d" ] || continue
|
||||
echo "Building $d -> episode_$i"
|
||||
python "$SCRIPT" \
|
||||
--segment_dir "$d" \
|
||||
--output_dir "$OUT_DIR" \
|
||||
--episode_idx "$i" \
|
||||
--max_frames -1 \
|
||||
--camera_name top \
|
||||
--crop 733 30 1754 1051 \
|
||||
--resize 224 224 \
|
||||
--motion_window 3 \
|
||||
--motion_threshold 0.05 \
|
||||
--state_norm minus1_1 \
|
||||
--action_norm minus1_1 \
|
||||
--trim_stationary_edges \
|
||||
--no_text_instruction
|
||||
i=$((i+1))
|
||||
done
|
||||
29
build_text_dataset.sh
Executable file
29
build_text_dataset.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/01-cannulation"
|
||||
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/cannulation"
|
||||
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
|
||||
|
||||
mkdir -p "$OUT_DIR"
|
||||
i=12
|
||||
for d in "$SEG_ROOT"/seg_*; do
|
||||
[ -d "$d" ] || continue
|
||||
echo "Building $d -> episode_$i"
|
||||
python "$SCRIPT" \
|
||||
--segment_dir "$d" \
|
||||
--output_dir "$OUT_DIR" \
|
||||
--episode_idx "$i" \
|
||||
--max_frames -1 \
|
||||
--camera_name top \
|
||||
--crop 733 30 1754 1051 \
|
||||
--resize 224 224 \
|
||||
--instruction_template 'Cannulate the {label} on the phantom located at the {region} with the sphincterotome.' \
|
||||
--instruction_empty 'No target visible.' \
|
||||
--stop_instruction 'Stop move.' \
|
||||
--motion_window 3 \
|
||||
--motion_threshold 0.05 \
|
||||
--state_norm minus1_1 \
|
||||
--action_norm minus1_1 \
|
||||
--encode_text_features \
|
||||
--text_model_name distilbert-base-uncased \
|
||||
--text_batch_size 32
|
||||
i=$((i+1))
|
||||
done
|
||||
@@ -21,3 +21,5 @@ dependencies:
|
||||
- packaging=23.0
|
||||
- h5py=3.8.0
|
||||
- ipython=8.12.0
|
||||
- pip:
|
||||
- transformers==4.38.2
|
||||
|
||||
81
constants.py
81
constants.py
@@ -1,7 +1,7 @@
|
||||
import pathlib
|
||||
|
||||
### Task parameters
|
||||
DATA_DIR = '<put your data dir here>'
|
||||
DATA_DIR = str(pathlib.Path(__file__).parent.resolve() / 'data')
|
||||
SIM_TASK_CONFIGS = {
|
||||
'sim_transfer_cube_scripted':{
|
||||
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
|
||||
@@ -32,6 +32,85 @@ SIM_TASK_CONFIGS = {
|
||||
},
|
||||
}
|
||||
|
||||
ENDOSCOPE_TASK_CONFIGS = {
|
||||
'endoscope_default': {
|
||||
'dataset_dir': DATA_DIR + '/endoscope_default',
|
||||
'num_episodes': 50,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': True,
|
||||
'instruction_mode': 'timestep-level',
|
||||
'use_cached_text_features': True,
|
||||
'text_encoder_type': 'distilbert',
|
||||
'text_feature_dim': 768,
|
||||
'text_fusion_type': 'concat_transformer_input',
|
||||
'freeze_text_encoder': True,
|
||||
'text_max_length': 32,
|
||||
'text_tokenizer_name': 'distilbert-base-uncased',
|
||||
},
|
||||
'endoscope_follow': {
|
||||
'dataset_dir': DATA_DIR + '/follow',
|
||||
'num_episodes': 3,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': True,
|
||||
'instruction_mode': 'timestep-level',
|
||||
'use_cached_text_features': True,
|
||||
'text_encoder_type': 'distilbert',
|
||||
'text_feature_dim': 768,
|
||||
'text_fusion_type': 'concat_transformer_input',
|
||||
'freeze_text_encoder': True,
|
||||
'text_max_length': 32,
|
||||
'text_tokenizer_name': 'distilbert-base-uncased',
|
||||
},
|
||||
'endoscope_both_no_text': {
|
||||
'dataset_dir': DATA_DIR + '/both-no-text',
|
||||
'num_episodes': 3,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': False,
|
||||
},
|
||||
'endoscope_sanity_check': {
|
||||
'dataset_dir': DATA_DIR + '/sanity-check',
|
||||
'num_episodes': 3,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': False,
|
||||
},
|
||||
'endoscope_cannulation_no_text': {
|
||||
'dataset_dir': DATA_DIR + '/cannulation-no-text',
|
||||
'num_episodes': 3,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': False,
|
||||
},
|
||||
'endoscope_follow_no_text': {
|
||||
'dataset_dir': DATA_DIR + '/follow-no-text',
|
||||
'num_episodes': 3,
|
||||
'episode_len': 400,
|
||||
'camera_names': ['top'],
|
||||
'state_dim': 2,
|
||||
'action_dim': 2,
|
||||
'real_action_t_minus_1': False,
|
||||
'use_text_instruction': False,
|
||||
},
|
||||
}
|
||||
|
||||
### Simulation envs fixed constants
|
||||
DT = 0.02
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
|
||||
25
detr/main.py
25
detr/main.py
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.adamw import AdamW
|
||||
from .models import build_ACT_model, build_CNNMLP_model
|
||||
|
||||
import IPython
|
||||
@@ -30,6 +31,15 @@ def get_args_parser():
|
||||
help="Type of positional embedding to use on top of the image features")
|
||||
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
||||
help="A list of camera names")
|
||||
parser.add_argument('--state_dim', default=14, type=int)
|
||||
parser.add_argument('--action_dim', default=14, type=int)
|
||||
parser.add_argument('--use_text', action='store_true')
|
||||
parser.add_argument('--text_encoder_type', default='distilbert', type=str)
|
||||
parser.add_argument('--text_feature_dim', default=768, type=int)
|
||||
parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str)
|
||||
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||
parser.add_argument('--text_max_length', default=32, type=int)
|
||||
parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str)
|
||||
|
||||
# * Transformer
|
||||
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
||||
@@ -60,16 +70,17 @@ def get_args_parser():
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
||||
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
||||
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||
parser.add_argument('--temporal_agg', action='store_true')
|
||||
parser.add_argument('--image_aug', action='store_true')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_ACT_model_and_optimizer(args_override):
|
||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
@@ -84,15 +95,15 @@ def build_ACT_model_and_optimizer(args_override):
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def build_CNNMLP_model_and_optimizer(args_override):
|
||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
for k, v in args_override.items():
|
||||
setattr(args, k, v)
|
||||
@@ -107,8 +118,8 @@ def build_CNNMLP_model_and_optimizer(args_override):
|
||||
"lr": args.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
return model, optimizer
|
||||
|
||||
|
||||
@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
||||
backbone_builder = getattr(torchvision.models, name)
|
||||
weights = None
|
||||
if is_main_process():
|
||||
weight_enum_name_map = {
|
||||
'resnet18': 'ResNet18_Weights',
|
||||
'resnet34': 'ResNet34_Weights',
|
||||
'resnet50': 'ResNet50_Weights',
|
||||
'resnet101': 'ResNet101_Weights',
|
||||
}
|
||||
enum_name = weight_enum_name_map.get(name)
|
||||
if enum_name is not None and hasattr(torchvision.models, enum_name):
|
||||
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
|
||||
|
||||
try:
|
||||
backbone = backbone_builder(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
weights=weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
except TypeError:
|
||||
# Backward compatibility for older torchvision that still expects `pretrained`.
|
||||
backbone = backbone_builder(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
pretrained=(weights is not None),
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
""" This is the DETR module that performs object detection """
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
|
||||
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
|
||||
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.use_text = use_text
|
||||
self.text_fusion_type = text_fusion_type
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, state_dim)
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
||||
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
||||
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
|
||||
num_extra_tokens = 3 if self.use_text else 2
|
||||
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
|
||||
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
||||
def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
extra_input_tokens = None
|
||||
if self.use_text and text_features is not None:
|
||||
if self.text_fusion_type != 'concat_transformer_input':
|
||||
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
|
||||
text_input = self.text_proj(text_features)
|
||||
extra_input_tokens = text_input.unsqueeze(0)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
extra_input_tokens=extra_input_tokens,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
def __init__(self, backbones, state_dim, action_dim, camera_names):
|
||||
""" Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
self.action_head = nn.Linear(1000, action_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + 14
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
||||
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
features = torch.cat([flattened_features, qpos], axis=1)
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
@@ -227,7 +246,8 @@ def build_encoder(args):
|
||||
|
||||
|
||||
def build(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
@@ -245,8 +265,12 @@ def build(args):
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
num_queries=args.num_queries,
|
||||
camera_names=args.camera_names,
|
||||
use_text=args.use_text,
|
||||
text_feature_dim=args.text_feature_dim,
|
||||
text_fusion_type=args.text_fusion_type,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
@@ -255,7 +279,8 @@ def build(args):
|
||||
return model
|
||||
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
state_dim = args.state_dim
|
||||
action_dim = args.action_dim
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class Transformer(nn.Module):
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
additional_inputs = [latent_input, proprio_input]
|
||||
if extra_input_tokens is not None:
|
||||
if len(extra_input_tokens.shape) == 2:
|
||||
extra_input_tokens = extra_input_tokens.unsqueeze(0)
|
||||
for i in range(extra_input_tokens.shape[0]):
|
||||
additional_inputs.append(extra_input_tokens[i])
|
||||
|
||||
addition_input = torch.stack(additional_inputs, axis=0)
|
||||
if additional_pos_embed is not None:
|
||||
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
|
||||
@@ -10,17 +10,25 @@ from einops import rearrange
|
||||
|
||||
from constants import DT
|
||||
from constants import PUPPET_GRIPPER_JOINT_OPEN
|
||||
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
|
||||
from utils import load_data # data functions
|
||||
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
||||
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
||||
from policy import ACTPolicy, CNNMLPPolicy
|
||||
from visualize_episodes import save_videos
|
||||
|
||||
from sim_env import BOX_POSE
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def load_checkpoint_state_dict(ckpt_path):
|
||||
"""Load checkpoint state_dict safely across different torch versions."""
|
||||
try:
|
||||
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||
except TypeError:
|
||||
# For older PyTorch versions that do not support `weights_only`.
|
||||
return torch.load(ckpt_path, map_location='cpu')
|
||||
|
||||
def main(args):
|
||||
set_seed(1)
|
||||
# command line parameters
|
||||
@@ -34,25 +42,50 @@ def main(args):
|
||||
num_epochs = args['num_epochs']
|
||||
|
||||
# get task parameters
|
||||
is_sim = task_name[:4] == 'sim_'
|
||||
if is_sim:
|
||||
from constants import SIM_TASK_CONFIGS
|
||||
is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
|
||||
if is_endoscope:
|
||||
task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
|
||||
is_sim = False
|
||||
elif task_name in SIM_TASK_CONFIGS:
|
||||
task_config = SIM_TASK_CONFIGS[task_name]
|
||||
is_sim = True
|
||||
else:
|
||||
from aloha_scripts.constants import TASK_CONFIGS
|
||||
task_config = TASK_CONFIGS[task_name]
|
||||
is_sim = False
|
||||
|
||||
dataset_dir = task_config['dataset_dir']
|
||||
num_episodes = task_config['num_episodes']
|
||||
episode_len = task_config['episode_len']
|
||||
camera_names = task_config['camera_names']
|
||||
state_dim = task_config.get('state_dim', 14)
|
||||
action_dim = task_config.get('action_dim', state_dim)
|
||||
use_text_instruction = task_config.get('use_text_instruction', False)
|
||||
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
|
||||
use_cached_text_features = task_config.get('use_cached_text_features', True)
|
||||
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
|
||||
text_feature_dim = task_config.get('text_feature_dim', 768)
|
||||
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
|
||||
text_max_length = task_config.get('text_max_length', 32)
|
||||
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
||||
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
||||
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
|
||||
|
||||
if args.get('text_encoder_type') is not None:
|
||||
text_encoder_type = args['text_encoder_type']
|
||||
if args.get('text_max_length') is not None:
|
||||
text_max_length = args['text_max_length']
|
||||
if args.get('freeze_text_encoder', False):
|
||||
freeze_text_encoder = True
|
||||
if args.get('disable_real_action_shift', False):
|
||||
real_action_t_minus_1 = False
|
||||
|
||||
# fixed parameters
|
||||
state_dim = 14
|
||||
lr_backbone = 1e-5
|
||||
backbone = 'resnet18'
|
||||
if policy_class == 'ACT':
|
||||
enc_layers = 4
|
||||
dec_layers = 7
|
||||
enc_layers = 2
|
||||
dec_layers = 4
|
||||
nheads = 8
|
||||
policy_config = {'lr': args['lr'],
|
||||
'num_queries': args['chunk_size'],
|
||||
@@ -65,18 +98,36 @@ def main(args):
|
||||
'dec_layers': dec_layers,
|
||||
'nheads': nheads,
|
||||
'camera_names': camera_names,
|
||||
'state_dim': state_dim,
|
||||
'action_dim': action_dim,
|
||||
'use_text': use_text_instruction,
|
||||
'text_encoder_type': text_encoder_type,
|
||||
'text_feature_dim': text_feature_dim,
|
||||
'text_fusion_type': text_fusion_type,
|
||||
'freeze_text_encoder': freeze_text_encoder,
|
||||
'instruction_mode': instruction_mode,
|
||||
'use_cached_text_features': use_cached_text_features,
|
||||
'text_max_length': text_max_length,
|
||||
'text_tokenizer_name': text_tokenizer_name,
|
||||
}
|
||||
elif policy_class == 'CNNMLP':
|
||||
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
||||
'camera_names': camera_names,}
|
||||
'camera_names': camera_names,
|
||||
'state_dim': state_dim,
|
||||
'action_dim': action_dim,
|
||||
'use_text': use_text_instruction,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
config = {
|
||||
'num_epochs': num_epochs,
|
||||
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
|
||||
'resume_ckpt_path': args.get('resume_ckpt', None),
|
||||
'ckpt_dir': ckpt_dir,
|
||||
'episode_len': episode_len,
|
||||
'state_dim': state_dim,
|
||||
'action_dim': action_dim,
|
||||
'lr': args['lr'],
|
||||
'policy_class': policy_class,
|
||||
'onscreen_render': onscreen_render,
|
||||
@@ -85,9 +136,25 @@ def main(args):
|
||||
'seed': args['seed'],
|
||||
'temporal_agg': args['temporal_agg'],
|
||||
'camera_names': camera_names,
|
||||
'real_robot': not is_sim
|
||||
'real_robot': (not is_sim) and (not is_endoscope),
|
||||
'use_text_instruction': use_text_instruction,
|
||||
'instruction_mode': instruction_mode,
|
||||
'use_cached_text_features': use_cached_text_features,
|
||||
'text_tokenizer_name': text_tokenizer_name,
|
||||
'text_max_length': text_max_length,
|
||||
'debug_input': args.get('debug_input', False),
|
||||
}
|
||||
|
||||
if config['resume_ckpt_path']:
|
||||
resume_ckpt_path = config['resume_ckpt_path']
|
||||
if not os.path.isabs(resume_ckpt_path):
|
||||
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
|
||||
if os.path.isfile(candidate_path):
|
||||
resume_ckpt_path = candidate_path
|
||||
if not os.path.isfile(resume_ckpt_path):
|
||||
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
|
||||
config['resume_ckpt_path'] = resume_ckpt_path
|
||||
|
||||
if is_eval:
|
||||
ckpt_names = [f'policy_best.ckpt']
|
||||
results = []
|
||||
@@ -100,7 +167,21 @@ def main(args):
|
||||
print()
|
||||
exit()
|
||||
|
||||
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
||||
train_dataloader, val_dataloader, stats, _ = load_data(
|
||||
dataset_dir,
|
||||
num_episodes,
|
||||
camera_names,
|
||||
batch_size_train,
|
||||
batch_size_val,
|
||||
use_text_instruction=use_text_instruction,
|
||||
instruction_mode=instruction_mode,
|
||||
use_cached_text_features=use_cached_text_features,
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=args['image_aug'],
|
||||
)
|
||||
|
||||
# save dataset stats
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
@@ -152,6 +233,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
set_seed(1000)
|
||||
ckpt_dir = config['ckpt_dir']
|
||||
state_dim = config['state_dim']
|
||||
action_dim = config['action_dim']
|
||||
real_robot = config['real_robot']
|
||||
policy_class = config['policy_class']
|
||||
onscreen_render = config['onscreen_render']
|
||||
@@ -161,11 +243,12 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
task_name = config['task_name']
|
||||
temporal_agg = config['temporal_agg']
|
||||
onscreen_cam = 'angle'
|
||||
BOX_POSE = None
|
||||
|
||||
# load policy and stats
|
||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||
policy = make_policy(policy_class, policy_config)
|
||||
loading_status = policy.load_state_dict(torch.load(ckpt_path))
|
||||
loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
|
||||
print(loading_status)
|
||||
policy.cuda()
|
||||
policy.eval()
|
||||
@@ -202,8 +285,14 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
rollout_id += 0
|
||||
### set task
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
if BOX_POSE is None:
|
||||
from sim_env import BOX_POSE as _BOX_POSE
|
||||
BOX_POSE = _BOX_POSE
|
||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||
elif 'sim_insertion' in task_name:
|
||||
if BOX_POSE is None:
|
||||
from sim_env import BOX_POSE as _BOX_POSE
|
||||
BOX_POSE = _BOX_POSE
|
||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
ts = env.reset()
|
||||
@@ -216,7 +305,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
|
||||
### evaluation loop
|
||||
if temporal_agg:
|
||||
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
|
||||
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
|
||||
|
||||
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
|
||||
image_list = [] # for visualization
|
||||
@@ -313,23 +402,72 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
||||
return success_rate, avg_return
|
||||
|
||||
|
||||
def forward_pass(data, policy):
|
||||
image_data, qpos_data, action_data, is_pad = data
|
||||
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
|
||||
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
|
||||
def forward_pass(data, policy, debug_input=False, debug_tag=''):
|
||||
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
|
||||
image_data = image_data.cuda()
|
||||
qpos_data = qpos_data.cuda()
|
||||
action_data = action_data.cuda()
|
||||
is_pad = is_pad.cuda()
|
||||
text_input_ids = text_input_ids.cuda()
|
||||
text_attention_mask = text_attention_mask.cuda()
|
||||
text_feature_data = text_feature_data.cuda()
|
||||
text_feature_valid = text_feature_valid.cuda()
|
||||
|
||||
text_features = None
|
||||
if torch.any(text_feature_valid):
|
||||
text_features = text_feature_data
|
||||
|
||||
if debug_input:
|
||||
image_min = float(image_data.min().item())
|
||||
image_max = float(image_data.max().item())
|
||||
qpos_mean = float(qpos_data.mean().item())
|
||||
qpos_std = float(qpos_data.std().item())
|
||||
action_mean = float(action_data.mean().item())
|
||||
action_std = float(action_data.std().item())
|
||||
pad_ratio = float(is_pad.float().mean().item())
|
||||
|
||||
print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]')
|
||||
print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})')
|
||||
print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})')
|
||||
print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}')
|
||||
print(
|
||||
f'[debug_input] {debug_tag} has_nan_or_inf: '
|
||||
f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, '
|
||||
f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, '
|
||||
f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}'
|
||||
)
|
||||
|
||||
return policy(
|
||||
qpos_data,
|
||||
image_data,
|
||||
text_input_ids=text_input_ids,
|
||||
text_attention_mask=text_attention_mask,
|
||||
text_features=text_features,
|
||||
actions=action_data,
|
||||
is_pad=is_pad,
|
||||
)
|
||||
|
||||
|
||||
def train_bc(train_dataloader, val_dataloader, config):
|
||||
num_epochs = config['num_epochs']
|
||||
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
|
||||
resume_ckpt_path = config.get('resume_ckpt_path', None)
|
||||
ckpt_dir = config['ckpt_dir']
|
||||
seed = config['seed']
|
||||
policy_class = config['policy_class']
|
||||
policy_config = config['policy_config']
|
||||
debug_input = config.get('debug_input', False)
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
policy = make_policy(policy_class, policy_config)
|
||||
policy.cuda()
|
||||
|
||||
if resume_ckpt_path:
|
||||
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
|
||||
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
|
||||
print(loading_status)
|
||||
|
||||
optimizer = make_optimizer(policy_class, policy)
|
||||
|
||||
train_history = []
|
||||
@@ -343,7 +481,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
policy.eval()
|
||||
epoch_dicts = []
|
||||
for batch_idx, data in enumerate(val_dataloader):
|
||||
forward_dict = forward_pass(data, policy)
|
||||
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
||||
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0')
|
||||
epoch_dicts.append(forward_dict)
|
||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
||||
validation_history.append(epoch_summary)
|
||||
@@ -361,15 +500,31 @@ def train_bc(train_dataloader, val_dataloader, config):
|
||||
# training
|
||||
policy.train()
|
||||
optimizer.zero_grad()
|
||||
for batch_idx, data in enumerate(train_dataloader):
|
||||
forward_dict = forward_pass(data, policy)
|
||||
epoch_train_dicts = []
|
||||
if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
|
||||
train_steps_this_epoch = len(train_dataloader)
|
||||
train_iterator = iter(train_dataloader)
|
||||
else:
|
||||
train_steps_this_epoch = int(train_steps_per_epoch)
|
||||
train_iterator = iter(train_dataloader)
|
||||
|
||||
for step_idx in range(train_steps_this_epoch):
|
||||
try:
|
||||
data = next(train_iterator)
|
||||
except StopIteration:
|
||||
train_iterator = iter(train_dataloader)
|
||||
data = next(train_iterator)
|
||||
|
||||
should_debug = debug_input and epoch == 0 and step_idx == 0
|
||||
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
||||
# backward
|
||||
loss = forward_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
train_history.append(detach_dict(forward_dict))
|
||||
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
|
||||
epoch_train_dicts.append(detach_dict(forward_dict))
|
||||
epoch_summary = compute_dict_mean(epoch_train_dicts)
|
||||
epoch_train_loss = epoch_summary['loss']
|
||||
print(f'Train loss: {epoch_train_loss:.5f}')
|
||||
summary_string = ''
|
||||
@@ -426,10 +581,23 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
||||
|
||||
# for ACT
|
||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
||||
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
||||
parser.add_argument('--temporal_agg', action='store_true')
|
||||
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
|
||||
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||
parser.add_argument('--image_aug', action='store_true',
|
||||
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
||||
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
|
||||
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
|
||||
parser.add_argument('--disable_real_action_shift', action='store_true',
|
||||
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
|
||||
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
|
||||
help='Optional checkpoint path to initialize model weights for fine-tuning')
|
||||
parser.add_argument('--debug_input', action='store_true',
|
||||
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
|
||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
31
models/text_encoder.py
Normal file
31
models/text_encoder.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DistilBERTTextEncoder(nn.Module):
|
||||
def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True):
|
||||
super().__init__()
|
||||
try:
|
||||
from transformers import DistilBertModel
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
'transformers is required for DistilBERT text encoding. '
|
||||
'Install it with: pip install transformers'
|
||||
) from exc
|
||||
|
||||
self.encoder = DistilBertModel.from_pretrained(model_name)
|
||||
self.output_dim = output_dim
|
||||
self.freeze = freeze
|
||||
|
||||
if self.freeze:
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
self.encoder.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
if self.freeze:
|
||||
self.encoder.eval()
|
||||
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
# DistilBERT has no pooled output; use [CLS] token embedding
|
||||
cls_feature = outputs.last_hidden_state[:, 0, :]
|
||||
return cls_feature
|
||||
35
policy.py
35
policy.py
@@ -3,6 +3,7 @@ from torch.nn import functional as F
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
||||
from models.text_encoder import DistilBERTTextEncoder
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
@@ -13,18 +14,44 @@ class ACTPolicy(nn.Module):
|
||||
self.model = model # CVAE decoder
|
||||
self.optimizer = optimizer
|
||||
self.kl_weight = args_override['kl_weight']
|
||||
self.use_text = args_override.get('use_text', False)
|
||||
self.text_encoder = None
|
||||
if self.use_text:
|
||||
text_encoder_type = args_override.get('text_encoder_type', 'distilbert')
|
||||
if text_encoder_type != 'distilbert':
|
||||
raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}')
|
||||
self.text_encoder = DistilBERTTextEncoder(
|
||||
model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'),
|
||||
output_dim=args_override.get('text_feature_dim', 768),
|
||||
freeze=args_override.get('freeze_text_encoder', True),
|
||||
)
|
||||
print(f'KL Weight {self.kl_weight}')
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
image = normalize(image)
|
||||
|
||||
if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None:
|
||||
if self.text_encoder is None:
|
||||
raise RuntimeError('Text encoder is not initialized while use_text=True.')
|
||||
text_features = self.text_encoder(text_input_ids, text_attention_mask)
|
||||
|
||||
if actions is not None: # training time
|
||||
if is_pad is None:
|
||||
raise ValueError('`is_pad` must be provided during training when `actions` is not None.')
|
||||
actions = actions[:, :self.model.num_queries]
|
||||
is_pad = is_pad[:, :self.model.num_queries]
|
||||
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(
|
||||
qpos,
|
||||
image,
|
||||
env_state,
|
||||
text_features=text_features,
|
||||
actions=actions,
|
||||
is_pad=is_pad,
|
||||
)
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
loss_dict = dict()
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
||||
@@ -34,7 +61,7 @@ class ACTPolicy(nn.Module):
|
||||
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # inference time
|
||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
@@ -48,7 +75,7 @@ class CNNMLPPolicy(nn.Module):
|
||||
self.model = model # decoder
|
||||
self.optimizer = optimizer
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
|
||||
env_state = None # TODO
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
337
utils.py
337
utils.py
@@ -2,21 +2,170 @@ import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import h5py
|
||||
import re
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
class EpisodicDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats):
|
||||
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats,
|
||||
use_text_instruction=False,
|
||||
instruction_mode='timestep-level',
|
||||
use_cached_text_features=True,
|
||||
text_feature_dim=768,
|
||||
text_tokenizer_name='distilbert-base-uncased',
|
||||
text_max_length=32,
|
||||
real_action_t_minus_1=True,
|
||||
image_augment=False,
|
||||
image_aug_cfg=None):
|
||||
super(EpisodicDataset).__init__()
|
||||
self.episode_ids = episode_ids
|
||||
self.dataset_dir = dataset_dir
|
||||
self.camera_names = camera_names
|
||||
self.norm_stats = norm_stats
|
||||
self.use_text_instruction = use_text_instruction
|
||||
self.instruction_mode = instruction_mode
|
||||
self.use_cached_text_features = use_cached_text_features
|
||||
self.text_feature_dim = text_feature_dim
|
||||
self.text_max_length = text_max_length
|
||||
self.real_action_t_minus_1 = real_action_t_minus_1
|
||||
self.image_augment = image_augment
|
||||
self.image_aug_cfg = {
|
||||
'p_color': 0.4,
|
||||
'p_highlight': 0.3,
|
||||
'p_noise': 0.35,
|
||||
'p_blur': 0.15,
|
||||
'brightness': 0.12,
|
||||
'contrast': 0.12,
|
||||
'saturation': 0.12,
|
||||
'hue': 0.03,
|
||||
'highlight_strength': (0.08, 0.25),
|
||||
'noise_std': (0.003, 0.015),
|
||||
'blur_sigma': (0.1, 0.8),
|
||||
'blur_kernel_choices': (3, ),
|
||||
}
|
||||
if image_aug_cfg is not None:
|
||||
self.image_aug_cfg.update(image_aug_cfg)
|
||||
self.is_sim = None
|
||||
self.max_episode_len = None
|
||||
self.action_dim = None
|
||||
|
||||
self.text_tokenizer = None
|
||||
if self.use_text_instruction:
|
||||
try:
|
||||
from transformers import DistilBertTokenizerFast
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
'transformers is required for text instruction loading. '
|
||||
'Install it with: pip install transformers'
|
||||
) from exc
|
||||
self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name)
|
||||
|
||||
self._init_episode_shapes()
|
||||
|
||||
self.__getitem__(0) # initialize self.is_sim
|
||||
|
||||
def _apply_image_augmentation(self, all_cam_images):
|
||||
"""
|
||||
Apply identical augmentation parameters to all camera images for one sample.
|
||||
all_cam_images: np.ndarray [K, H, W, C], uint8
|
||||
"""
|
||||
imgs = torch.from_numpy(all_cam_images).float() / 255.0
|
||||
imgs = torch.einsum('k h w c -> k c h w', imgs)
|
||||
|
||||
cfg = self.image_aug_cfg
|
||||
# color jitter (shared params)
|
||||
if np.random.rand() < cfg['p_color']:
|
||||
b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness'])
|
||||
c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast'])
|
||||
s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation'])
|
||||
h = np.random.uniform(-cfg['hue'], cfg['hue'])
|
||||
for cam_idx in range(imgs.shape[0]):
|
||||
img = imgs[cam_idx]
|
||||
img = TF.adjust_brightness(img, b)
|
||||
img = TF.adjust_contrast(img, c)
|
||||
img = TF.adjust_saturation(img, s)
|
||||
img = TF.adjust_hue(img, h)
|
||||
imgs[cam_idx] = img
|
||||
|
||||
# synthetic highlight / glare (shared parameters)
|
||||
if np.random.rand() < cfg['p_highlight']:
|
||||
_, h_img, w_img = imgs[0].shape
|
||||
cx = np.random.uniform(0.2 * w_img, 0.8 * w_img)
|
||||
cy = np.random.uniform(0.2 * h_img, 0.8 * h_img)
|
||||
sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img)
|
||||
strength = np.random.uniform(*cfg['highlight_strength'])
|
||||
yy, xx = torch.meshgrid(
|
||||
torch.arange(h_img, dtype=torch.float32),
|
||||
torch.arange(w_img, dtype=torch.float32),
|
||||
indexing='ij',
|
||||
)
|
||||
gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma))
|
||||
gauss = (gauss * strength).unsqueeze(0)
|
||||
imgs = imgs + gauss
|
||||
|
||||
# gaussian noise
|
||||
if np.random.rand() < cfg['p_noise']:
|
||||
noise_std = np.random.uniform(*cfg['noise_std'])
|
||||
imgs = imgs + torch.randn_like(imgs) * noise_std
|
||||
|
||||
# gaussian blur
|
||||
if np.random.rand() < cfg['p_blur']:
|
||||
kernel = int(np.random.choice(cfg['blur_kernel_choices']))
|
||||
sigma = float(np.random.uniform(*cfg['blur_sigma']))
|
||||
for cam_idx in range(imgs.shape[0]):
|
||||
imgs[cam_idx] = TF.gaussian_blur(
|
||||
imgs[cam_idx],
|
||||
kernel_size=[kernel, kernel],
|
||||
sigma=[sigma, sigma],
|
||||
)
|
||||
|
||||
imgs = imgs.clamp(0.0, 1.0)
|
||||
imgs = torch.einsum('k c h w -> k h w c', imgs)
|
||||
imgs = (imgs * 255.0).byte().cpu().numpy()
|
||||
return imgs
|
||||
|
||||
def _init_episode_shapes(self):
|
||||
max_len = 0
|
||||
action_dim = None
|
||||
for episode_id in self.episode_ids:
|
||||
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
shape = root['/action'].shape
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}')
|
||||
max_len = max(max_len, int(shape[0]))
|
||||
if action_dim is None:
|
||||
action_dim = int(shape[1])
|
||||
elif int(shape[1]) != action_dim:
|
||||
raise ValueError(
|
||||
f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}'
|
||||
)
|
||||
|
||||
if max_len <= 0 or action_dim is None:
|
||||
raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}')
|
||||
|
||||
self.max_episode_len = max_len
|
||||
self.action_dim = action_dim
|
||||
|
||||
@staticmethod
|
||||
def _decode_instruction(raw_value):
|
||||
if raw_value is None:
|
||||
return ''
|
||||
if isinstance(raw_value, bytes):
|
||||
return raw_value.decode('utf-8')
|
||||
if isinstance(raw_value, np.bytes_):
|
||||
return raw_value.tobytes().decode('utf-8')
|
||||
if isinstance(raw_value, np.ndarray):
|
||||
if raw_value.shape == ():
|
||||
return EpisodicDataset._decode_instruction(raw_value.item())
|
||||
if raw_value.size == 0:
|
||||
return ''
|
||||
return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0])
|
||||
return str(raw_value)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.episode_ids)
|
||||
|
||||
@@ -26,7 +175,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
episode_id = self.episode_ids[index]
|
||||
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
is_sim = bool(root.attrs.get('sim', False))
|
||||
original_action_shape = root['/action'].shape
|
||||
episode_len = original_action_shape[0]
|
||||
if sample_full_episode:
|
||||
@@ -35,29 +184,62 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
start_ts = np.random.choice(episode_len)
|
||||
# get observation at start_ts only
|
||||
qpos = root['/observations/qpos'][start_ts]
|
||||
qvel = root['/observations/qvel'][start_ts]
|
||||
image_dict = dict()
|
||||
for cam_name in self.camera_names:
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
|
||||
|
||||
instruction = ''
|
||||
text_feature = None
|
||||
if self.use_text_instruction:
|
||||
effective_mode = self.instruction_mode
|
||||
if effective_mode == 'timestep-level' and '/instruction_timestep' in root:
|
||||
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
|
||||
elif '/instruction' in root:
|
||||
instruction_node = root['/instruction']
|
||||
if getattr(instruction_node, 'shape', ()) == ():
|
||||
instruction = self._decode_instruction(instruction_node[()])
|
||||
else:
|
||||
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
|
||||
instruction = self._decode_instruction(instruction_node[start_ts])
|
||||
else:
|
||||
instruction = self._decode_instruction(instruction_node[0])
|
||||
|
||||
if self.use_cached_text_features:
|
||||
if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root:
|
||||
text_feature = root['/instruction_features_timestep'][start_ts]
|
||||
elif '/instruction_features' in root:
|
||||
feat_node = root['/instruction_features']
|
||||
if getattr(feat_node, 'shape', ()) == ():
|
||||
text_feature = np.array(feat_node[()])
|
||||
elif len(feat_node.shape) == 1:
|
||||
text_feature = feat_node[()]
|
||||
elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len:
|
||||
text_feature = feat_node[start_ts]
|
||||
else:
|
||||
text_feature = feat_node[0]
|
||||
|
||||
# get all actions after and including start_ts
|
||||
if is_sim:
|
||||
action = root['/action'][start_ts:]
|
||||
action_len = episode_len - start_ts
|
||||
else:
|
||||
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
|
||||
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
|
||||
action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts
|
||||
action = root['/action'][action_start:]
|
||||
action_len = episode_len - action_start
|
||||
|
||||
self.is_sim = is_sim
|
||||
padded_action = np.zeros(original_action_shape, dtype=np.float32)
|
||||
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
|
||||
padded_action[:action_len] = action
|
||||
is_pad = np.zeros(episode_len)
|
||||
is_pad[action_len:] = 1
|
||||
is_pad = np.ones(self.max_episode_len)
|
||||
is_pad[:action_len] = 0
|
||||
|
||||
# new axis for different cameras
|
||||
all_cam_images = []
|
||||
for cam_name in self.camera_names:
|
||||
all_cam_images.append(image_dict[cam_name])
|
||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||
if self.image_augment:
|
||||
all_cam_images = self._apply_image_augmentation(all_cam_images)
|
||||
|
||||
# construct observations
|
||||
image_data = torch.from_numpy(all_cam_images)
|
||||
@@ -73,55 +255,146 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
||||
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
|
||||
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
|
||||
|
||||
return image_data, qpos_data, action_data, is_pad
|
||||
if self.use_text_instruction and text_feature is not None:
|
||||
text_feature_data = torch.from_numpy(np.array(text_feature)).float()
|
||||
text_feature_valid = torch.tensor(True, dtype=torch.bool)
|
||||
text_input_ids = torch.zeros(1, dtype=torch.long)
|
||||
text_attention_mask = torch.zeros(1, dtype=torch.long)
|
||||
elif self.use_text_instruction:
|
||||
tokenized = self.text_tokenizer(
|
||||
instruction,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=self.text_max_length,
|
||||
return_tensors='pt',
|
||||
)
|
||||
text_input_ids = tokenized['input_ids'].squeeze(0).long()
|
||||
text_attention_mask = tokenized['attention_mask'].squeeze(0).long()
|
||||
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
|
||||
text_feature_valid = torch.tensor(False, dtype=torch.bool)
|
||||
else:
|
||||
text_input_ids = torch.zeros(1, dtype=torch.long)
|
||||
text_attention_mask = torch.zeros(1, dtype=torch.long)
|
||||
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
|
||||
text_feature_valid = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid
|
||||
|
||||
|
||||
def get_norm_stats(dataset_dir, num_episodes):
|
||||
def _discover_episode_ids(dataset_dir, num_episodes=None):
|
||||
pattern = re.compile(r'^episode_(\d+)\.hdf5$')
|
||||
episode_ids = []
|
||||
for fname in os.listdir(dataset_dir):
|
||||
m = pattern.match(fname)
|
||||
if m:
|
||||
episode_ids.append(int(m.group(1)))
|
||||
episode_ids.sort()
|
||||
if num_episodes is not None:
|
||||
episode_ids = episode_ids[:num_episodes]
|
||||
return episode_ids
|
||||
|
||||
|
||||
def get_norm_stats(dataset_dir, episode_ids):
|
||||
all_qpos_data = []
|
||||
all_action_data = []
|
||||
for episode_idx in range(num_episodes):
|
||||
example_qpos = None
|
||||
for episode_idx in episode_ids:
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
action = root['/action'][()]
|
||||
all_qpos_data.append(torch.from_numpy(qpos))
|
||||
all_action_data.append(torch.from_numpy(action))
|
||||
all_qpos_data = torch.stack(all_qpos_data)
|
||||
all_action_data = torch.stack(all_action_data)
|
||||
all_action_data = all_action_data
|
||||
qpos_t = torch.from_numpy(qpos)
|
||||
action_t = torch.from_numpy(action)
|
||||
all_qpos_data.append(qpos_t)
|
||||
all_action_data.append(action_t)
|
||||
if example_qpos is None and len(qpos) > 0:
|
||||
example_qpos = qpos[0]
|
||||
|
||||
# Episodes may have different lengths; concatenate over time axis.
|
||||
all_qpos_data = torch.cat(all_qpos_data, dim=0)
|
||||
all_action_data = torch.cat(all_action_data, dim=0)
|
||||
|
||||
# normalize action data
|
||||
action_mean = all_action_data.mean(dim=[0, 1], keepdim=True)
|
||||
action_std = all_action_data.std(dim=[0, 1], keepdim=True)
|
||||
action_mean = all_action_data.mean(dim=0, keepdim=True)
|
||||
action_std = all_action_data.std(dim=0, keepdim=True)
|
||||
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
||||
|
||||
# normalize qpos data
|
||||
qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True)
|
||||
qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True)
|
||||
qpos_mean = all_qpos_data.mean(dim=0, keepdim=True)
|
||||
qpos_std = all_qpos_data.std(dim=0, keepdim=True)
|
||||
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
||||
|
||||
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
|
||||
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
|
||||
"example_qpos": qpos}
|
||||
"example_qpos": example_qpos}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
|
||||
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val,
|
||||
use_text_instruction=False,
|
||||
instruction_mode='timestep-level',
|
||||
use_cached_text_features=True,
|
||||
text_feature_dim=768,
|
||||
text_tokenizer_name='distilbert-base-uncased',
|
||||
text_max_length=32,
|
||||
real_action_t_minus_1=True,
|
||||
image_augment=False,
|
||||
image_aug_cfg=None):
|
||||
print(f'\nData from: {dataset_dir}\n')
|
||||
# obtain train test split
|
||||
train_ratio = 0.8
|
||||
shuffled_indices = np.random.permutation(num_episodes)
|
||||
train_indices = shuffled_indices[:int(train_ratio * num_episodes)]
|
||||
val_indices = shuffled_indices[int(train_ratio * num_episodes):]
|
||||
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
||||
if len(episode_ids) == 0:
|
||||
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
|
||||
|
||||
# obtain train/val split
|
||||
if len(episode_ids) == 1:
|
||||
# sanity-check mode: reuse the same episode for both train and val
|
||||
# so training/evaluation loops remain unchanged.
|
||||
train_episode_ids = np.array(episode_ids)
|
||||
val_episode_ids = np.array(episode_ids)
|
||||
print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).')
|
||||
else:
|
||||
train_ratio = 0.9
|
||||
shuffled_indices = np.random.permutation(len(episode_ids))
|
||||
train_count = int(train_ratio * len(episode_ids))
|
||||
train_count = max(1, min(len(episode_ids) - 1, train_count))
|
||||
train_indices = shuffled_indices[:train_count]
|
||||
val_indices = shuffled_indices[train_count:]
|
||||
train_episode_ids = np.array(episode_ids)[train_indices]
|
||||
val_episode_ids = np.array(episode_ids)[val_indices]
|
||||
|
||||
# obtain normalization stats for qpos and action
|
||||
norm_stats = get_norm_stats(dataset_dir, num_episodes)
|
||||
norm_stats = get_norm_stats(dataset_dir, episode_ids)
|
||||
|
||||
# construct dataset and dataloader
|
||||
train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats)
|
||||
val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats)
|
||||
train_dataset = EpisodicDataset(
|
||||
train_episode_ids,
|
||||
dataset_dir,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
use_text_instruction=use_text_instruction,
|
||||
instruction_mode=instruction_mode,
|
||||
use_cached_text_features=use_cached_text_features,
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=image_augment,
|
||||
image_aug_cfg=image_aug_cfg,
|
||||
)
|
||||
val_dataset = EpisodicDataset(
|
||||
val_episode_ids,
|
||||
dataset_dir,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
use_text_instruction=use_text_instruction,
|
||||
instruction_mode=instruction_mode,
|
||||
use_cached_text_features=use_cached_text_features,
|
||||
text_feature_dim=text_feature_dim,
|
||||
text_tokenizer_name=text_tokenizer_name,
|
||||
text_max_length=text_max_length,
|
||||
real_action_t_minus_1=real_action_t_minus_1,
|
||||
image_augment=False,
|
||||
)
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||
|
||||
|
||||
@@ -21,32 +21,33 @@ def load_hdf5(dataset_dir, dataset_name):
|
||||
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
dt = float(root.attrs.get('dt', DT))
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
qvel = root['/observations/qvel'][()] if '/observations/qvel' in root else np.zeros_like(qpos)
|
||||
action = root['/action'][()]
|
||||
image_dict = dict()
|
||||
for cam_name in root[f'/observations/images/'].keys():
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||
|
||||
return qpos, qvel, action, image_dict
|
||||
return qpos, qvel, action, image_dict, dt
|
||||
|
||||
def main(args):
|
||||
dataset_dir = args['dataset_dir']
|
||||
episode_idx = args['episode_idx']
|
||||
dataset_name = f'episode_{episode_idx}'
|
||||
|
||||
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
||||
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||
qpos, qvel, action, image_dict, dt = load_hdf5(dataset_dir, dataset_name)
|
||||
save_videos(image_dict, dt, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||
|
||||
|
||||
def save_videos(video, dt, video_path=None):
|
||||
def save_videos(video, dt, video_path):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1/dt)
|
||||
fps = max(1, int(round(1 / dt)))
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for ts, image_dict in enumerate(video):
|
||||
images = []
|
||||
@@ -66,7 +67,7 @@ def save_videos(video, dt, video_path=None):
|
||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
fps = max(1, int(round(1 / dt)))
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
image = all_cam_videos[t]
|
||||
|
||||
Reference in New Issue
Block a user