增加了stop命令
This commit is contained in:
@@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
|
||||
import h5py
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
|
||||
|
||||
@@ -96,6 +97,12 @@ def parse_args() -> argparse.Namespace:
|
||||
default="No target visible.",
|
||||
help="Instruction when no valid target after crop",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop_instruction",
|
||||
type=str,
|
||||
default="Stop move.",
|
||||
help="Instruction used for stationary head/tail frames",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state_norm",
|
||||
choices=["minus1_1", "0_1", "raw"],
|
||||
@@ -125,6 +132,26 @@ def parse_args() -> argparse.Namespace:
|
||||
default=32,
|
||||
help="Batch size for text feature extraction",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--motion_window",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Sliding window size used for stationary detection at beginning/end",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--motion_threshold",
|
||||
type=float,
|
||||
default=0.002,
|
||||
help=(
|
||||
"Motion threshold in normalized delta space (0~1). "
|
||||
"Smaller value means stricter stationary detection"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable_stop_override",
|
||||
action="store_true",
|
||||
help="Disable head/tail stationary instruction override",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -272,6 +299,76 @@ def extract_text_features(
|
||||
return np.concatenate(feats, axis=0)
|
||||
|
||||
|
||||
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
|
||||
scale = max_v - min_v
|
||||
if scale <= 0:
|
||||
return np.zeros_like(v, dtype=np.float32)
|
||||
return ((v - min_v) / scale).astype(np.float32)
|
||||
|
||||
|
||||
def override_stationary_edge_instructions(
|
||||
instructions: List[str],
|
||||
motor_pos_y: np.ndarray,
|
||||
motor_pos_x: np.ndarray,
|
||||
motor_cmd_0: np.ndarray,
|
||||
motor_cmd_1: np.ndarray,
|
||||
stop_instruction: str,
|
||||
motion_window: int,
|
||||
motion_threshold: float,
|
||||
) -> Tuple[List[str], int, int]:
|
||||
"""
|
||||
Override instruction text at head/tail when deviation is below threshold.
|
||||
Start side: use the first frame as reference and expand forward until
|
||||
deviation exceeds threshold.
|
||||
End side: use the last frame as reference and expand backward until
|
||||
deviation exceeds threshold.
|
||||
"""
|
||||
num = len(instructions)
|
||||
if num == 0:
|
||||
return instructions, 0, 0
|
||||
|
||||
# normalize to comparable scales (0~1)
|
||||
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
||||
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
||||
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0)
|
||||
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0)
|
||||
|
||||
if num == 1:
|
||||
return [stop_instruction], 1, 1
|
||||
|
||||
# keep argument for backward CLI compatibility
|
||||
_ = motion_window
|
||||
|
||||
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
|
||||
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
|
||||
|
||||
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
|
||||
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
|
||||
return float(np.max(np.abs(cur - ref)))
|
||||
|
||||
start_count = 0
|
||||
for i in range(num):
|
||||
if deviation_to_ref(i, start_ref) <= motion_threshold:
|
||||
start_count += 1
|
||||
else:
|
||||
break
|
||||
|
||||
end_count = 0
|
||||
for i in range(num - 1, -1, -1):
|
||||
if deviation_to_ref(i, end_ref) <= motion_threshold:
|
||||
end_count += 1
|
||||
else:
|
||||
break
|
||||
|
||||
updated = list(instructions)
|
||||
for i in range(start_count):
|
||||
updated[i] = stop_instruction
|
||||
for i in range(num - end_count, num):
|
||||
updated[i] = stop_instruction
|
||||
|
||||
return updated, start_count, end_count
|
||||
|
||||
|
||||
def find_segment_csv(segment_dir: Path) -> Path:
|
||||
csvs = sorted(segment_dir.glob("*.csv"))
|
||||
if not csvs:
|
||||
@@ -279,6 +376,68 @@ def find_segment_csv(segment_dir: Path) -> Path:
|
||||
return csvs[0]
|
||||
|
||||
|
||||
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
|
||||
"""Convert boolean mask to closed-open segments [start, end)."""
|
||||
segments: List[Tuple[int, int]] = []
|
||||
if mask.size == 0:
|
||||
return segments
|
||||
in_seg = False
|
||||
start = 0
|
||||
for i, v in enumerate(mask.tolist()):
|
||||
if v and not in_seg:
|
||||
in_seg = True
|
||||
start = i
|
||||
elif (not v) and in_seg:
|
||||
in_seg = False
|
||||
segments.append((start, i))
|
||||
if in_seg:
|
||||
segments.append((start, int(mask.size)))
|
||||
return segments
|
||||
|
||||
|
||||
def save_episode_plot_with_stop_segments(
|
||||
qpos: np.ndarray,
|
||||
action: np.ndarray,
|
||||
instructions: Sequence[str],
|
||||
stop_instruction: str,
|
||||
plot_path: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
|
||||
"""
|
||||
qpos = np.asarray(qpos)
|
||||
action = np.asarray(action)
|
||||
if qpos.ndim != 2 or action.ndim != 2:
|
||||
return
|
||||
|
||||
num_ts, num_dim = qpos.shape
|
||||
if num_ts == 0 or num_dim == 0:
|
||||
return
|
||||
|
||||
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
|
||||
stop_segments = _mask_to_segments(stop_mask)
|
||||
|
||||
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
|
||||
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
|
||||
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs_list[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
|
||||
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
|
||||
for seg_idx, (st, ed) in enumerate(stop_segments):
|
||||
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
|
||||
label='stop instruction' if seg_idx == 0 else None)
|
||||
ax.set_ylabel(f'dim {dim_idx}')
|
||||
ax.legend(loc='best')
|
||||
ax.grid(alpha=0.25, linestyle='--')
|
||||
|
||||
axs_list[-1].set_xlabel('timestep')
|
||||
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
|
||||
fig.tight_layout()
|
||||
fig.savefig(str(plot_path), dpi=140)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
@@ -311,6 +470,10 @@ def main() -> None:
|
||||
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
|
||||
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
|
||||
instructions: List[str] = []
|
||||
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
|
||||
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
|
||||
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
|
||||
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
|
||||
|
||||
y_min, y_max = 8000.0, 18884.0
|
||||
x_min, x_max = 7000.0, 17384.0
|
||||
@@ -340,6 +503,11 @@ def main() -> None:
|
||||
motor_cmd_0 = float(row["motor_command_0"])
|
||||
motor_cmd_1 = float(row["motor_command_1"])
|
||||
|
||||
motor_pos_y_series[i] = motor_pos_y
|
||||
motor_pos_x_series[i] = motor_pos_x
|
||||
motor_cmd_0_series[i] = motor_cmd_0
|
||||
motor_cmd_1_series[i] = motor_cmd_1
|
||||
|
||||
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
|
||||
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
|
||||
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||
@@ -353,6 +521,20 @@ def main() -> None:
|
||||
)
|
||||
instructions.append(ins)
|
||||
|
||||
start_stop_count = 0
|
||||
end_stop_count = 0
|
||||
if not args.disable_stop_override:
|
||||
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
|
||||
instructions=instructions,
|
||||
motor_pos_y=motor_pos_y_series,
|
||||
motor_pos_x=motor_pos_x_series,
|
||||
motor_cmd_0=motor_cmd_0_series,
|
||||
motor_cmd_1=motor_cmd_1_series,
|
||||
stop_instruction=args.stop_instruction,
|
||||
motion_window=args.motion_window,
|
||||
motion_threshold=args.motion_threshold,
|
||||
)
|
||||
|
||||
text_features = None
|
||||
if args.encode_text_features:
|
||||
text_features = extract_text_features(
|
||||
@@ -404,9 +586,26 @@ def main() -> None:
|
||||
print(f"Frames used: {num}")
|
||||
print(f"Image shape: {images.shape}")
|
||||
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
|
||||
if not args.disable_stop_override:
|
||||
print(
|
||||
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
|
||||
f"mode=endpoint_reference, threshold={args.motion_threshold}, "
|
||||
f"instruction='{args.stop_instruction}'"
|
||||
)
|
||||
if text_features is not None:
|
||||
print(f"instruction_features_timestep shape: {text_features.shape}")
|
||||
|
||||
# Save a same-basename plot next to the generated hdf5
|
||||
plot_path = out_path.with_suffix('.png')
|
||||
save_episode_plot_with_stop_segments(
|
||||
qpos=qpos,
|
||||
action=action,
|
||||
instructions=instructions,
|
||||
stop_instruction=args.stop_instruction,
|
||||
plot_path=plot_path,
|
||||
)
|
||||
print(f"Saved episode plot to: {plot_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user