Compare commits

...

10 Commits

Author SHA1 Message Date
96d19c0ffc chore: ignore checkpoints and model weights 2026-02-20 16:56:57 +08:00
81e1bf8838 follow用policy_last 2026-02-20 16:45:16 +08:00
88d0cc5ca2 加了--debug_input参数 2026-02-20 14:59:44 +08:00
d85cce8a52 构建no-text数据集 2026-02-20 14:13:25 +08:00
ee257bcb6c 增加了stop命令 2026-02-19 22:11:10 +08:00
7023d5dde4 数据增强 2026-02-19 21:29:32 +08:00
88d14221ae 代码可以跑起来了 2026-02-19 15:32:28 +08:00
b701d939c2 暂时可以生成hdf5数据 2026-02-17 22:20:25 +08:00
ba006e14c4 修改训练代码 2026-02-17 19:16:09 +08:00
Tony Z. Zhao
d4b4d554f8 Update README.md 2024-01-28 12:18:07 -08:00
18 changed files with 1784 additions and 99 deletions

4
.gitignore vendored
View File

@@ -4,6 +4,10 @@ wandb
outputs outputs
data data
data_local data_local
ckpt
*.ckpt
*.pt
*.pth
.vscode .vscode
_wandb _wandb

View File

@@ -0,0 +1,297 @@
# ACT 仓库适配内镜机器人2-DOF + 图像/qpos/Text 指令)修改清单(仅训练版)
## 1. 目标与约束
### 目标
将当前标准 ACT 仓库改造成可用于你的内镜机器人离线训练,支持:
- **动作维度仅 2**2 个电机)
- **不依赖 Gym / 仿真环境**
- 输入为 **图像 + qpos + text instruction**
- 以离线数据训练为主(本阶段不包含真实机器人在线接口)
### 约束
当前代码默认是 ALOHA 双臂14 维状态/动作)和 sim/real ALOHA 环境接口,且**没有 text 分支**,存在大量硬编码,需要系统性改造。
---
## 2. 现有代码中的关键硬编码(必须改)
1. **状态/动作维度硬编码为 14**
- `imitate_episodes.py``state_dim = 14`
- `detr/models/detr_vae.py` 中多个 `nn.Linear(14, ...)`
- `record_sim_episodes.py` 的数据写入 shape 固定 `(T, 14)`
2. **训练/评估流程绑定 sim 或 aloha_scripts real_env**
- `imitate_episodes.py``eval_bc()` 依赖 `sim_env.make_sim_env()``aloha_scripts.real_env`
3. **数据加载器默认字段是 qpos/qvel/action + images**
- `utils.py``EpisodicDataset` 仅加载 `qpos``action``images`,无 text
4. **模型只融合图像 + 状态,无文本编码**
- `policy.py``detr/models/detr_vae.py` 当前无 text 输入通道
---
## 3. 必改模块清单(按文件)
## A. 配置与任务定义
### 文件
- `constants.py`
### 修改点
- 新增内镜任务配置(建议新字典 `ENDOSCOPE_TASK_CONFIGS`
- `dataset_dir`
- `num_episodes`
- `episode_len`
- `camera_names`
- `state_dim=2`
- `action_dim=2`
- `use_text_instruction=True`
- `instruction_mode`episode-level / timestep-level
- `text_encoder_type="distilbert"`
- `text_feature_dim=768`
- `text_fusion_type="concat_transformer_input"`
- 避免继续依赖 `sim_` 前缀来判断任务类型。
### 目的
把任务参数从 ALOHA 默认值中解耦,作为后续训练与模型构建的统一入口。
---
## B. 数据协议与数据集加载
### 文件
- `utils.py`
- (新增)`dataset_tools/` 下的数据转换脚本
### 修改点
1. **数据协议统一定义(建议 HDF5**
- `/observations/images/<cam_name>`: `(T, H, W, C)`
- `/observations/qpos`: `(T, 2)`
- `/action`: `(T, 2)`
- `/instruction`(字符串或 token
- 可选:`/instruction_timestep`(若每步指令不同)
2. **重构 `EpisodicDataset`**
- 保持 `qpos` 命名,不做重命名
- 加载 text instruction
- 返回训练样本改为:
- `image_data`
- `qpos_data`
- `action_data`
- `is_pad`
- `text_input_ids`
- `text_attention_mask`
3. **归一化统计扩展**
- `get_norm_stats()` 支持 `qpos/action` 任意维度(本任务均为 2
- text 采用在线 DistilBERT 编码(默认),可选缓存特征
4. **兼容性策略**
- 保持对旧字段 `qpos` 的直接兼容
- 支持多相机或单相机
### 目的
构建与你真实数据一致的数据管线,彻底摆脱 14 维与仿真字段假设。
---
## C. 训练入口与流程控制
### 文件
- `imitate_episodes.py`
### 修改点
1. **配置读取改造**
- 使用新任务配置读取 `state_dim=2/action_dim=2/camera_names/use_text_instruction`
2. **移除/隔离仿真耦合逻辑(训练范围内)**
- `main()` 保留纯离线训练路径
- `eval_bc()` 仅保留离线评估路径
3. **前向输入变更**
- `forward_pass()` 改为支持 text
- dataloader batch 解包增加 text 分量
4. **命令行参数补充**
- `--task_config``--task_name` 对应新配置
- `--text_encoder_type`
- `--freeze_text_encoder`
### 目的
让训练脚本成为“真实机器人离线模仿学习”的统一入口,而不是 sim demo 入口。
---
## D. 策略封装层Policy API
### 文件
- `policy.py`
### 修改点
1. `ACTPolicy.__call__()``CNNMLPPolicy.__call__()` 签名扩展:
- 现有:`(qpos, image, actions=None, is_pad=None)`
- 目标:`(qpos, image, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
2. 图像归一化保留,但要确保支持任意相机数量。
3. loss 计算保持一致,同时确保 text 缺失时可降级运行(便于 ablation
### 目的
把 text 从数据层顺畅传递到模型层。
---
## E. 模型构建参数与入口
### 文件
- `detr/main.py`
### 修改点
- 新增模型参数:
- `state_dim`
- `action_dim`
- `use_text`
- `text_encoder_type="distilbert"`
- `text_feature_dim=768`
- `text_fusion_type="concat_transformer_input"`
- 删除或弱化与原脚本无关的占位参数。
### 目的
将所有硬编码维度下放为可配置项,便于后续迭代。
---
## F. ACT 主干网络(核心改造)
### 文件
- `detr/models/detr_vae.py`
### 修改点
1. **去掉 14 维硬编码**
- `input_proj_robot_state = nn.Linear(state_dim, hidden_dim)`
- `encoder_action_proj = nn.Linear(action_dim, hidden_dim)`
- `encoder_joint_proj = nn.Linear(state_dim, hidden_dim)`
- `action_head = nn.Linear(hidden_dim, action_dim)`
2. **加入 text 分支**
- 使用 DistilBERT 输出特征768 维)
- 新增 text 投影层:`nn.Linear(768, hidden_dim)`
- 融合策略固定为:**将 text token/特征作为额外 token直接 concat 到 Transformer 输入序列**
3. **前向函数增加 text 输入**
- `forward(self, qpos, image, env_state, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
4. **保持训练/推理双模式一致**
- 训练:动作序列 + text 条件 VAE
- 推理:先验采样 + text 条件生成
### 目的
把 ACT 从“图像+14 维关节”模型改造成“图像+2 维 qpos+文本条件”模型。
---
## G. 真实机器人接口与采集(暂不纳入本轮)
本轮仅做训练侧改造,以下内容延期:
- 在线推理接口
- 真实机器人数据采集脚本
- 在线安全控制与频率控制
---
## H. 文档与脚本
### 文件
- `README.md`
- (新增)`docs/endoscope_data_format.md`
- (新增)`docs/endoscope_train_eval.md`
### 修改点
- 给出最小可运行流程:
1) 准备数据
2) 训练命令
3) 离线评估
- 明确 text instruction 的格式规范。
---
## 4. 建议新增文件(清单)
- `configs/endoscope_task.yaml`(或继续用 python dict
- `dataset_tools/convert_endoscope_to_act_hdf5.py`
- `dataset_tools/validate_endoscope_dataset.py`
- `models/text_encoder.py`DistilBERT 封装)
- `docs/endoscope_data_format.md`
---
## 5. 分阶段实施顺序(建议)
### Phase 1先跑通“无 text”2-DOF 版本
-`state_dim/action_dim`
- 跑通数据加载 + 训练 + 离线验证
### Phase 2加入 text instruction
- 数据协议加入 instruction
- 接入 DistilBERT768
-`concat_transformer_input` 完成 text 融合训练
---
## 6. 验收标准Definition of Done
1. 可用你的 HDF5 数据直接训练,不依赖 sim/gym。
2. 模型输入同时支持图像、2D qpos、text instruction。
3. Text 编码器使用 DistilBERT输出特征维度为 768。
4. Text 融合方式为 Transformer 输入级 concat。
5. README 有完整训练与离线评估命令示例,团队可复现。
---
## 7. 风险点与提前规避
1. **text 与动作时序对齐问题**
- 需明确 instruction 是 episode-level 还是 timestep-level。
2. **小维度控制下的动作抖动**
- 可在后处理中加入 low-pass / action smoothing。
3. **多模态尺度不平衡**
- 需关注图像/状态/text 融合后梯度主导问题(可加 modality dropout 或 loss 权重调节)。
4. **文本编码开销导致训练变慢**
- 可选缓存 DistilBERT 特征,或冻结 text encoder。
---
## 8. 你接下来只需提供的最小信息(进入代码改造前)
1. 2 个电机各自的物理含义与取值范围(单位、上下限):电机分别为 motor_x 和 motor_yx 的范围为 7000-17384y 的范围为 8000-18884。对应的 action_x 和 action_y 都为 0~65535 之间
2. 你当前数据中 `qpos``action` 的实际定义是否相同action 和 qpos 定义接近,只不过是将 0~65535 分别映射到电机磁编码器数值上。
3. text instruction 是每个 episode 一条,还是每个 timestep 一条text instruction 是每个 timestep 一条。
4. 相机数量、分辨率、帧率:相机数量为 1分辨率为 224*224帧率为 30Hz对应的电机控制频率也为 30Hz
5. 是否在训练时冻结 DistilBERT`freeze_text_encoder=True/False`DistilBERT 完全冻结。构建训练集时,先将每一个 frame 的 text instruction 用 DistilBERT 编码以后再保存。这样训练过程中不需要调用 DistilBERT。
> 有了这 5 项,即可进入下一步代码改造。
text_input_ids、text_attention_mask 什么意思;'instruction_mode': 'episode-level'没有用到;
```python
instruction = ''
if self.use_text_instruction:
if '/instruction_timestep' in root:
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
elif '/instruction' in root:
instruction_node = root['/instruction']
if getattr(instruction_node, 'shape', ()) == ():
instruction = self._decode_instruction(instruction_node[()])
else:
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
instruction = self._decode_instruction(instruction_node[start_ts])
else:
instruction = self._decode_instruction(instruction_node[0])
```
为什么修改了 Transformer 的定义?这里是否会生效?

View File

@@ -35,8 +35,8 @@ You can find all scripted/human demo for simulated environments [here](https://d
pip install pyyaml pip install pyyaml
pip install rospkg pip install rospkg
pip install pexpect pip install pexpect
pip install mujoco pip install mujoco==2.3.7
pip install dm_control pip install dm_control==1.0.14
pip install opencv-python pip install opencv-python
pip install matplotlib pip install matplotlib
pip install einops pip install einops

View File

@@ -0,0 +1,680 @@
#!/usr/bin/env python3
import argparse
import csv
import json
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import h5py
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
@dataclass
class CropBox:
x1: int
y1: int
x2: int
y2: int
@property
def w(self) -> int:
return self.x2 - self.x1
@property
def h(self) -> int:
return self.y2 - self.y1
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)."
)
)
parser.add_argument(
"--segment_dir",
type=str,
required=True,
help="Path to one raw segment, e.g. data/follow_seg_001",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Output dir for episode_*.hdf5",
)
parser.add_argument(
"--episode_idx",
type=int,
default=0,
help="Output episode index (default: 0)",
)
parser.add_argument(
"--max_frames",
type=int,
default=-1,
help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)",
)
parser.add_argument(
"--camera_name",
type=str,
default="top",
help="Camera name written to /observations/images/<camera_name>",
)
parser.add_argument(
"--crop",
type=int,
nargs=4,
default=[733, 30, 1754, 1051],
metavar=("X1", "Y1", "X2", "Y2"),
help="Crop box in original image coordinates",
)
parser.add_argument(
"--resize",
type=int,
nargs=2,
default=[224, 224],
metavar=("W", "H"),
help="Output image size",
)
parser.add_argument(
"--instruction_template",
type=str,
default="Move toward the {label} at {region}.",
help="Template for per-frame instruction",
)
parser.add_argument(
"--instruction_empty",
type=str,
default="No target visible.",
help="Instruction when no valid target after crop",
)
parser.add_argument(
"--stop_instruction",
type=str,
default="Stop move.",
help="Instruction used for stationary head/tail frames",
)
parser.add_argument(
"--state_norm",
choices=["minus1_1", "0_1", "raw"],
default="minus1_1",
help="Normalization for qpos (motor_pos_y, motor_pos_x)",
)
parser.add_argument(
"--action_norm",
choices=["minus1_1", "0_1", "raw"],
default="minus1_1",
help="Normalization for action (motor_command_0, motor_command_1)",
)
parser.add_argument(
"--encode_text_features",
action="store_true",
help="Encode per-frame instruction into 768-dim DistilBERT features",
)
parser.add_argument(
"--text_model_name",
type=str,
default="distilbert-base-uncased",
help="HuggingFace model name for DistilBERT",
)
parser.add_argument(
"--text_batch_size",
type=int,
default=32,
help="Batch size for text feature extraction",
)
parser.add_argument(
"--motion_window",
type=int,
default=5,
help="Sliding window size used for stationary detection at beginning/end",
)
parser.add_argument(
"--motion_threshold",
type=float,
default=0.002,
help=(
"Motion threshold in normalized delta space (0~1). "
"Smaller value means stricter stationary detection"
),
)
parser.add_argument(
"--disable_stop_override",
action="store_true",
help="Disable head/tail stationary instruction override",
)
parser.add_argument(
"--trim_stationary_edges",
action="store_true",
help="Trim stationary head/tail segments and keep only the middle moving segment",
)
parser.add_argument(
"--no_text_instruction",
action="store_true",
help="Do not save instruction/instruction_timestep (and disable text feature encoding)",
)
return parser.parse_args()
def sorted_frame_jsons(frames_dir: Path) -> List[Path]:
json_files = list(frames_dir.glob("*.json"))
def key_fn(p: Path) -> Tuple[int, str]:
m = re.search(r"frame_(\d+)", p.name)
idx = int(m.group(1)) if m else 10**9
return idx, p.name
json_files.sort(key=key_fn)
return json_files
def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]:
with csv_path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
return list(reader)
def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray:
if mode == "raw":
return x.astype(np.float32)
x01 = (x - min_v) / (max_v - min_v)
if mode == "0_1":
return x01.astype(np.float32)
# minus1_1
return (x01 * 2.0 - 1.0).astype(np.float32)
def clip_bbox_to_crop(
x_min: float,
y_min: float,
x_max: float,
y_max: float,
crop: CropBox,
) -> Optional[Tuple[float, float, float, float]]:
nx1 = max(x_min - crop.x1, 0.0)
ny1 = max(y_min - crop.y1, 0.0)
nx2 = min(x_max - crop.x1, float(crop.w - 1))
ny2 = min(y_max - crop.y1, float(crop.h - 1))
if nx2 <= nx1 or ny2 <= ny1:
return None
return nx1, ny1, nx2, ny2
def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]:
x1, y1, x2, y2 = box
return (x1 + x2) * 0.5, (y1 + y2) * 0.5
def region_3x3(cx: float, cy: float, w: int, h: int) -> str:
x_bin = min(2, max(0, int(cx / (w / 3.0))))
y_bin = min(2, max(0, int(cy / (h / 3.0))))
xs = ["left", "center", "right"]
ys = ["top", "middle", "bottom"]
return f"{ys[y_bin]}-{xs[x_bin]}"
def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]:
points = shape.get("points", None)
label = shape.get("label", "target")
if not points or len(points) < 2:
return None
pts = np.array(points, dtype=np.float32)
x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min())
x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max())
area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min)
return label, x_min, y_min, x_max, y_max, area
def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]:
shapes = annotation.get("shapes", [])
best = None
for shape in shapes:
parsed = read_shape_bbox(shape)
if parsed is None:
continue
label, x1, y1, x2, y2, area = parsed
clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop)
if clipped is None:
continue
c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1])
if best is None or c_area > best[2]:
best = (label, clipped, c_area)
if best is None:
return None
return best[0], best[1]
def instruction_from_annotation(
annotation: Dict,
crop: CropBox,
template: str,
empty_instruction: str,
) -> str:
picked = select_target_box(annotation, crop)
if picked is None:
return empty_instruction
label, box = picked
cx, cy = bbox_center(box)
region = region_3x3(cx, cy, crop.w, crop.h)
return template.format(label=label, region=region)
def extract_text_features(
instructions: Sequence[str],
model_name: str,
batch_size: int = 32,
) -> np.ndarray:
try:
import torch
from transformers import DistilBertTokenizerFast
except ImportError as exc:
raise ImportError(
"Text feature encoding requires transformers. Please install: pip install transformers"
) from exc
repo_root = Path(__file__).resolve().parents[1]
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from models.text_encoder import DistilBERTTextEncoder
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device)
model.eval()
feats: List[np.ndarray] = []
with torch.no_grad():
for i in range(0, len(instructions), batch_size):
batch = list(instructions[i:i + batch_size])
tok = tokenizer(
batch,
padding=True,
truncation=True,
max_length=32,
return_tensors="pt",
)
input_ids = tok["input_ids"].to(device)
attention_mask = tok["attention_mask"].to(device)
cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32)
feats.append(cls)
return np.concatenate(feats, axis=0)
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
scale = max_v - min_v
if scale <= 0:
return np.zeros_like(v, dtype=np.float32)
return ((v - min_v) / scale).astype(np.float32)
def override_stationary_edge_instructions(
instructions: List[str],
motor_pos_y: np.ndarray,
motor_pos_x: np.ndarray,
motor_cmd_0: np.ndarray,
motor_cmd_1: np.ndarray,
stop_instruction: str,
motion_window: int,
motion_threshold: float,
) -> Tuple[List[str], int, int]:
"""
Override instruction text at head/tail using qpos velocity.
Start side: once qpos speed is above threshold for consecutive frames,
stop applying stop_instruction from that point onward.
End side: similarly scan backward from the end.
"""
num = len(instructions)
if num == 0:
return instructions, 0, 0
start_count, end_count = detect_stationary_edge_counts_from_qpos(
motor_pos_y=motor_pos_y,
motor_pos_x=motor_pos_x,
motion_window=motion_window,
motion_threshold=motion_threshold,
)
updated = list(instructions)
for i in range(start_count):
updated[i] = stop_instruction
for i in range(num - end_count, num):
updated[i] = stop_instruction
return updated, start_count, end_count
def detect_stationary_edge_counts_from_qpos(
motor_pos_y: np.ndarray,
motor_pos_x: np.ndarray,
motion_window: int,
motion_threshold: float,
) -> Tuple[int, int]:
"""Return stationary frame counts on head and tail using qpos velocity rule."""
num = int(len(motor_pos_y))
if num == 0:
return 0, 0
if num == 1:
return 1, 1
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
consecutive = max(1, int(motion_window))
dt = 1.0 / 30.0
frame_speed = np.zeros((num,), dtype=np.float32)
dy = np.abs(np.diff(py)) / dt
dx = np.abs(np.diff(px)) / dt
frame_speed[1:] = np.maximum(dy, dx)
high_run = 0
start_count = num
for i in range(1, num):
if frame_speed[i] > motion_threshold:
high_run += 1
if high_run >= consecutive:
start_count = i - consecutive + 1
break
else:
high_run = 0
high_run = 0
end_count = num
for i in range(num - 1, 0, -1):
if frame_speed[i] > motion_threshold:
high_run += 1
if high_run >= consecutive:
tail_start = i + consecutive
end_count = max(0, num - tail_start)
break
else:
high_run = 0
return start_count, end_count
def find_segment_csv(segment_dir: Path) -> Path:
csvs = sorted(segment_dir.glob("*.csv"))
if not csvs:
raise FileNotFoundError(f"No csv file found in {segment_dir}")
return csvs[0]
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
"""Convert boolean mask to closed-open segments [start, end)."""
segments: List[Tuple[int, int]] = []
if mask.size == 0:
return segments
in_seg = False
start = 0
for i, v in enumerate(mask.tolist()):
if v and not in_seg:
in_seg = True
start = i
elif (not v) and in_seg:
in_seg = False
segments.append((start, i))
if in_seg:
segments.append((start, int(mask.size)))
return segments
def save_episode_plot_with_stop_segments(
qpos: np.ndarray,
action: np.ndarray,
instructions: Sequence[str],
stop_instruction: str,
plot_path: Path,
) -> None:
"""
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
"""
qpos = np.asarray(qpos)
action = np.asarray(action)
if qpos.ndim != 2 or action.ndim != 2:
return
num_ts, num_dim = qpos.shape
if num_ts == 0 or num_dim == 0:
return
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
stop_segments = _mask_to_segments(stop_mask)
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
for dim_idx in range(num_dim):
ax = axs_list[dim_idx]
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
for seg_idx, (st, ed) in enumerate(stop_segments):
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
label='stop instruction' if seg_idx == 0 else None)
ax.set_ylabel(f'dim {dim_idx}')
ax.legend(loc='best')
ax.grid(alpha=0.25, linestyle='--')
axs_list[-1].set_xlabel('timestep')
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
fig.tight_layout()
fig.savefig(str(plot_path), dpi=140)
plt.close(fig)
def main() -> None:
args = parse_args()
if args.no_text_instruction and args.encode_text_features:
raise ValueError('--no_text_instruction and --encode_text_features cannot be used together.')
segment_dir = Path(args.segment_dir).resolve()
output_dir = Path(args.output_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
frames_dir = segment_dir / "frames"
if not frames_dir.exists():
raise FileNotFoundError(f"frames dir not found: {frames_dir}")
csv_path = find_segment_csv(segment_dir)
csv_rows = load_csv_rows(csv_path)
if len(csv_rows) == 0:
raise ValueError(f"CSV has no rows: {csv_path}")
crop = CropBox(*args.crop)
resize_w, resize_h = int(args.resize[0]), int(args.resize[1])
json_files = sorted_frame_jsons(frames_dir)
if not json_files:
raise FileNotFoundError(f"No frame json found in: {frames_dir}")
max_aligned = min(len(json_files), len(csv_rows))
num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned)
if num <= 0:
raise ValueError("No aligned frames available.")
images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8)
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
instructions: List[str] = []
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
y_min, y_max = 8000.0, 18884.0
x_min, x_max = 7000.0, 17384.0
cmd_min, cmd_max = 0.0, 65535.0
for i in range(num):
json_path = json_files[i]
with json_path.open("r", encoding="utf-8") as f:
ann = json.load(f)
image_path = frames_dir / ann["imagePath"]
if not image_path.exists():
alt = json_path.with_suffix(".jpg")
if alt.exists():
image_path = alt
else:
raise FileNotFoundError(f"Image not found for {json_path.name}")
img = Image.open(image_path).convert("RGB")
img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2))
img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR)
images[i] = np.asarray(img_resize, dtype=np.uint8)
row = csv_rows[i]
motor_pos_y = float(row["motor_pos_y"])
motor_pos_x = float(row["motor_pos_x"])
motor_cmd_0 = float(row["motor_command_0"])
motor_cmd_1 = float(row["motor_command_1"])
motor_pos_y_series[i] = motor_pos_y
motor_pos_x_series[i] = motor_pos_x
motor_cmd_0_series[i] = motor_cmd_0
motor_cmd_1_series[i] = motor_cmd_1
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
if not args.no_text_instruction:
ins = instruction_from_annotation(
ann,
crop,
args.instruction_template,
args.instruction_empty,
)
instructions.append(ins)
start_stop_count, end_stop_count = detect_stationary_edge_counts_from_qpos(
motor_pos_y=motor_pos_y_series,
motor_pos_x=motor_pos_x_series,
motion_window=args.motion_window,
motion_threshold=args.motion_threshold,
)
if args.trim_stationary_edges:
keep_start = int(start_stop_count)
keep_end = int(num - end_stop_count)
if keep_end <= keep_start:
raise ValueError(
f'No moving segment left after trim: start={start_stop_count}, end={end_stop_count}, num={num}. '
f'Consider lowering --motion_threshold or --motion_window.'
)
images = images[keep_start:keep_end]
qpos = qpos[keep_start:keep_end]
action = action[keep_start:keep_end]
motor_pos_y_series = motor_pos_y_series[keep_start:keep_end]
motor_pos_x_series = motor_pos_x_series[keep_start:keep_end]
motor_cmd_0_series = motor_cmd_0_series[keep_start:keep_end]
motor_cmd_1_series = motor_cmd_1_series[keep_start:keep_end]
if not args.no_text_instruction:
instructions = instructions[keep_start:keep_end]
print(
f'Trim stationary edges: removed head={start_stop_count}, tail={end_stop_count}, '
f'kept={keep_end - keep_start}'
)
# After trimming, full kept segment is the moving region.
start_stop_count, end_stop_count = 0, 0
if (not args.disable_stop_override) and (not args.no_text_instruction):
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
instructions=instructions,
motor_pos_y=motor_pos_y_series,
motor_pos_x=motor_pos_x_series,
motor_cmd_0=motor_cmd_0_series,
motor_cmd_1=motor_cmd_1_series,
stop_instruction=args.stop_instruction,
motion_window=args.motion_window,
motion_threshold=args.motion_threshold,
)
text_features = None
if args.encode_text_features:
text_features = extract_text_features(
instructions,
model_name=args.text_model_name,
batch_size=args.text_batch_size,
)
out_path = output_dir / f"episode_{args.episode_idx}.hdf5"
dt = 1.0 / 30.0
with h5py.File(out_path, "w") as root:
root.attrs["sim"] = False
root.attrs["source_segment"] = str(segment_dir)
root.attrs["frame_rate"] = 30
root.attrs["dt"] = dt
root.attrs["state_norm_mode"] = args.state_norm
root.attrs["action_norm_mode"] = args.action_norm
root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]"
root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]"
root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32)
obs = root.create_group("observations")
obs.create_dataset("qpos", data=qpos, dtype=np.float32)
images_group = obs.create_group("images")
images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8)
root.create_dataset("action", data=action, dtype=np.float32)
if not args.no_text_instruction:
str_dtype = h5py.string_dtype(encoding="utf-8")
root.create_dataset(
"instruction_timestep",
shape=(len(instructions),),
dtype=str_dtype,
data=np.asarray(instructions, dtype=object),
)
root.create_dataset(
"instruction",
shape=(),
dtype=str_dtype,
data=instructions[0] if len(instructions) > 0 else "",
)
if text_features is not None:
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32)
print(f"Saved: {out_path}")
print(f"Frames used: {num}")
print(f"Image shape: {images.shape}")
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
if not args.disable_stop_override:
print(
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
f"mode=qpos_velocity_consecutive, consecutive={args.motion_window}, "
f"threshold={args.motion_threshold}, "
f"instruction='{args.stop_instruction}'"
)
if text_features is not None:
print(f"instruction_features_timestep shape: {text_features.shape}")
# Save a same-basename plot next to the generated hdf5
plot_path = out_path.with_suffix('.png')
save_episode_plot_with_stop_segments(
qpos=qpos,
action=action,
instructions=instructions,
stop_instruction=args.stop_instruction,
plot_path=plot_path,
)
print(f"Saved episode plot to: {plot_path}")
if __name__ == "__main__":
main()

25
build_no_text_dataset.sh Executable file
View File

@@ -0,0 +1,25 @@
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/00-follow"
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/follow-no-text"
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
mkdir -p "$OUT_DIR"
i=50
for d in "$SEG_ROOT"/follow_seg_*; do
[ -d "$d" ] || continue
echo "Building $d -> episode_$i"
python "$SCRIPT" \
--segment_dir "$d" \
--output_dir "$OUT_DIR" \
--episode_idx "$i" \
--max_frames -1 \
--camera_name top \
--crop 733 30 1754 1051 \
--resize 224 224 \
--motion_window 3 \
--motion_threshold 0.05 \
--state_norm minus1_1 \
--action_norm minus1_1 \
--trim_stationary_edges \
--no_text_instruction
i=$((i+1))
done

29
build_text_dataset.sh Executable file
View File

@@ -0,0 +1,29 @@
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/01-cannulation"
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/cannulation"
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
mkdir -p "$OUT_DIR"
i=12
for d in "$SEG_ROOT"/seg_*; do
[ -d "$d" ] || continue
echo "Building $d -> episode_$i"
python "$SCRIPT" \
--segment_dir "$d" \
--output_dir "$OUT_DIR" \
--episode_idx "$i" \
--max_frames -1 \
--camera_name top \
--crop 733 30 1754 1051 \
--resize 224 224 \
--instruction_template 'Cannulate the {label} on the phantom located at the {region} with the sphincterotome.' \
--instruction_empty 'No target visible.' \
--stop_instruction 'Stop move.' \
--motion_window 3 \
--motion_threshold 0.05 \
--state_norm minus1_1 \
--action_norm minus1_1 \
--encode_text_features \
--text_model_name distilbert-base-uncased \
--text_batch_size 32
i=$((i+1))
done

View File

@@ -21,3 +21,5 @@ dependencies:
- packaging=23.0 - packaging=23.0
- h5py=3.8.0 - h5py=3.8.0
- ipython=8.12.0 - ipython=8.12.0
- pip:
- transformers==4.38.2

View File

@@ -1,7 +1,7 @@
import pathlib import pathlib
### Task parameters ### Task parameters
DATA_DIR = '<put your data dir here>' DATA_DIR = str(pathlib.Path(__file__).parent.resolve() / 'data')
SIM_TASK_CONFIGS = { SIM_TASK_CONFIGS = {
'sim_transfer_cube_scripted':{ 'sim_transfer_cube_scripted':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted', 'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
@@ -32,6 +32,85 @@ SIM_TASK_CONFIGS = {
}, },
} }
ENDOSCOPE_TASK_CONFIGS = {
'endoscope_default': {
'dataset_dir': DATA_DIR + '/endoscope_default',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': True,
'instruction_mode': 'timestep-level',
'use_cached_text_features': True,
'text_encoder_type': 'distilbert',
'text_feature_dim': 768,
'text_fusion_type': 'concat_transformer_input',
'freeze_text_encoder': True,
'text_max_length': 32,
'text_tokenizer_name': 'distilbert-base-uncased',
},
'endoscope_follow': {
'dataset_dir': DATA_DIR + '/follow',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': True,
'instruction_mode': 'timestep-level',
'use_cached_text_features': True,
'text_encoder_type': 'distilbert',
'text_feature_dim': 768,
'text_fusion_type': 'concat_transformer_input',
'freeze_text_encoder': True,
'text_max_length': 32,
'text_tokenizer_name': 'distilbert-base-uncased',
},
'endoscope_both_no_text': {
'dataset_dir': DATA_DIR + '/both-no-text',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': False,
},
'endoscope_sanity_check': {
'dataset_dir': DATA_DIR + '/sanity-check',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': False,
},
'endoscope_cannulation_no_text': {
'dataset_dir': DATA_DIR + '/cannulation-no-text',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': False,
},
'endoscope_follow_no_text': {
'dataset_dir': DATA_DIR + '/follow-no-text',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'real_action_t_minus_1': False,
'use_text_instruction': False,
},
}
### Simulation envs fixed constants ### Simulation envs fixed constants
DT = 0.02 DT = 0.02
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]

View File

@@ -4,6 +4,7 @@ from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from torch.optim.adamw import AdamW
from .models import build_ACT_model, build_CNNMLP_model from .models import build_ACT_model, build_CNNMLP_model
import IPython import IPython
@@ -30,6 +31,15 @@ def get_args_parser():
help="Type of positional embedding to use on top of the image features") help="Type of positional embedding to use on top of the image features")
parser.add_argument('--camera_names', default=[], type=list, # will be overridden parser.add_argument('--camera_names', default=[], type=list, # will be overridden
help="A list of camera names") help="A list of camera names")
parser.add_argument('--state_dim', default=14, type=int)
parser.add_argument('--action_dim', default=14, type=int)
parser.add_argument('--use_text', action='store_true')
parser.add_argument('--text_encoder_type', default='distilbert', type=str)
parser.add_argument('--text_feature_dim', default=768, type=int)
parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str)
parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', default=32, type=int)
parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str)
# * Transformer # * Transformer
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
@@ -60,16 +70,17 @@ def get_args_parser():
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
parser.add_argument('--seed', action='store', type=int, help='seed', required=True) parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True) parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--temporal_agg', action='store_true') parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--image_aug', action='store_true')
return parser return parser
def build_ACT_model_and_optimizer(args_override): def build_ACT_model_and_optimizer(args_override):
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args() args, _ = parser.parse_known_args()
for k, v in args_override.items(): for k, v in args_override.items():
setattr(args, k, v) setattr(args, k, v)
@@ -84,7 +95,7 @@ def build_ACT_model_and_optimizer(args_override):
"lr": args.lr_backbone, "lr": args.lr_backbone,
}, },
] ]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, optimizer = AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
return model, optimizer return model, optimizer
@@ -92,7 +103,7 @@ def build_ACT_model_and_optimizer(args_override):
def build_CNNMLP_model_and_optimizer(args_override): def build_CNNMLP_model_and_optimizer(args_override):
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args() args, _ = parser.parse_known_args()
for k, v in args_override.items(): for k, v in args_override.items():
setattr(args, k, v) setattr(args, k, v)
@@ -107,7 +118,7 @@ def build_CNNMLP_model_and_optimizer(args_override):
"lr": args.lr_backbone, "lr": args.lr_backbone,
}, },
] ]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, optimizer = AdamW(param_dicts, lr=args.lr,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
return model, optimizer return model, optimizer

View File

@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
train_backbone: bool, train_backbone: bool,
return_interm_layers: bool, return_interm_layers: bool,
dilation: bool): dilation: bool):
backbone = getattr(torchvision.models, name)( backbone_builder = getattr(torchvision.models, name)
weights = None
if is_main_process():
weight_enum_name_map = {
'resnet18': 'ResNet18_Weights',
'resnet34': 'ResNet34_Weights',
'resnet50': 'ResNet50_Weights',
'resnet101': 'ResNet101_Weights',
}
enum_name = weight_enum_name_map.get(name)
if enum_name is not None and hasattr(torchvision.models, enum_name):
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
try:
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation], replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? weights=weights,
norm_layer=FrozenBatchNorm2d,
)
except TypeError:
# Backward compatibility for older torchvision that still expects `pretrained`.
backbone = backbone_builder(
replace_stride_with_dilation=[False, False, dilation],
pretrained=(weights is not None),
norm_layer=FrozenBatchNorm2d,
)
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers) super().__init__(backbone, train_backbone, num_channels, return_interm_layers)

View File

@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
class DETRVAE(nn.Module): class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """ """ This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names): def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
""" Initializes the model. """ Initializes the model.
Parameters: Parameters:
backbones: torch module of the backbone to be used. See backbone.py backbones: torch module of the backbone to be used. See backbone.py
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
self.camera_names = camera_names self.camera_names = camera_names
self.transformer = transformer self.transformer = transformer
self.encoder = encoder self.encoder = encoder
self.use_text = use_text
self.text_fusion_type = text_fusion_type
hidden_dim = transformer.d_model hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, state_dim) self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1) self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim) self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None: if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1) self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones) self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(14, hidden_dim) self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else: else:
# input_dim = 14 + 7 # robot_state + env_state self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim) self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim) self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None self.backbones = None
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
# encoder extra parameters # encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters # decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent num_extra_tokens = 3 if self.use_text else 2
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
def forward(self, qpos, image, env_state, actions=None, is_pad=None): def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
""" """
qpos: batch, qpos_dim qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width image: batch, num_cam, channel, height, width
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
all_cam_pos.append(pos) all_cam_pos.append(pos)
# proprioception features # proprioception features
proprio_input = self.input_proj_robot_state(qpos) proprio_input = self.input_proj_robot_state(qpos)
extra_input_tokens = None
if self.use_text and text_features is not None:
if self.text_fusion_type != 'concat_transformer_input':
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
text_input = self.text_proj(text_features)
extra_input_tokens = text_input.unsqueeze(0)
# fold camera dimension into width dimension # fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3) src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3) pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
extra_input_tokens=extra_input_tokens,
)[0]
else: else:
qpos = self.input_proj_robot_state(qpos) qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state) env_state = self.input_proj_env_state(env_state)
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
class CNNMLP(nn.Module): class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names): def __init__(self, backbones, state_dim, action_dim, camera_names):
""" Initializes the model. """ Initializes the model.
Parameters: Parameters:
backbones: torch module of the backbone to be used. See backbone.py backbones: torch module of the backbone to be used. See backbone.py
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
""" """
super().__init__() super().__init__()
self.camera_names = camera_names self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more self.action_head = nn.Linear(1000, action_dim) # TODO add more
if backbones is not None: if backbones is not None:
self.backbones = nn.ModuleList(backbones) self.backbones = nn.ModuleList(backbones)
backbone_down_projs = [] backbone_down_projs = []
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
backbone_down_projs.append(down_proj) backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs) self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + 14 mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2) self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
else: else:
raise NotImplementedError raise NotImplementedError
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
for cam_feature in all_cam_features: for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1])) flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14 features = torch.cat([flattened_features, qpos], axis=1)
a_hat = self.mlp(features) a_hat = self.mlp(features)
return a_hat return a_hat
@@ -227,7 +246,8 @@ def build_encoder(args):
def build(args): def build(args):
state_dim = 14 # TODO hardcode state_dim = args.state_dim
action_dim = args.action_dim
# From state # From state
# backbone = None # from state for now, no need for conv nets # backbone = None # from state for now, no need for conv nets
@@ -245,8 +265,12 @@ def build(args):
transformer, transformer,
encoder, encoder,
state_dim=state_dim, state_dim=state_dim,
action_dim=action_dim,
num_queries=args.num_queries, num_queries=args.num_queries,
camera_names=args.camera_names, camera_names=args.camera_names,
use_text=args.use_text,
text_feature_dim=args.text_feature_dim,
text_fusion_type=args.text_fusion_type,
) )
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
@@ -255,7 +279,8 @@ def build(args):
return model return model
def build_cnnmlp(args): def build_cnnmlp(args):
state_dim = 14 # TODO hardcode state_dim = args.state_dim
action_dim = args.action_dim
# From state # From state
# backbone = None # from state for now, no need for conv nets # backbone = None # from state for now, no need for conv nets
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
model = CNNMLP( model = CNNMLP(
backbones, backbones,
state_dim=state_dim, state_dim=state_dim,
action_dim=action_dim,
camera_names=args.camera_names, camera_names=args.camera_names,
) )

View File

@@ -46,7 +46,7 @@ class Transformer(nn.Module):
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None): def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
# TODO flatten only when input has H and W # TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC # flatten NxCxHxW to HWxNxC
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1) # mask = mask.flatten(1)
additional_inputs = [latent_input, proprio_input]
if extra_input_tokens is not None:
if len(extra_input_tokens.shape) == 2:
extra_input_tokens = extra_input_tokens.unsqueeze(0)
for i in range(extra_input_tokens.shape[0]):
additional_inputs.append(extra_input_tokens[i])
addition_input = torch.stack(additional_inputs, axis=0)
if additional_pos_embed is not None:
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0) pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0) src = torch.cat([addition_input, src], axis=0)
else: else:
assert len(src.shape) == 3 assert len(src.shape) == 3

View File

@@ -10,17 +10,25 @@ from einops import rearrange
from constants import DT from constants import DT
from constants import PUPPET_GRIPPER_JOINT_OPEN from constants import PUPPET_GRIPPER_JOINT_OPEN
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
from utils import load_data # data functions from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict # helper functions from utils import compute_dict_mean, set_seed, detach_dict # helper functions
from policy import ACTPolicy, CNNMLPPolicy from policy import ACTPolicy, CNNMLPPolicy
from visualize_episodes import save_videos from visualize_episodes import save_videos
from sim_env import BOX_POSE
import IPython import IPython
e = IPython.embed e = IPython.embed
def load_checkpoint_state_dict(ckpt_path):
"""Load checkpoint state_dict safely across different torch versions."""
try:
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
except TypeError:
# For older PyTorch versions that do not support `weights_only`.
return torch.load(ckpt_path, map_location='cpu')
def main(args): def main(args):
set_seed(1) set_seed(1)
# command line parameters # command line parameters
@@ -34,25 +42,50 @@ def main(args):
num_epochs = args['num_epochs'] num_epochs = args['num_epochs']
# get task parameters # get task parameters
is_sim = task_name[:4] == 'sim_' is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
if is_sim: if is_endoscope:
from constants import SIM_TASK_CONFIGS task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
is_sim = False
elif task_name in SIM_TASK_CONFIGS:
task_config = SIM_TASK_CONFIGS[task_name] task_config = SIM_TASK_CONFIGS[task_name]
is_sim = True
else: else:
from aloha_scripts.constants import TASK_CONFIGS from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name] task_config = TASK_CONFIGS[task_name]
is_sim = False
dataset_dir = task_config['dataset_dir'] dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes'] num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len'] episode_len = task_config['episode_len']
camera_names = task_config['camera_names'] camera_names = task_config['camera_names']
state_dim = task_config.get('state_dim', 14)
action_dim = task_config.get('action_dim', state_dim)
use_text_instruction = task_config.get('use_text_instruction', False)
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
use_cached_text_features = task_config.get('use_cached_text_features', True)
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
text_feature_dim = task_config.get('text_feature_dim', 768)
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
text_max_length = task_config.get('text_max_length', 32)
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
if args.get('text_encoder_type') is not None:
text_encoder_type = args['text_encoder_type']
if args.get('text_max_length') is not None:
text_max_length = args['text_max_length']
if args.get('freeze_text_encoder', False):
freeze_text_encoder = True
if args.get('disable_real_action_shift', False):
real_action_t_minus_1 = False
# fixed parameters # fixed parameters
state_dim = 14
lr_backbone = 1e-5 lr_backbone = 1e-5
backbone = 'resnet18' backbone = 'resnet18'
if policy_class == 'ACT': if policy_class == 'ACT':
enc_layers = 4 enc_layers = 2
dec_layers = 7 dec_layers = 4
nheads = 8 nheads = 8
policy_config = {'lr': args['lr'], policy_config = {'lr': args['lr'],
'num_queries': args['chunk_size'], 'num_queries': args['chunk_size'],
@@ -65,18 +98,36 @@ def main(args):
'dec_layers': dec_layers, 'dec_layers': dec_layers,
'nheads': nheads, 'nheads': nheads,
'camera_names': camera_names, 'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
'text_encoder_type': text_encoder_type,
'text_feature_dim': text_feature_dim,
'text_fusion_type': text_fusion_type,
'freeze_text_encoder': freeze_text_encoder,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_max_length': text_max_length,
'text_tokenizer_name': text_tokenizer_name,
} }
elif policy_class == 'CNNMLP': elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1, policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,} 'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
}
else: else:
raise NotImplementedError raise NotImplementedError
config = { config = {
'num_epochs': num_epochs, 'num_epochs': num_epochs,
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
'resume_ckpt_path': args.get('resume_ckpt', None),
'ckpt_dir': ckpt_dir, 'ckpt_dir': ckpt_dir,
'episode_len': episode_len, 'episode_len': episode_len,
'state_dim': state_dim, 'state_dim': state_dim,
'action_dim': action_dim,
'lr': args['lr'], 'lr': args['lr'],
'policy_class': policy_class, 'policy_class': policy_class,
'onscreen_render': onscreen_render, 'onscreen_render': onscreen_render,
@@ -85,9 +136,25 @@ def main(args):
'seed': args['seed'], 'seed': args['seed'],
'temporal_agg': args['temporal_agg'], 'temporal_agg': args['temporal_agg'],
'camera_names': camera_names, 'camera_names': camera_names,
'real_robot': not is_sim 'real_robot': (not is_sim) and (not is_endoscope),
'use_text_instruction': use_text_instruction,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_tokenizer_name': text_tokenizer_name,
'text_max_length': text_max_length,
'debug_input': args.get('debug_input', False),
} }
if config['resume_ckpt_path']:
resume_ckpt_path = config['resume_ckpt_path']
if not os.path.isabs(resume_ckpt_path):
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
if os.path.isfile(candidate_path):
resume_ckpt_path = candidate_path
if not os.path.isfile(resume_ckpt_path):
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
config['resume_ckpt_path'] = resume_ckpt_path
if is_eval: if is_eval:
ckpt_names = [f'policy_best.ckpt'] ckpt_names = [f'policy_best.ckpt']
results = [] results = []
@@ -100,7 +167,21 @@ def main(args):
print() print()
exit() exit()
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val) train_dataloader, val_dataloader, stats, _ = load_data(
dataset_dir,
num_episodes,
camera_names,
batch_size_train,
batch_size_val,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=args['image_aug'],
)
# save dataset stats # save dataset stats
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
@@ -152,6 +233,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
set_seed(1000) set_seed(1000)
ckpt_dir = config['ckpt_dir'] ckpt_dir = config['ckpt_dir']
state_dim = config['state_dim'] state_dim = config['state_dim']
action_dim = config['action_dim']
real_robot = config['real_robot'] real_robot = config['real_robot']
policy_class = config['policy_class'] policy_class = config['policy_class']
onscreen_render = config['onscreen_render'] onscreen_render = config['onscreen_render']
@@ -161,11 +243,12 @@ def eval_bc(config, ckpt_name, save_episode=True):
task_name = config['task_name'] task_name = config['task_name']
temporal_agg = config['temporal_agg'] temporal_agg = config['temporal_agg']
onscreen_cam = 'angle' onscreen_cam = 'angle'
BOX_POSE = None
# load policy and stats # load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name) ckpt_path = os.path.join(ckpt_dir, ckpt_name)
policy = make_policy(policy_class, policy_config) policy = make_policy(policy_class, policy_config)
loading_status = policy.load_state_dict(torch.load(ckpt_path)) loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
print(loading_status) print(loading_status)
policy.cuda() policy.cuda()
policy.eval() policy.eval()
@@ -202,8 +285,14 @@ def eval_bc(config, ckpt_name, save_episode=True):
rollout_id += 0 rollout_id += 0
### set task ### set task
if 'sim_transfer_cube' in task_name: if 'sim_transfer_cube' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = sample_box_pose() # used in sim reset BOX_POSE[0] = sample_box_pose() # used in sim reset
elif 'sim_insertion' in task_name: elif 'sim_insertion' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
ts = env.reset() ts = env.reset()
@@ -216,7 +305,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
### evaluation loop ### evaluation loop
if temporal_agg: if temporal_agg:
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda() all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda() qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
image_list = [] # for visualization image_list = [] # for visualization
@@ -313,23 +402,72 @@ def eval_bc(config, ckpt_name, save_episode=True):
return success_rate, avg_return return success_rate, avg_return
def forward_pass(data, policy): def forward_pass(data, policy, debug_input=False, debug_tag=''):
image_data, qpos_data, action_data, is_pad = data image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda() image_data = image_data.cuda()
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None qpos_data = qpos_data.cuda()
action_data = action_data.cuda()
is_pad = is_pad.cuda()
text_input_ids = text_input_ids.cuda()
text_attention_mask = text_attention_mask.cuda()
text_feature_data = text_feature_data.cuda()
text_feature_valid = text_feature_valid.cuda()
text_features = None
if torch.any(text_feature_valid):
text_features = text_feature_data
if debug_input:
image_min = float(image_data.min().item())
image_max = float(image_data.max().item())
qpos_mean = float(qpos_data.mean().item())
qpos_std = float(qpos_data.std().item())
action_mean = float(action_data.mean().item())
action_std = float(action_data.std().item())
pad_ratio = float(is_pad.float().mean().item())
print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]')
print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})')
print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})')
print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}')
print(
f'[debug_input] {debug_tag} has_nan_or_inf: '
f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, '
f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, '
f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}'
)
return policy(
qpos_data,
image_data,
text_input_ids=text_input_ids,
text_attention_mask=text_attention_mask,
text_features=text_features,
actions=action_data,
is_pad=is_pad,
)
def train_bc(train_dataloader, val_dataloader, config): def train_bc(train_dataloader, val_dataloader, config):
num_epochs = config['num_epochs'] num_epochs = config['num_epochs']
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
resume_ckpt_path = config.get('resume_ckpt_path', None)
ckpt_dir = config['ckpt_dir'] ckpt_dir = config['ckpt_dir']
seed = config['seed'] seed = config['seed']
policy_class = config['policy_class'] policy_class = config['policy_class']
policy_config = config['policy_config'] policy_config = config['policy_config']
debug_input = config.get('debug_input', False)
set_seed(seed) set_seed(seed)
policy = make_policy(policy_class, policy_config) policy = make_policy(policy_class, policy_config)
policy.cuda() policy.cuda()
if resume_ckpt_path:
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
print(loading_status)
optimizer = make_optimizer(policy_class, policy) optimizer = make_optimizer(policy_class, policy)
train_history = [] train_history = []
@@ -343,7 +481,8 @@ def train_bc(train_dataloader, val_dataloader, config):
policy.eval() policy.eval()
epoch_dicts = [] epoch_dicts = []
for batch_idx, data in enumerate(val_dataloader): for batch_idx, data in enumerate(val_dataloader):
forward_dict = forward_pass(data, policy) should_debug = debug_input and epoch == 0 and batch_idx == 0
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0')
epoch_dicts.append(forward_dict) epoch_dicts.append(forward_dict)
epoch_summary = compute_dict_mean(epoch_dicts) epoch_summary = compute_dict_mean(epoch_dicts)
validation_history.append(epoch_summary) validation_history.append(epoch_summary)
@@ -361,15 +500,31 @@ def train_bc(train_dataloader, val_dataloader, config):
# training # training
policy.train() policy.train()
optimizer.zero_grad() optimizer.zero_grad()
for batch_idx, data in enumerate(train_dataloader): epoch_train_dicts = []
forward_dict = forward_pass(data, policy) if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
train_steps_this_epoch = len(train_dataloader)
train_iterator = iter(train_dataloader)
else:
train_steps_this_epoch = int(train_steps_per_epoch)
train_iterator = iter(train_dataloader)
for step_idx in range(train_steps_this_epoch):
try:
data = next(train_iterator)
except StopIteration:
train_iterator = iter(train_dataloader)
data = next(train_iterator)
should_debug = debug_input and epoch == 0 and step_idx == 0
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
# backward # backward
loss = forward_dict['loss'] loss = forward_dict['loss']
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
train_history.append(detach_dict(forward_dict)) train_history.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)]) epoch_train_dicts.append(detach_dict(forward_dict))
epoch_summary = compute_dict_mean(epoch_train_dicts)
epoch_train_loss = epoch_summary['loss'] epoch_train_loss = epoch_summary['loss']
print(f'Train loss: {epoch_train_loss:.5f}') print(f'Train loss: {epoch_train_loss:.5f}')
summary_string = '' summary_string = ''
@@ -426,10 +581,23 @@ if __name__ == '__main__':
parser.add_argument('--lr', action='store', type=float, help='lr', required=True) parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
# for ACT # for ACT
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False) parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False) parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
parser.add_argument('--temporal_agg', action='store_true') parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', action='store', type=int, required=False)
parser.add_argument('--image_aug', action='store_true',
help='Enable training-time image augmentation (color/highlight/noise/blur)')
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
parser.add_argument('--disable_real_action_shift', action='store_true',
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
help='Optional checkpoint path to initialize model weights for fine-tuning')
parser.add_argument('--debug_input', action='store_true',
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
main(vars(parser.parse_args())) main(vars(parser.parse_args()))

0
models/__init__.py Normal file
View File

31
models/text_encoder.py Normal file
View File

@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
class DistilBERTTextEncoder(nn.Module):
def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True):
super().__init__()
try:
from transformers import DistilBertModel
except ImportError as exc:
raise ImportError(
'transformers is required for DistilBERT text encoding. '
'Install it with: pip install transformers'
) from exc
self.encoder = DistilBertModel.from_pretrained(model_name)
self.output_dim = output_dim
self.freeze = freeze
if self.freeze:
for param in self.encoder.parameters():
param.requires_grad = False
self.encoder.eval()
def forward(self, input_ids, attention_mask):
if self.freeze:
self.encoder.eval()
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# DistilBERT has no pooled output; use [CLS] token embedding
cls_feature = outputs.last_hidden_state[:, 0, :]
return cls_feature

View File

@@ -3,6 +3,7 @@ from torch.nn import functional as F
import torchvision.transforms as transforms import torchvision.transforms as transforms
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
from models.text_encoder import DistilBERTTextEncoder
import IPython import IPython
e = IPython.embed e = IPython.embed
@@ -13,18 +14,44 @@ class ACTPolicy(nn.Module):
self.model = model # CVAE decoder self.model = model # CVAE decoder
self.optimizer = optimizer self.optimizer = optimizer
self.kl_weight = args_override['kl_weight'] self.kl_weight = args_override['kl_weight']
self.use_text = args_override.get('use_text', False)
self.text_encoder = None
if self.use_text:
text_encoder_type = args_override.get('text_encoder_type', 'distilbert')
if text_encoder_type != 'distilbert':
raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}')
self.text_encoder = DistilBERTTextEncoder(
model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'),
output_dim=args_override.get('text_feature_dim', 768),
freeze=args_override.get('freeze_text_encoder', True),
)
print(f'KL Weight {self.kl_weight}') print(f'KL Weight {self.kl_weight}')
def __call__(self, qpos, image, actions=None, is_pad=None): def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
env_state = None env_state = None
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])
image = normalize(image) image = normalize(image)
if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None:
if self.text_encoder is None:
raise RuntimeError('Text encoder is not initialized while use_text=True.')
text_features = self.text_encoder(text_input_ids, text_attention_mask)
if actions is not None: # training time if actions is not None: # training time
if is_pad is None:
raise ValueError('`is_pad` must be provided during training when `actions` is not None.')
actions = actions[:, :self.model.num_queries] actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries] is_pad = is_pad[:, :self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad) a_hat, is_pad_hat, (mu, logvar) = self.model(
qpos,
image,
env_state,
text_features=text_features,
actions=actions,
is_pad=is_pad,
)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict() loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none') all_l1 = F.l1_loss(actions, a_hat, reduction='none')
@@ -34,7 +61,7 @@ class ACTPolicy(nn.Module):
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict return loss_dict
else: # inference time else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior
return a_hat return a_hat
def configure_optimizers(self): def configure_optimizers(self):
@@ -48,7 +75,7 @@ class CNNMLPPolicy(nn.Module):
self.model = model # decoder self.model = model # decoder
self.optimizer = optimizer self.optimizer = optimizer
def __call__(self, qpos, image, actions=None, is_pad=None): def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
env_state = None # TODO env_state = None # TODO
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])

337
utils.py
View File

@@ -2,21 +2,170 @@ import numpy as np
import torch import torch
import os import os
import h5py import h5py
import re
from torch.utils.data import TensorDataset, DataLoader from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms.functional as TF
import IPython import IPython
e = IPython.embed e = IPython.embed
class EpisodicDataset(torch.utils.data.Dataset): class EpisodicDataset(torch.utils.data.Dataset):
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats): def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats,
use_text_instruction=False,
instruction_mode='timestep-level',
use_cached_text_features=True,
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32,
real_action_t_minus_1=True,
image_augment=False,
image_aug_cfg=None):
super(EpisodicDataset).__init__() super(EpisodicDataset).__init__()
self.episode_ids = episode_ids self.episode_ids = episode_ids
self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.camera_names = camera_names self.camera_names = camera_names
self.norm_stats = norm_stats self.norm_stats = norm_stats
self.use_text_instruction = use_text_instruction
self.instruction_mode = instruction_mode
self.use_cached_text_features = use_cached_text_features
self.text_feature_dim = text_feature_dim
self.text_max_length = text_max_length
self.real_action_t_minus_1 = real_action_t_minus_1
self.image_augment = image_augment
self.image_aug_cfg = {
'p_color': 0.4,
'p_highlight': 0.3,
'p_noise': 0.35,
'p_blur': 0.15,
'brightness': 0.12,
'contrast': 0.12,
'saturation': 0.12,
'hue': 0.03,
'highlight_strength': (0.08, 0.25),
'noise_std': (0.003, 0.015),
'blur_sigma': (0.1, 0.8),
'blur_kernel_choices': (3, ),
}
if image_aug_cfg is not None:
self.image_aug_cfg.update(image_aug_cfg)
self.is_sim = None self.is_sim = None
self.max_episode_len = None
self.action_dim = None
self.text_tokenizer = None
if self.use_text_instruction:
try:
from transformers import DistilBertTokenizerFast
except ImportError as exc:
raise ImportError(
'transformers is required for text instruction loading. '
'Install it with: pip install transformers'
) from exc
self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name)
self._init_episode_shapes()
self.__getitem__(0) # initialize self.is_sim self.__getitem__(0) # initialize self.is_sim
def _apply_image_augmentation(self, all_cam_images):
"""
Apply identical augmentation parameters to all camera images for one sample.
all_cam_images: np.ndarray [K, H, W, C], uint8
"""
imgs = torch.from_numpy(all_cam_images).float() / 255.0
imgs = torch.einsum('k h w c -> k c h w', imgs)
cfg = self.image_aug_cfg
# color jitter (shared params)
if np.random.rand() < cfg['p_color']:
b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness'])
c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast'])
s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation'])
h = np.random.uniform(-cfg['hue'], cfg['hue'])
for cam_idx in range(imgs.shape[0]):
img = imgs[cam_idx]
img = TF.adjust_brightness(img, b)
img = TF.adjust_contrast(img, c)
img = TF.adjust_saturation(img, s)
img = TF.adjust_hue(img, h)
imgs[cam_idx] = img
# synthetic highlight / glare (shared parameters)
if np.random.rand() < cfg['p_highlight']:
_, h_img, w_img = imgs[0].shape
cx = np.random.uniform(0.2 * w_img, 0.8 * w_img)
cy = np.random.uniform(0.2 * h_img, 0.8 * h_img)
sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img)
strength = np.random.uniform(*cfg['highlight_strength'])
yy, xx = torch.meshgrid(
torch.arange(h_img, dtype=torch.float32),
torch.arange(w_img, dtype=torch.float32),
indexing='ij',
)
gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma))
gauss = (gauss * strength).unsqueeze(0)
imgs = imgs + gauss
# gaussian noise
if np.random.rand() < cfg['p_noise']:
noise_std = np.random.uniform(*cfg['noise_std'])
imgs = imgs + torch.randn_like(imgs) * noise_std
# gaussian blur
if np.random.rand() < cfg['p_blur']:
kernel = int(np.random.choice(cfg['blur_kernel_choices']))
sigma = float(np.random.uniform(*cfg['blur_sigma']))
for cam_idx in range(imgs.shape[0]):
imgs[cam_idx] = TF.gaussian_blur(
imgs[cam_idx],
kernel_size=[kernel, kernel],
sigma=[sigma, sigma],
)
imgs = imgs.clamp(0.0, 1.0)
imgs = torch.einsum('k c h w -> k h w c', imgs)
imgs = (imgs * 255.0).byte().cpu().numpy()
return imgs
def _init_episode_shapes(self):
max_len = 0
action_dim = None
for episode_id in self.episode_ids:
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root:
shape = root['/action'].shape
if len(shape) != 2:
raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}')
max_len = max(max_len, int(shape[0]))
if action_dim is None:
action_dim = int(shape[1])
elif int(shape[1]) != action_dim:
raise ValueError(
f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}'
)
if max_len <= 0 or action_dim is None:
raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}')
self.max_episode_len = max_len
self.action_dim = action_dim
@staticmethod
def _decode_instruction(raw_value):
if raw_value is None:
return ''
if isinstance(raw_value, bytes):
return raw_value.decode('utf-8')
if isinstance(raw_value, np.bytes_):
return raw_value.tobytes().decode('utf-8')
if isinstance(raw_value, np.ndarray):
if raw_value.shape == ():
return EpisodicDataset._decode_instruction(raw_value.item())
if raw_value.size == 0:
return ''
return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0])
return str(raw_value)
def __len__(self): def __len__(self):
return len(self.episode_ids) return len(self.episode_ids)
@@ -26,7 +175,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
episode_id = self.episode_ids[index] episode_id = self.episode_ids[index]
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5') dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
with h5py.File(dataset_path, 'r') as root: with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim'] is_sim = bool(root.attrs.get('sim', False))
original_action_shape = root['/action'].shape original_action_shape = root['/action'].shape
episode_len = original_action_shape[0] episode_len = original_action_shape[0]
if sample_full_episode: if sample_full_episode:
@@ -35,29 +184,62 @@ class EpisodicDataset(torch.utils.data.Dataset):
start_ts = np.random.choice(episode_len) start_ts = np.random.choice(episode_len)
# get observation at start_ts only # get observation at start_ts only
qpos = root['/observations/qpos'][start_ts] qpos = root['/observations/qpos'][start_ts]
qvel = root['/observations/qvel'][start_ts]
image_dict = dict() image_dict = dict()
for cam_name in self.camera_names: for cam_name in self.camera_names:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts] image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
instruction = ''
text_feature = None
if self.use_text_instruction:
effective_mode = self.instruction_mode
if effective_mode == 'timestep-level' and '/instruction_timestep' in root:
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
elif '/instruction' in root:
instruction_node = root['/instruction']
if getattr(instruction_node, 'shape', ()) == ():
instruction = self._decode_instruction(instruction_node[()])
else:
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
instruction = self._decode_instruction(instruction_node[start_ts])
else:
instruction = self._decode_instruction(instruction_node[0])
if self.use_cached_text_features:
if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root:
text_feature = root['/instruction_features_timestep'][start_ts]
elif '/instruction_features' in root:
feat_node = root['/instruction_features']
if getattr(feat_node, 'shape', ()) == ():
text_feature = np.array(feat_node[()])
elif len(feat_node.shape) == 1:
text_feature = feat_node[()]
elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len:
text_feature = feat_node[start_ts]
else:
text_feature = feat_node[0]
# get all actions after and including start_ts # get all actions after and including start_ts
if is_sim: if is_sim:
action = root['/action'][start_ts:] action = root['/action'][start_ts:]
action_len = episode_len - start_ts action_len = episode_len - start_ts
else: else:
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned action = root['/action'][action_start:]
action_len = episode_len - action_start
self.is_sim = is_sim self.is_sim = is_sim
padded_action = np.zeros(original_action_shape, dtype=np.float32) padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
padded_action[:action_len] = action padded_action[:action_len] = action
is_pad = np.zeros(episode_len) is_pad = np.ones(self.max_episode_len)
is_pad[action_len:] = 1 is_pad[:action_len] = 0
# new axis for different cameras # new axis for different cameras
all_cam_images = [] all_cam_images = []
for cam_name in self.camera_names: for cam_name in self.camera_names:
all_cam_images.append(image_dict[cam_name]) all_cam_images.append(image_dict[cam_name])
all_cam_images = np.stack(all_cam_images, axis=0) all_cam_images = np.stack(all_cam_images, axis=0)
if self.image_augment:
all_cam_images = self._apply_image_augmentation(all_cam_images)
# construct observations # construct observations
image_data = torch.from_numpy(all_cam_images) image_data = torch.from_numpy(all_cam_images)
@@ -73,55 +255,146 @@ class EpisodicDataset(torch.utils.data.Dataset):
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"] action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"] qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
return image_data, qpos_data, action_data, is_pad if self.use_text_instruction and text_feature is not None:
text_feature_data = torch.from_numpy(np.array(text_feature)).float()
text_feature_valid = torch.tensor(True, dtype=torch.bool)
text_input_ids = torch.zeros(1, dtype=torch.long)
text_attention_mask = torch.zeros(1, dtype=torch.long)
elif self.use_text_instruction:
tokenized = self.text_tokenizer(
instruction,
padding='max_length',
truncation=True,
max_length=self.text_max_length,
return_tensors='pt',
)
text_input_ids = tokenized['input_ids'].squeeze(0).long()
text_attention_mask = tokenized['attention_mask'].squeeze(0).long()
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
text_feature_valid = torch.tensor(False, dtype=torch.bool)
else:
text_input_ids = torch.zeros(1, dtype=torch.long)
text_attention_mask = torch.zeros(1, dtype=torch.long)
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
text_feature_valid = torch.tensor(False, dtype=torch.bool)
return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid
def get_norm_stats(dataset_dir, num_episodes): def _discover_episode_ids(dataset_dir, num_episodes=None):
pattern = re.compile(r'^episode_(\d+)\.hdf5$')
episode_ids = []
for fname in os.listdir(dataset_dir):
m = pattern.match(fname)
if m:
episode_ids.append(int(m.group(1)))
episode_ids.sort()
if num_episodes is not None:
episode_ids = episode_ids[:num_episodes]
return episode_ids
def get_norm_stats(dataset_dir, episode_ids):
all_qpos_data = [] all_qpos_data = []
all_action_data = [] all_action_data = []
for episode_idx in range(num_episodes): example_qpos = None
for episode_idx in episode_ids:
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5') dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
with h5py.File(dataset_path, 'r') as root: with h5py.File(dataset_path, 'r') as root:
qpos = root['/observations/qpos'][()] qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
action = root['/action'][()] action = root['/action'][()]
all_qpos_data.append(torch.from_numpy(qpos)) qpos_t = torch.from_numpy(qpos)
all_action_data.append(torch.from_numpy(action)) action_t = torch.from_numpy(action)
all_qpos_data = torch.stack(all_qpos_data) all_qpos_data.append(qpos_t)
all_action_data = torch.stack(all_action_data) all_action_data.append(action_t)
all_action_data = all_action_data if example_qpos is None and len(qpos) > 0:
example_qpos = qpos[0]
# Episodes may have different lengths; concatenate over time axis.
all_qpos_data = torch.cat(all_qpos_data, dim=0)
all_action_data = torch.cat(all_action_data, dim=0)
# normalize action data # normalize action data
action_mean = all_action_data.mean(dim=[0, 1], keepdim=True) action_mean = all_action_data.mean(dim=0, keepdim=True)
action_std = all_action_data.std(dim=[0, 1], keepdim=True) action_std = all_action_data.std(dim=0, keepdim=True)
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
# normalize qpos data # normalize qpos data
qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True) qpos_mean = all_qpos_data.mean(dim=0, keepdim=True)
qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True) qpos_std = all_qpos_data.std(dim=0, keepdim=True)
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(), stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(), "qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
"example_qpos": qpos} "example_qpos": example_qpos}
return stats return stats
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val,
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val): use_text_instruction=False,
instruction_mode='timestep-level',
use_cached_text_features=True,
text_feature_dim=768,
text_tokenizer_name='distilbert-base-uncased',
text_max_length=32,
real_action_t_minus_1=True,
image_augment=False,
image_aug_cfg=None):
print(f'\nData from: {dataset_dir}\n') print(f'\nData from: {dataset_dir}\n')
# obtain train test split episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
train_ratio = 0.8 if len(episode_ids) == 0:
shuffled_indices = np.random.permutation(num_episodes) raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
train_indices = shuffled_indices[:int(train_ratio * num_episodes)]
val_indices = shuffled_indices[int(train_ratio * num_episodes):] # obtain train/val split
if len(episode_ids) == 1:
# sanity-check mode: reuse the same episode for both train and val
# so training/evaluation loops remain unchanged.
train_episode_ids = np.array(episode_ids)
val_episode_ids = np.array(episode_ids)
print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).')
else:
train_ratio = 0.9
shuffled_indices = np.random.permutation(len(episode_ids))
train_count = int(train_ratio * len(episode_ids))
train_count = max(1, min(len(episode_ids) - 1, train_count))
train_indices = shuffled_indices[:train_count]
val_indices = shuffled_indices[train_count:]
train_episode_ids = np.array(episode_ids)[train_indices]
val_episode_ids = np.array(episode_ids)[val_indices]
# obtain normalization stats for qpos and action # obtain normalization stats for qpos and action
norm_stats = get_norm_stats(dataset_dir, num_episodes) norm_stats = get_norm_stats(dataset_dir, episode_ids)
# construct dataset and dataloader # construct dataset and dataloader
train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats) train_dataset = EpisodicDataset(
val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats) train_episode_ids,
dataset_dir,
camera_names,
norm_stats,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=image_augment,
image_aug_cfg=image_aug_cfg,
)
val_dataset = EpisodicDataset(
val_episode_ids,
dataset_dir,
camera_names,
norm_stats,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
real_action_t_minus_1=real_action_t_minus_1,
image_augment=False,
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1) val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)

View File

@@ -21,32 +21,33 @@ def load_hdf5(dataset_dir, dataset_name):
with h5py.File(dataset_path, 'r') as root: with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim'] is_sim = root.attrs['sim']
dt = float(root.attrs.get('dt', DT))
qpos = root['/observations/qpos'][()] qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()] qvel = root['/observations/qvel'][()] if '/observations/qvel' in root else np.zeros_like(qpos)
action = root['/action'][()] action = root['/action'][()]
image_dict = dict() image_dict = dict()
for cam_name in root[f'/observations/images/'].keys(): for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()] image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
return qpos, qvel, action, image_dict return qpos, qvel, action, image_dict, dt
def main(args): def main(args):
dataset_dir = args['dataset_dir'] dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx'] episode_idx = args['episode_idx']
dataset_name = f'episode_{episode_idx}' dataset_name = f'episode_{episode_idx}'
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name) qpos, qvel, action, image_dict, dt = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4')) save_videos(image_dict, dt, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png')) visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back # visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
def save_videos(video, dt, video_path=None): def save_videos(video, dt, video_path):
if isinstance(video, list): if isinstance(video, list):
cam_names = list(video[0].keys()) cam_names = list(video[0].keys())
h, w, _ = video[0][cam_names[0]].shape h, w, _ = video[0][cam_names[0]].shape
w = w * len(cam_names) w = w * len(cam_names)
fps = int(1/dt) fps = max(1, int(round(1 / dt)))
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for ts, image_dict in enumerate(video): for ts, image_dict in enumerate(video):
images = [] images = []
@@ -66,7 +67,7 @@ def save_videos(video, dt, video_path=None):
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
n_frames, h, w, _ = all_cam_videos.shape n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt) fps = max(1, int(round(1 / dt)))
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames): for t in range(n_frames):
image = all_cam_videos[t] image = all_cam_videos[t]