Capsule Networks (#capsnets)
In my previous article on Handwriting Decoder (#ocr), we touched on how can we read Hand Writing using Computer vision. Though the CNN model driven CV library does a fantastic job at identifying seperated letter, there is a problem of not all types of handwritings being interpreted accurately. During my research on this issue, i found a few advancements in this area. Capsule Networks comes into play and overcomes some of these drawbacks that are presented on CNN.
Capsule Networks are a type of NN architecture, proposed by Geoffrey Hinton and his team in 2017 paper "Dynamic Routing Between Capsules" and caught attention there after in deep learning community. CapsNets are built from "capsules", which are groups of neurons that encode the pose and properties of an object. The idea is that, by representing objects as capsules, the model can better capture the hierarchical structure of objects and their relationships to each other. CNN uses pooling layer architecture, which results in loss of data. CNN requires massive amounts of data to learn, layes in CNN reduce spatial resolution resulting the output not being sensitive to small changes in the inputs. These issues are addressed by using Capsule architecture.
Now lets see what this caps network actually made of:
I will use Tensorflow with a famous dataset that many are already familiar with, MNIST.
import?tensorflow?as?tf
from?tensorflow.examples.tutorials.mnist?import?input_data
mnist?=?input_data.read_data_sets("/tmp/data/")
Step 1: Primary Capsules
I have created primary capsules of dimension 6 X 6 composed of 32 maps.
caps1_n_maps?=?3
caps1_n_caps?=?caps1_n_maps?*?6?*?6??#?1152?primary?capsules
caps1_n_dims?=?8
caps1_raw?=?tf.reshape(conv2,?[-1,?caps1_n_caps,?caps1_n_dims]
???????????????????????name="caps1_raw")
Step 2 : Digit Capsules
Capsules layer contains 10 capsules (one for each digit) of 16 dimensions each
caps2_n_caps = 10
caps2_n_dims = 16
init_sigma = 0.1
W_init = tf.random_normal(
? ? shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims),
? ? stddev=init_sigma, dtype=tf.float32, name="W_init")
W = tf.Variable(W_init, name="W")
batch_size = tf.shape(X)[0]
W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1], name="W_tiled")
caps1_output_expanded = tf.expand_dims(caps1_output, -1,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="caps1_output_expanded")
caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="caps1_output_tile")
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="caps1_output_tiled")
Step 3 : Mask
领英推荐
We must sent only output vector of the capsule that corresponds to the target digit.
mask_with_labels = tf.placeholder_with_default(False, shape=(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="mask_with_labels")
reconstruction_targets = tf.cond(mask_with_labels, # condition
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?lambda: y,? ? ? ? # if True
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?lambda: y_pred,? ?# if False
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="reconstruction_targets")
Step 4 : Decoder
n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28
with tf.name_scope("decoder"):
? ? hidden1 = tf.layers.dense(decoder_input, n_hidden1,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? activation=tf.nn.relu,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? name="hidden1")
? ? hidden2 = tf.layers.dense(hidden1, n_hidden2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? activation=tf.nn.relu,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? name="hidden2")
? ? decoder_output = tf.layers.dense(hidden2, n_output,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?activation=tf.nn.sigmoid,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="decoder_output")
Step 5 : Training
This step is very standard to the other tensorflow training methods.
n_epochs = 10
batch_size = 50
restore_checkpoint = True
n_iterations_per_epoch = mnist.train.num_examples // batch_size
n_iterations_validation = mnist.validation.num_examples // batch_size
best_loss_val = np.infty
checkpoint_path = "./my_capsule_network"
with tf.Session() as sess:
? ? if restore_checkpoint and tf.train.checkpoint_exists(checkpoint_path):
? ? ? ? saver.restore(sess, checkpoint_path)
? ? else:
? ? ? ? init.run()
? ? for epoch in range(n_epochs):
? ? ? ? for iteration in range(1, n_iterations_per_epoch + 1):
? ? ? ? ? ? X_batch, y_batch = mnist.train.next_batch(batch_size)
? ? ? ? ? ? # Run the training operation and measure the loss:
? ? ? ? ? ? _, loss_train = sess.run(
? ? ? ? ? ? ? ? [training_op, loss],
? ? ? ? ? ? ? ? feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
? ? ? ? ? ? ? ? ? ? ? ? ? ?y: y_batch,
? ? ? ? ? ? ? ? ? ? ? ? ? ?mask_with_labels: True})
? ? ? ? ? ? print("\rIteration: {}/{} ({:.1f}%)? Loss: {:.5f}".format(
? ? ? ? ? ? ? ? ? ? ? iteration, n_iterations_per_epoch,
? ? ? ? ? ? ? ? ? ? ? iteration * 100 / n_iterations_per_epoch,
? ? ? ? ? ? ? ? ? ? ? loss_train),
? ? ? ? ? ? ? ? ? end="")
? ? ? ? # At the end of each epoch,
? ? ? ? # measure the validation loss and accuracy:
? ? ? ? loss_vals = []
? ? ? ? acc_vals = []
? ? ? ? for iteration in range(1, n_iterations_validation + 1):
? ? ? ? ? ? X_batch, y_batch = mnist.validation.next_batch(batch_size)
? ? ? ? ? ? loss_val, acc_val = sess.run(
? ? ? ? ? ? ? ? ? ? [loss, accuracy],
? ? ? ? ? ? ? ? ? ? feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?y: y_batch})
? ? ? ? ? ? loss_vals.append(loss_val)
? ? ? ? ? ? acc_vals.append(acc_val)
? ? ? ? ? ? print("\rEvaluating the model: {}/{} ({:.1f}%)".format(
? ? ? ? ? ? ? ? ? ? ? iteration, n_iterations_validation,
? ? ? ? ? ? ? ? ? ? ? iteration * 100 / n_iterations_validation),
? ? ? ? ? ? ? ? ? end=" " * 10)
? ? ? ? loss_val = np.mean(loss_vals)
? ? ? ? acc_val = np.mean(acc_vals)
? ? ? ? print("\rEpoch: {}? Val accuracy: {:.4f}%? Loss: {:.6f}{}".format(
? ? ? ? ? ? epoch + 1, acc_val * 100, loss_val,
? ? ? ? ? ? " (improved)" if loss_val < best_loss_val else ""))
? ? ? ? # And save the model if it improved:
? ? ? ? if loss_val < best_loss_val:
? ? ? ? ? ? save_path = saver.save(sess, checkpoint_path)
? ? ? ? ? ? best_loss_val = loss_val
Step 5: Prediction
n_samples = 5
sample_images = mnist.test.images[:n_samples].reshape([-1, 28, 28, 1])
with tf.Session() as sess:
? ? saver.restore(sess, checkpoint_path)
? ? caps2_output_value, decoder_output_value, y_pred_value = sess.run(
? ? ? ? ? ? [caps2_output, decoder_output, y_pred],
? ? ? ? ? ? feed_dict={X: sample_images,
? ? ? ? ? ? ? ? ? ? ? ?y: np.array([], dtype=np.int64)})
Output :
Not bad for a simple caps net with prediction, If you would like to learn more please refer to below github link.
Reference - "https://github.com/naturomics/CapsLayer"