代码可以跑起来了

This commit is contained in:
2026-02-19 15:32:28 +08:00
parent b701d939c2
commit 88d14221ae
11 changed files with 503 additions and 89 deletions

View File

@@ -10,14 +10,13 @@ from einops import rearrange
from constants import DT
from constants import PUPPET_GRIPPER_JOINT_OPEN
from constants import SIM_TASK_CONFIGS, ENDOSCOPE_TASK_CONFIGS
from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
from policy import ACTPolicy, CNNMLPPolicy
from visualize_episodes import save_videos
from sim_env import BOX_POSE
import IPython
e = IPython.embed
@@ -34,25 +33,47 @@ def main(args):
num_epochs = args['num_epochs']
# get task parameters
is_sim = task_name[:4] == 'sim_'
if is_sim:
from constants import SIM_TASK_CONFIGS
is_endoscope = task_name in ENDOSCOPE_TASK_CONFIGS
if is_endoscope:
task_config = ENDOSCOPE_TASK_CONFIGS[task_name]
is_sim = False
elif task_name in SIM_TASK_CONFIGS:
task_config = SIM_TASK_CONFIGS[task_name]
is_sim = True
else:
from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name]
is_sim = False
dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len']
camera_names = task_config['camera_names']
state_dim = task_config.get('state_dim', 14)
action_dim = task_config.get('action_dim', state_dim)
use_text_instruction = task_config.get('use_text_instruction', False)
instruction_mode = task_config.get('instruction_mode', 'timestep-level')
use_cached_text_features = task_config.get('use_cached_text_features', True)
text_encoder_type = task_config.get('text_encoder_type', 'distilbert')
text_feature_dim = task_config.get('text_feature_dim', 768)
text_fusion_type = task_config.get('text_fusion_type', 'concat_transformer_input')
text_max_length = task_config.get('text_max_length', 32)
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
if args.get('text_encoder_type') is not None:
text_encoder_type = args['text_encoder_type']
if args.get('text_max_length') is not None:
text_max_length = args['text_max_length']
if args.get('freeze_text_encoder', False):
freeze_text_encoder = True
# fixed parameters
state_dim = 14
lr_backbone = 1e-5
backbone = 'resnet18'
if policy_class == 'ACT':
enc_layers = 4
dec_layers = 7
enc_layers = 2
dec_layers = 4
nheads = 8
policy_config = {'lr': args['lr'],
'num_queries': args['chunk_size'],
@@ -65,10 +86,25 @@ def main(args):
'dec_layers': dec_layers,
'nheads': nheads,
'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
'text_encoder_type': text_encoder_type,
'text_feature_dim': text_feature_dim,
'text_fusion_type': text_fusion_type,
'freeze_text_encoder': freeze_text_encoder,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_max_length': text_max_length,
'text_tokenizer_name': text_tokenizer_name,
}
elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,}
'camera_names': camera_names,
'state_dim': state_dim,
'action_dim': action_dim,
'use_text': use_text_instruction,
}
else:
raise NotImplementedError
@@ -77,6 +113,7 @@ def main(args):
'ckpt_dir': ckpt_dir,
'episode_len': episode_len,
'state_dim': state_dim,
'action_dim': action_dim,
'lr': args['lr'],
'policy_class': policy_class,
'onscreen_render': onscreen_render,
@@ -85,7 +122,12 @@ def main(args):
'seed': args['seed'],
'temporal_agg': args['temporal_agg'],
'camera_names': camera_names,
'real_robot': not is_sim
'real_robot': (not is_sim) and (not is_endoscope),
'use_text_instruction': use_text_instruction,
'instruction_mode': instruction_mode,
'use_cached_text_features': use_cached_text_features,
'text_tokenizer_name': text_tokenizer_name,
'text_max_length': text_max_length,
}
if is_eval:
@@ -100,7 +142,19 @@ def main(args):
print()
exit()
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)
train_dataloader, val_dataloader, stats, _ = load_data(
dataset_dir,
num_episodes,
camera_names,
batch_size_train,
batch_size_val,
use_text_instruction=use_text_instruction,
instruction_mode=instruction_mode,
use_cached_text_features=use_cached_text_features,
text_feature_dim=text_feature_dim,
text_tokenizer_name=text_tokenizer_name,
text_max_length=text_max_length,
)
# save dataset stats
if not os.path.isdir(ckpt_dir):
@@ -152,6 +206,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
set_seed(1000)
ckpt_dir = config['ckpt_dir']
state_dim = config['state_dim']
action_dim = config['action_dim']
real_robot = config['real_robot']
policy_class = config['policy_class']
onscreen_render = config['onscreen_render']
@@ -161,6 +216,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
task_name = config['task_name']
temporal_agg = config['temporal_agg']
onscreen_cam = 'angle'
BOX_POSE = None
# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
@@ -202,8 +258,14 @@ def eval_bc(config, ckpt_name, save_episode=True):
rollout_id += 0
### set task
if 'sim_transfer_cube' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif 'sim_insertion' in task_name:
if BOX_POSE is None:
from sim_env import BOX_POSE as _BOX_POSE
BOX_POSE = _BOX_POSE
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
ts = env.reset()
@@ -216,7 +278,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
### evaluation loop
if temporal_agg:
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, state_dim]).cuda()
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, action_dim]).cuda()
qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
image_list = [] # for visualization
@@ -314,9 +376,29 @@ def eval_bc(config, ckpt_name, save_episode=True):
def forward_pass(data, policy):
image_data, qpos_data, action_data, is_pad = data
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
image_data, qpos_data, action_data, is_pad, text_input_ids, text_attention_mask, text_feature_data, text_feature_valid = data
image_data = image_data.cuda()
qpos_data = qpos_data.cuda()
action_data = action_data.cuda()
is_pad = is_pad.cuda()
text_input_ids = text_input_ids.cuda()
text_attention_mask = text_attention_mask.cuda()
text_feature_data = text_feature_data.cuda()
text_feature_valid = text_feature_valid.cuda()
text_features = None
if torch.any(text_feature_valid):
text_features = text_feature_data
return policy(
qpos_data,
image_data,
text_input_ids=text_input_ids,
text_attention_mask=text_attention_mask,
text_features=text_features,
actions=action_data,
is_pad=is_pad,
)
def train_bc(train_dataloader, val_dataloader, config):
@@ -431,5 +513,8 @@ if __name__ == '__main__':
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
parser.add_argument('--temporal_agg', action='store_true')
parser.add_argument('--text_encoder_type', action='store', type=str, required=False)
parser.add_argument('--freeze_text_encoder', action='store_true')
parser.add_argument('--text_max_length', action='store', type=int, required=False)
main(vars(parser.parse_args()))