Compare commits
10 Commits
2133db326e
...
96d19c0ffc
| Author | SHA1 | Date | |
|---|---|---|---|
| 96d19c0ffc | |||
| 81e1bf8838 | |||
| 88d0cc5ca2 | |||
| d85cce8a52 | |||
| ee257bcb6c | |||
| 7023d5dde4 | |||
| 88d14221ae | |||
| b701d939c2 | |||
| ba006e14c4 | |||
|
|
d4b4d554f8 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -4,6 +4,10 @@ wandb
|
|||||||
outputs
|
outputs
|
||||||
data
|
data
|
||||||
data_local
|
data_local
|
||||||
|
ckpt
|
||||||
|
*.ckpt
|
||||||
|
*.pt
|
||||||
|
*.pth
|
||||||
.vscode
|
.vscode
|
||||||
_wandb
|
_wandb
|
||||||
|
|
||||||
|
|||||||
297
ENDOSCOPE_ACT_ADAPTATION_PLAN.md
Normal file
297
ENDOSCOPE_ACT_ADAPTATION_PLAN.md
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
# ACT 仓库适配内镜机器人(2-DOF + 图像/qpos/Text 指令)修改清单(仅训练版)
|
||||||
|
|
||||||
|
## 1. 目标与约束
|
||||||
|
|
||||||
|
### 目标
|
||||||
|
将当前标准 ACT 仓库改造成可用于你的内镜机器人离线训练,支持:
|
||||||
|
- **动作维度仅 2**(2 个电机)
|
||||||
|
- **不依赖 Gym / 仿真环境**
|
||||||
|
- 输入为 **图像 + qpos + text instruction**
|
||||||
|
- 以离线数据训练为主(本阶段不包含真实机器人在线接口)
|
||||||
|
|
||||||
|
### 约束
|
||||||
|
当前代码默认是 ALOHA 双臂(14 维状态/动作)和 sim/real ALOHA 环境接口,且**没有 text 分支**,存在大量硬编码,需要系统性改造。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 现有代码中的关键硬编码(必须改)
|
||||||
|
|
||||||
|
1. **状态/动作维度硬编码为 14**
|
||||||
|
- `imitate_episodes.py` 中 `state_dim = 14`
|
||||||
|
- `detr/models/detr_vae.py` 中多个 `nn.Linear(14, ...)`
|
||||||
|
- `record_sim_episodes.py` 的数据写入 shape 固定 `(T, 14)`
|
||||||
|
|
||||||
|
2. **训练/评估流程绑定 sim 或 aloha_scripts real_env**
|
||||||
|
- `imitate_episodes.py` 的 `eval_bc()` 依赖 `sim_env.make_sim_env()` 或 `aloha_scripts.real_env`
|
||||||
|
|
||||||
|
3. **数据加载器默认字段是 qpos/qvel/action + images**
|
||||||
|
- `utils.py` 的 `EpisodicDataset` 仅加载 `qpos`、`action`、`images`,无 text
|
||||||
|
|
||||||
|
4. **模型只融合图像 + 状态,无文本编码**
|
||||||
|
- `policy.py` 与 `detr/models/detr_vae.py` 当前无 text 输入通道
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 必改模块清单(按文件)
|
||||||
|
|
||||||
|
## A. 配置与任务定义
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `constants.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
- 新增内镜任务配置(建议新字典 `ENDOSCOPE_TASK_CONFIGS`):
|
||||||
|
- `dataset_dir`
|
||||||
|
- `num_episodes`
|
||||||
|
- `episode_len`
|
||||||
|
- `camera_names`
|
||||||
|
- `state_dim=2`
|
||||||
|
- `action_dim=2`
|
||||||
|
- `use_text_instruction=True`
|
||||||
|
- `instruction_mode`(episode-level / timestep-level)
|
||||||
|
- `text_encoder_type="distilbert"`
|
||||||
|
- `text_feature_dim=768`
|
||||||
|
- `text_fusion_type="concat_transformer_input"`
|
||||||
|
- 避免继续依赖 `sim_` 前缀来判断任务类型。
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
把任务参数从 ALOHA 默认值中解耦,作为后续训练与模型构建的统一入口。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## B. 数据协议与数据集加载
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `utils.py`
|
||||||
|
- (新增)`dataset_tools/` 下的数据转换脚本
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
1. **数据协议统一定义(建议 HDF5)**
|
||||||
|
- `/observations/images/<cam_name>`: `(T, H, W, C)`
|
||||||
|
- `/observations/qpos`: `(T, 2)`
|
||||||
|
- `/action`: `(T, 2)`
|
||||||
|
- `/instruction`(字符串或 token)
|
||||||
|
- 可选:`/instruction_timestep`(若每步指令不同)
|
||||||
|
|
||||||
|
2. **重构 `EpisodicDataset`**
|
||||||
|
- 保持 `qpos` 命名,不做重命名
|
||||||
|
- 加载 text instruction
|
||||||
|
- 返回训练样本改为:
|
||||||
|
- `image_data`
|
||||||
|
- `qpos_data`
|
||||||
|
- `action_data`
|
||||||
|
- `is_pad`
|
||||||
|
- `text_input_ids`
|
||||||
|
- `text_attention_mask`
|
||||||
|
|
||||||
|
3. **归一化统计扩展**
|
||||||
|
- `get_norm_stats()` 支持 `qpos/action` 任意维度(本任务均为 2)
|
||||||
|
- text 采用在线 DistilBERT 编码(默认),可选缓存特征
|
||||||
|
|
||||||
|
4. **兼容性策略**
|
||||||
|
- 保持对旧字段 `qpos` 的直接兼容
|
||||||
|
- 支持多相机或单相机
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
构建与你真实数据一致的数据管线,彻底摆脱 14 维与仿真字段假设。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## C. 训练入口与流程控制
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `imitate_episodes.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
1. **配置读取改造**
|
||||||
|
- 使用新任务配置读取 `state_dim=2/action_dim=2/camera_names/use_text_instruction`
|
||||||
|
|
||||||
|
2. **移除/隔离仿真耦合逻辑(训练范围内)**
|
||||||
|
- `main()` 保留纯离线训练路径
|
||||||
|
- `eval_bc()` 仅保留离线评估路径
|
||||||
|
|
||||||
|
3. **前向输入变更**
|
||||||
|
- `forward_pass()` 改为支持 text
|
||||||
|
- dataloader batch 解包增加 text 分量
|
||||||
|
|
||||||
|
4. **命令行参数补充**
|
||||||
|
- `--task_config` 或 `--task_name` 对应新配置
|
||||||
|
- `--text_encoder_type`
|
||||||
|
- `--freeze_text_encoder`
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
让训练脚本成为“真实机器人离线模仿学习”的统一入口,而不是 sim demo 入口。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## D. 策略封装层(Policy API)
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `policy.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
1. `ACTPolicy.__call__()` 和 `CNNMLPPolicy.__call__()` 签名扩展:
|
||||||
|
- 现有:`(qpos, image, actions=None, is_pad=None)`
|
||||||
|
- 目标:`(qpos, image, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
|
||||||
|
|
||||||
|
2. 图像归一化保留,但要确保支持任意相机数量。
|
||||||
|
|
||||||
|
3. loss 计算保持一致,同时确保 text 缺失时可降级运行(便于 ablation)。
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
把 text 从数据层顺畅传递到模型层。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## E. 模型构建参数与入口
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `detr/main.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
- 新增模型参数:
|
||||||
|
- `state_dim`
|
||||||
|
- `action_dim`
|
||||||
|
- `use_text`
|
||||||
|
- `text_encoder_type="distilbert"`
|
||||||
|
- `text_feature_dim=768`
|
||||||
|
- `text_fusion_type="concat_transformer_input"`
|
||||||
|
- 删除或弱化与原脚本无关的占位参数。
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
将所有硬编码维度下放为可配置项,便于后续迭代。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## F. ACT 主干网络(核心改造)
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `detr/models/detr_vae.py`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
1. **去掉 14 维硬编码**
|
||||||
|
- `input_proj_robot_state = nn.Linear(state_dim, hidden_dim)`
|
||||||
|
- `encoder_action_proj = nn.Linear(action_dim, hidden_dim)`
|
||||||
|
- `encoder_joint_proj = nn.Linear(state_dim, hidden_dim)`
|
||||||
|
- `action_head = nn.Linear(hidden_dim, action_dim)`
|
||||||
|
|
||||||
|
2. **加入 text 分支**
|
||||||
|
- 使用 DistilBERT 输出特征(768 维)
|
||||||
|
- 新增 text 投影层:`nn.Linear(768, hidden_dim)`
|
||||||
|
- 融合策略固定为:**将 text token/特征作为额外 token,直接 concat 到 Transformer 输入序列**
|
||||||
|
|
||||||
|
3. **前向函数增加 text 输入**
|
||||||
|
- `forward(self, qpos, image, env_state, text_input_ids=None, text_attention_mask=None, actions=None, is_pad=None)`
|
||||||
|
|
||||||
|
4. **保持训练/推理双模式一致**
|
||||||
|
- 训练:动作序列 + text 条件 VAE
|
||||||
|
- 推理:先验采样 + text 条件生成
|
||||||
|
|
||||||
|
### 目的
|
||||||
|
把 ACT 从“图像+14 维关节”模型改造成“图像+2 维 qpos+文本条件”模型。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## G. 真实机器人接口与采集(暂不纳入本轮)
|
||||||
|
|
||||||
|
本轮仅做训练侧改造,以下内容延期:
|
||||||
|
- 在线推理接口
|
||||||
|
- 真实机器人数据采集脚本
|
||||||
|
- 在线安全控制与频率控制
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## H. 文档与脚本
|
||||||
|
|
||||||
|
### 文件
|
||||||
|
- `README.md`
|
||||||
|
- (新增)`docs/endoscope_data_format.md`
|
||||||
|
- (新增)`docs/endoscope_train_eval.md`
|
||||||
|
|
||||||
|
### 修改点
|
||||||
|
- 给出最小可运行流程:
|
||||||
|
1) 准备数据
|
||||||
|
2) 训练命令
|
||||||
|
3) 离线评估
|
||||||
|
- 明确 text instruction 的格式规范。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 建议新增文件(清单)
|
||||||
|
|
||||||
|
- `configs/endoscope_task.yaml`(或继续用 python dict)
|
||||||
|
- `dataset_tools/convert_endoscope_to_act_hdf5.py`
|
||||||
|
- `dataset_tools/validate_endoscope_dataset.py`
|
||||||
|
- `models/text_encoder.py`(DistilBERT 封装)
|
||||||
|
- `docs/endoscope_data_format.md`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 分阶段实施顺序(建议)
|
||||||
|
|
||||||
|
### Phase 1:先跑通“无 text”2-DOF 版本
|
||||||
|
- 改 `state_dim/action_dim`
|
||||||
|
- 跑通数据加载 + 训练 + 离线验证
|
||||||
|
|
||||||
|
### Phase 2:加入 text instruction
|
||||||
|
- 数据协议加入 instruction
|
||||||
|
- 接入 DistilBERT(768)
|
||||||
|
- 按 `concat_transformer_input` 完成 text 融合训练
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 验收标准(Definition of Done)
|
||||||
|
|
||||||
|
1. 可用你的 HDF5 数据直接训练,不依赖 sim/gym。
|
||||||
|
2. 模型输入同时支持图像、2D qpos、text instruction。
|
||||||
|
3. Text 编码器使用 DistilBERT,输出特征维度为 768。
|
||||||
|
4. Text 融合方式为 Transformer 输入级 concat。
|
||||||
|
5. README 有完整训练与离线评估命令示例,团队可复现。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 风险点与提前规避
|
||||||
|
|
||||||
|
1. **text 与动作时序对齐问题**
|
||||||
|
- 需明确 instruction 是 episode-level 还是 timestep-level。
|
||||||
|
|
||||||
|
2. **小维度控制下的动作抖动**
|
||||||
|
- 可在后处理中加入 low-pass / action smoothing。
|
||||||
|
|
||||||
|
3. **多模态尺度不平衡**
|
||||||
|
- 需关注图像/状态/text 融合后梯度主导问题(可加 modality dropout 或 loss 权重调节)。
|
||||||
|
|
||||||
|
4. **文本编码开销导致训练变慢**
|
||||||
|
- 可选缓存 DistilBERT 特征,或冻结 text encoder。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 你接下来只需提供的最小信息(进入代码改造前)
|
||||||
|
|
||||||
|
1. 2 个电机各自的物理含义与取值范围(单位、上下限):电机分别为 motor_x 和 motor_y,x 的范围为 7000-17384,y 的范围为 8000-18884。对应的 action_x 和 action_y 都为 0~65535 之间
|
||||||
|
2. 你当前数据中 `qpos` 和 `action` 的实际定义(是否相同):action 和 qpos 定义接近,只不过是将 0~65535 分别映射到电机磁编码器数值上。
|
||||||
|
3. text instruction 是每个 episode 一条,还是每个 timestep 一条:text instruction 是每个 timestep 一条。
|
||||||
|
4. 相机数量、分辨率、帧率:相机数量为 1,分辨率为 224*224,帧率为 30Hz;对应的电机控制频率也为 30Hz
|
||||||
|
5. 是否在训练时冻结 DistilBERT(`freeze_text_encoder=True/False`):DistilBERT 完全冻结。构建训练集时,先将每一个 frame 的 text instruction 用 DistilBERT 编码以后再保存。这样训练过程中不需要调用 DistilBERT。
|
||||||
|
|
||||||
|
> 有了这 5 项,即可进入下一步代码改造。
|
||||||
|
|
||||||
|
text_input_ids、text_attention_mask 什么意思;'instruction_mode': 'episode-level'没有用到;
|
||||||
|
|
||||||
|
```python
|
||||||
|
instruction = ''
|
||||||
|
if self.use_text_instruction:
|
||||||
|
if '/instruction_timestep' in root:
|
||||||
|
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
|
||||||
|
elif '/instruction' in root:
|
||||||
|
instruction_node = root['/instruction']
|
||||||
|
if getattr(instruction_node, 'shape', ()) == ():
|
||||||
|
instruction = self._decode_instruction(instruction_node[()])
|
||||||
|
else:
|
||||||
|
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
|
||||||
|
instruction = self._decode_instruction(instruction_node[start_ts])
|
||||||
|
else:
|
||||||
|
instruction = self._decode_instruction(instruction_node[0])
|
||||||
|
```
|
||||||
|
|
||||||
|
为什么修改了 Transformer 的定义?这里是否会生效?
|
||||||
@@ -35,8 +35,8 @@ You can find all scripted/human demo for simulated environments [here](https://d
|
|||||||
pip install pyyaml
|
pip install pyyaml
|
||||||
pip install rospkg
|
pip install rospkg
|
||||||
pip install pexpect
|
pip install pexpect
|
||||||
pip install mujoco
|
pip install mujoco==2.3.7
|
||||||
pip install dm_control
|
pip install dm_control==1.0.14
|
||||||
pip install opencv-python
|
pip install opencv-python
|
||||||
pip install matplotlib
|
pip install matplotlib
|
||||||
pip install einops
|
pip install einops
|
||||||
|
|||||||
680
build_endoscope_act_dataset.py
Normal file
680
build_endoscope_act_dataset.py
Normal file
@@ -0,0 +1,680 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CropBox:
|
||||||
|
x1: int
|
||||||
|
y1: int
|
||||||
|
x2: int
|
||||||
|
y2: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w(self) -> int:
|
||||||
|
return self.x2 - self.x1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def h(self) -> int:
|
||||||
|
return self.y2 - self.y1
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=(
|
||||||
|
"Convert endoscope raw data (frames + json + csv) to ACT-compatible HDF5 episode(s)."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--segment_dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to one raw segment, e.g. data/follow_seg_001",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Output dir for episode_*.hdf5",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--episode_idx",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Output episode index (default: 0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_frames",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Use first N frames from this segment; <=0 means use all aligned frames (default: -1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--camera_name",
|
||||||
|
type=str,
|
||||||
|
default="top",
|
||||||
|
help="Camera name written to /observations/images/<camera_name>",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--crop",
|
||||||
|
type=int,
|
||||||
|
nargs=4,
|
||||||
|
default=[733, 30, 1754, 1051],
|
||||||
|
metavar=("X1", "Y1", "X2", "Y2"),
|
||||||
|
help="Crop box in original image coordinates",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resize",
|
||||||
|
type=int,
|
||||||
|
nargs=2,
|
||||||
|
default=[224, 224],
|
||||||
|
metavar=("W", "H"),
|
||||||
|
help="Output image size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--instruction_template",
|
||||||
|
type=str,
|
||||||
|
default="Move toward the {label} at {region}.",
|
||||||
|
help="Template for per-frame instruction",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--instruction_empty",
|
||||||
|
type=str,
|
||||||
|
default="No target visible.",
|
||||||
|
help="Instruction when no valid target after crop",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop_instruction",
|
||||||
|
type=str,
|
||||||
|
default="Stop move.",
|
||||||
|
help="Instruction used for stationary head/tail frames",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--state_norm",
|
||||||
|
choices=["minus1_1", "0_1", "raw"],
|
||||||
|
default="minus1_1",
|
||||||
|
help="Normalization for qpos (motor_pos_y, motor_pos_x)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--action_norm",
|
||||||
|
choices=["minus1_1", "0_1", "raw"],
|
||||||
|
default="minus1_1",
|
||||||
|
help="Normalization for action (motor_command_0, motor_command_1)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encode_text_features",
|
||||||
|
action="store_true",
|
||||||
|
help="Encode per-frame instruction into 768-dim DistilBERT features",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_model_name",
|
||||||
|
type=str,
|
||||||
|
default="distilbert-base-uncased",
|
||||||
|
help="HuggingFace model name for DistilBERT",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Batch size for text feature extraction",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--motion_window",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Sliding window size used for stationary detection at beginning/end",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--motion_threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.002,
|
||||||
|
help=(
|
||||||
|
"Motion threshold in normalized delta space (0~1). "
|
||||||
|
"Smaller value means stricter stationary detection"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable_stop_override",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable head/tail stationary instruction override",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trim_stationary_edges",
|
||||||
|
action="store_true",
|
||||||
|
help="Trim stationary head/tail segments and keep only the middle moving segment",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_text_instruction",
|
||||||
|
action="store_true",
|
||||||
|
help="Do not save instruction/instruction_timestep (and disable text feature encoding)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def sorted_frame_jsons(frames_dir: Path) -> List[Path]:
|
||||||
|
json_files = list(frames_dir.glob("*.json"))
|
||||||
|
|
||||||
|
def key_fn(p: Path) -> Tuple[int, str]:
|
||||||
|
m = re.search(r"frame_(\d+)", p.name)
|
||||||
|
idx = int(m.group(1)) if m else 10**9
|
||||||
|
return idx, p.name
|
||||||
|
|
||||||
|
json_files.sort(key=key_fn)
|
||||||
|
return json_files
|
||||||
|
|
||||||
|
|
||||||
|
def load_csv_rows(csv_path: Path) -> List[Dict[str, str]]:
|
||||||
|
with csv_path.open("r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
return list(reader)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_value(x: np.ndarray, min_v: float, max_v: float, mode: str) -> np.ndarray:
|
||||||
|
if mode == "raw":
|
||||||
|
return x.astype(np.float32)
|
||||||
|
x01 = (x - min_v) / (max_v - min_v)
|
||||||
|
if mode == "0_1":
|
||||||
|
return x01.astype(np.float32)
|
||||||
|
# minus1_1
|
||||||
|
return (x01 * 2.0 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def clip_bbox_to_crop(
|
||||||
|
x_min: float,
|
||||||
|
y_min: float,
|
||||||
|
x_max: float,
|
||||||
|
y_max: float,
|
||||||
|
crop: CropBox,
|
||||||
|
) -> Optional[Tuple[float, float, float, float]]:
|
||||||
|
nx1 = max(x_min - crop.x1, 0.0)
|
||||||
|
ny1 = max(y_min - crop.y1, 0.0)
|
||||||
|
nx2 = min(x_max - crop.x1, float(crop.w - 1))
|
||||||
|
ny2 = min(y_max - crop.y1, float(crop.h - 1))
|
||||||
|
if nx2 <= nx1 or ny2 <= ny1:
|
||||||
|
return None
|
||||||
|
return nx1, ny1, nx2, ny2
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_center(box: Tuple[float, float, float, float]) -> Tuple[float, float]:
|
||||||
|
x1, y1, x2, y2 = box
|
||||||
|
return (x1 + x2) * 0.5, (y1 + y2) * 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def region_3x3(cx: float, cy: float, w: int, h: int) -> str:
|
||||||
|
x_bin = min(2, max(0, int(cx / (w / 3.0))))
|
||||||
|
y_bin = min(2, max(0, int(cy / (h / 3.0))))
|
||||||
|
xs = ["left", "center", "right"]
|
||||||
|
ys = ["top", "middle", "bottom"]
|
||||||
|
return f"{ys[y_bin]}-{xs[x_bin]}"
|
||||||
|
|
||||||
|
|
||||||
|
def read_shape_bbox(shape: Dict) -> Optional[Tuple[str, float, float, float, float, float]]:
|
||||||
|
points = shape.get("points", None)
|
||||||
|
label = shape.get("label", "target")
|
||||||
|
if not points or len(points) < 2:
|
||||||
|
return None
|
||||||
|
pts = np.array(points, dtype=np.float32)
|
||||||
|
x_min, y_min = float(pts[:, 0].min()), float(pts[:, 1].min())
|
||||||
|
x_max, y_max = float(pts[:, 0].max()), float(pts[:, 1].max())
|
||||||
|
area = max(0.0, x_max - x_min) * max(0.0, y_max - y_min)
|
||||||
|
return label, x_min, y_min, x_max, y_max, area
|
||||||
|
|
||||||
|
|
||||||
|
def select_target_box(annotation: Dict, crop: CropBox) -> Optional[Tuple[str, Tuple[float, float, float, float]]]:
|
||||||
|
shapes = annotation.get("shapes", [])
|
||||||
|
best = None
|
||||||
|
for shape in shapes:
|
||||||
|
parsed = read_shape_bbox(shape)
|
||||||
|
if parsed is None:
|
||||||
|
continue
|
||||||
|
label, x1, y1, x2, y2, area = parsed
|
||||||
|
clipped = clip_bbox_to_crop(x1, y1, x2, y2, crop)
|
||||||
|
if clipped is None:
|
||||||
|
continue
|
||||||
|
c_area = max(0.0, clipped[2] - clipped[0]) * max(0.0, clipped[3] - clipped[1])
|
||||||
|
if best is None or c_area > best[2]:
|
||||||
|
best = (label, clipped, c_area)
|
||||||
|
if best is None:
|
||||||
|
return None
|
||||||
|
return best[0], best[1]
|
||||||
|
|
||||||
|
|
||||||
|
def instruction_from_annotation(
|
||||||
|
annotation: Dict,
|
||||||
|
crop: CropBox,
|
||||||
|
template: str,
|
||||||
|
empty_instruction: str,
|
||||||
|
) -> str:
|
||||||
|
picked = select_target_box(annotation, crop)
|
||||||
|
if picked is None:
|
||||||
|
return empty_instruction
|
||||||
|
label, box = picked
|
||||||
|
cx, cy = bbox_center(box)
|
||||||
|
region = region_3x3(cx, cy, crop.w, crop.h)
|
||||||
|
return template.format(label=label, region=region)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_features(
|
||||||
|
instructions: Sequence[str],
|
||||||
|
model_name: str,
|
||||||
|
batch_size: int = 32,
|
||||||
|
) -> np.ndarray:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import DistilBertTokenizerFast
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Text feature encoding requires transformers. Please install: pip install transformers"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
repo_root = Path(__file__).resolve().parents[1]
|
||||||
|
if str(repo_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(repo_root))
|
||||||
|
from models.text_encoder import DistilBERTTextEncoder
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
||||||
|
model = DistilBERTTextEncoder(model_name=model_name, output_dim=768, freeze=True).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
feats: List[np.ndarray] = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(0, len(instructions), batch_size):
|
||||||
|
batch = list(instructions[i:i + batch_size])
|
||||||
|
tok = tokenizer(
|
||||||
|
batch,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=32,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
input_ids = tok["input_ids"].to(device)
|
||||||
|
attention_mask = tok["attention_mask"].to(device)
|
||||||
|
cls = model(input_ids=input_ids, attention_mask=attention_mask).detach().cpu().numpy().astype(np.float32)
|
||||||
|
feats.append(cls)
|
||||||
|
return np.concatenate(feats, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
|
||||||
|
scale = max_v - min_v
|
||||||
|
if scale <= 0:
|
||||||
|
return np.zeros_like(v, dtype=np.float32)
|
||||||
|
return ((v - min_v) / scale).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def override_stationary_edge_instructions(
|
||||||
|
instructions: List[str],
|
||||||
|
motor_pos_y: np.ndarray,
|
||||||
|
motor_pos_x: np.ndarray,
|
||||||
|
motor_cmd_0: np.ndarray,
|
||||||
|
motor_cmd_1: np.ndarray,
|
||||||
|
stop_instruction: str,
|
||||||
|
motion_window: int,
|
||||||
|
motion_threshold: float,
|
||||||
|
) -> Tuple[List[str], int, int]:
|
||||||
|
"""
|
||||||
|
Override instruction text at head/tail using qpos velocity.
|
||||||
|
Start side: once qpos speed is above threshold for consecutive frames,
|
||||||
|
stop applying stop_instruction from that point onward.
|
||||||
|
End side: similarly scan backward from the end.
|
||||||
|
"""
|
||||||
|
num = len(instructions)
|
||||||
|
if num == 0:
|
||||||
|
return instructions, 0, 0
|
||||||
|
|
||||||
|
start_count, end_count = detect_stationary_edge_counts_from_qpos(
|
||||||
|
motor_pos_y=motor_pos_y,
|
||||||
|
motor_pos_x=motor_pos_x,
|
||||||
|
motion_window=motion_window,
|
||||||
|
motion_threshold=motion_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated = list(instructions)
|
||||||
|
for i in range(start_count):
|
||||||
|
updated[i] = stop_instruction
|
||||||
|
for i in range(num - end_count, num):
|
||||||
|
updated[i] = stop_instruction
|
||||||
|
|
||||||
|
return updated, start_count, end_count
|
||||||
|
|
||||||
|
|
||||||
|
def detect_stationary_edge_counts_from_qpos(
|
||||||
|
motor_pos_y: np.ndarray,
|
||||||
|
motor_pos_x: np.ndarray,
|
||||||
|
motion_window: int,
|
||||||
|
motion_threshold: float,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Return stationary frame counts on head and tail using qpos velocity rule."""
|
||||||
|
num = int(len(motor_pos_y))
|
||||||
|
if num == 0:
|
||||||
|
return 0, 0
|
||||||
|
if num == 1:
|
||||||
|
return 1, 1
|
||||||
|
|
||||||
|
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
||||||
|
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
||||||
|
|
||||||
|
consecutive = max(1, int(motion_window))
|
||||||
|
dt = 1.0 / 30.0
|
||||||
|
|
||||||
|
frame_speed = np.zeros((num,), dtype=np.float32)
|
||||||
|
dy = np.abs(np.diff(py)) / dt
|
||||||
|
dx = np.abs(np.diff(px)) / dt
|
||||||
|
frame_speed[1:] = np.maximum(dy, dx)
|
||||||
|
|
||||||
|
high_run = 0
|
||||||
|
start_count = num
|
||||||
|
for i in range(1, num):
|
||||||
|
if frame_speed[i] > motion_threshold:
|
||||||
|
high_run += 1
|
||||||
|
if high_run >= consecutive:
|
||||||
|
start_count = i - consecutive + 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
high_run = 0
|
||||||
|
|
||||||
|
high_run = 0
|
||||||
|
end_count = num
|
||||||
|
for i in range(num - 1, 0, -1):
|
||||||
|
if frame_speed[i] > motion_threshold:
|
||||||
|
high_run += 1
|
||||||
|
if high_run >= consecutive:
|
||||||
|
tail_start = i + consecutive
|
||||||
|
end_count = max(0, num - tail_start)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
high_run = 0
|
||||||
|
|
||||||
|
return start_count, end_count
|
||||||
|
|
||||||
|
|
||||||
|
def find_segment_csv(segment_dir: Path) -> Path:
|
||||||
|
csvs = sorted(segment_dir.glob("*.csv"))
|
||||||
|
if not csvs:
|
||||||
|
raise FileNotFoundError(f"No csv file found in {segment_dir}")
|
||||||
|
return csvs[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
|
||||||
|
"""Convert boolean mask to closed-open segments [start, end)."""
|
||||||
|
segments: List[Tuple[int, int]] = []
|
||||||
|
if mask.size == 0:
|
||||||
|
return segments
|
||||||
|
in_seg = False
|
||||||
|
start = 0
|
||||||
|
for i, v in enumerate(mask.tolist()):
|
||||||
|
if v and not in_seg:
|
||||||
|
in_seg = True
|
||||||
|
start = i
|
||||||
|
elif (not v) and in_seg:
|
||||||
|
in_seg = False
|
||||||
|
segments.append((start, i))
|
||||||
|
if in_seg:
|
||||||
|
segments.append((start, int(mask.size)))
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def save_episode_plot_with_stop_segments(
|
||||||
|
qpos: np.ndarray,
|
||||||
|
action: np.ndarray,
|
||||||
|
instructions: Sequence[str],
|
||||||
|
stop_instruction: str,
|
||||||
|
plot_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
|
||||||
|
"""
|
||||||
|
qpos = np.asarray(qpos)
|
||||||
|
action = np.asarray(action)
|
||||||
|
if qpos.ndim != 2 or action.ndim != 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
num_ts, num_dim = qpos.shape
|
||||||
|
if num_ts == 0 or num_dim == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
|
||||||
|
stop_segments = _mask_to_segments(stop_mask)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
|
||||||
|
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
|
||||||
|
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs_list[dim_idx]
|
||||||
|
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
|
||||||
|
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
|
||||||
|
for seg_idx, (st, ed) in enumerate(stop_segments):
|
||||||
|
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
|
||||||
|
label='stop instruction' if seg_idx == 0 else None)
|
||||||
|
ax.set_ylabel(f'dim {dim_idx}')
|
||||||
|
ax.legend(loc='best')
|
||||||
|
ax.grid(alpha=0.25, linestyle='--')
|
||||||
|
|
||||||
|
axs_list[-1].set_xlabel('timestep')
|
||||||
|
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(str(plot_path), dpi=140)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.no_text_instruction and args.encode_text_features:
|
||||||
|
raise ValueError('--no_text_instruction and --encode_text_features cannot be used together.')
|
||||||
|
|
||||||
|
segment_dir = Path(args.segment_dir).resolve()
|
||||||
|
output_dir = Path(args.output_dir).resolve()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
frames_dir = segment_dir / "frames"
|
||||||
|
if not frames_dir.exists():
|
||||||
|
raise FileNotFoundError(f"frames dir not found: {frames_dir}")
|
||||||
|
|
||||||
|
csv_path = find_segment_csv(segment_dir)
|
||||||
|
csv_rows = load_csv_rows(csv_path)
|
||||||
|
if len(csv_rows) == 0:
|
||||||
|
raise ValueError(f"CSV has no rows: {csv_path}")
|
||||||
|
|
||||||
|
crop = CropBox(*args.crop)
|
||||||
|
resize_w, resize_h = int(args.resize[0]), int(args.resize[1])
|
||||||
|
|
||||||
|
json_files = sorted_frame_jsons(frames_dir)
|
||||||
|
if not json_files:
|
||||||
|
raise FileNotFoundError(f"No frame json found in: {frames_dir}")
|
||||||
|
|
||||||
|
max_aligned = min(len(json_files), len(csv_rows))
|
||||||
|
num = max_aligned if args.max_frames <= 0 else min(args.max_frames, max_aligned)
|
||||||
|
if num <= 0:
|
||||||
|
raise ValueError("No aligned frames available.")
|
||||||
|
|
||||||
|
images = np.zeros((num, resize_h, resize_w, 3), dtype=np.uint8)
|
||||||
|
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
|
||||||
|
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
|
||||||
|
instructions: List[str] = []
|
||||||
|
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
|
||||||
|
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
|
||||||
|
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
|
||||||
|
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
|
||||||
|
|
||||||
|
y_min, y_max = 8000.0, 18884.0
|
||||||
|
x_min, x_max = 7000.0, 17384.0
|
||||||
|
cmd_min, cmd_max = 0.0, 65535.0
|
||||||
|
|
||||||
|
for i in range(num):
|
||||||
|
json_path = json_files[i]
|
||||||
|
with json_path.open("r", encoding="utf-8") as f:
|
||||||
|
ann = json.load(f)
|
||||||
|
|
||||||
|
image_path = frames_dir / ann["imagePath"]
|
||||||
|
if not image_path.exists():
|
||||||
|
alt = json_path.with_suffix(".jpg")
|
||||||
|
if alt.exists():
|
||||||
|
image_path = alt
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"Image not found for {json_path.name}")
|
||||||
|
|
||||||
|
img = Image.open(image_path).convert("RGB")
|
||||||
|
img_crop = img.crop((crop.x1, crop.y1, crop.x2, crop.y2))
|
||||||
|
img_resize = img_crop.resize((resize_w, resize_h), RESAMPLE_BILINEAR)
|
||||||
|
images[i] = np.asarray(img_resize, dtype=np.uint8)
|
||||||
|
|
||||||
|
row = csv_rows[i]
|
||||||
|
motor_pos_y = float(row["motor_pos_y"])
|
||||||
|
motor_pos_x = float(row["motor_pos_x"])
|
||||||
|
motor_cmd_0 = float(row["motor_command_0"])
|
||||||
|
motor_cmd_1 = float(row["motor_command_1"])
|
||||||
|
|
||||||
|
motor_pos_y_series[i] = motor_pos_y
|
||||||
|
motor_pos_x_series[i] = motor_pos_x
|
||||||
|
motor_cmd_0_series[i] = motor_cmd_0
|
||||||
|
motor_cmd_1_series[i] = motor_cmd_1
|
||||||
|
|
||||||
|
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
|
||||||
|
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
|
||||||
|
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||||
|
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||||
|
|
||||||
|
if not args.no_text_instruction:
|
||||||
|
ins = instruction_from_annotation(
|
||||||
|
ann,
|
||||||
|
crop,
|
||||||
|
args.instruction_template,
|
||||||
|
args.instruction_empty,
|
||||||
|
)
|
||||||
|
instructions.append(ins)
|
||||||
|
|
||||||
|
start_stop_count, end_stop_count = detect_stationary_edge_counts_from_qpos(
|
||||||
|
motor_pos_y=motor_pos_y_series,
|
||||||
|
motor_pos_x=motor_pos_x_series,
|
||||||
|
motion_window=args.motion_window,
|
||||||
|
motion_threshold=args.motion_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.trim_stationary_edges:
|
||||||
|
keep_start = int(start_stop_count)
|
||||||
|
keep_end = int(num - end_stop_count)
|
||||||
|
if keep_end <= keep_start:
|
||||||
|
raise ValueError(
|
||||||
|
f'No moving segment left after trim: start={start_stop_count}, end={end_stop_count}, num={num}. '
|
||||||
|
f'Consider lowering --motion_threshold or --motion_window.'
|
||||||
|
)
|
||||||
|
images = images[keep_start:keep_end]
|
||||||
|
qpos = qpos[keep_start:keep_end]
|
||||||
|
action = action[keep_start:keep_end]
|
||||||
|
motor_pos_y_series = motor_pos_y_series[keep_start:keep_end]
|
||||||
|
motor_pos_x_series = motor_pos_x_series[keep_start:keep_end]
|
||||||
|
motor_cmd_0_series = motor_cmd_0_series[keep_start:keep_end]
|
||||||
|
motor_cmd_1_series = motor_cmd_1_series[keep_start:keep_end]
|
||||||
|
if not args.no_text_instruction:
|
||||||
|
instructions = instructions[keep_start:keep_end]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'Trim stationary edges: removed head={start_stop_count}, tail={end_stop_count}, '
|
||||||
|
f'kept={keep_end - keep_start}'
|
||||||
|
)
|
||||||
|
# After trimming, full kept segment is the moving region.
|
||||||
|
start_stop_count, end_stop_count = 0, 0
|
||||||
|
|
||||||
|
if (not args.disable_stop_override) and (not args.no_text_instruction):
|
||||||
|
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
|
||||||
|
instructions=instructions,
|
||||||
|
motor_pos_y=motor_pos_y_series,
|
||||||
|
motor_pos_x=motor_pos_x_series,
|
||||||
|
motor_cmd_0=motor_cmd_0_series,
|
||||||
|
motor_cmd_1=motor_cmd_1_series,
|
||||||
|
stop_instruction=args.stop_instruction,
|
||||||
|
motion_window=args.motion_window,
|
||||||
|
motion_threshold=args.motion_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_features = None
|
||||||
|
if args.encode_text_features:
|
||||||
|
text_features = extract_text_features(
|
||||||
|
instructions,
|
||||||
|
model_name=args.text_model_name,
|
||||||
|
batch_size=args.text_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_path = output_dir / f"episode_{args.episode_idx}.hdf5"
|
||||||
|
dt = 1.0 / 30.0
|
||||||
|
|
||||||
|
with h5py.File(out_path, "w") as root:
|
||||||
|
root.attrs["sim"] = False
|
||||||
|
root.attrs["source_segment"] = str(segment_dir)
|
||||||
|
root.attrs["frame_rate"] = 30
|
||||||
|
root.attrs["dt"] = dt
|
||||||
|
root.attrs["state_norm_mode"] = args.state_norm
|
||||||
|
root.attrs["action_norm_mode"] = args.action_norm
|
||||||
|
root.attrs["qpos_order"] = "[motor_pos_y, motor_pos_x]"
|
||||||
|
root.attrs["action_order"] = "[motor_command_0(y), motor_command_1(x)]"
|
||||||
|
root.attrs["crop_xyxy"] = np.array(args.crop, dtype=np.int32)
|
||||||
|
|
||||||
|
obs = root.create_group("observations")
|
||||||
|
obs.create_dataset("qpos", data=qpos, dtype=np.float32)
|
||||||
|
images_group = obs.create_group("images")
|
||||||
|
images_group.create_dataset(args.camera_name, data=images, dtype=np.uint8)
|
||||||
|
|
||||||
|
root.create_dataset("action", data=action, dtype=np.float32)
|
||||||
|
|
||||||
|
if not args.no_text_instruction:
|
||||||
|
str_dtype = h5py.string_dtype(encoding="utf-8")
|
||||||
|
root.create_dataset(
|
||||||
|
"instruction_timestep",
|
||||||
|
shape=(len(instructions),),
|
||||||
|
dtype=str_dtype,
|
||||||
|
data=np.asarray(instructions, dtype=object),
|
||||||
|
)
|
||||||
|
root.create_dataset(
|
||||||
|
"instruction",
|
||||||
|
shape=(),
|
||||||
|
dtype=str_dtype,
|
||||||
|
data=instructions[0] if len(instructions) > 0 else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_features is not None:
|
||||||
|
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
|
||||||
|
root.create_dataset("instruction_features", data=text_features[0], dtype=np.float32)
|
||||||
|
|
||||||
|
print(f"Saved: {out_path}")
|
||||||
|
print(f"Frames used: {num}")
|
||||||
|
print(f"Image shape: {images.shape}")
|
||||||
|
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
|
||||||
|
if not args.disable_stop_override:
|
||||||
|
print(
|
||||||
|
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
|
||||||
|
f"mode=qpos_velocity_consecutive, consecutive={args.motion_window}, "
|
||||||
|
f"threshold={args.motion_threshold}, "
|
||||||
|
f"instruction='{args.stop_instruction}'"
|
||||||
|
)
|
||||||
|
if text_features is not None:
|
||||||
|
print(f"instruction_features_timestep shape: {text_features.shape}")
|
||||||
|
|
||||||
|
# Save a same-basename plot next to the generated hdf5
|
||||||
|
plot_path = out_path.with_suffix('.png')
|
||||||
|
save_episode_plot_with_stop_segments(
|
||||||
|
qpos=qpos,
|
||||||
|
action=action,
|
||||||
|
instructions=instructions,
|
||||||
|
stop_instruction=args.stop_instruction,
|
||||||
|
plot_path=plot_path,
|
||||||
|
)
|
||||||
|
print(f"Saved episode plot to: {plot_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
25
build_no_text_dataset.sh
Executable file
25
build_no_text_dataset.sh
Executable file
@@ -0,0 +1,25 @@
|
|||||||
|
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/00-follow"
|
||||||
|
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/follow-no-text"
|
||||||
|
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
|
||||||
|
|
||||||
|
mkdir -p "$OUT_DIR"
|
||||||
|
i=50
|
||||||
|
for d in "$SEG_ROOT"/follow_seg_*; do
|
||||||
|
[ -d "$d" ] || continue
|
||||||
|
echo "Building $d -> episode_$i"
|
||||||
|
python "$SCRIPT" \
|
||||||
|
--segment_dir "$d" \
|
||||||
|
--output_dir "$OUT_DIR" \
|
||||||
|
--episode_idx "$i" \
|
||||||
|
--max_frames -1 \
|
||||||
|
--camera_name top \
|
||||||
|
--crop 733 30 1754 1051 \
|
||||||
|
--resize 224 224 \
|
||||||
|
--motion_window 3 \
|
||||||
|
--motion_threshold 0.05 \
|
||||||
|
--state_norm minus1_1 \
|
||||||
|
--action_norm minus1_1 \
|
||||||
|
--trim_stationary_edges \
|
||||||
|
--no_text_instruction
|
||||||
|
i=$((i+1))
|
||||||
|
done
|
||||||
29
build_text_dataset.sh
Executable file
29
build_text_dataset.sh
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/01-cannulation"
|
||||||
|
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/cannulation"
|
||||||
|
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
|
||||||
|
|
||||||
|
mkdir -p "$OUT_DIR"
|
||||||
|
i=12
|
||||||
|
for d in "$SEG_ROOT"/seg_*; do
|
||||||
|
[ -d "$d" ] || continue
|
||||||
|
echo "Building $d -> episode_$i"
|
||||||
|
python "$SCRIPT" \
|
||||||
|
--segment_dir "$d" \
|
||||||
|
--output_dir "$OUT_DIR" \
|
||||||
|
--episode_idx "$i" \
|
||||||
|
--max_frames -1 \
|
||||||
|
--camera_name top \
|
||||||
|
--crop 733 30 1754 1051 \
|
||||||
|
--resize 224 224 \
|
||||||
|
--instruction_template 'Cannulate the {label} on the phantom located at the {region} with the sphincterotome.' \
|
||||||
|
--instruction_empty 'No target visible.' \
|
||||||
|
--stop_instruction 'Stop move.' \
|
||||||
|
--motion_window 3 \
|
||||||
|
--motion_threshold 0.05 \
|
||||||
|
--state_norm minus1_1 \
|
||||||
|
--action_norm minus1_1 \
|
||||||
|
--encode_text_features \
|
||||||
|
--text_model_name distilbert-base-uncased \
|
||||||
|
--text_batch_size 32
|
||||||
|
i=$((i+1))
|
||||||
|
done
|
||||||
@@ -21,3 +21,5 @@ dependencies:
|
|||||||
- packaging=23.0
|
- packaging=23.0
|
||||||
- h5py=3.8.0
|
- h5py=3.8.0
|
||||||
- ipython=8.12.0
|
- ipython=8.12.0
|
||||||
|
- pip:
|
||||||
|
- transformers==4.38.2
|
||||||
|
|||||||
81
constants.py
81
constants.py
@@ -1,7 +1,7 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
### Task parameters
|
### Task parameters
|
||||||
DATA_DIR = '<put your data dir here>'
|
DATA_DIR = str(pathlib.Path(__file__).parent.resolve() / 'data')
|
||||||
SIM_TASK_CONFIGS = {
|
SIM_TASK_CONFIGS = {
|
||||||
'sim_transfer_cube_scripted':{
|
'sim_transfer_cube_scripted':{
|
||||||
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
|
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
|
||||||
@@ -32,6 +32,85 @@ SIM_TASK_CONFIGS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ENDOSCOPE_TASK_CONFIGS = {
|
||||||
|
'endoscope_default': {
|
||||||
|
'dataset_dir': DATA_DIR + '/endoscope_default',
|
||||||
|
'num_episodes': 50,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': True,
|
||||||
|
'instruction_mode': 'timestep-level',
|
||||||
|
'use_cached_text_features': True,
|
||||||
|
'text_encoder_type': 'distilbert',
|
||||||
|
'text_feature_dim': 768,
|
||||||
|
'text_fusion_type': 'concat_transformer_input',
|
||||||
|
'freeze_text_encoder': True,
|
||||||
|
'text_max_length': 32,
|
||||||
|
'text_tokenizer_name': 'distilbert-base-uncased',
|
||||||
|
},
|
||||||
|
'endoscope_follow': {
|
||||||
|
'dataset_dir': DATA_DIR + '/follow',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': True,
|
||||||
|
'instruction_mode': 'timestep-level',
|
||||||
|
'use_cached_text_features': True,
|
||||||
|
'text_encoder_type': 'distilbert',
|
||||||
|
'text_feature_dim': 768,
|
||||||
|
'text_fusion_type': 'concat_transformer_input',
|
||||||
|
'freeze_text_encoder': True,
|
||||||
|
'text_max_length': 32,
|
||||||
|
'text_tokenizer_name': 'distilbert-base-uncased',
|
||||||
|
},
|
||||||
|
'endoscope_both_no_text': {
|
||||||
|
'dataset_dir': DATA_DIR + '/both-no-text',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': False,
|
||||||
|
},
|
||||||
|
'endoscope_sanity_check': {
|
||||||
|
'dataset_dir': DATA_DIR + '/sanity-check',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': False,
|
||||||
|
},
|
||||||
|
'endoscope_cannulation_no_text': {
|
||||||
|
'dataset_dir': DATA_DIR + '/cannulation-no-text',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': False,
|
||||||
|
},
|
||||||
|
'endoscope_follow_no_text': {
|
||||||
|
'dataset_dir': DATA_DIR + '/follow-no-text',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
### Simulation envs fixed constants
|
### Simulation envs fixed constants
|
||||||
DT = 0.02
|
DT = 0.02
|
||||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||||
|
|||||||
21
detr/main.py
21
detr/main.py
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch.optim.adamw import AdamW
|
||||||
from .models import build_ACT_model, build_CNNMLP_model
|
from .models import build_ACT_model, build_CNNMLP_model
|
||||||
|
|
||||||
import IPython
|
import IPython
|
||||||
@@ -30,6 +31,15 @@ def get_args_parser():
|
|||||||
help="Type of positional embedding to use on top of the image features")
|
help="Type of positional embedding to use on top of the image features")
|
||||||
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
parser.add_argument('--camera_names', default=[], type=list, # will be overridden
|
||||||
help="A list of camera names")
|
help="A list of camera names")
|
||||||
|
parser.add_argument('--state_dim', default=14, type=int)
|
||||||
|
parser.add_argument('--action_dim', default=14, type=int)
|
||||||
|
parser.add_argument('--use_text', action='store_true')
|
||||||
|
parser.add_argument('--text_encoder_type', default='distilbert', type=str)
|
||||||
|
parser.add_argument('--text_feature_dim', default=768, type=int)
|
||||||
|
parser.add_argument('--text_fusion_type', default='concat_transformer_input', type=str)
|
||||||
|
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||||
|
parser.add_argument('--text_max_length', default=32, type=int)
|
||||||
|
parser.add_argument('--text_tokenizer_name', default='distilbert-base-uncased', type=str)
|
||||||
|
|
||||||
# * Transformer
|
# * Transformer
|
||||||
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
|
||||||
@@ -60,16 +70,17 @@ def get_args_parser():
|
|||||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||||
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
||||||
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
||||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||||
parser.add_argument('--temporal_agg', action='store_true')
|
parser.add_argument('--temporal_agg', action='store_true')
|
||||||
|
parser.add_argument('--image_aug', action='store_true')
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def build_ACT_model_and_optimizer(args_override):
|
def build_ACT_model_and_optimizer(args_override):
|
||||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||||
args = parser.parse_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
for k, v in args_override.items():
|
for k, v in args_override.items():
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
@@ -84,7 +95,7 @@ def build_ACT_model_and_optimizer(args_override):
|
|||||||
"lr": args.lr_backbone,
|
"lr": args.lr_backbone,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
return model, optimizer
|
return model, optimizer
|
||||||
@@ -92,7 +103,7 @@ def build_ACT_model_and_optimizer(args_override):
|
|||||||
|
|
||||||
def build_CNNMLP_model_and_optimizer(args_override):
|
def build_CNNMLP_model_and_optimizer(args_override):
|
||||||
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
||||||
args = parser.parse_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
for k, v in args_override.items():
|
for k, v in args_override.items():
|
||||||
setattr(args, k, v)
|
setattr(args, k, v)
|
||||||
@@ -107,7 +118,7 @@ def build_CNNMLP_model_and_optimizer(args_override):
|
|||||||
"lr": args.lr_backbone,
|
"lr": args.lr_backbone,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
optimizer = AdamW(param_dicts, lr=args.lr,
|
||||||
weight_decay=args.weight_decay)
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
return model, optimizer
|
return model, optimizer
|
||||||
|
|||||||
@@ -89,9 +89,32 @@ class Backbone(BackboneBase):
|
|||||||
train_backbone: bool,
|
train_backbone: bool,
|
||||||
return_interm_layers: bool,
|
return_interm_layers: bool,
|
||||||
dilation: bool):
|
dilation: bool):
|
||||||
backbone = getattr(torchvision.models, name)(
|
backbone_builder = getattr(torchvision.models, name)
|
||||||
|
weights = None
|
||||||
|
if is_main_process():
|
||||||
|
weight_enum_name_map = {
|
||||||
|
'resnet18': 'ResNet18_Weights',
|
||||||
|
'resnet34': 'ResNet34_Weights',
|
||||||
|
'resnet50': 'ResNet50_Weights',
|
||||||
|
'resnet101': 'ResNet101_Weights',
|
||||||
|
}
|
||||||
|
enum_name = weight_enum_name_map.get(name)
|
||||||
|
if enum_name is not None and hasattr(torchvision.models, enum_name):
|
||||||
|
weights = getattr(getattr(torchvision.models, enum_name), 'DEFAULT')
|
||||||
|
|
||||||
|
try:
|
||||||
|
backbone = backbone_builder(
|
||||||
replace_stride_with_dilation=[False, False, dilation],
|
replace_stride_with_dilation=[False, False, dilation],
|
||||||
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
|
weights=weights,
|
||||||
|
norm_layer=FrozenBatchNorm2d,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
# Backward compatibility for older torchvision that still expects `pretrained`.
|
||||||
|
backbone = backbone_builder(
|
||||||
|
replace_stride_with_dilation=[False, False, dilation],
|
||||||
|
pretrained=(weights is not None),
|
||||||
|
norm_layer=FrozenBatchNorm2d,
|
||||||
|
)
|
||||||
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
||||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ def get_sinusoid_encoding_table(n_position, d_hid):
|
|||||||
|
|
||||||
class DETRVAE(nn.Module):
|
class DETRVAE(nn.Module):
|
||||||
""" This is the DETR module that performs object detection """
|
""" This is the DETR module that performs object detection """
|
||||||
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names):
|
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries, camera_names,
|
||||||
|
use_text=False, text_feature_dim=768, text_fusion_type='concat_transformer_input'):
|
||||||
""" Initializes the model.
|
""" Initializes the model.
|
||||||
Parameters:
|
Parameters:
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
backbones: torch module of the backbone to be used. See backbone.py
|
||||||
@@ -48,17 +49,18 @@ class DETRVAE(nn.Module):
|
|||||||
self.camera_names = camera_names
|
self.camera_names = camera_names
|
||||||
self.transformer = transformer
|
self.transformer = transformer
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
self.use_text = use_text
|
||||||
|
self.text_fusion_type = text_fusion_type
|
||||||
hidden_dim = transformer.d_model
|
hidden_dim = transformer.d_model
|
||||||
self.action_head = nn.Linear(hidden_dim, state_dim)
|
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||||
if backbones is not None:
|
if backbones is not None:
|
||||||
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
|
||||||
self.backbones = nn.ModuleList(backbones)
|
self.backbones = nn.ModuleList(backbones)
|
||||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||||
else:
|
else:
|
||||||
# input_dim = 14 + 7 # robot_state + env_state
|
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||||
self.input_proj_robot_state = nn.Linear(14, hidden_dim)
|
|
||||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||||
self.backbones = None
|
self.backbones = None
|
||||||
@@ -66,16 +68,18 @@ class DETRVAE(nn.Module):
|
|||||||
# encoder extra parameters
|
# encoder extra parameters
|
||||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||||
self.encoder_action_proj = nn.Linear(14, hidden_dim) # project action to embedding
|
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
|
||||||
self.encoder_joint_proj = nn.Linear(14, hidden_dim) # project qpos to embedding
|
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
|
||||||
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
|
||||||
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
|
||||||
|
|
||||||
# decoder extra parameters
|
# decoder extra parameters
|
||||||
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
|
||||||
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
|
num_extra_tokens = 3 if self.use_text else 2
|
||||||
|
self.additional_pos_embed = nn.Embedding(num_extra_tokens, hidden_dim) # latent, proprio, optional text
|
||||||
|
self.text_proj = nn.Linear(text_feature_dim, hidden_dim) if self.use_text else None
|
||||||
|
|
||||||
def forward(self, qpos, image, env_state, actions=None, is_pad=None):
|
def forward(self, qpos, image, env_state, text_features=None, actions=None, is_pad=None):
|
||||||
"""
|
"""
|
||||||
qpos: batch, qpos_dim
|
qpos: batch, qpos_dim
|
||||||
image: batch, num_cam, channel, height, width
|
image: batch, num_cam, channel, height, width
|
||||||
@@ -125,10 +129,25 @@ class DETRVAE(nn.Module):
|
|||||||
all_cam_pos.append(pos)
|
all_cam_pos.append(pos)
|
||||||
# proprioception features
|
# proprioception features
|
||||||
proprio_input = self.input_proj_robot_state(qpos)
|
proprio_input = self.input_proj_robot_state(qpos)
|
||||||
|
extra_input_tokens = None
|
||||||
|
if self.use_text and text_features is not None:
|
||||||
|
if self.text_fusion_type != 'concat_transformer_input':
|
||||||
|
raise NotImplementedError(f'Unsupported text fusion type: {self.text_fusion_type}')
|
||||||
|
text_input = self.text_proj(text_features)
|
||||||
|
extra_input_tokens = text_input.unsqueeze(0)
|
||||||
# fold camera dimension into width dimension
|
# fold camera dimension into width dimension
|
||||||
src = torch.cat(all_cam_features, axis=3)
|
src = torch.cat(all_cam_features, axis=3)
|
||||||
pos = torch.cat(all_cam_pos, axis=3)
|
pos = torch.cat(all_cam_pos, axis=3)
|
||||||
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
|
hs = self.transformer(
|
||||||
|
src,
|
||||||
|
None,
|
||||||
|
self.query_embed.weight,
|
||||||
|
pos,
|
||||||
|
latent_input,
|
||||||
|
proprio_input,
|
||||||
|
self.additional_pos_embed.weight,
|
||||||
|
extra_input_tokens=extra_input_tokens,
|
||||||
|
)[0]
|
||||||
else:
|
else:
|
||||||
qpos = self.input_proj_robot_state(qpos)
|
qpos = self.input_proj_robot_state(qpos)
|
||||||
env_state = self.input_proj_env_state(env_state)
|
env_state = self.input_proj_env_state(env_state)
|
||||||
@@ -141,7 +160,7 @@ class DETRVAE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CNNMLP(nn.Module):
|
class CNNMLP(nn.Module):
|
||||||
def __init__(self, backbones, state_dim, camera_names):
|
def __init__(self, backbones, state_dim, action_dim, camera_names):
|
||||||
""" Initializes the model.
|
""" Initializes the model.
|
||||||
Parameters:
|
Parameters:
|
||||||
backbones: torch module of the backbone to be used. See backbone.py
|
backbones: torch module of the backbone to be used. See backbone.py
|
||||||
@@ -153,7 +172,7 @@ class CNNMLP(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.camera_names = camera_names
|
self.camera_names = camera_names
|
||||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
self.action_head = nn.Linear(1000, action_dim) # TODO add more
|
||||||
if backbones is not None:
|
if backbones is not None:
|
||||||
self.backbones = nn.ModuleList(backbones)
|
self.backbones = nn.ModuleList(backbones)
|
||||||
backbone_down_projs = []
|
backbone_down_projs = []
|
||||||
@@ -166,8 +185,8 @@ class CNNMLP(nn.Module):
|
|||||||
backbone_down_projs.append(down_proj)
|
backbone_down_projs.append(down_proj)
|
||||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||||
|
|
||||||
mlp_in_dim = 768 * len(backbones) + 14
|
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||||
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=14, hidden_depth=2)
|
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=action_dim, hidden_depth=2)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -192,7 +211,7 @@ class CNNMLP(nn.Module):
|
|||||||
for cam_feature in all_cam_features:
|
for cam_feature in all_cam_features:
|
||||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
features = torch.cat([flattened_features, qpos], axis=1)
|
||||||
a_hat = self.mlp(features)
|
a_hat = self.mlp(features)
|
||||||
return a_hat
|
return a_hat
|
||||||
|
|
||||||
@@ -227,7 +246,8 @@ def build_encoder(args):
|
|||||||
|
|
||||||
|
|
||||||
def build(args):
|
def build(args):
|
||||||
state_dim = 14 # TODO hardcode
|
state_dim = args.state_dim
|
||||||
|
action_dim = args.action_dim
|
||||||
|
|
||||||
# From state
|
# From state
|
||||||
# backbone = None # from state for now, no need for conv nets
|
# backbone = None # from state for now, no need for conv nets
|
||||||
@@ -245,8 +265,12 @@ def build(args):
|
|||||||
transformer,
|
transformer,
|
||||||
encoder,
|
encoder,
|
||||||
state_dim=state_dim,
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
num_queries=args.num_queries,
|
num_queries=args.num_queries,
|
||||||
camera_names=args.camera_names,
|
camera_names=args.camera_names,
|
||||||
|
use_text=args.use_text,
|
||||||
|
text_feature_dim=args.text_feature_dim,
|
||||||
|
text_fusion_type=args.text_fusion_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
@@ -255,7 +279,8 @@ def build(args):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def build_cnnmlp(args):
|
def build_cnnmlp(args):
|
||||||
state_dim = 14 # TODO hardcode
|
state_dim = args.state_dim
|
||||||
|
action_dim = args.action_dim
|
||||||
|
|
||||||
# From state
|
# From state
|
||||||
# backbone = None # from state for now, no need for conv nets
|
# backbone = None # from state for now, no need for conv nets
|
||||||
@@ -268,6 +293,7 @@ def build_cnnmlp(args):
|
|||||||
model = CNNMLP(
|
model = CNNMLP(
|
||||||
backbones,
|
backbones,
|
||||||
state_dim=state_dim,
|
state_dim=state_dim,
|
||||||
|
action_dim=action_dim,
|
||||||
camera_names=args.camera_names,
|
camera_names=args.camera_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class Transformer(nn.Module):
|
|||||||
if p.dim() > 1:
|
if p.dim() > 1:
|
||||||
nn.init.xavier_uniform_(p)
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None):
|
def forward(self, src, mask, query_embed, pos_embed, latent_input=None, proprio_input=None, additional_pos_embed=None, extra_input_tokens=None):
|
||||||
# TODO flatten only when input has H and W
|
# TODO flatten only when input has H and W
|
||||||
if len(src.shape) == 4: # has H and W
|
if len(src.shape) == 4: # has H and W
|
||||||
# flatten NxCxHxW to HWxNxC
|
# flatten NxCxHxW to HWxNxC
|
||||||
@@ -56,10 +56,19 @@ class Transformer(nn.Module):
|
|||||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||||
# mask = mask.flatten(1)
|
# mask = mask.flatten(1)
|
||||||
|
|
||||||
|
additional_inputs = [latent_input, proprio_input]
|
||||||
|
if extra_input_tokens is not None:
|
||||||
|
if len(extra_input_tokens.shape) == 2:
|
||||||
|
extra_input_tokens = extra_input_tokens.unsqueeze(0)
|
||||||
|
for i in range(extra_input_tokens.shape[0]):
|
||||||
|
additional_inputs.append(extra_input_tokens[i])
|
||||||
|
|
||||||
|
addition_input = torch.stack(additional_inputs, axis=0)
|
||||||
|
if additional_pos_embed is not None:
|
||||||
|
additional_pos_embed = additional_pos_embed[:addition_input.shape[0]]
|
||||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(1, bs, 1) # seq, bs, dim
|
||||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||||
|
|
||||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
|
||||||
src = torch.cat([addition_input, src], axis=0)
|
src = torch.cat([addition_input, src], axis=0)
|
||||||
else:
|
else:
|
||||||
assert len(src.shape) == 3
|
assert len(src.shape) == 3
|
||||||
|
|||||||
@@ -10,17 +10,25 @@ from einops import rearrange
|
|||||||
|
|
||||||
from constants import DT
|
from constants import DT
|
||||||
from constants import PUPPET_GRIPPER_JOINT_OPEN
|
from constants import PUPPET_GRIPPER_JOINT_OPEN
|
||||||
|
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
|
||||||
from utils import load_data # data functions
|
from utils import load_data # data functions
|
||||||
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
from utils import sample_box_pose, sample_insertion_pose # robot functions
|
||||||
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
|
||||||
from policy import ACTPolicy, CNNMLPPolicy
|
from policy import ACTPolicy, CNNMLPPolicy
|
||||||
from visualize_episodes import save_videos
|
from visualize_episodes import save_videos
|
||||||
|
|
||||||
from sim_env import BOX_POSE
|
|
||||||
|
|
||||||
import IPython
|
import IPython
|
||||||
e = IPython.embed
|
e = IPython.embed
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_state_dict(ckpt_path):
|
||||||
|
"""Load checkpoint state_dict safely across different torch versions."""
|
||||||
|
try:
|
||||||
|
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||||
|
except TypeError:
|
||||||
|
# For older PyTorch versions that do not support `weights_only`.
|
||||||
|
return torch.load(ckpt_path, map_location='cpu')
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
# command line parameters
|
# command line parameters
|
||||||
@@ -34,25 +42,50 @@ def main(args):
|
|||||||
num_epochs = args['num_epochs']
|
num_epochs = args['num_epochs']
|
||||||
|
|
||||||
# get task parameters
|
# get task parameters
|
||||||
is_sim = task_name[:4] == 'sim_'
|
is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
|
||||||
if is_sim:
|
if is_endoscope:
|
||||||
from constants import SIM_TASK_CONFIGS
|
task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
|
||||||
|
is_sim = False
|
||||||
|
elif task_name in SIM_TASK_CONFIGS:
|
||||||
task_config = SIM_TASK_CONFIGS[task_name]
|
task_config = SIM_TASK_CONFIGS[task_name]
|
||||||
|
is_sim = True
|
||||||
else:
|
else:
|
||||||
from aloha_scripts.constants import TASK_CONFIGS
|
from aloha_scripts.constants import TASK_CONFIGS
|
||||||
task_config = TASK_CONFIGS[task_name]
|
task_config = TASK_CONFIGS[task_name]
|
||||||
|
is_sim = False
|
||||||
|
|
||||||
dataset_dir = task_config['dataset_dir']
|
dataset_dir = task_config['dataset_dir']
|
||||||
num_episodes = task_config['num_episodes']
|
num_episodes = task_config['num_episodes']
|
||||||
episode_len = task_config['episode_len']
|
episode_len = task_config['episode_len']
|
||||||
camera_names = task_config['camera_names']
|
camera_names = task_config['camera_names']
|
||||||
|
state_dim = task_config.get('state_dim', 14)
|
||||||
|
action_dim = task_config.get('action_dim', state_dim)
|
||||||
|
use_text_instruction = task_config.get('use_text_instruction', False)
|
||||||
|
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
|
||||||
|
use_cached_text_features = task_config.get('use_cached_text_features', True)
|
||||||
|
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
|
||||||
|
text_feature_dim = task_config.get('text_feature_dim', 768)
|
||||||
|
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
|
||||||
|
text_max_length = task_config.get('text_max_length', 32)
|
||||||
|
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
||||||
|
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
||||||
|
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
|
||||||
|
|
||||||
|
if args.get('text_encoder_type') is not None:
|
||||||
|
text_encoder_type = args['text_encoder_type']
|
||||||
|
if args.get('text_max_length') is not None:
|
||||||
|
text_max_length = args['text_max_length']
|
||||||
|
if args.get('freeze_text_encoder', False):
|
||||||
|
freeze_text_encoder = True
|
||||||
|
if args.get('disable_real_action_shift', False):
|
||||||
|
real_action_t_minus_1 = False
|
||||||
|
|
||||||
# fixed parameters
|
# fixed parameters
|
||||||
state_dim = 14
|
|
||||||
lr_backbone = 1e-5
|
lr_backbone = 1e-5
|
||||||
backbone = 'resnet18'
|
backbone = 'resnet18'
|
||||||
if policy_class == 'ACT':
|
if policy_class == 'ACT':
|
||||||
enc_layers = 4
|
enc_layers = 2
|
||||||
dec_layers = 7
|
dec_layers = 4
|
||||||
nheads = 8
|
nheads = 8
|
||||||
policy_config = {'lr': args['lr'],
|
policy_config = {'lr': args['lr'],
|
||||||
'num_queries': args['chunk_size'],
|
'num_queries': args['chunk_size'],
|
||||||
@@ -65,18 +98,36 @@ def main(args):
|
|||||||
'dec_layers': dec_layers,
|
'dec_layers': dec_layers,
|
||||||
'nheads': nheads,
|
'nheads': nheads,
|
||||||
'camera_names': camera_names,
|
'camera_names': camera_names,
|
||||||
|
'state_dim': state_dim,
|
||||||
|
'action_dim': action_dim,
|
||||||
|
'use_text': use_text_instruction,
|
||||||
|
'text_encoder_type': text_encoder_type,
|
||||||
|
'text_feature_dim': text_feature_dim,
|
||||||
|
'text_fusion_type': text_fusion_type,
|
||||||
|
'freeze_text_encoder': freeze_text_encoder,
|
||||||
|
'instruction_mode': instruction_mode,
|
||||||
|
'use_cached_text_features': use_cached_text_features,
|
||||||
|
'text_max_length': text_max_length,
|
||||||
|
'text_tokenizer_name': text_tokenizer_name,
|
||||||
}
|
}
|
||||||
elif policy_class == 'CNNMLP':
|
elif policy_class == 'CNNMLP':
|
||||||
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
||||||
'camera_names': camera_names,}
|
'camera_names': camera_names,
|
||||||
|
'state_dim': state_dim,
|
||||||
|
'action_dim': action_dim,
|
||||||
|
'use_text': use_text_instruction,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'num_epochs': num_epochs,
|
'num_epochs': num_epochs,
|
||||||
|
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
|
||||||
|
'resume_ckpt_path': args.get('resume_ckpt', None),
|
||||||
'ckpt_dir': ckpt_dir,
|
'ckpt_dir': ckpt_dir,
|
||||||
'episode_len': episode_len,
|
'episode_len': episode_len,
|
||||||
'state_dim': state_dim,
|
'state_dim': state_dim,
|
||||||
|
'action_dim': action_dim,
|
||||||
'lr': args['lr'],
|
'lr': args['lr'],
|
||||||
'policy_class': policy_class,
|
'policy_class': policy_class,
|
||||||
'onscreen_render': onscreen_render,
|
'onscreen_render': onscreen_render,
|
||||||
@@ -85,9 +136,25 @@ def main(args):
|
|||||||
'seed': args['seed'],
|
'seed': args['seed'],
|
||||||
'temporal_agg': args['temporal_agg'],
|
'temporal_agg': args['temporal_agg'],
|
||||||
'camera_names': camera_names,
|
'camera_names': camera_names,
|
||||||
'real_robot': not is_sim
|
'real_robot': (not is_sim) and (not is_endoscope),
|
||||||
|
'use_text_instruction': use_text_instruction,
|
||||||
|
'instruction_mode': instruction_mode,
|
||||||
|
'use_cached_text_features': use_cached_text_features,
|
||||||
|
'text_tokenizer_name': text_tokenizer_name,
|
||||||
|
'text_max_length': text_max_length,
|
||||||
|
'debug_input': args.get('debug_input', False),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config['resume_ckpt_path']:
|
||||||
|
resume_ckpt_path = config['resume_ckpt_path']
|
||||||
|
if not os.path.isabs(resume_ckpt_path):
|
||||||
|
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
|
||||||
|
if os.path.isfile(candidate_path):
|
||||||
|
resume_ckpt_path = candidate_path
|
||||||
|
if not os.path.isfile(resume_ckpt_path):
|
||||||
|
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
|
||||||
|
config['resume_ckpt_path'] = resume_ckpt_path
|
||||||
|
|
||||||
if is_eval:
|
if is_eval:
|
||||||
ckpt_names = [f'policy_best.ckpt']
|
ckpt_names = [f'policy_best.ckpt']
|
||||||
results = []
|
results = []
|
||||||
@@ -100,7 +167,21 @@ def main(args):
|
|||||||
print()
|
print()
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
|
train_dataloader, val_dataloader, stats, _ = load_data(
|
||||||
|
dataset_dir,
|
||||||
|
num_episodes,
|
||||||
|
camera_names,
|
||||||
|
batch_size_train,
|
||||||
|
batch_size_val,
|
||||||
|
use_text_instruction=use_text_instruction,
|
||||||
|
instruction_mode=instruction_mode,
|
||||||
|
use_cached_text_features=use_cached_text_features,
|
||||||
|
text_feature_dim=text_feature_dim,
|
||||||
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
|
text_max_length=text_max_length,
|
||||||
|
real_action_t_minus_1=real_action_t_minus_1,
|
||||||
|
image_augment=args['image_aug'],
|
||||||
|
)
|
||||||
|
|
||||||
# save dataset stats
|
# save dataset stats
|
||||||
if not os.path.isdir(ckpt_dir):
|
if not os.path.isdir(ckpt_dir):
|
||||||
@@ -152,6 +233,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
set_seed(1000)
|
set_seed(1000)
|
||||||
ckpt_dir = config['ckpt_dir']
|
ckpt_dir = config['ckpt_dir']
|
||||||
state_dim = config['state_dim']
|
state_dim = config['state_dim']
|
||||||
|
action_dim = config['action_dim']
|
||||||
real_robot = config['real_robot']
|
real_robot = config['real_robot']
|
||||||
policy_class = config['policy_class']
|
policy_class = config['policy_class']
|
||||||
onscreen_render = config['onscreen_render']
|
onscreen_render = config['onscreen_render']
|
||||||
@@ -161,11 +243,12 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
task_name = config['task_name']
|
task_name = config['task_name']
|
||||||
temporal_agg = config['temporal_agg']
|
temporal_agg = config['temporal_agg']
|
||||||
onscreen_cam = 'angle'
|
onscreen_cam = 'angle'
|
||||||
|
BOX_POSE = None
|
||||||
|
|
||||||
# load policy and stats
|
# load policy and stats
|
||||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||||
policy = make_policy(policy_class, policy_config)
|
policy = make_policy(policy_class, policy_config)
|
||||||
loading_status = policy.load_state_dict(torch.load(ckpt_path))
|
loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
|
||||||
print(loading_status)
|
print(loading_status)
|
||||||
policy.cuda()
|
policy.cuda()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
@@ -202,8 +285,14 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
rollout_id += 0
|
rollout_id += 0
|
||||||
### set task
|
### set task
|
||||||
if 'sim_transfer_cube' in task_name:
|
if 'sim_transfer_cube' in task_name:
|
||||||
|
if BOX_POSE is None:
|
||||||
|
from sim_env import BOX_POSE as _BOX_POSE
|
||||||
|
BOX_POSE = _BOX_POSE
|
||||||
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
BOX_POSE[0] = sample_box_pose() # used in sim reset
|
||||||
elif 'sim_insertion' in task_name:
|
elif 'sim_insertion' in task_name:
|
||||||
|
if BOX_POSE is None:
|
||||||
|
from sim_env import BOX_POSE as _BOX_POSE
|
||||||
|
BOX_POSE = _BOX_POSE
|
||||||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||||
|
|
||||||
ts = env.reset()
|
ts = env.reset()
|
||||||
@@ -216,7 +305,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
|
|
||||||
### evaluation loop
|
### evaluation loop
|
||||||
if temporal_agg:
|
if temporal_agg:
|
||||||
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
|
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
|
||||||
|
|
||||||
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
|
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
|
||||||
image_list = [] # for visualization
|
image_list = [] # for visualization
|
||||||
@@ -313,23 +402,72 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
return success_rate, avg_return
|
return success_rate, avg_return
|
||||||
|
|
||||||
|
|
||||||
def forward_pass(data, policy):
|
def forward_pass(data, policy, debug_input=False, debug_tag=''):
|
||||||
image_data, qpos_data, action_data, is_pad = data
|
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
|
||||||
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
|
image_data = image_data.cuda()
|
||||||
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
|
qpos_data = qpos_data.cuda()
|
||||||
|
action_data = action_data.cuda()
|
||||||
|
is_pad = is_pad.cuda()
|
||||||
|
text_input_ids = text_input_ids.cuda()
|
||||||
|
text_attention_mask = text_attention_mask.cuda()
|
||||||
|
text_feature_data = text_feature_data.cuda()
|
||||||
|
text_feature_valid = text_feature_valid.cuda()
|
||||||
|
|
||||||
|
text_features = None
|
||||||
|
if torch.any(text_feature_valid):
|
||||||
|
text_features = text_feature_data
|
||||||
|
|
||||||
|
if debug_input:
|
||||||
|
image_min = float(image_data.min().item())
|
||||||
|
image_max = float(image_data.max().item())
|
||||||
|
qpos_mean = float(qpos_data.mean().item())
|
||||||
|
qpos_std = float(qpos_data.std().item())
|
||||||
|
action_mean = float(action_data.mean().item())
|
||||||
|
action_std = float(action_data.std().item())
|
||||||
|
pad_ratio = float(is_pad.float().mean().item())
|
||||||
|
|
||||||
|
print(f'[debug_input] {debug_tag} image shape={tuple(image_data.shape)} range=[{image_min:.4f}, {image_max:.4f}]')
|
||||||
|
print(f'[debug_input] {debug_tag} qpos shape={tuple(qpos_data.shape)} mean/std=({qpos_mean:.4f}, {qpos_std:.4f})')
|
||||||
|
print(f'[debug_input] {debug_tag} action shape={tuple(action_data.shape)} mean/std=({action_mean:.4f}, {action_std:.4f})')
|
||||||
|
print(f'[debug_input] {debug_tag} is_pad shape={tuple(is_pad.shape)} pad_ratio={pad_ratio:.4f}')
|
||||||
|
print(
|
||||||
|
f'[debug_input] {debug_tag} has_nan_or_inf: '
|
||||||
|
f'image={bool(torch.logical_not(torch.isfinite(image_data)).any().item())}, '
|
||||||
|
f'qpos={bool(torch.logical_not(torch.isfinite(qpos_data)).any().item())}, '
|
||||||
|
f'action={bool(torch.logical_not(torch.isfinite(action_data)).any().item())}'
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy(
|
||||||
|
qpos_data,
|
||||||
|
image_data,
|
||||||
|
text_input_ids=text_input_ids,
|
||||||
|
text_attention_mask=text_attention_mask,
|
||||||
|
text_features=text_features,
|
||||||
|
actions=action_data,
|
||||||
|
is_pad=is_pad,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_bc(train_dataloader, val_dataloader, config):
|
def train_bc(train_dataloader, val_dataloader, config):
|
||||||
num_epochs = config['num_epochs']
|
num_epochs = config['num_epochs']
|
||||||
|
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
|
||||||
|
resume_ckpt_path = config.get('resume_ckpt_path', None)
|
||||||
ckpt_dir = config['ckpt_dir']
|
ckpt_dir = config['ckpt_dir']
|
||||||
seed = config['seed']
|
seed = config['seed']
|
||||||
policy_class = config['policy_class']
|
policy_class = config['policy_class']
|
||||||
policy_config = config['policy_config']
|
policy_config = config['policy_config']
|
||||||
|
debug_input = config.get('debug_input', False)
|
||||||
|
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
|
||||||
policy = make_policy(policy_class, policy_config)
|
policy = make_policy(policy_class, policy_config)
|
||||||
policy.cuda()
|
policy.cuda()
|
||||||
|
|
||||||
|
if resume_ckpt_path:
|
||||||
|
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
|
||||||
|
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
|
||||||
|
print(loading_status)
|
||||||
|
|
||||||
optimizer = make_optimizer(policy_class, policy)
|
optimizer = make_optimizer(policy_class, policy)
|
||||||
|
|
||||||
train_history = []
|
train_history = []
|
||||||
@@ -343,7 +481,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
|||||||
policy.eval()
|
policy.eval()
|
||||||
epoch_dicts = []
|
epoch_dicts = []
|
||||||
for batch_idx, data in enumerate(val_dataloader):
|
for batch_idx, data in enumerate(val_dataloader):
|
||||||
forward_dict = forward_pass(data, policy)
|
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
||||||
|
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='val/epoch0/batch0')
|
||||||
epoch_dicts.append(forward_dict)
|
epoch_dicts.append(forward_dict)
|
||||||
epoch_summary = compute_dict_mean(epoch_dicts)
|
epoch_summary = compute_dict_mean(epoch_dicts)
|
||||||
validation_history.append(epoch_summary)
|
validation_history.append(epoch_summary)
|
||||||
@@ -361,15 +500,31 @@ def train_bc(train_dataloader, val_dataloader, config):
|
|||||||
# training
|
# training
|
||||||
policy.train()
|
policy.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
for batch_idx, data in enumerate(train_dataloader):
|
epoch_train_dicts = []
|
||||||
forward_dict = forward_pass(data, policy)
|
if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
|
||||||
|
train_steps_this_epoch = len(train_dataloader)
|
||||||
|
train_iterator = iter(train_dataloader)
|
||||||
|
else:
|
||||||
|
train_steps_this_epoch = int(train_steps_per_epoch)
|
||||||
|
train_iterator = iter(train_dataloader)
|
||||||
|
|
||||||
|
for step_idx in range(train_steps_this_epoch):
|
||||||
|
try:
|
||||||
|
data = next(train_iterator)
|
||||||
|
except StopIteration:
|
||||||
|
train_iterator = iter(train_dataloader)
|
||||||
|
data = next(train_iterator)
|
||||||
|
|
||||||
|
should_debug = debug_input and epoch == 0 and step_idx == 0
|
||||||
|
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
||||||
# backward
|
# backward
|
||||||
loss = forward_dict['loss']
|
loss = forward_dict['loss']
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
train_history.append(detach_dict(forward_dict))
|
train_history.append(detach_dict(forward_dict))
|
||||||
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
|
epoch_train_dicts.append(detach_dict(forward_dict))
|
||||||
|
epoch_summary = compute_dict_mean(epoch_train_dicts)
|
||||||
epoch_train_loss = epoch_summary['loss']
|
epoch_train_loss = epoch_summary['loss']
|
||||||
print(f'Train loss: {epoch_train_loss:.5f}')
|
print(f'Train loss: {epoch_train_loss:.5f}')
|
||||||
summary_string = ''
|
summary_string = ''
|
||||||
@@ -426,10 +581,23 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
||||||
|
|
||||||
# for ACT
|
# for ACT
|
||||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
||||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
||||||
parser.add_argument('--temporal_agg', action='store_true')
|
parser.add_argument('--temporal_agg', action='store_true')
|
||||||
|
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
|
||||||
|
parser.add_argument('--freeze_text_encoder', action='store_true')
|
||||||
|
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||||
|
parser.add_argument('--image_aug', action='store_true',
|
||||||
|
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
||||||
|
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
|
||||||
|
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
|
||||||
|
parser.add_argument('--disable_real_action_shift', action='store_true',
|
||||||
|
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
|
||||||
|
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
|
||||||
|
help='Optional checkpoint path to initialize model weights for fine-tuning')
|
||||||
|
parser.add_argument('--debug_input', action='store_true',
|
||||||
|
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
||||||
|
|
||||||
main(vars(parser.parse_args()))
|
main(vars(parser.parse_args()))
|
||||||
|
|||||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
31
models/text_encoder.py
Normal file
31
models/text_encoder.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DistilBERTTextEncoder(nn.Module):
|
||||||
|
def __init__(self, model_name='distilbert-base-uncased', output_dim=768, freeze=True):
|
||||||
|
super().__init__()
|
||||||
|
try:
|
||||||
|
from transformers import DistilBertModel
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
'transformers is required for DistilBERT text encoding. '
|
||||||
|
'Install it with: pip install transformers'
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
self.encoder = DistilBertModel.from_pretrained(model_name)
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.freeze = freeze
|
||||||
|
|
||||||
|
if self.freeze:
|
||||||
|
for param in self.encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.encoder.eval()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
if self.freeze:
|
||||||
|
self.encoder.eval()
|
||||||
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
# DistilBERT has no pooled output; use [CLS] token embedding
|
||||||
|
cls_feature = outputs.last_hidden_state[:, 0, :]
|
||||||
|
return cls_feature
|
||||||
35
policy.py
35
policy.py
@@ -3,6 +3,7 @@ from torch.nn import functional as F
|
|||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
|
||||||
|
from models.text_encoder import DistilBERTTextEncoder
|
||||||
import IPython
|
import IPython
|
||||||
e = IPython.embed
|
e = IPython.embed
|
||||||
|
|
||||||
@@ -13,18 +14,44 @@ class ACTPolicy(nn.Module):
|
|||||||
self.model = model # CVAE decoder
|
self.model = model # CVAE decoder
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.kl_weight = args_override['kl_weight']
|
self.kl_weight = args_override['kl_weight']
|
||||||
|
self.use_text = args_override.get('use_text', False)
|
||||||
|
self.text_encoder = None
|
||||||
|
if self.use_text:
|
||||||
|
text_encoder_type = args_override.get('text_encoder_type', 'distilbert')
|
||||||
|
if text_encoder_type != 'distilbert':
|
||||||
|
raise NotImplementedError(f'Unsupported text encoder: {text_encoder_type}')
|
||||||
|
self.text_encoder = DistilBERTTextEncoder(
|
||||||
|
model_name=args_override.get('text_tokenizer_name', 'distilbert-base-uncased'),
|
||||||
|
output_dim=args_override.get('text_feature_dim', 768),
|
||||||
|
freeze=args_override.get('freeze_text_encoder', True),
|
||||||
|
)
|
||||||
print(f'KL Weight {self.kl_weight}')
|
print(f'KL Weight {self.kl_weight}')
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
|
||||||
env_state = None
|
env_state = None
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225])
|
std=[0.229, 0.224, 0.225])
|
||||||
image = normalize(image)
|
image = normalize(image)
|
||||||
|
|
||||||
|
if self.use_text and text_features is None and text_input_ids is not None and text_attention_mask is not None:
|
||||||
|
if self.text_encoder is None:
|
||||||
|
raise RuntimeError('Text encoder is not initialized while use_text=True.')
|
||||||
|
text_features = self.text_encoder(text_input_ids, text_attention_mask)
|
||||||
|
|
||||||
if actions is not None: # training time
|
if actions is not None: # training time
|
||||||
|
if is_pad is None:
|
||||||
|
raise ValueError('`is_pad` must be provided during training when `actions` is not None.')
|
||||||
actions = actions[:, :self.model.num_queries]
|
actions = actions[:, :self.model.num_queries]
|
||||||
is_pad = is_pad[:, :self.model.num_queries]
|
is_pad = is_pad[:, :self.model.num_queries]
|
||||||
|
|
||||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
a_hat, is_pad_hat, (mu, logvar) = self.model(
|
||||||
|
qpos,
|
||||||
|
image,
|
||||||
|
env_state,
|
||||||
|
text_features=text_features,
|
||||||
|
actions=actions,
|
||||||
|
is_pad=is_pad,
|
||||||
|
)
|
||||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
|
||||||
@@ -34,7 +61,7 @@ class ACTPolicy(nn.Module):
|
|||||||
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
|
||||||
return loss_dict
|
return loss_dict
|
||||||
else: # inference time
|
else: # inference time
|
||||||
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
a_hat, _, (_, _) = self.model(qpos, image, env_state, text_features=text_features) # no action, sample from prior
|
||||||
return a_hat
|
return a_hat
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
@@ -48,7 +75,7 @@ class CNNMLPPolicy(nn.Module):
|
|||||||
self.model = model # decoder
|
self.model = model # decoder
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
|
||||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
def __call__(self, qpos, image, text_input_ids=None, text_attention_mask=None, text_features=None, actions=None, is_pad=None):
|
||||||
env_state = None # TODO
|
env_state = None # TODO
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225])
|
std=[0.229, 0.224, 0.225])
|
||||||
|
|||||||
337
utils.py
337
utils.py
@@ -2,21 +2,170 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import h5py
|
import h5py
|
||||||
|
import re
|
||||||
from torch.utils.data import TensorDataset, DataLoader
|
from torch.utils.data import TensorDataset, DataLoader
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
|
||||||
import IPython
|
import IPython
|
||||||
e = IPython.embed
|
e = IPython.embed
|
||||||
|
|
||||||
class EpisodicDataset(torch.utils.data.Dataset):
|
class EpisodicDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats):
|
def __init__(self, episode_ids, dataset_dir, camera_names, norm_stats,
|
||||||
|
use_text_instruction=False,
|
||||||
|
instruction_mode='timestep-level',
|
||||||
|
use_cached_text_features=True,
|
||||||
|
text_feature_dim=768,
|
||||||
|
text_tokenizer_name='distilbert-base-uncased',
|
||||||
|
text_max_length=32,
|
||||||
|
real_action_t_minus_1=True,
|
||||||
|
image_augment=False,
|
||||||
|
image_aug_cfg=None):
|
||||||
super(EpisodicDataset).__init__()
|
super(EpisodicDataset).__init__()
|
||||||
self.episode_ids = episode_ids
|
self.episode_ids = episode_ids
|
||||||
self.dataset_dir = dataset_dir
|
self.dataset_dir = dataset_dir
|
||||||
self.camera_names = camera_names
|
self.camera_names = camera_names
|
||||||
self.norm_stats = norm_stats
|
self.norm_stats = norm_stats
|
||||||
|
self.use_text_instruction = use_text_instruction
|
||||||
|
self.instruction_mode = instruction_mode
|
||||||
|
self.use_cached_text_features = use_cached_text_features
|
||||||
|
self.text_feature_dim = text_feature_dim
|
||||||
|
self.text_max_length = text_max_length
|
||||||
|
self.real_action_t_minus_1 = real_action_t_minus_1
|
||||||
|
self.image_augment = image_augment
|
||||||
|
self.image_aug_cfg = {
|
||||||
|
'p_color': 0.4,
|
||||||
|
'p_highlight': 0.3,
|
||||||
|
'p_noise': 0.35,
|
||||||
|
'p_blur': 0.15,
|
||||||
|
'brightness': 0.12,
|
||||||
|
'contrast': 0.12,
|
||||||
|
'saturation': 0.12,
|
||||||
|
'hue': 0.03,
|
||||||
|
'highlight_strength': (0.08, 0.25),
|
||||||
|
'noise_std': (0.003, 0.015),
|
||||||
|
'blur_sigma': (0.1, 0.8),
|
||||||
|
'blur_kernel_choices': (3, ),
|
||||||
|
}
|
||||||
|
if image_aug_cfg is not None:
|
||||||
|
self.image_aug_cfg.update(image_aug_cfg)
|
||||||
self.is_sim = None
|
self.is_sim = None
|
||||||
|
self.max_episode_len = None
|
||||||
|
self.action_dim = None
|
||||||
|
|
||||||
|
self.text_tokenizer = None
|
||||||
|
if self.use_text_instruction:
|
||||||
|
try:
|
||||||
|
from transformers import DistilBertTokenizerFast
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
'transformers is required for text instruction loading. '
|
||||||
|
'Install it with: pip install transformers'
|
||||||
|
) from exc
|
||||||
|
self.text_tokenizer = DistilBertTokenizerFast.from_pretrained(text_tokenizer_name)
|
||||||
|
|
||||||
|
self._init_episode_shapes()
|
||||||
|
|
||||||
self.__getitem__(0) # initialize self.is_sim
|
self.__getitem__(0) # initialize self.is_sim
|
||||||
|
|
||||||
|
def _apply_image_augmentation(self, all_cam_images):
|
||||||
|
"""
|
||||||
|
Apply identical augmentation parameters to all camera images for one sample.
|
||||||
|
all_cam_images: np.ndarray [K, H, W, C], uint8
|
||||||
|
"""
|
||||||
|
imgs = torch.from_numpy(all_cam_images).float() / 255.0
|
||||||
|
imgs = torch.einsum('k h w c -> k c h w', imgs)
|
||||||
|
|
||||||
|
cfg = self.image_aug_cfg
|
||||||
|
# color jitter (shared params)
|
||||||
|
if np.random.rand() < cfg['p_color']:
|
||||||
|
b = 1.0 + np.random.uniform(-cfg['brightness'], cfg['brightness'])
|
||||||
|
c = 1.0 + np.random.uniform(-cfg['contrast'], cfg['contrast'])
|
||||||
|
s = 1.0 + np.random.uniform(-cfg['saturation'], cfg['saturation'])
|
||||||
|
h = np.random.uniform(-cfg['hue'], cfg['hue'])
|
||||||
|
for cam_idx in range(imgs.shape[0]):
|
||||||
|
img = imgs[cam_idx]
|
||||||
|
img = TF.adjust_brightness(img, b)
|
||||||
|
img = TF.adjust_contrast(img, c)
|
||||||
|
img = TF.adjust_saturation(img, s)
|
||||||
|
img = TF.adjust_hue(img, h)
|
||||||
|
imgs[cam_idx] = img
|
||||||
|
|
||||||
|
# synthetic highlight / glare (shared parameters)
|
||||||
|
if np.random.rand() < cfg['p_highlight']:
|
||||||
|
_, h_img, w_img = imgs[0].shape
|
||||||
|
cx = np.random.uniform(0.2 * w_img, 0.8 * w_img)
|
||||||
|
cy = np.random.uniform(0.2 * h_img, 0.8 * h_img)
|
||||||
|
sigma = np.random.uniform(0.08, 0.2) * min(h_img, w_img)
|
||||||
|
strength = np.random.uniform(*cfg['highlight_strength'])
|
||||||
|
yy, xx = torch.meshgrid(
|
||||||
|
torch.arange(h_img, dtype=torch.float32),
|
||||||
|
torch.arange(w_img, dtype=torch.float32),
|
||||||
|
indexing='ij',
|
||||||
|
)
|
||||||
|
gauss = torch.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2.0 * sigma * sigma))
|
||||||
|
gauss = (gauss * strength).unsqueeze(0)
|
||||||
|
imgs = imgs + gauss
|
||||||
|
|
||||||
|
# gaussian noise
|
||||||
|
if np.random.rand() < cfg['p_noise']:
|
||||||
|
noise_std = np.random.uniform(*cfg['noise_std'])
|
||||||
|
imgs = imgs + torch.randn_like(imgs) * noise_std
|
||||||
|
|
||||||
|
# gaussian blur
|
||||||
|
if np.random.rand() < cfg['p_blur']:
|
||||||
|
kernel = int(np.random.choice(cfg['blur_kernel_choices']))
|
||||||
|
sigma = float(np.random.uniform(*cfg['blur_sigma']))
|
||||||
|
for cam_idx in range(imgs.shape[0]):
|
||||||
|
imgs[cam_idx] = TF.gaussian_blur(
|
||||||
|
imgs[cam_idx],
|
||||||
|
kernel_size=[kernel, kernel],
|
||||||
|
sigma=[sigma, sigma],
|
||||||
|
)
|
||||||
|
|
||||||
|
imgs = imgs.clamp(0.0, 1.0)
|
||||||
|
imgs = torch.einsum('k c h w -> k h w c', imgs)
|
||||||
|
imgs = (imgs * 255.0).byte().cpu().numpy()
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
def _init_episode_shapes(self):
|
||||||
|
max_len = 0
|
||||||
|
action_dim = None
|
||||||
|
for episode_id in self.episode_ids:
|
||||||
|
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
||||||
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
|
shape = root['/action'].shape
|
||||||
|
if len(shape) != 2:
|
||||||
|
raise ValueError(f'Expected /action to have shape [T, D], got {shape} in {dataset_path}')
|
||||||
|
max_len = max(max_len, int(shape[0]))
|
||||||
|
if action_dim is None:
|
||||||
|
action_dim = int(shape[1])
|
||||||
|
elif int(shape[1]) != action_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f'Inconsistent action dim in dataset. Expected {action_dim}, got {shape[1]} in {dataset_path}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_len <= 0 or action_dim is None:
|
||||||
|
raise ValueError(f'Invalid dataset metadata in {self.dataset_dir}')
|
||||||
|
|
||||||
|
self.max_episode_len = max_len
|
||||||
|
self.action_dim = action_dim
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_instruction(raw_value):
|
||||||
|
if raw_value is None:
|
||||||
|
return ''
|
||||||
|
if isinstance(raw_value, bytes):
|
||||||
|
return raw_value.decode('utf-8')
|
||||||
|
if isinstance(raw_value, np.bytes_):
|
||||||
|
return raw_value.tobytes().decode('utf-8')
|
||||||
|
if isinstance(raw_value, np.ndarray):
|
||||||
|
if raw_value.shape == ():
|
||||||
|
return EpisodicDataset._decode_instruction(raw_value.item())
|
||||||
|
if raw_value.size == 0:
|
||||||
|
return ''
|
||||||
|
return EpisodicDataset._decode_instruction(raw_value.reshape(-1)[0])
|
||||||
|
return str(raw_value)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.episode_ids)
|
return len(self.episode_ids)
|
||||||
|
|
||||||
@@ -26,7 +175,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
episode_id = self.episode_ids[index]
|
episode_id = self.episode_ids[index]
|
||||||
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
dataset_path = os.path.join(self.dataset_dir, f'episode_{episode_id}.hdf5')
|
||||||
with h5py.File(dataset_path, 'r') as root:
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
is_sim = root.attrs['sim']
|
is_sim = bool(root.attrs.get('sim', False))
|
||||||
original_action_shape = root['/action'].shape
|
original_action_shape = root['/action'].shape
|
||||||
episode_len = original_action_shape[0]
|
episode_len = original_action_shape[0]
|
||||||
if sample_full_episode:
|
if sample_full_episode:
|
||||||
@@ -35,29 +184,62 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
start_ts = np.random.choice(episode_len)
|
start_ts = np.random.choice(episode_len)
|
||||||
# get observation at start_ts only
|
# get observation at start_ts only
|
||||||
qpos = root['/observations/qpos'][start_ts]
|
qpos = root['/observations/qpos'][start_ts]
|
||||||
qvel = root['/observations/qvel'][start_ts]
|
|
||||||
image_dict = dict()
|
image_dict = dict()
|
||||||
for cam_name in self.camera_names:
|
for cam_name in self.camera_names:
|
||||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
|
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][start_ts]
|
||||||
|
|
||||||
|
instruction = ''
|
||||||
|
text_feature = None
|
||||||
|
if self.use_text_instruction:
|
||||||
|
effective_mode = self.instruction_mode
|
||||||
|
if effective_mode == 'timestep-level' and '/instruction_timestep' in root:
|
||||||
|
instruction = self._decode_instruction(root['/instruction_timestep'][start_ts])
|
||||||
|
elif '/instruction' in root:
|
||||||
|
instruction_node = root['/instruction']
|
||||||
|
if getattr(instruction_node, 'shape', ()) == ():
|
||||||
|
instruction = self._decode_instruction(instruction_node[()])
|
||||||
|
else:
|
||||||
|
if len(instruction_node.shape) == 1 and instruction_node.shape[0] == episode_len:
|
||||||
|
instruction = self._decode_instruction(instruction_node[start_ts])
|
||||||
|
else:
|
||||||
|
instruction = self._decode_instruction(instruction_node[0])
|
||||||
|
|
||||||
|
if self.use_cached_text_features:
|
||||||
|
if effective_mode == 'timestep-level' and '/instruction_features_timestep' in root:
|
||||||
|
text_feature = root['/instruction_features_timestep'][start_ts]
|
||||||
|
elif '/instruction_features' in root:
|
||||||
|
feat_node = root['/instruction_features']
|
||||||
|
if getattr(feat_node, 'shape', ()) == ():
|
||||||
|
text_feature = np.array(feat_node[()])
|
||||||
|
elif len(feat_node.shape) == 1:
|
||||||
|
text_feature = feat_node[()]
|
||||||
|
elif len(feat_node.shape) == 2 and feat_node.shape[0] == episode_len:
|
||||||
|
text_feature = feat_node[start_ts]
|
||||||
|
else:
|
||||||
|
text_feature = feat_node[0]
|
||||||
|
|
||||||
# get all actions after and including start_ts
|
# get all actions after and including start_ts
|
||||||
if is_sim:
|
if is_sim:
|
||||||
action = root['/action'][start_ts:]
|
action = root['/action'][start_ts:]
|
||||||
action_len = episode_len - start_ts
|
action_len = episode_len - start_ts
|
||||||
else:
|
else:
|
||||||
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
|
action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts
|
||||||
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
|
action = root['/action'][action_start:]
|
||||||
|
action_len = episode_len - action_start
|
||||||
|
|
||||||
self.is_sim = is_sim
|
self.is_sim = is_sim
|
||||||
padded_action = np.zeros(original_action_shape, dtype=np.float32)
|
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
|
||||||
padded_action[:action_len] = action
|
padded_action[:action_len] = action
|
||||||
is_pad = np.zeros(episode_len)
|
is_pad = np.ones(self.max_episode_len)
|
||||||
is_pad[action_len:] = 1
|
is_pad[:action_len] = 0
|
||||||
|
|
||||||
# new axis for different cameras
|
# new axis for different cameras
|
||||||
all_cam_images = []
|
all_cam_images = []
|
||||||
for cam_name in self.camera_names:
|
for cam_name in self.camera_names:
|
||||||
all_cam_images.append(image_dict[cam_name])
|
all_cam_images.append(image_dict[cam_name])
|
||||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||||
|
if self.image_augment:
|
||||||
|
all_cam_images = self._apply_image_augmentation(all_cam_images)
|
||||||
|
|
||||||
# construct observations
|
# construct observations
|
||||||
image_data = torch.from_numpy(all_cam_images)
|
image_data = torch.from_numpy(all_cam_images)
|
||||||
@@ -73,55 +255,146 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
|
action_data = (action_data - self.norm_stats["action_mean"]) / self.norm_stats["action_std"]
|
||||||
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
|
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats["qpos_std"]
|
||||||
|
|
||||||
return image_data, qpos_data, action_data, is_pad
|
if self.use_text_instruction and text_feature is not None:
|
||||||
|
text_feature_data = torch.from_numpy(np.array(text_feature)).float()
|
||||||
|
text_feature_valid = torch.tensor(True, dtype=torch.bool)
|
||||||
|
text_input_ids = torch.zeros(1, dtype=torch.long)
|
||||||
|
text_attention_mask = torch.zeros(1, dtype=torch.long)
|
||||||
|
elif self.use_text_instruction:
|
||||||
|
tokenized = self.text_tokenizer(
|
||||||
|
instruction,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.text_max_length,
|
||||||
|
return_tensors='pt',
|
||||||
|
)
|
||||||
|
text_input_ids = tokenized['input_ids'].squeeze(0).long()
|
||||||
|
text_attention_mask = tokenized['attention_mask'].squeeze(0).long()
|
||||||
|
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
|
||||||
|
text_feature_valid = torch.tensor(False, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
text_input_ids = torch.zeros(1, dtype=torch.long)
|
||||||
|
text_attention_mask = torch.zeros(1, dtype=torch.long)
|
||||||
|
text_feature_data = torch.zeros(self.text_feature_dim, dtype=torch.float32)
|
||||||
|
text_feature_valid = torch.tensor(False, dtype=torch.bool)
|
||||||
|
|
||||||
|
return image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid
|
||||||
|
|
||||||
|
|
||||||
def get_norm_stats(dataset_dir, num_episodes):
|
def _discover_episode_ids(dataset_dir, num_episodes=None):
|
||||||
|
pattern = re.compile(r'^episode_(\d+)\.hdf5$')
|
||||||
|
episode_ids = []
|
||||||
|
for fname in os.listdir(dataset_dir):
|
||||||
|
m = pattern.match(fname)
|
||||||
|
if m:
|
||||||
|
episode_ids.append(int(m.group(1)))
|
||||||
|
episode_ids.sort()
|
||||||
|
if num_episodes is not None:
|
||||||
|
episode_ids = episode_ids[:num_episodes]
|
||||||
|
return episode_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_stats(dataset_dir, episode_ids):
|
||||||
all_qpos_data = []
|
all_qpos_data = []
|
||||||
all_action_data = []
|
all_action_data = []
|
||||||
for episode_idx in range(num_episodes):
|
example_qpos = None
|
||||||
|
for episode_idx in episode_ids:
|
||||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
|
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
|
||||||
with h5py.File(dataset_path, 'r') as root:
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
qpos = root['/observations/qpos'][()]
|
qpos = root['/observations/qpos'][()]
|
||||||
qvel = root['/observations/qvel'][()]
|
|
||||||
action = root['/action'][()]
|
action = root['/action'][()]
|
||||||
all_qpos_data.append(torch.from_numpy(qpos))
|
qpos_t = torch.from_numpy(qpos)
|
||||||
all_action_data.append(torch.from_numpy(action))
|
action_t = torch.from_numpy(action)
|
||||||
all_qpos_data = torch.stack(all_qpos_data)
|
all_qpos_data.append(qpos_t)
|
||||||
all_action_data = torch.stack(all_action_data)
|
all_action_data.append(action_t)
|
||||||
all_action_data = all_action_data
|
if example_qpos is None and len(qpos) > 0:
|
||||||
|
example_qpos = qpos[0]
|
||||||
|
|
||||||
|
# Episodes may have different lengths; concatenate over time axis.
|
||||||
|
all_qpos_data = torch.cat(all_qpos_data, dim=0)
|
||||||
|
all_action_data = torch.cat(all_action_data, dim=0)
|
||||||
|
|
||||||
# normalize action data
|
# normalize action data
|
||||||
action_mean = all_action_data.mean(dim=[0, 1], keepdim=True)
|
action_mean = all_action_data.mean(dim=0, keepdim=True)
|
||||||
action_std = all_action_data.std(dim=[0, 1], keepdim=True)
|
action_std = all_action_data.std(dim=0, keepdim=True)
|
||||||
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
||||||
|
|
||||||
# normalize qpos data
|
# normalize qpos data
|
||||||
qpos_mean = all_qpos_data.mean(dim=[0, 1], keepdim=True)
|
qpos_mean = all_qpos_data.mean(dim=0, keepdim=True)
|
||||||
qpos_std = all_qpos_data.std(dim=[0, 1], keepdim=True)
|
qpos_std = all_qpos_data.std(dim=0, keepdim=True)
|
||||||
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
||||||
|
|
||||||
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
|
stats = {"action_mean": action_mean.numpy().squeeze(), "action_std": action_std.numpy().squeeze(),
|
||||||
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
|
"qpos_mean": qpos_mean.numpy().squeeze(), "qpos_std": qpos_std.numpy().squeeze(),
|
||||||
"example_qpos": qpos}
|
"example_qpos": example_qpos}
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val,
|
||||||
def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val):
|
use_text_instruction=False,
|
||||||
|
instruction_mode='timestep-level',
|
||||||
|
use_cached_text_features=True,
|
||||||
|
text_feature_dim=768,
|
||||||
|
text_tokenizer_name='distilbert-base-uncased',
|
||||||
|
text_max_length=32,
|
||||||
|
real_action_t_minus_1=True,
|
||||||
|
image_augment=False,
|
||||||
|
image_aug_cfg=None):
|
||||||
print(f'\nData from: {dataset_dir}\n')
|
print(f'\nData from: {dataset_dir}\n')
|
||||||
# obtain train test split
|
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
||||||
train_ratio = 0.8
|
if len(episode_ids) == 0:
|
||||||
shuffled_indices = np.random.permutation(num_episodes)
|
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
|
||||||
train_indices = shuffled_indices[:int(train_ratio * num_episodes)]
|
|
||||||
val_indices = shuffled_indices[int(train_ratio * num_episodes):]
|
# obtain train/val split
|
||||||
|
if len(episode_ids) == 1:
|
||||||
|
# sanity-check mode: reuse the same episode for both train and val
|
||||||
|
# so training/evaluation loops remain unchanged.
|
||||||
|
train_episode_ids = np.array(episode_ids)
|
||||||
|
val_episode_ids = np.array(episode_ids)
|
||||||
|
print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).')
|
||||||
|
else:
|
||||||
|
train_ratio = 0.9
|
||||||
|
shuffled_indices = np.random.permutation(len(episode_ids))
|
||||||
|
train_count = int(train_ratio * len(episode_ids))
|
||||||
|
train_count = max(1, min(len(episode_ids) - 1, train_count))
|
||||||
|
train_indices = shuffled_indices[:train_count]
|
||||||
|
val_indices = shuffled_indices[train_count:]
|
||||||
|
train_episode_ids = np.array(episode_ids)[train_indices]
|
||||||
|
val_episode_ids = np.array(episode_ids)[val_indices]
|
||||||
|
|
||||||
# obtain normalization stats for qpos and action
|
# obtain normalization stats for qpos and action
|
||||||
norm_stats = get_norm_stats(dataset_dir, num_episodes)
|
norm_stats = get_norm_stats(dataset_dir, episode_ids)
|
||||||
|
|
||||||
# construct dataset and dataloader
|
# construct dataset and dataloader
|
||||||
train_dataset = EpisodicDataset(train_indices, dataset_dir, camera_names, norm_stats)
|
train_dataset = EpisodicDataset(
|
||||||
val_dataset = EpisodicDataset(val_indices, dataset_dir, camera_names, norm_stats)
|
train_episode_ids,
|
||||||
|
dataset_dir,
|
||||||
|
camera_names,
|
||||||
|
norm_stats,
|
||||||
|
use_text_instruction=use_text_instruction,
|
||||||
|
instruction_mode=instruction_mode,
|
||||||
|
use_cached_text_features=use_cached_text_features,
|
||||||
|
text_feature_dim=text_feature_dim,
|
||||||
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
|
text_max_length=text_max_length,
|
||||||
|
real_action_t_minus_1=real_action_t_minus_1,
|
||||||
|
image_augment=image_augment,
|
||||||
|
image_aug_cfg=image_aug_cfg,
|
||||||
|
)
|
||||||
|
val_dataset = EpisodicDataset(
|
||||||
|
val_episode_ids,
|
||||||
|
dataset_dir,
|
||||||
|
camera_names,
|
||||||
|
norm_stats,
|
||||||
|
use_text_instruction=use_text_instruction,
|
||||||
|
instruction_mode=instruction_mode,
|
||||||
|
use_cached_text_features=use_cached_text_features,
|
||||||
|
text_feature_dim=text_feature_dim,
|
||||||
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
|
text_max_length=text_max_length,
|
||||||
|
real_action_t_minus_1=real_action_t_minus_1,
|
||||||
|
image_augment=False,
|
||||||
|
)
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||||
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||||
|
|
||||||
|
|||||||
@@ -21,32 +21,33 @@ def load_hdf5(dataset_dir, dataset_name):
|
|||||||
|
|
||||||
with h5py.File(dataset_path, 'r') as root:
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
is_sim = root.attrs['sim']
|
is_sim = root.attrs['sim']
|
||||||
|
dt = float(root.attrs.get('dt', DT))
|
||||||
qpos = root['/observations/qpos'][()]
|
qpos = root['/observations/qpos'][()]
|
||||||
qvel = root['/observations/qvel'][()]
|
qvel = root['/observations/qvel'][()] if '/observations/qvel' in root else np.zeros_like(qpos)
|
||||||
action = root['/action'][()]
|
action = root['/action'][()]
|
||||||
image_dict = dict()
|
image_dict = dict()
|
||||||
for cam_name in root[f'/observations/images/'].keys():
|
for cam_name in root[f'/observations/images/'].keys():
|
||||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||||
|
|
||||||
return qpos, qvel, action, image_dict
|
return qpos, qvel, action, image_dict, dt
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset_dir = args['dataset_dir']
|
dataset_dir = args['dataset_dir']
|
||||||
episode_idx = args['episode_idx']
|
episode_idx = args['episode_idx']
|
||||||
dataset_name = f'episode_{episode_idx}'
|
dataset_name = f'episode_{episode_idx}'
|
||||||
|
|
||||||
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
qpos, qvel, action, image_dict, dt = load_hdf5(dataset_dir, dataset_name)
|
||||||
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
save_videos(image_dict, dt, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||||
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||||
|
|
||||||
|
|
||||||
def save_videos(video, dt, video_path=None):
|
def save_videos(video, dt, video_path):
|
||||||
if isinstance(video, list):
|
if isinstance(video, list):
|
||||||
cam_names = list(video[0].keys())
|
cam_names = list(video[0].keys())
|
||||||
h, w, _ = video[0][cam_names[0]].shape
|
h, w, _ = video[0][cam_names[0]].shape
|
||||||
w = w * len(cam_names)
|
w = w * len(cam_names)
|
||||||
fps = int(1/dt)
|
fps = max(1, int(round(1 / dt)))
|
||||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
for ts, image_dict in enumerate(video):
|
for ts, image_dict in enumerate(video):
|
||||||
images = []
|
images = []
|
||||||
@@ -66,7 +67,7 @@ def save_videos(video, dt, video_path=None):
|
|||||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||||
|
|
||||||
n_frames, h, w, _ = all_cam_videos.shape
|
n_frames, h, w, _ = all_cam_videos.shape
|
||||||
fps = int(1 / dt)
|
fps = max(1, int(round(1 / dt)))
|
||||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
for t in range(n_frames):
|
for t in range(n_frames):
|
||||||
image = all_cam_videos[t]
|
image = all_cam_videos[t]
|
||||||
|
|||||||
Reference in New Issue
Block a user