q_pre_2d = (0, 2, 3, 1)
q_pos_2d = (0, 3, 1, 2)
quant = Quantize(128, 512)
inp2d = torch.randn(20, 128, 8, 8)
print(inp2d.shape)
inp2d = inp2d.permute(q_pre_2d)
print(inp2d.shape)
output, _, _ = quant(inp2d)
print(output.shape)
output = output.permute(q_pos_2d)
print(output.shape)
q_pre_3d = (0, 2, 3, 4, 1)
q_pos_3d = (0, 4, 1, 2, 3)
inp3d = torch.randn(20, 128, 4, 8, 8)
print(inp3d.shape)
inp3d = inp3d.permute(q_pre_3d)
print(inp3d.shape)
output, _, _ = quant(inp3d)
print(output.shape)
output = output.permute(q_pos_3d)
print(output.shape)
args.model_type = "vqvae"
args.dim = 3
args = compat_args(args)
test_one_batch(args)
args.dim = 2
args = compat_args(args)
test_one_batch(args)