构建no-text数据集
This commit is contained in:
@@ -152,6 +152,16 @@ def parse_args() -> argparse.Namespace:
|
||||
action="store_true",
|
||||
help="Disable head/tail stationary instruction override",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trim_stationary_edges",
|
||||
action="store_true",
|
||||
help="Trim stationary head/tail segments and keep only the middle moving segment",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_text_instruction",
|
||||
action="store_true",
|
||||
help="Do not save instruction/instruction_timestep (and disable text feature encoding)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -317,48 +327,21 @@ def override_stationary_edge_instructions(
|
||||
motion_threshold: float,
|
||||
) -> Tuple[List[str], int, int]:
|
||||
"""
|
||||
Override instruction text at head/tail when deviation is below threshold.
|
||||
Start side: use the first frame as reference and expand forward until
|
||||
deviation exceeds threshold.
|
||||
End side: use the last frame as reference and expand backward until
|
||||
deviation exceeds threshold.
|
||||
Override instruction text at head/tail using qpos velocity.
|
||||
Start side: once qpos speed is above threshold for consecutive frames,
|
||||
stop applying stop_instruction from that point onward.
|
||||
End side: similarly scan backward from the end.
|
||||
"""
|
||||
num = len(instructions)
|
||||
if num == 0:
|
||||
return instructions, 0, 0
|
||||
|
||||
# normalize to comparable scales (0~1)
|
||||
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
||||
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
||||
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0)
|
||||
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0)
|
||||
|
||||
if num == 1:
|
||||
return [stop_instruction], 1, 1
|
||||
|
||||
# keep argument for backward CLI compatibility
|
||||
_ = motion_window
|
||||
|
||||
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
|
||||
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
|
||||
|
||||
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
|
||||
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
|
||||
return float(np.max(np.abs(cur - ref)))
|
||||
|
||||
start_count = 0
|
||||
for i in range(num):
|
||||
if deviation_to_ref(i, start_ref) <= motion_threshold:
|
||||
start_count += 1
|
||||
else:
|
||||
break
|
||||
|
||||
end_count = 0
|
||||
for i in range(num - 1, -1, -1):
|
||||
if deviation_to_ref(i, end_ref) <= motion_threshold:
|
||||
end_count += 1
|
||||
else:
|
||||
break
|
||||
start_count, end_count = detect_stationary_edge_counts_from_qpos(
|
||||
motor_pos_y=motor_pos_y,
|
||||
motor_pos_x=motor_pos_x,
|
||||
motion_window=motion_window,
|
||||
motion_threshold=motion_threshold,
|
||||
)
|
||||
|
||||
updated = list(instructions)
|
||||
for i in range(start_count):
|
||||
@@ -369,6 +352,56 @@ def override_stationary_edge_instructions(
|
||||
return updated, start_count, end_count
|
||||
|
||||
|
||||
def detect_stationary_edge_counts_from_qpos(
|
||||
motor_pos_y: np.ndarray,
|
||||
motor_pos_x: np.ndarray,
|
||||
motion_window: int,
|
||||
motion_threshold: float,
|
||||
) -> Tuple[int, int]:
|
||||
"""Return stationary frame counts on head and tail using qpos velocity rule."""
|
||||
num = int(len(motor_pos_y))
|
||||
if num == 0:
|
||||
return 0, 0
|
||||
if num == 1:
|
||||
return 1, 1
|
||||
|
||||
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
||||
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
||||
|
||||
consecutive = max(1, int(motion_window))
|
||||
dt = 1.0 / 30.0
|
||||
|
||||
frame_speed = np.zeros((num,), dtype=np.float32)
|
||||
dy = np.abs(np.diff(py)) / dt
|
||||
dx = np.abs(np.diff(px)) / dt
|
||||
frame_speed[1:] = np.maximum(dy, dx)
|
||||
|
||||
high_run = 0
|
||||
start_count = num
|
||||
for i in range(1, num):
|
||||
if frame_speed[i] > motion_threshold:
|
||||
high_run += 1
|
||||
if high_run >= consecutive:
|
||||
start_count = i - consecutive + 1
|
||||
break
|
||||
else:
|
||||
high_run = 0
|
||||
|
||||
high_run = 0
|
||||
end_count = num
|
||||
for i in range(num - 1, 0, -1):
|
||||
if frame_speed[i] > motion_threshold:
|
||||
high_run += 1
|
||||
if high_run >= consecutive:
|
||||
tail_start = i + consecutive
|
||||
end_count = max(0, num - tail_start)
|
||||
break
|
||||
else:
|
||||
high_run = 0
|
||||
|
||||
return start_count, end_count
|
||||
|
||||
|
||||
def find_segment_csv(segment_dir: Path) -> Path:
|
||||
csvs = sorted(segment_dir.glob("*.csv"))
|
||||
if not csvs:
|
||||
@@ -441,6 +474,9 @@ def save_episode_plot_with_stop_segments(
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if args.no_text_instruction and args.encode_text_features:
|
||||
raise ValueError('--no_text_instruction and --encode_text_features cannot be used together.')
|
||||
|
||||
segment_dir = Path(args.segment_dir).resolve()
|
||||
output_dir = Path(args.output_dir).resolve()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -513,17 +549,48 @@ def main() -> None:
|
||||
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||
action[i, 1] = normalize_value(np.array([motor_cmd_1], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||
|
||||
ins = instruction_from_annotation(
|
||||
ann,
|
||||
crop,
|
||||
args.instruction_template,
|
||||
args.instruction_empty,
|
||||
)
|
||||
instructions.append(ins)
|
||||
if not args.no_text_instruction:
|
||||
ins = instruction_from_annotation(
|
||||
ann,
|
||||
crop,
|
||||
args.instruction_template,
|
||||
args.instruction_empty,
|
||||
)
|
||||
instructions.append(ins)
|
||||
|
||||
start_stop_count = 0
|
||||
end_stop_count = 0
|
||||
if not args.disable_stop_override:
|
||||
start_stop_count, end_stop_count = detect_stationary_edge_counts_from_qpos(
|
||||
motor_pos_y=motor_pos_y_series,
|
||||
motor_pos_x=motor_pos_x_series,
|
||||
motion_window=args.motion_window,
|
||||
motion_threshold=args.motion_threshold,
|
||||
)
|
||||
|
||||
if args.trim_stationary_edges:
|
||||
keep_start = int(start_stop_count)
|
||||
keep_end = int(num - end_stop_count)
|
||||
if keep_end <= keep_start:
|
||||
raise ValueError(
|
||||
f'No moving segment left after trim: start={start_stop_count}, end={end_stop_count}, num={num}. '
|
||||
f'Consider lowering --motion_threshold or --motion_window.'
|
||||
)
|
||||
images = images[keep_start:keep_end]
|
||||
qpos = qpos[keep_start:keep_end]
|
||||
action = action[keep_start:keep_end]
|
||||
motor_pos_y_series = motor_pos_y_series[keep_start:keep_end]
|
||||
motor_pos_x_series = motor_pos_x_series[keep_start:keep_end]
|
||||
motor_cmd_0_series = motor_cmd_0_series[keep_start:keep_end]
|
||||
motor_cmd_1_series = motor_cmd_1_series[keep_start:keep_end]
|
||||
if not args.no_text_instruction:
|
||||
instructions = instructions[keep_start:keep_end]
|
||||
|
||||
print(
|
||||
f'Trim stationary edges: removed head={start_stop_count}, tail={end_stop_count}, '
|
||||
f'kept={keep_end - keep_start}'
|
||||
)
|
||||
# After trimming, full kept segment is the moving region.
|
||||
start_stop_count, end_stop_count = 0, 0
|
||||
|
||||
if (not args.disable_stop_override) and (not args.no_text_instruction):
|
||||
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
|
||||
instructions=instructions,
|
||||
motor_pos_y=motor_pos_y_series,
|
||||
@@ -564,19 +631,20 @@ def main() -> None:
|
||||
|
||||
root.create_dataset("action", data=action, dtype=np.float32)
|
||||
|
||||
str_dtype = h5py.string_dtype(encoding="utf-8")
|
||||
root.create_dataset(
|
||||
"instruction_timestep",
|
||||
shape=(num,),
|
||||
dtype=str_dtype,
|
||||
data=np.asarray(instructions, dtype=object),
|
||||
)
|
||||
root.create_dataset(
|
||||
"instruction",
|
||||
shape=(),
|
||||
dtype=str_dtype,
|
||||
data=instructions[0] if len(instructions) > 0 else "",
|
||||
)
|
||||
if not args.no_text_instruction:
|
||||
str_dtype = h5py.string_dtype(encoding="utf-8")
|
||||
root.create_dataset(
|
||||
"instruction_timestep",
|
||||
shape=(len(instructions),),
|
||||
dtype=str_dtype,
|
||||
data=np.asarray(instructions, dtype=object),
|
||||
)
|
||||
root.create_dataset(
|
||||
"instruction",
|
||||
shape=(),
|
||||
dtype=str_dtype,
|
||||
data=instructions[0] if len(instructions) > 0 else "",
|
||||
)
|
||||
|
||||
if text_features is not None:
|
||||
root.create_dataset("instruction_features_timestep", data=text_features, dtype=np.float32)
|
||||
@@ -589,7 +657,8 @@ def main() -> None:
|
||||
if not args.disable_stop_override:
|
||||
print(
|
||||
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
|
||||
f"mode=endpoint_reference, threshold={args.motion_threshold}, "
|
||||
f"mode=qpos_velocity_consecutive, consecutive={args.motion_window}, "
|
||||
f"threshold={args.motion_threshold}, "
|
||||
f"instruction='{args.stop_instruction}'"
|
||||
)
|
||||
if text_features is not None:
|
||||
|
||||
Reference in New Issue
Block a user