Define the Dataset and Testset
args.model_type = "vqvae"
args.dim = 3
args = compat_args(args)
batch = load_test_batch(args)
batch["img"].shape
args.model_type = "vqvae"
args.dim = 2
args = compat_args(args)
batch = load_test_batch(args)
batch["img"].shape
args.model_type = "diagnosis"
args.dim = 3
args = compat_args(args)
batch = load_test_batch(args)
batch["img"]["axial"].shape