增加了stop命令
This commit is contained in:
@@ -12,6 +12,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
|
|||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
|
RESAMPLE_BILINEAR = getattr(getattr(Image, "Resampling", Image), "BILINEAR")
|
||||||
|
|
||||||
@@ -96,6 +97,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default="No target visible.",
|
default="No target visible.",
|
||||||
help="Instruction when no valid target after crop",
|
help="Instruction when no valid target after crop",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop_instruction",
|
||||||
|
type=str,
|
||||||
|
default="Stop move.",
|
||||||
|
help="Instruction used for stationary head/tail frames",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--state_norm",
|
"--state_norm",
|
||||||
choices=["minus1_1", "0_1", "raw"],
|
choices=["minus1_1", "0_1", "raw"],
|
||||||
@@ -125,6 +132,26 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default=32,
|
default=32,
|
||||||
help="Batch size for text feature extraction",
|
help="Batch size for text feature extraction",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--motion_window",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Sliding window size used for stationary detection at beginning/end",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--motion_threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.002,
|
||||||
|
help=(
|
||||||
|
"Motion threshold in normalized delta space (0~1). "
|
||||||
|
"Smaller value means stricter stationary detection"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable_stop_override",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable head/tail stationary instruction override",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -272,6 +299,76 @@ def extract_text_features(
|
|||||||
return np.concatenate(feats, axis=0)
|
return np.concatenate(feats, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_series(v: np.ndarray, min_v: float, max_v: float) -> np.ndarray:
|
||||||
|
scale = max_v - min_v
|
||||||
|
if scale <= 0:
|
||||||
|
return np.zeros_like(v, dtype=np.float32)
|
||||||
|
return ((v - min_v) / scale).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def override_stationary_edge_instructions(
|
||||||
|
instructions: List[str],
|
||||||
|
motor_pos_y: np.ndarray,
|
||||||
|
motor_pos_x: np.ndarray,
|
||||||
|
motor_cmd_0: np.ndarray,
|
||||||
|
motor_cmd_1: np.ndarray,
|
||||||
|
stop_instruction: str,
|
||||||
|
motion_window: int,
|
||||||
|
motion_threshold: float,
|
||||||
|
) -> Tuple[List[str], int, int]:
|
||||||
|
"""
|
||||||
|
Override instruction text at head/tail when deviation is below threshold.
|
||||||
|
Start side: use the first frame as reference and expand forward until
|
||||||
|
deviation exceeds threshold.
|
||||||
|
End side: use the last frame as reference and expand backward until
|
||||||
|
deviation exceeds threshold.
|
||||||
|
"""
|
||||||
|
num = len(instructions)
|
||||||
|
if num == 0:
|
||||||
|
return instructions, 0, 0
|
||||||
|
|
||||||
|
# normalize to comparable scales (0~1)
|
||||||
|
py = _normalize_series(motor_pos_y.astype(np.float32), 8000.0, 18884.0)
|
||||||
|
px = _normalize_series(motor_pos_x.astype(np.float32), 7000.0, 17384.0)
|
||||||
|
c0 = _normalize_series(motor_cmd_0.astype(np.float32), 0.0, 65535.0)
|
||||||
|
c1 = _normalize_series(motor_cmd_1.astype(np.float32), 0.0, 65535.0)
|
||||||
|
|
||||||
|
if num == 1:
|
||||||
|
return [stop_instruction], 1, 1
|
||||||
|
|
||||||
|
# keep argument for backward CLI compatibility
|
||||||
|
_ = motion_window
|
||||||
|
|
||||||
|
start_ref = np.array([py[0], px[0], c0[0], c1[0]], dtype=np.float32)
|
||||||
|
end_ref = np.array([py[-1], px[-1], c0[-1], c1[-1]], dtype=np.float32)
|
||||||
|
|
||||||
|
def deviation_to_ref(i: int, ref: np.ndarray) -> float:
|
||||||
|
cur = np.array([py[i], px[i], c0[i], c1[i]], dtype=np.float32)
|
||||||
|
return float(np.max(np.abs(cur - ref)))
|
||||||
|
|
||||||
|
start_count = 0
|
||||||
|
for i in range(num):
|
||||||
|
if deviation_to_ref(i, start_ref) <= motion_threshold:
|
||||||
|
start_count += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
end_count = 0
|
||||||
|
for i in range(num - 1, -1, -1):
|
||||||
|
if deviation_to_ref(i, end_ref) <= motion_threshold:
|
||||||
|
end_count += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
updated = list(instructions)
|
||||||
|
for i in range(start_count):
|
||||||
|
updated[i] = stop_instruction
|
||||||
|
for i in range(num - end_count, num):
|
||||||
|
updated[i] = stop_instruction
|
||||||
|
|
||||||
|
return updated, start_count, end_count
|
||||||
|
|
||||||
|
|
||||||
def find_segment_csv(segment_dir: Path) -> Path:
|
def find_segment_csv(segment_dir: Path) -> Path:
|
||||||
csvs = sorted(segment_dir.glob("*.csv"))
|
csvs = sorted(segment_dir.glob("*.csv"))
|
||||||
if not csvs:
|
if not csvs:
|
||||||
@@ -279,6 +376,68 @@ def find_segment_csv(segment_dir: Path) -> Path:
|
|||||||
return csvs[0]
|
return csvs[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_to_segments(mask: np.ndarray) -> List[Tuple[int, int]]:
|
||||||
|
"""Convert boolean mask to closed-open segments [start, end)."""
|
||||||
|
segments: List[Tuple[int, int]] = []
|
||||||
|
if mask.size == 0:
|
||||||
|
return segments
|
||||||
|
in_seg = False
|
||||||
|
start = 0
|
||||||
|
for i, v in enumerate(mask.tolist()):
|
||||||
|
if v and not in_seg:
|
||||||
|
in_seg = True
|
||||||
|
start = i
|
||||||
|
elif (not v) and in_seg:
|
||||||
|
in_seg = False
|
||||||
|
segments.append((start, i))
|
||||||
|
if in_seg:
|
||||||
|
segments.append((start, int(mask.size)))
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
def save_episode_plot_with_stop_segments(
|
||||||
|
qpos: np.ndarray,
|
||||||
|
action: np.ndarray,
|
||||||
|
instructions: Sequence[str],
|
||||||
|
stop_instruction: str,
|
||||||
|
plot_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save a diagnostics plot (qpos/action) and highlight stop-instruction spans.
|
||||||
|
"""
|
||||||
|
qpos = np.asarray(qpos)
|
||||||
|
action = np.asarray(action)
|
||||||
|
if qpos.ndim != 2 or action.ndim != 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
num_ts, num_dim = qpos.shape
|
||||||
|
if num_ts == 0 or num_dim == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
stop_mask = np.array([ins == stop_instruction for ins in instructions], dtype=bool)
|
||||||
|
stop_segments = _mask_to_segments(stop_mask)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(num_dim, 1, figsize=(10, 3.0 * num_dim), sharex=True)
|
||||||
|
axs_list = np.atleast_1d(axs).reshape(-1).tolist()
|
||||||
|
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs_list[dim_idx]
|
||||||
|
ax.plot(qpos[:, dim_idx], label=f'qpos[{dim_idx}]', linewidth=1.4)
|
||||||
|
ax.plot(action[:, dim_idx], label=f'action[{dim_idx}]', linewidth=1.2)
|
||||||
|
for seg_idx, (st, ed) in enumerate(stop_segments):
|
||||||
|
ax.axvspan(st, ed - 1, color='orange', alpha=0.2,
|
||||||
|
label='stop instruction' if seg_idx == 0 else None)
|
||||||
|
ax.set_ylabel(f'dim {dim_idx}')
|
||||||
|
ax.legend(loc='best')
|
||||||
|
ax.grid(alpha=0.25, linestyle='--')
|
||||||
|
|
||||||
|
axs_list[-1].set_xlabel('timestep')
|
||||||
|
fig.suptitle('Episode diagnostics with stop-instruction spans', y=1.02)
|
||||||
|
fig.tight_layout()
|
||||||
|
fig.savefig(str(plot_path), dpi=140)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
@@ -311,6 +470,10 @@ def main() -> None:
|
|||||||
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
|
qpos = np.zeros((num, 2), dtype=np.float32) # [y, x]
|
||||||
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
|
action = np.zeros((num, 2), dtype=np.float32) # [cmd0(y), cmd1(x)]
|
||||||
instructions: List[str] = []
|
instructions: List[str] = []
|
||||||
|
motor_pos_y_series = np.zeros((num,), dtype=np.float32)
|
||||||
|
motor_pos_x_series = np.zeros((num,), dtype=np.float32)
|
||||||
|
motor_cmd_0_series = np.zeros((num,), dtype=np.float32)
|
||||||
|
motor_cmd_1_series = np.zeros((num,), dtype=np.float32)
|
||||||
|
|
||||||
y_min, y_max = 8000.0, 18884.0
|
y_min, y_max = 8000.0, 18884.0
|
||||||
x_min, x_max = 7000.0, 17384.0
|
x_min, x_max = 7000.0, 17384.0
|
||||||
@@ -340,6 +503,11 @@ def main() -> None:
|
|||||||
motor_cmd_0 = float(row["motor_command_0"])
|
motor_cmd_0 = float(row["motor_command_0"])
|
||||||
motor_cmd_1 = float(row["motor_command_1"])
|
motor_cmd_1 = float(row["motor_command_1"])
|
||||||
|
|
||||||
|
motor_pos_y_series[i] = motor_pos_y
|
||||||
|
motor_pos_x_series[i] = motor_pos_x
|
||||||
|
motor_cmd_0_series[i] = motor_cmd_0
|
||||||
|
motor_cmd_1_series[i] = motor_cmd_1
|
||||||
|
|
||||||
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
|
qpos[i, 0] = normalize_value(np.array([motor_pos_y], dtype=np.float32), y_min, y_max, args.state_norm)[0]
|
||||||
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
|
qpos[i, 1] = normalize_value(np.array([motor_pos_x], dtype=np.float32), x_min, x_max, args.state_norm)[0]
|
||||||
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
action[i, 0] = normalize_value(np.array([motor_cmd_0], dtype=np.float32), cmd_min, cmd_max, args.action_norm)[0]
|
||||||
@@ -353,6 +521,20 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
instructions.append(ins)
|
instructions.append(ins)
|
||||||
|
|
||||||
|
start_stop_count = 0
|
||||||
|
end_stop_count = 0
|
||||||
|
if not args.disable_stop_override:
|
||||||
|
instructions, start_stop_count, end_stop_count = override_stationary_edge_instructions(
|
||||||
|
instructions=instructions,
|
||||||
|
motor_pos_y=motor_pos_y_series,
|
||||||
|
motor_pos_x=motor_pos_x_series,
|
||||||
|
motor_cmd_0=motor_cmd_0_series,
|
||||||
|
motor_cmd_1=motor_cmd_1_series,
|
||||||
|
stop_instruction=args.stop_instruction,
|
||||||
|
motion_window=args.motion_window,
|
||||||
|
motion_threshold=args.motion_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
text_features = None
|
text_features = None
|
||||||
if args.encode_text_features:
|
if args.encode_text_features:
|
||||||
text_features = extract_text_features(
|
text_features = extract_text_features(
|
||||||
@@ -404,9 +586,26 @@ def main() -> None:
|
|||||||
print(f"Frames used: {num}")
|
print(f"Frames used: {num}")
|
||||||
print(f"Image shape: {images.shape}")
|
print(f"Image shape: {images.shape}")
|
||||||
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
|
print(f"qpos shape: {qpos.shape}, action shape: {action.shape}")
|
||||||
|
if not args.disable_stop_override:
|
||||||
|
print(
|
||||||
|
f"stationary override: head={start_stop_count}, tail={end_stop_count}, "
|
||||||
|
f"mode=endpoint_reference, threshold={args.motion_threshold}, "
|
||||||
|
f"instruction='{args.stop_instruction}'"
|
||||||
|
)
|
||||||
if text_features is not None:
|
if text_features is not None:
|
||||||
print(f"instruction_features_timestep shape: {text_features.shape}")
|
print(f"instruction_features_timestep shape: {text_features.shape}")
|
||||||
|
|
||||||
|
# Save a same-basename plot next to the generated hdf5
|
||||||
|
plot_path = out_path.with_suffix('.png')
|
||||||
|
save_episode_plot_with_stop_segments(
|
||||||
|
qpos=qpos,
|
||||||
|
action=action,
|
||||||
|
instructions=instructions,
|
||||||
|
stop_instruction=args.stop_instruction,
|
||||||
|
plot_path=plot_path,
|
||||||
|
)
|
||||||
|
print(f"Saved episode plot to: {plot_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user