follow用policy_last
This commit is contained in:
33
constants.py
33
constants.py
@@ -40,6 +40,7 @@ ENDOSCOPE_TASK_CONFIGS = {
|
|||||||
'camera_names': ['top'],
|
'camera_names': ['top'],
|
||||||
'state_dim': 2,
|
'state_dim': 2,
|
||||||
'action_dim': 2,
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
'use_text_instruction': True,
|
'use_text_instruction': True,
|
||||||
'instruction_mode': 'timestep-level',
|
'instruction_mode': 'timestep-level',
|
||||||
'use_cached_text_features': True,
|
'use_cached_text_features': True,
|
||||||
@@ -57,6 +58,7 @@ ENDOSCOPE_TASK_CONFIGS = {
|
|||||||
'camera_names': ['top'],
|
'camera_names': ['top'],
|
||||||
'state_dim': 2,
|
'state_dim': 2,
|
||||||
'action_dim': 2,
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
'use_text_instruction': True,
|
'use_text_instruction': True,
|
||||||
'instruction_mode': 'timestep-level',
|
'instruction_mode': 'timestep-level',
|
||||||
'use_cached_text_features': True,
|
'use_cached_text_features': True,
|
||||||
@@ -74,6 +76,37 @@ ENDOSCOPE_TASK_CONFIGS = {
|
|||||||
'camera_names': ['top'],
|
'camera_names': ['top'],
|
||||||
'state_dim': 2,
|
'state_dim': 2,
|
||||||
'action_dim': 2,
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': False,
|
||||||
|
},
|
||||||
|
'endoscope_sanity_check': {
|
||||||
|
'dataset_dir': DATA_DIR + '/sanity-check',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': False,
|
||||||
|
},
|
||||||
|
'endoscope_cannulation_no_text': {
|
||||||
|
'dataset_dir': DATA_DIR + '/cannulation-no-text',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
|
'use_text_instruction': False,
|
||||||
|
},
|
||||||
|
'endoscope_follow_no_text': {
|
||||||
|
'dataset_dir': DATA_DIR + '/follow-no-text',
|
||||||
|
'num_episodes': 3,
|
||||||
|
'episode_len': 400,
|
||||||
|
'camera_names': ['top'],
|
||||||
|
'state_dim': 2,
|
||||||
|
'action_dim': 2,
|
||||||
|
'real_action_t_minus_1': False,
|
||||||
'use_text_instruction': False,
|
'use_text_instruction': False,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def get_args_parser():
|
|||||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||||
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
||||||
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
||||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||||
parser.add_argument('--temporal_agg', action='store_true')
|
parser.add_argument('--temporal_agg', action='store_true')
|
||||||
parser.add_argument('--image_aug', action='store_true')
|
parser.add_argument('--image_aug', action='store_true')
|
||||||
|
|||||||
@@ -20,6 +20,15 @@ from visualize_episodes import save_videos
|
|||||||
import IPython
|
import IPython
|
||||||
e = IPython.embed
|
e = IPython.embed
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_state_dict(ckpt_path):
|
||||||
|
"""Load checkpoint state_dict safely across different torch versions."""
|
||||||
|
try:
|
||||||
|
return torch.load(ckpt_path, map_location='cpu', weights_only=True)
|
||||||
|
except TypeError:
|
||||||
|
# For older PyTorch versions that do not support `weights_only`.
|
||||||
|
return torch.load(ckpt_path, map_location='cpu')
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
# command line parameters
|
# command line parameters
|
||||||
@@ -60,6 +69,7 @@ def main(args):
|
|||||||
text_max_length = task_config.get('text_max_length', 32)
|
text_max_length = task_config.get('text_max_length', 32)
|
||||||
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
text_tokenizer_name = task_config.get('text_tokenizer_name', 'distilbert-base-uncased')
|
||||||
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
freeze_text_encoder = task_config.get('freeze_text_encoder', True)
|
||||||
|
real_action_t_minus_1 = task_config.get('real_action_t_minus_1', True)
|
||||||
|
|
||||||
if args.get('text_encoder_type') is not None:
|
if args.get('text_encoder_type') is not None:
|
||||||
text_encoder_type = args['text_encoder_type']
|
text_encoder_type = args['text_encoder_type']
|
||||||
@@ -67,6 +77,8 @@ def main(args):
|
|||||||
text_max_length = args['text_max_length']
|
text_max_length = args['text_max_length']
|
||||||
if args.get('freeze_text_encoder', False):
|
if args.get('freeze_text_encoder', False):
|
||||||
freeze_text_encoder = True
|
freeze_text_encoder = True
|
||||||
|
if args.get('disable_real_action_shift', False):
|
||||||
|
real_action_t_minus_1 = False
|
||||||
|
|
||||||
# fixed parameters
|
# fixed parameters
|
||||||
lr_backbone = 1e-5
|
lr_backbone = 1e-5
|
||||||
@@ -110,6 +122,8 @@ def main(args):
|
|||||||
|
|
||||||
config = {
|
config = {
|
||||||
'num_epochs': num_epochs,
|
'num_epochs': num_epochs,
|
||||||
|
'train_steps_per_epoch': args.get('train_steps_per_epoch', None),
|
||||||
|
'resume_ckpt_path': args.get('resume_ckpt', None),
|
||||||
'ckpt_dir': ckpt_dir,
|
'ckpt_dir': ckpt_dir,
|
||||||
'episode_len': episode_len,
|
'episode_len': episode_len,
|
||||||
'state_dim': state_dim,
|
'state_dim': state_dim,
|
||||||
@@ -131,6 +145,16 @@ def main(args):
|
|||||||
'debug_input': args.get('debug_input', False),
|
'debug_input': args.get('debug_input', False),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config['resume_ckpt_path']:
|
||||||
|
resume_ckpt_path = config['resume_ckpt_path']
|
||||||
|
if not os.path.isabs(resume_ckpt_path):
|
||||||
|
candidate_path = os.path.join(ckpt_dir, resume_ckpt_path)
|
||||||
|
if os.path.isfile(candidate_path):
|
||||||
|
resume_ckpt_path = candidate_path
|
||||||
|
if not os.path.isfile(resume_ckpt_path):
|
||||||
|
raise FileNotFoundError(f'--resume_ckpt not found: {config["resume_ckpt_path"]}')
|
||||||
|
config['resume_ckpt_path'] = resume_ckpt_path
|
||||||
|
|
||||||
if is_eval:
|
if is_eval:
|
||||||
ckpt_names = [f'policy_best.ckpt']
|
ckpt_names = [f'policy_best.ckpt']
|
||||||
results = []
|
results = []
|
||||||
@@ -155,6 +179,7 @@ def main(args):
|
|||||||
text_feature_dim=text_feature_dim,
|
text_feature_dim=text_feature_dim,
|
||||||
text_tokenizer_name=text_tokenizer_name,
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
text_max_length=text_max_length,
|
text_max_length=text_max_length,
|
||||||
|
real_action_t_minus_1=real_action_t_minus_1,
|
||||||
image_augment=args['image_aug'],
|
image_augment=args['image_aug'],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +248,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
|
|||||||
# load policy and stats
|
# load policy and stats
|
||||||
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
|
||||||
policy = make_policy(policy_class, policy_config)
|
policy = make_policy(policy_class, policy_config)
|
||||||
loading_status = policy.load_state_dict(torch.load(ckpt_path))
|
loading_status = policy.load_state_dict(load_checkpoint_state_dict(ckpt_path))
|
||||||
print(loading_status)
|
print(loading_status)
|
||||||
policy.cuda()
|
policy.cuda()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
@@ -425,6 +450,8 @@ def forward_pass(data, policy, debug_input=False, debug_tag=''):
|
|||||||
|
|
||||||
def train_bc(train_dataloader, val_dataloader, config):
|
def train_bc(train_dataloader, val_dataloader, config):
|
||||||
num_epochs = config['num_epochs']
|
num_epochs = config['num_epochs']
|
||||||
|
train_steps_per_epoch = config.get('train_steps_per_epoch', None)
|
||||||
|
resume_ckpt_path = config.get('resume_ckpt_path', None)
|
||||||
ckpt_dir = config['ckpt_dir']
|
ckpt_dir = config['ckpt_dir']
|
||||||
seed = config['seed']
|
seed = config['seed']
|
||||||
policy_class = config['policy_class']
|
policy_class = config['policy_class']
|
||||||
@@ -435,6 +462,12 @@ def train_bc(train_dataloader, val_dataloader, config):
|
|||||||
|
|
||||||
policy = make_policy(policy_class, policy_config)
|
policy = make_policy(policy_class, policy_config)
|
||||||
policy.cuda()
|
policy.cuda()
|
||||||
|
|
||||||
|
if resume_ckpt_path:
|
||||||
|
loading_status = policy.load_state_dict(load_checkpoint_state_dict(resume_ckpt_path))
|
||||||
|
print(f'Loaded finetune init ckpt: {resume_ckpt_path}')
|
||||||
|
print(loading_status)
|
||||||
|
|
||||||
optimizer = make_optimizer(policy_class, policy)
|
optimizer = make_optimizer(policy_class, policy)
|
||||||
|
|
||||||
train_history = []
|
train_history = []
|
||||||
@@ -467,8 +500,22 @@ def train_bc(train_dataloader, val_dataloader, config):
|
|||||||
# training
|
# training
|
||||||
policy.train()
|
policy.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
for batch_idx, data in enumerate(train_dataloader):
|
epoch_train_dicts = []
|
||||||
should_debug = debug_input and epoch == 0 and batch_idx == 0
|
if train_steps_per_epoch is None or train_steps_per_epoch <= 0:
|
||||||
|
train_steps_this_epoch = len(train_dataloader)
|
||||||
|
train_iterator = iter(train_dataloader)
|
||||||
|
else:
|
||||||
|
train_steps_this_epoch = int(train_steps_per_epoch)
|
||||||
|
train_iterator = iter(train_dataloader)
|
||||||
|
|
||||||
|
for step_idx in range(train_steps_this_epoch):
|
||||||
|
try:
|
||||||
|
data = next(train_iterator)
|
||||||
|
except StopIteration:
|
||||||
|
train_iterator = iter(train_dataloader)
|
||||||
|
data = next(train_iterator)
|
||||||
|
|
||||||
|
should_debug = debug_input and epoch == 0 and step_idx == 0
|
||||||
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
forward_dict = forward_pass(data, policy, debug_input=should_debug, debug_tag='train/epoch0/batch0')
|
||||||
# backward
|
# backward
|
||||||
loss = forward_dict['loss']
|
loss = forward_dict['loss']
|
||||||
@@ -476,7 +523,8 @@ def train_bc(train_dataloader, val_dataloader, config):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
train_history.append(detach_dict(forward_dict))
|
train_history.append(detach_dict(forward_dict))
|
||||||
epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
|
epoch_train_dicts.append(detach_dict(forward_dict))
|
||||||
|
epoch_summary = compute_dict_mean(epoch_train_dicts)
|
||||||
epoch_train_loss = epoch_summary['loss']
|
epoch_train_loss = epoch_summary['loss']
|
||||||
print(f'Train loss: {epoch_train_loss:.5f}')
|
print(f'Train loss: {epoch_train_loss:.5f}')
|
||||||
summary_string = ''
|
summary_string = ''
|
||||||
@@ -533,7 +581,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
||||||
|
|
||||||
# for ACT
|
# for ACT
|
||||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
parser.add_argument('--kl_weight', action='store', type=float, help='KL Weight', required=False)
|
||||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
||||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
||||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
||||||
@@ -543,6 +591,12 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
parser.add_argument('--text_max_length', action='store', type=int, required=False)
|
||||||
parser.add_argument('--image_aug', action='store_true',
|
parser.add_argument('--image_aug', action='store_true',
|
||||||
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
help='Enable training-time image augmentation (color/highlight/noise/blur)')
|
||||||
|
parser.add_argument('--train_steps_per_epoch', action='store', type=int, required=False,
|
||||||
|
help='If set > 0, run a fixed number of optimizer steps per epoch by cycling over the train dataloader')
|
||||||
|
parser.add_argument('--disable_real_action_shift', action='store_true',
|
||||||
|
help='Disable real-data action alignment shift (use action[start_ts:] instead of action[start_ts-1:])')
|
||||||
|
parser.add_argument('--resume_ckpt', action='store', type=str, required=False,
|
||||||
|
help='Optional checkpoint path to initialize model weights for fine-tuning')
|
||||||
parser.add_argument('--debug_input', action='store_true',
|
parser.add_argument('--debug_input', action='store_true',
|
||||||
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
help='Print one-batch input sanity checks (shape/range/nan) for val/train at epoch 0')
|
||||||
|
|
||||||
|
|||||||
46
utils.py
46
utils.py
@@ -17,6 +17,7 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
text_feature_dim=768,
|
text_feature_dim=768,
|
||||||
text_tokenizer_name='distilbert-base-uncased',
|
text_tokenizer_name='distilbert-base-uncased',
|
||||||
text_max_length=32,
|
text_max_length=32,
|
||||||
|
real_action_t_minus_1=True,
|
||||||
image_augment=False,
|
image_augment=False,
|
||||||
image_aug_cfg=None):
|
image_aug_cfg=None):
|
||||||
super(EpisodicDataset).__init__()
|
super(EpisodicDataset).__init__()
|
||||||
@@ -29,20 +30,21 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
self.use_cached_text_features = use_cached_text_features
|
self.use_cached_text_features = use_cached_text_features
|
||||||
self.text_feature_dim = text_feature_dim
|
self.text_feature_dim = text_feature_dim
|
||||||
self.text_max_length = text_max_length
|
self.text_max_length = text_max_length
|
||||||
|
self.real_action_t_minus_1 = real_action_t_minus_1
|
||||||
self.image_augment = image_augment
|
self.image_augment = image_augment
|
||||||
self.image_aug_cfg = {
|
self.image_aug_cfg = {
|
||||||
'p_color': 0.8,
|
'p_color': 0.4,
|
||||||
'p_highlight': 0.5,
|
'p_highlight': 0.3,
|
||||||
'p_noise': 0.5,
|
'p_noise': 0.35,
|
||||||
'p_blur': 0.3,
|
'p_blur': 0.15,
|
||||||
'brightness': 0.25,
|
'brightness': 0.12,
|
||||||
'contrast': 0.25,
|
'contrast': 0.12,
|
||||||
'saturation': 0.25,
|
'saturation': 0.12,
|
||||||
'hue': 0.08,
|
'hue': 0.03,
|
||||||
'highlight_strength': (0.15, 0.5),
|
'highlight_strength': (0.08, 0.25),
|
||||||
'noise_std': (0.005, 0.03),
|
'noise_std': (0.003, 0.015),
|
||||||
'blur_sigma': (0.1, 1.5),
|
'blur_sigma': (0.1, 0.8),
|
||||||
'blur_kernel_choices': (3, 5),
|
'blur_kernel_choices': (3, ),
|
||||||
}
|
}
|
||||||
if image_aug_cfg is not None:
|
if image_aug_cfg is not None:
|
||||||
self.image_aug_cfg.update(image_aug_cfg)
|
self.image_aug_cfg.update(image_aug_cfg)
|
||||||
@@ -221,8 +223,9 @@ class EpisodicDataset(torch.utils.data.Dataset):
|
|||||||
action = root['/action'][start_ts:]
|
action = root['/action'][start_ts:]
|
||||||
action_len = episode_len - start_ts
|
action_len = episode_len - start_ts
|
||||||
else:
|
else:
|
||||||
action = root['/action'][max(0, start_ts - 1):] # hack, to make timesteps more aligned
|
action_start = max(0, start_ts - 1) if self.real_action_t_minus_1 else start_ts
|
||||||
action_len = episode_len - max(0, start_ts - 1) # hack, to make timesteps more aligned
|
action = root['/action'][action_start:]
|
||||||
|
action_len = episode_len - action_start
|
||||||
|
|
||||||
self.is_sim = is_sim
|
self.is_sim = is_sim
|
||||||
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
|
padded_action = np.zeros((self.max_episode_len, self.action_dim), dtype=np.float32)
|
||||||
@@ -334,19 +337,26 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
|||||||
text_feature_dim=768,
|
text_feature_dim=768,
|
||||||
text_tokenizer_name='distilbert-base-uncased',
|
text_tokenizer_name='distilbert-base-uncased',
|
||||||
text_max_length=32,
|
text_max_length=32,
|
||||||
|
real_action_t_minus_1=True,
|
||||||
image_augment=False,
|
image_augment=False,
|
||||||
image_aug_cfg=None):
|
image_aug_cfg=None):
|
||||||
print(f'\nData from: {dataset_dir}\n')
|
print(f'\nData from: {dataset_dir}\n')
|
||||||
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
episode_ids = _discover_episode_ids(dataset_dir, num_episodes)
|
||||||
if len(episode_ids) == 0:
|
if len(episode_ids) == 0:
|
||||||
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
|
raise FileNotFoundError(f'No episode_*.hdf5 found in {dataset_dir}')
|
||||||
if len(episode_ids) < 2:
|
|
||||||
raise ValueError(f'Need at least 2 episodes for train/val split, found {len(episode_ids)} in {dataset_dir}')
|
|
||||||
|
|
||||||
# obtain train test split
|
# obtain train/val split
|
||||||
|
if len(episode_ids) == 1:
|
||||||
|
# sanity-check mode: reuse the same episode for both train and val
|
||||||
|
# so training/evaluation loops remain unchanged.
|
||||||
|
train_episode_ids = np.array(episode_ids)
|
||||||
|
val_episode_ids = np.array(episode_ids)
|
||||||
|
print('[load_data] Only 1 episode found. Reusing the same episode for both train and val (sanity-check mode).')
|
||||||
|
else:
|
||||||
train_ratio = 0.9
|
train_ratio = 0.9
|
||||||
shuffled_indices = np.random.permutation(len(episode_ids))
|
shuffled_indices = np.random.permutation(len(episode_ids))
|
||||||
train_count = int(train_ratio * len(episode_ids))
|
train_count = int(train_ratio * len(episode_ids))
|
||||||
|
train_count = max(1, min(len(episode_ids) - 1, train_count))
|
||||||
train_indices = shuffled_indices[:train_count]
|
train_indices = shuffled_indices[:train_count]
|
||||||
val_indices = shuffled_indices[train_count:]
|
val_indices = shuffled_indices[train_count:]
|
||||||
train_episode_ids = np.array(episode_ids)[train_indices]
|
train_episode_ids = np.array(episode_ids)[train_indices]
|
||||||
@@ -367,6 +377,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
|||||||
text_feature_dim=text_feature_dim,
|
text_feature_dim=text_feature_dim,
|
||||||
text_tokenizer_name=text_tokenizer_name,
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
text_max_length=text_max_length,
|
text_max_length=text_max_length,
|
||||||
|
real_action_t_minus_1=real_action_t_minus_1,
|
||||||
image_augment=image_augment,
|
image_augment=image_augment,
|
||||||
image_aug_cfg=image_aug_cfg,
|
image_aug_cfg=image_aug_cfg,
|
||||||
)
|
)
|
||||||
@@ -381,6 +392,7 @@ def load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_s
|
|||||||
text_feature_dim=text_feature_dim,
|
text_feature_dim=text_feature_dim,
|
||||||
text_tokenizer_name=text_tokenizer_name,
|
text_tokenizer_name=text_tokenizer_name,
|
||||||
text_max_length=text_max_length,
|
text_max_length=text_max_length,
|
||||||
|
real_action_t_minus_1=real_action_t_minus_1,
|
||||||
image_augment=False,
|
image_augment=False,
|
||||||
)
|
)
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=True, num_workers=1, prefetch_factor=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user