Matrix Capsules with EM Routing - Implementation

Matrix Capsules Layers

The core functions are implemented in

/src/capsules/core.py
  • The initialization operation to connect a regular layer to a matrix capsule layercapsules_init()

    This function constructs a matrix capsule layer (e.g., primaryCaps) from a regular layer.

  • The convolution operation between matrix capsule layercapsules_conv()

    This function constructs a matrix capsule layer (e.g., ConvCaps1, ConvCaps2) from a matrix capsule layer (e.g., primaryCaps, ConvCaps1).

  • The fully-connected operation with shared view transformation weight matrix between matrix capsule layer capsules_fc()

    This function constructs an output matrix capsule layer with poses and activations (e.g., Class Capsules) from a matrix capsule layer (ConvCaps2).

  • The EM routing algorithm matrix_capsules_em_routing()

    This function implements the matrix capsules EM routing algorithm.

Matrix Capsules Nets

The network and loss functions are implemented in

/src/capsule/nets.py

Build a matrix capsules neural network as the same way of building CNN:

def capsules_net(inputs, num_classes, iterations, name='CapsuleEM-V0'):
  """Replicate the network in `Matrix Capsules with EM Routing.`
  """

  with tf.variable_scope(name) as scope:

    # inputs [N, H, W, C] -> conv2d, 5x5, strides 2, channels 32 -> nets [N, OH, OW, 32]
    nets = _conv2d_wrapper(
      inputs, shape=[5, 5, 1, 32], strides=[1, 2, 2, 1], padding='SAME', add_bias=True, activation_fn=tf.nn.relu, name='conv1'
    )
    # inputs [N, H, W, C] -> conv2d, 1x1, strides 1, channels 32x(4x4+1) -> (poses, activations)
    nets = capsules_init(
      nets, shape=[1, 1, 32, 32], strides=[1, 1, 1, 1], padding='VALID', pose_shape=[4, 4], name='capsule_init'
    )
    # inputs: (poses, activations) -> capsule-conv 3x3x32x32x4x4, strides 2 -> (poses, activations)
    nets = capsules_conv(
      nets, shape=[3, 3, 32, 32], strides=[1, 2, 2, 1], iterations=iterations, name='capsule_conv1'
    )
    # inputs: (poses, activations) -> capsule-conv 3x3x32x32x4x4, strides 1 -> (poses, activations)
    nets = capsules_conv(
      nets, shape=[3, 3, 32, 32], strides=[1, 1, 1, 1], iterations=iterations, name='capsule_conv2'
    )
    # inputs: (poses, activations) -> capsule-fc 1x1x32x10x4x4 shared view transform matrix within each channel -> (poses, activations)
    nets = capsules_fc(
      nets, num_classes, iterations=iterations, name='capsule_fc'
    )

    poses, activations = nets

  return poses, activations

Check out the source code for detailed documentation.

TODO:

  • add examples

  • doc should be generated from code.

results matching ""

    No results matching ""