1. Apply Quantization

class EncQuantDec[source]

EncQuantDec(args) :: Module

Helper Class for the generic generated Network with variable number of Quantization Layers It Contains: Enc <- List of Encoders Dec <- List of Decoders Quant <- List of Quantizations If Required: Cla <- List of Classifiers

q_pre_2d = (0, 2, 3, 1)
q_pos_2d = (0, 3, 1, 2)

quant = Quantize(128, 512)
inp2d = torch.randn(20, 128, 8, 8)
print(inp2d.shape)
inp2d = inp2d.permute(q_pre_2d)
print(inp2d.shape)
output, _, _ = quant(inp2d)
print(output.shape)
output = output.permute(q_pos_2d)
print(output.shape)
torch.Size([20, 128, 8, 8])
torch.Size([20, 8, 8, 128])
torch.Size([20, 8, 8, 128])
torch.Size([20, 128, 8, 8])
..\torch\csrc\utils\python_arg_parser.cpp:756: UserWarning: This overload of add_ is deprecated:
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha)
q_pre_3d = (0, 2, 3, 4, 1)
q_pos_3d = (0, 4, 1, 2, 3)

inp3d = torch.randn(20, 128, 4, 8, 8)
print(inp3d.shape)
inp3d = inp3d.permute(q_pre_3d)
print(inp3d.shape)
output, _, _ = quant(inp3d)
print(output.shape)
output = output.permute(q_pos_3d)
print(output.shape)
torch.Size([20, 128, 4, 8, 8])
torch.Size([20, 4, 8, 8, 128])
torch.Size([20, 4, 8, 8, 128])
torch.Size([20, 128, 4, 8, 8])

2. The Complete VQVAE class

class VQVAE2[source]

VQVAE2(device, args) :: AbsModel

Vector Quantized Variational AutoEncoder based on https://arxiv.org/abs/1906.00446 adapted from https://github.com/rosinality/vq-vae-2-pytorch

args.model_type = "vqvae"
args.dim = 3
args = compat_args(args)
test_one_batch(args)
Model-Type: vqvae
args.dim = 2
args = compat_args(args)
test_one_batch(args)
Model-Type: vqvae