from deeptool.parameters import get_all_args
from deeptool.dataloader import load_test_batch
args = get_all_args()
args.model_type = "rnnvae"
args.batch_size = 1
args.track = False
batch = load_test_batch(args)
batch["img"].shape
batchmod = mod_batch_3d(batch)
batchmod["img"].shape
trans = nn.Sequential()
a = 5
trans(a) == a
args.rnn_transition = "cnn"
x = torch.randn(100, 34)
tran = Transition(args)
print(tran(x).shape)
tran
args.model_type = "rnnvae"
args.dataset_type = "MRNet"
args.rnn_type = "ae"
args.dim = 3
args = compat_args(args)
test_one_batch(args)
args.model_type = "rnnvae"
args.rnn_type = "vae"
args.dim = 3
args = compat_args(args)
test_one_batch(args)
args.model_type = "rnnvae"
args.rnn_type = "introvae"
args.dim = 3
args = compat_args(args)
test_one_batch(args)
device = torch.device(
"cuda:0" if (torch.cuda.is_available() and args.n_gpu > 0) else "cpu"
)
rnn_bigan = RNNBIGAN(device, args)
data = load_test_batch(args)
x, tr = rnn_bigan(data, update=False)
x.shape