Matrix Capsules with EM Routing - Implementation
Matrix Capsules Layers
The core functions are implemented in
/src/capsules/core.py
The initialization operation to connect a regular layer to a matrix capsule layer
capsules_init()
This function constructs a matrix capsule layer (e.g., primaryCaps) from a regular layer.
The convolution operation between matrix capsule layer
capsules_conv()
This function constructs a matrix capsule layer (e.g., ConvCaps1, ConvCaps2) from a matrix capsule layer (e.g., primaryCaps, ConvCaps1).
The fully-connected operation with shared view transformation weight matrix between matrix capsule layer
capsules_fc()
This function constructs an output matrix capsule layer with poses and activations (e.g., Class Capsules) from a matrix capsule layer (ConvCaps2).
The EM routing algorithm
matrix_capsules_em_routing()
This function implements the matrix capsules EM routing algorithm.
Matrix Capsules Nets
The network and loss functions are implemented in
/src/capsule/nets.py
Build a matrix capsules neural network as the same way of building CNN:
def capsules_net(inputs, num_classes, iterations, name='CapsuleEM-V0'):
"""Replicate the network in `Matrix Capsules with EM Routing.`
"""
with tf.variable_scope(name) as scope:
# inputs [N, H, W, C] -> conv2d, 5x5, strides 2, channels 32 -> nets [N, OH, OW, 32]
nets = _conv2d_wrapper(
inputs, shape=[5, 5, 1, 32], strides=[1, 2, 2, 1], padding='SAME', add_bias=True, activation_fn=tf.nn.relu, name='conv1'
)
# inputs [N, H, W, C] -> conv2d, 1x1, strides 1, channels 32x(4x4+1) -> (poses, activations)
nets = capsules_init(
nets, shape=[1, 1, 32, 32], strides=[1, 1, 1, 1], padding='VALID', pose_shape=[4, 4], name='capsule_init'
)
# inputs: (poses, activations) -> capsule-conv 3x3x32x32x4x4, strides 2 -> (poses, activations)
nets = capsules_conv(
nets, shape=[3, 3, 32, 32], strides=[1, 2, 2, 1], iterations=iterations, name='capsule_conv1'
)
# inputs: (poses, activations) -> capsule-conv 3x3x32x32x4x4, strides 1 -> (poses, activations)
nets = capsules_conv(
nets, shape=[3, 3, 32, 32], strides=[1, 1, 1, 1], iterations=iterations, name='capsule_conv2'
)
# inputs: (poses, activations) -> capsule-fc 1x1x32x10x4x4 shared view transform matrix within each channel -> (poses, activations)
nets = capsules_fc(
nets, num_classes, iterations=iterations, name='capsule_fc'
)
poses, activations = nets
return poses, activations
Check out the source code for detailed documentation.
TODO:
add examples
doc should be generated from code.