构建no-text数据集
This commit is contained in:
@@ -152,6 +152,16 @@ def parse_args() -> argparse.Namespace:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable head/tail stationary instruction override",
|
help="Disable head/tail stationary instruction override",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trim_stationary_edges",
|
||||||
|
action="store_true",
|
||||||
|
help="Trim stationary head/tail segments and keep only the middle moving segment",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_text_instruction",
|
||||||
|
action="store_true",
|
||||||
|
help="Do not save instruction/instruction_timestep (and disable text feature encoding)",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -317,48 +327,21 @@ def override_stationary_edge_instructions(
|
|||||||
motion_threshold: float,
|
motion_threshold: float,
|
||||||
) -> Tuple[List[str], int, int]:
|
) -> Tuple[List[str], int, int]:
|
||||||
"""
|
"""
|
||||||
Override instruction text at head/tail when deviation is below threshold.
|
Override instruction text at head/tail using qpos velocity.
|
||||||
Start side: use the first frame as reference and expand forward until
|
Start side: once qpos speed is above threshold for consecutive frames,
|
||||||
deviation exceeds threshold.
|
stop applying stop_instruction from that point onward.
|
||||||
End side: use the last frame as reference and expand backward until
|
End side: similarly scan backward from the end.
|
||||||
deviation exceeds threshold.
|
|
||||||
"""
|
"""
|
||||||
num = len(instructions)
|
num = len(instructions)
|
||||||
if num == 0:
|
if num == 0:
|
||||||
return instructions, 0, 0
|
return instructions, 0, 0
|
||||||
|
|
||||||
# normalize to comparable scales (0~1)
|
start_count, end_count = detect_stationary_edge_counts_from_qpos(
|
||||||
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
motor_pos_y=motor_pos_y,
|
||||||
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
motor_pos_x=motor_pos_x,
|
||||||
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0)
|
motion_window=motion_window,
|
||||||
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0)
|
motion_threshold=motion_threshold,
|
||||||
|
)
|
||||||
if num == 1:
|
|
||||||
return [stop_instruction], 1, 1
|
|
||||||
|
|
||||||
# keep argument for backward CLI compatibility
|
|
||||||
_ = motion_window
|
|
||||||
|
|
||||||
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
|
|
||||||
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
|
|
||||||
|
|
||||||
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
|
|
||||||
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
|
|
||||||
return float(np.max(np.abs(cur - ref)))
|
|
||||||
|
|
||||||
start_count = 0
|
|
||||||
for i in range(num):
|
|
||||||
if deviation_to_ref(i, start_ref) <= motion_threshold:
|
|
||||||
start_count += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
end_count = 0
|
|
||||||
for i in range(num - 1, -1, -1):
|
|
||||||
if deviation_to_ref(i, end_ref) <= motion_threshold:
|
|
||||||
end_count += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
updated = list(instructions)
|
updated = list(instructions)
|
||||||
for i in range(start_count):
|
for i in range(start_count):
|
||||||
@@ -369,6 +352,56 @@ def override_stationary_edge_instructions(
|
|||||||
return updated, start_count, end_count
|
return updated, start_count, end_count
|
||||||
|
|
||||||
|
|
||||||
|
def detect_stationary_edge_counts_from_qpos(
|
||||||
|
motor_pos_y: np.ndarray,
|
||||||
|
motor_pos_x: np.ndarray,
|
||||||
|
motion_window: int,
|
||||||
|
motion_threshold: float,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Return stationary frame counts on head and tail using qpos velocity rule."""
|
||||||
|
num = int(len(motor_pos_y))
|
||||||
|
if num == 0:
|
||||||
|
return 0, 0
|
||||||
|
if num == 1:
|
||||||
|
return 1, 1
|
||||||
|
|
||||||
|
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
||||||
|
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
||||||
|
|
||||||
|
consecutive = max(1, int(motion_window))
|
||||||
|
dt = 1.0 / 30.0
|
||||||
|
|
||||||
|
frame_speed = np.zeros((num,), dtype=np.float32)
|
||||||
|
dy = np.abs(np.diff(py)) / dt
|
||||||
|
dx = np.abs(np.diff(px)) / dt
|
||||||
|
frame_speed[1:] = np.maximum(dy, dx)
|
||||||
|
|
||||||
|
high_run = 0
|
||||||
|
start_count = num
|
||||||
|
for i in range(1, num):
|
||||||
|
if frame_speed[i] > motion_threshold:
|
||||||
|
high_run += 1
|
||||||
|
if high_run >= consecutive:
|
||||||
|
start_count = i - consecutive + 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
high_run = 0
|
||||||
|
|
||||||
|
high_run = 0
|
||||||
|
end_count = num
|
||||||
|
for i in range(num - 1, 0, -1):
|
||||||
|
if frame_speed[i] > motion_threshold:
|
||||||
|
high_run += 1
|
||||||
|
if high_run >= consecutive:
|
||||||
|
tail_start = i + consecutive
|
||||||
|
end_count = max(0, num - tail_start)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
high_run = 0
|
||||||
|
|
||||||
|
return start_count, end_count
|
||||||
|
|
||||||
|
|
||||||
def find_segment_csv(segment_dir: Path) -> Path:
|
def find_segment_csv(segment_dir: Path) -> Path:
|
||||||
csvs = sorted(segment_dir.glob("*.csv"))
|
csvs = sorted(segment_dir.glob("*.csv"))
|
||||||
if not csvs:
|
if not csvs:
|
||||||
@@ -441,6 +474,9 @@ def save_episode_plot_with_stop_segments(
|
|||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
if args.no_text_instruction and args.encode_text_features:
|
||||||
|
raise ValueError('--no_text_instruction and --encode_text_features cannot be used together.')
|
||||||
|
|
||||||
segment_dir = Path(args.segment_dir).resolve()
|
segment_dir = Path(args.segment_dir).resolve()
|
||||||
output_dir = Path(args.output_dir).resolve()
|
output_dir = Path(args.output_dir).resolve()
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -513,17 +549,48 @@ def main() -> None:
|
|||||||
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||||
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||||
|
|
||||||
ins = instruction_from_annotation(
|
if not args.no_text_instruction:
|
||||||
ann,
|
ins = instruction_from_annotation(
|
||||||
crop,
|
ann,
|
||||||
args.instruction_template,
|
crop,
|
||||||
args.instruction_empty,
|
args.instruction_template,
|
||||||
)
|
args.instruction_empty,
|
||||||
instructions.append(ins)
|
)
|
||||||
|
instructions.append(ins)
|
||||||
|
|
||||||
start_stop_count = 0
|
start_stop_count, end_stop_count = detect_stationary_edge_counts_from_qpos(
|
||||||
end_stop_count = 0
|
motor_pos_y=motor_pos_y_series,
|
||||||
if not args.disable_stop_override:
|
motor_pos_x=motor_pos_x_series,
|
||||||
|
motion_window=args.motion_window,
|
||||||
|
motion_threshold=args.motion_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.trim_stationary_edges:
|
||||||
|
keep_start = int(start_stop_count)
|
||||||
|
keep_end = int(num - end_stop_count)
|
||||||
|
if keep_end <= keep_start:
|
||||||
|
raise ValueError(
|
||||||
|
f'No moving segment left after trim: start={start_stop_count}, end={end_stop_count}, num={num}. '
|
||||||
|
f'Consider lowering --motion_threshold or --motion_window.'
|
||||||
|
)
|
||||||
|
images = images[keep_start:keep_end]
|
||||||
|
qpos = qpos[keep_start:keep_end]
|
||||||
|
action = action[keep_start:keep_end]
|
||||||
|
motor_pos_y_series = motor_pos_y_series[keep_start:keep_end]
|
||||||
|
motor_pos_x_series = motor_pos_x_series[keep_start:keep_end]
|
||||||
|
motor_cmd_0_series = motor_cmd_0_series[keep_start:keep_end]
|
||||||
|
motor_cmd_1_series = motor_cmd_1_series[keep_start:keep_end]
|
||||||
|
if not args.no_text_instruction:
|
||||||
|
instructions = instructions[keep_start:keep_end]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'Trim stationary edges: removed head={start_stop_count}, tail={end_stop_count}, '
|
||||||
|
f'kept={keep_end - keep_start}'
|
||||||
|
)
|
||||||
|
# After trimming, full kept segment is the moving region.
|
||||||
|
start_stop_count, end_stop_count = 0, 0
|
||||||
|
|
||||||
|
if (not args.disable_stop_override) and (not args.no_text_instruction):
|
||||||
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
|
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
motor_pos_y=motor_pos_y_series,
|
motor_pos_y=motor_pos_y_series,
|
||||||
@@ -564,19 +631,20 @@ def main() -> None:
|
|||||||
|
|
||||||
root.create_dataset("action", data=action, dtype=np.float32)
|
root.create_dataset("action", data=action, dtype=np.float32)
|
||||||
|
|
||||||
str_dtype = h5py.string_dtype(encoding="utf-8")
|
if not args.no_text_instruction:
|
||||||
root.create_dataset(
|
str_dtype = h5py.string_dtype(encoding="utf-8")
|
||||||
"instruction_timestep",
|
root.create_dataset(
|
||||||
shape=(num,),
|
"instruction_timestep",
|
||||||
dtype=str_dtype,
|
shape=(len(instructions),),
|
||||||
data=np.asarray(instructions, dtype=object),
|
dtype=str_dtype,
|
||||||
)
|
data=np.asarray(instructions, dtype=object),
|
||||||
root.create_dataset(
|
)
|
||||||
"instruction",
|
root.create_dataset(
|
||||||
shape=(),
|
"instruction",
|
||||||
dtype=str_dtype,
|
shape=(),
|
||||||
data=instructions[0] if len(instructions) > 0 else "",
|
dtype=str_dtype,
|
||||||
)
|
data=instructions[0] if len(instructions) > 0 else "",
|
||||||
|
)
|
||||||
|
|
||||||
if text_features is not None:
|
if text_features is not None:
|
||||||
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
|
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
|
||||||
@@ -589,7 +657,8 @@ def main() -> None:
|
|||||||
if not args.disable_stop_override:
|
if not args.disable_stop_override:
|
||||||
print(
|
print(
|
||||||
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
|
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
|
||||||
f"mode=endpoint_reference, threshold={args.motion_threshold}, "
|
f"mode=qpos_velocity_consecutive, consecutive={args.motion_window}, "
|
||||||
|
f"threshold={args.motion_threshold}, "
|
||||||
f"instruction='{args.stop_instruction}'"
|
f"instruction='{args.stop_instruction}'"
|
||||||
)
|
)
|
||||||
if text_features is not None:
|
if text_features is not None:
|
||||||
|
|||||||
25
build_no_text_dataset.sh
Executable file
25
build_no_text_dataset.sh
Executable file
@@ -0,0 +1,25 @@
|
|||||||
|
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/00-follow"
|
||||||
|
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/follow-no-text"
|
||||||
|
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
|
||||||
|
|
||||||
|
mkdir -p "$OUT_DIR"
|
||||||
|
i=50
|
||||||
|
for d in "$SEG_ROOT"/follow_seg_*; do
|
||||||
|
[ -d "$d" ] || continue
|
||||||
|
echo "Building $d -> episode_$i"
|
||||||
|
python "$SCRIPT" \
|
||||||
|
--segment_dir "$d" \
|
||||||
|
--output_dir "$OUT_DIR" \
|
||||||
|
--episode_idx "$i" \
|
||||||
|
--max_frames -1 \
|
||||||
|
--camera_name top \
|
||||||
|
--crop 733 30 1754 1051 \
|
||||||
|
--resize 224 224 \
|
||||||
|
--motion_window 3 \
|
||||||
|
--motion_threshold 0.05 \
|
||||||
|
--state_norm minus1_1 \
|
||||||
|
--action_norm minus1_1 \
|
||||||
|
--trim_stationary_edges \
|
||||||
|
--no_text_instruction
|
||||||
|
i=$((i+1))
|
||||||
|
done
|
||||||
29
build_text_dataset.sh
Executable file
29
build_text_dataset.sh
Executable file
@@ -0,0 +1,29 @@
|
|||||||
|
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/01-cannulation"
|
||||||
|
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/cannulation"
|
||||||
|
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
|
||||||
|
|
||||||
|
mkdir -p "$OUT_DIR"
|
||||||
|
i=12
|
||||||
|
for d in "$SEG_ROOT"/seg_*; do
|
||||||
|
[ -d "$d" ] || continue
|
||||||
|
echo "Building $d -> episode_$i"
|
||||||
|
python "$SCRIPT" \
|
||||||
|
--segment_dir "$d" \
|
||||||
|
--output_dir "$OUT_DIR" \
|
||||||
|
--episode_idx "$i" \
|
||||||
|
--max_frames -1 \
|
||||||
|
--camera_name top \
|
||||||
|
--crop 733 30 1754 1051 \
|
||||||
|
--resize 224 224 \
|
||||||
|
--instruction_template 'Cannulate the {label} on the phantom located at the {region} with the sphincterotome.' \
|
||||||
|
--instruction_empty 'No target visible.' \
|
||||||
|
--stop_instruction 'Stop move.' \
|
||||||
|
--motion_window 3 \
|
||||||
|
--motion_threshold 0.05 \
|
||||||
|
--state_norm minus1_1 \
|
||||||
|
--action_norm minus1_1 \
|
||||||
|
--encode_text_features \
|
||||||
|
--text_model_name distilbert-base-uncased \
|
||||||
|
--text_batch_size 32
|
||||||
|
i=$((i+1))
|
||||||
|
done
|
||||||
@@ -67,6 +67,15 @@ ENDOSCOPE_TASK_CONFIGS = {
|
|||||||
'text_max_length': 32,
|
'text_max_length': 32,
|
||||||
'text_tokenizer_name': 'distilbert-base-uncased',
|
'text_tokenizer_name': 'distilbert-base-uncased',
|
||||||
},
|
},
|
||||||
|
'endoscope_both_no_text': {
|
||||||
|
'dataset_dir': DATA_DIR + '/both-no-text',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'use_text_instruction': False,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
### Simulation envs fixed constants
|
### Simulation envs fixed constants
|
||||||
|
|||||||
2
utils.py
2
utils.py
@@ -344,7 +344,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
|||||||
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
|
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
|
||||||
|
|
||||||
# obtain train test split
|
# obtain train test split
|
||||||
train_ratio = 0.8
|
train_ratio = 0.9
|
||||||
shuffled_indices = np.random.permutation(len(episode_ids))
|
shuffled_indices = np.random.permutation(len(episode_ids))
|
||||||
train_count = int(train_ratio * len(episode_ids))
|
train_count = int(train_ratio * len(episode_ids))
|
||||||
train_indices = shuffled_indices[:train_count]
|
train_indices = shuffled_indices[:train_count]
|
||||||
|
|||||||
Reference in New Issue
Block a user