Define the general Training Loop

get_model[source]

get_model(device, args)

return the required model depending on the arguments:

get_dataset[source]

get_dataset(args)

return the required datasets and dataloaders depending on the dataset

args = get_all_args()
args.track = False
args.model_type = "bigan"
device = torch.device(
    "cuda:0" if (torch.cuda.is_available() and args.n_gpu > 0) else "cpu"
)
model = get_model(device, args)
Model-Type: bigan

test_one_batch[source]

test_one_batch(args)

Useful functionality to test a new model using a demo databatch and check compatibility

batch_info[source]

batch_info(test_x, args)

display some relevant infos about the dataset format and statistics

main_loop[source]

main_loop(args, tq_nb=True)

Perform the Training using the predefined arguments

evaluate_model[source]

evaluate_model(args, model, batch_count, epoch, test_data, valid_loader)

Evaluate the model every x iterations