构建no-text数据集

This commit is contained in:
2026-02-20 14:13:25 +08:00
parent ee257bcb6c
commit d85cce8a52
5 changed files with 194 additions and 62 deletions

View File

@@ -152,6 +152,16 @@ def parse_args() -> argparse.Namespace:
action="store_true", action="store_true",
help="Disable head/tail stationary instruction override", help="Disable head/tail stationary instruction override",
) )
parser.add_argument(
"--trim_stationary_edges",
action="store_true",
help="Trim stationary head/tail segments and keep only the middle moving segment",
)
parser.add_argument(
"--no_text_instruction",
action="store_true",
help="Do not save instruction/instruction_timestep (and disable text feature encoding)",
)
return parser.parse_args() return parser.parse_args()
@@ -317,48 +327,21 @@ def override_stationary_edge_instructions(
motion_threshold: float, motion_threshold: float,
) -> Tuple[List[str], int, int]: ) -> Tuple[List[str], int, int]:
""" """
Override instruction text at head/tail when deviation is below threshold. Override instruction text at head/tail using qpos velocity.
Start side: use the first frame as reference and expand forward until Start side: once qpos speed is above threshold for consecutive frames,
deviation exceeds threshold. stop applying stop_instruction from that point onward.
End side: use the last frame as reference and expand backward until End side: similarly scan backward from the end.
deviation exceeds threshold.
""" """
num = len(instructions) num = len(instructions)
if num == 0: if num == 0:
return instructions, 0, 0 return instructions, 0, 0
# normalize to comparable scales (0~1) start_count, end_count = detect_stationary_edge_counts_from_qpos(
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0) motor_pos_y=motor_pos_y,
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0) motor_pos_x=motor_pos_x,
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0) motion_window=motion_window,
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0) motion_threshold=motion_threshold,
)
if num == 1:
return [stop_instruction], 1, 1
# keep argument for backward CLI compatibility
_ = motion_window
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
return float(np.max(np.abs(cur - ref)))
start_count = 0
for i in range(num):
if deviation_to_ref(i, start_ref) <= motion_threshold:
start_count += 1
else:
break
end_count = 0
for i in range(num - 1, -1, -1):
if deviation_to_ref(i, end_ref) <= motion_threshold:
end_count += 1
else:
break
updated = list(instructions) updated = list(instructions)
for i in range(start_count): for i in range(start_count):
@@ -369,6 +352,56 @@ def override_stationary_edge_instructions(
return updated, start_count, end_count return updated, start_count, end_count
def detect_stationary_edge_counts_from_qpos(
motor_pos_y: np.ndarray,
motor_pos_x: np.ndarray,
motion_window: int,
motion_threshold: float,
) -> Tuple[int, int]:
"""Return stationary frame counts on head and tail using qpos velocity rule."""
num = int(len(motor_pos_y))
if num == 0:
return 0, 0
if num == 1:
return 1, 1
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
consecutive = max(1, int(motion_window))
dt = 1.0 / 30.0
frame_speed = np.zeros((num,), dtype=np.float32)
dy = np.abs(np.diff(py)) / dt
dx = np.abs(np.diff(px)) / dt
frame_speed[1:] = np.maximum(dy, dx)
high_run = 0
start_count = num
for i in range(1, num):
if frame_speed[i] > motion_threshold:
high_run += 1
if high_run >= consecutive:
start_count = i - consecutive + 1
break
else:
high_run = 0
high_run = 0
end_count = num
for i in range(num - 1, 0, -1):
if frame_speed[i] > motion_threshold:
high_run += 1
if high_run >= consecutive:
tail_start = i + consecutive
end_count = max(0, num - tail_start)
break
else:
high_run = 0
return start_count, end_count
def find_segment_csv(segment_dir: Path) -> Path: def find_segment_csv(segment_dir: Path) -> Path:
csvs = sorted(segment_dir.glob("*.csv")) csvs = sorted(segment_dir.glob("*.csv"))
if not csvs: if not csvs:
@@ -441,6 +474,9 @@ def save_episode_plot_with_stop_segments(
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
if args.no_text_instruction and args.encode_text_features:
raise ValueError('--no_text_instruction and --encode_text_features cannot be used together.')
segment_dir = Path(args.segment_dir).resolve() segment_dir = Path(args.segment_dir).resolve()
output_dir = Path(args.output_dir).resolve() output_dir = Path(args.output_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
@@ -513,6 +549,7 @@ def main() -> None:
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0] action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
if not args.no_text_instruction:
ins = instruction_from_annotation( ins = instruction_from_annotation(
ann, ann,
crop, crop,
@@ -521,9 +558,39 @@ def main() -> None:
) )
instructions.append(ins) instructions.append(ins)
start_stop_count = 0 start_stop_count, end_stop_count = detect_stationary_edge_counts_from_qpos(
end_stop_count = 0 motor_pos_y=motor_pos_y_series,
if not args.disable_stop_override: motor_pos_x=motor_pos_x_series,
motion_window=args.motion_window,
motion_threshold=args.motion_threshold,
)
if args.trim_stationary_edges:
keep_start = int(start_stop_count)
keep_end = int(num - end_stop_count)
if keep_end <= keep_start:
raise ValueError(
f'No moving segment left after trim: start={start_stop_count}, end={end_stop_count}, num={num}. '
f'Consider lowering --motion_threshold or --motion_window.'
)
images = images[keep_start:keep_end]
qpos = qpos[keep_start:keep_end]
action = action[keep_start:keep_end]
motor_pos_y_series = motor_pos_y_series[keep_start:keep_end]
motor_pos_x_series = motor_pos_x_series[keep_start:keep_end]
motor_cmd_0_series = motor_cmd_0_series[keep_start:keep_end]
motor_cmd_1_series = motor_cmd_1_series[keep_start:keep_end]
if not args.no_text_instruction:
instructions = instructions[keep_start:keep_end]
print(
f'Trim stationary edges: removed head={start_stop_count}, tail={end_stop_count}, '
f'kept={keep_end - keep_start}'
)
# After trimming, full kept segment is the moving region.
start_stop_count, end_stop_count = 0, 0
if (not args.disable_stop_override) and (not args.no_text_instruction):
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions( instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
instructions=instructions, instructions=instructions,
motor_pos_y=motor_pos_y_series, motor_pos_y=motor_pos_y_series,
@@ -564,10 +631,11 @@ def main() -> None:
root.create_dataset("action", data=action, dtype=np.float32) root.create_dataset("action", data=action, dtype=np.float32)
if not args.no_text_instruction:
str_dtype = h5py.string_dtype(encoding="utf-8") str_dtype = h5py.string_dtype(encoding="utf-8")
root.create_dataset( root.create_dataset(
"instruction_timestep", "instruction_timestep",
shape=(num,), shape=(len(instructions),),
dtype=str_dtype, dtype=str_dtype,
data=np.asarray(instructions, dtype=object), data=np.asarray(instructions, dtype=object),
) )
@@ -589,7 +657,8 @@ def main() -> None:
if not args.disable_stop_override: if not args.disable_stop_override:
print( print(
f"stationary override: head={start_stop_count}, tail={end_stop_count}, " f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
f"mode=endpoint_reference, threshold={args.motion_threshold}, " f"mode=qpos_velocity_consecutive, consecutive={args.motion_window}, "
f"threshold={args.motion_threshold}, "
f"instruction='{args.stop_instruction}'" f"instruction='{args.stop_instruction}'"
) )
if text_features is not None: if text_features is not None:

25
build_no_text_dataset.sh Executable file
View File

@@ -0,0 +1,25 @@
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/00-follow"
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/follow-no-text"
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
mkdir -p "$OUT_DIR"
i=50
for d in "$SEG_ROOT"/follow_seg_*; do
[ -d "$d" ] || continue
echo "Building $d -> episode_$i"
python "$SCRIPT" \
--segment_dir "$d" \
--output_dir "$OUT_DIR" \
--episode_idx "$i" \
--max_frames -1 \
--camera_name top \
--crop 733 30 1754 1051 \
--resize 224 224 \
--motion_window 3 \
--motion_threshold 0.05 \
--state_norm minus1_1 \
--action_norm minus1_1 \
--trim_stationary_edges \
--no_text_instruction
i=$((i+1))
done

29
build_text_dataset.sh Executable file
View File

@@ -0,0 +1,29 @@
SEG_ROOT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/raw_data/01-cannulation"
OUT_DIR="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/data/cannulation"
SCRIPT="/home/cyx6123/DuodenoVLA/data/ACT/aloha/act/build_endoscope_act_dataset.py"
mkdir -p "$OUT_DIR"
i=12
for d in "$SEG_ROOT"/seg_*; do
[ -d "$d" ] || continue
echo "Building $d -> episode_$i"
python "$SCRIPT" \
--segment_dir "$d" \
--output_dir "$OUT_DIR" \
--episode_idx "$i" \
--max_frames -1 \
--camera_name top \
--crop 733 30 1754 1051 \
--resize 224 224 \
--instruction_template 'Cannulate the {label} on the phantom located at the {region} with the sphincterotome.' \
--instruction_empty 'No target visible.' \
--stop_instruction 'Stop move.' \
--motion_window 3 \
--motion_threshold 0.05 \
--state_norm minus1_1 \
--action_norm minus1_1 \
--encode_text_features \
--text_model_name distilbert-base-uncased \
--text_batch_size 32
i=$((i+1))
done

View File

@@ -67,6 +67,15 @@ ENDOSCOPE_TASK_CONFIGS = {
'text_max_length': 32, 'text_max_length': 32,
'text_tokenizer_name': 'distilbert-base-uncased', 'text_tokenizer_name': 'distilbert-base-uncased',
}, },
'endoscope_both_no_text': {
'dataset_dir': DATA_DIR + '/both-no-text',
'num_episodes': 3,
'episode_len': 400,
'camera_names': ['top'],
'state_dim': 2,
'action_dim': 2,
'use_text_instruction': False,
},
} }
### Simulation envs fixed constants ### Simulation envs fixed constants

View File

@@ -344,7 +344,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}') raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
# obtain train test split # obtain train test split
train_ratio = 0.8 train_ratio = 0.9
shuffled_indices = np.random.permutation(len(episode_ids)) shuffled_indices = np.random.permutation(len(episode_ids))
train_count = int(train_ratio * len(episode_ids)) train_count = int(train_ratio * len(episode_ids))
train_indices = shuffled_indices[:train_count] train_indices = shuffled_indices[:train_count]