构建no-text数据集

This commit is contained in:
2026-02-20 14:13:25 +08:00
parent ee257bcb6c
commit d85cce8a52
5 changed files with 194 additions and 62 deletions

View File

@@ -152,6 +152,16 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Disable head/tail stationary instruction override",
)
parser.add_argument(
"--trim_stationary_edges",
action="store_true",
help="Trim stationary head/tail segments and keep only the middle moving segment",
)
parser.add_argument(
"--no_text_instruction",
action="store_true",
help="Do not save instruction/instruction_timestep (and disable text feature encoding)",
)
return parser.parse_args()
@@ -317,48 +327,21 @@ def override_stationary_edge_instructions(
motion_threshold: float,
) -> Tuple[List[str], int, int]:
"""
Override instruction text at head/tail when deviation is below threshold.
Start side: use the first frame as reference and expand forward until
deviation exceeds threshold.
End side: use the last frame as reference and expand backward until
deviation exceeds threshold.
Override instruction text at head/tail using qpos velocity.
Start side: once qpos speed is above threshold for consecutive frames,
stop applying stop_instruction from that point onward.
End side: similarly scan backward from the end.
"""
num = len(instructions)
if num == 0:
return instructions, 0, 0
# normalize to comparable scales (0~1)
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0)
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0)
if num == 1:
return [stop_instruction], 1, 1
# keep argument for backward CLI compatibility
_ = motion_window
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
return float(np.max(np.abs(cur - ref)))
start_count = 0
for i in range(num):
if deviation_to_ref(i, start_ref) <= motion_threshold:
start_count += 1
else:
break
end_count = 0
for i in range(num - 1, -1, -1):
if deviation_to_ref(i, end_ref) <= motion_threshold:
end_count += 1
else:
break
start_count, end_count = detect_stationary_edge_counts_from_qpos(
motor_pos_y=motor_pos_y,
motor_pos_x=motor_pos_x,
motion_window=motion_window,
motion_threshold=motion_threshold,
)
updated = list(instructions)
for i in range(start_count):
@@ -369,6 +352,56 @@ def override_stationary_edge_instructions(
return updated, start_count, end_count
def detect_stationary_edge_counts_from_qpos(
motor_pos_y: np.ndarray,
motor_pos_x: np.ndarray,
motion_window: int,
motion_threshold: float,
) -> Tuple[int, int]:
"""Return stationary frame counts on head and tail using qpos velocity rule."""
num = int(len(motor_pos_y))
if num == 0:
return 0, 0
if num == 1:
return 1, 1
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
consecutive = max(1, int(motion_window))
dt = 1.0 / 30.0
frame_speed = np.zeros((num,), dtype=np.float32)
dy = np.abs(np.diff(py)) / dt
dx = np.abs(np.diff(px)) / dt
frame_speed[1:] = np.maximum(dy, dx)
high_run = 0
start_count = num
for i in range(1, num):
if frame_speed[i] > motion_threshold:
high_run += 1
if high_run >= consecutive:
start_count = i - consecutive + 1
break
else:
high_run = 0
high_run = 0
end_count = num
for i in range(num - 1, 0, -1):
if frame_speed[i] > motion_threshold:
high_run += 1
if high_run >= consecutive:
tail_start = i + consecutive
end_count = max(0, num - tail_start)
break
else:
high_run = 0
return start_count, end_count
def find_segment_csv(segment_dir: Path) -> Path:
csvs = sorted(segment_dir.glob("*.csv"))
if not csvs:
@@ -441,6 +474,9 @@ def save_episode_plot_with_stop_segments(
def main() -> None:
args = parse_args()
if args.no_text_instruction and args.encode_text_features:
raise ValueError('--no_text_instruction and --encode_text_features cannot be used together.')
segment_dir = Path(args.segment_dir).resolve()
output_dir = Path(args.output_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
@@ -513,17 +549,48 @@ def main() -> None:
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
ins = instruction_from_annotation(
ann,
crop,
args.instruction_template,
args.instruction_empty,
)
instructions.append(ins)
if not args.no_text_instruction:
ins = instruction_from_annotation(
ann,
crop,
args.instruction_template,
args.instruction_empty,
)
instructions.append(ins)
start_stop_count = 0
end_stop_count = 0
if not args.disable_stop_override:
start_stop_count, end_stop_count = detect_stationary_edge_counts_from_qpos(
motor_pos_y=motor_pos_y_series,
motor_pos_x=motor_pos_x_series,
motion_window=args.motion_window,
motion_threshold=args.motion_threshold,
)
if args.trim_stationary_edges:
keep_start = int(start_stop_count)
keep_end = int(num - end_stop_count)
if keep_end <= keep_start:
raise ValueError(
f'No moving segment left after trim: start={start_stop_count}, end={end_stop_count}, num={num}. '
f'Consider lowering --motion_threshold or --motion_window.'
)
images = images[keep_start:keep_end]
qpos = qpos[keep_start:keep_end]
action = action[keep_start:keep_end]
motor_pos_y_series = motor_pos_y_series[keep_start:keep_end]
motor_pos_x_series = motor_pos_x_series[keep_start:keep_end]
motor_cmd_0_series = motor_cmd_0_series[keep_start:keep_end]
motor_cmd_1_series = motor_cmd_1_series[keep_start:keep_end]
if not args.no_text_instruction:
instructions = instructions[keep_start:keep_end]
print(
f'Trim stationary edges: removed head={start_stop_count}, tail={end_stop_count}, '
f'kept={keep_end - keep_start}'
)
# After trimming, full kept segment is the moving region.
start_stop_count, end_stop_count = 0, 0
if (not args.disable_stop_override) and (not args.no_text_instruction):
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
instructions=instructions,
motor_pos_y=motor_pos_y_series,
@@ -564,19 +631,20 @@ def main() -> None:
root.create_dataset("action", data=action, dtype=np.float32)
str_dtype = h5py.string_dtype(encoding="utf-8")
root.create_dataset(
"instruction_timestep",
shape=(num,),
dtype=str_dtype,
data=np.asarray(instructions, dtype=object),
)
root.create_dataset(
"instruction",
shape=(),
dtype=str_dtype,
data=instructions[0] if len(instructions) > 0 else "",
)
if not args.no_text_instruction:
str_dtype = h5py.string_dtype(encoding="utf-8")
root.create_dataset(
"instruction_timestep",
shape=(len(instructions),),
dtype=str_dtype,
data=np.asarray(instructions, dtype=object),
)
root.create_dataset(
"instruction",
shape=(),
dtype=str_dtype,
data=instructions[0] if len(instructions) > 0 else "",
)
if text_features is not None:
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
@@ -589,7 +657,8 @@ def main() -> None:
if not args.disable_stop_override:
print(
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
f"mode=endpoint_reference, threshold={args.motion_threshold}, "
f"mode=qpos_velocity_consecutive, consecutive={args.motion_window}, "
f"threshold={args.motion_threshold}, "
f"instruction='{args.stop_instruction}'"
)
if text_features is not None: