Capsule Networks (#capsnets)
CapsNets Architecure

Capsule Networks (#capsnets)

In my previous article on Handwriting Decoder (#ocr), we touched on how can we read Hand Writing using Computer vision. Though the CNN model driven CV library does a fantastic job at identifying seperated letter, there is a problem of not all types of handwritings being interpreted accurately. During my research on this issue, i found a few advancements in this area. Capsule Networks comes into play and overcomes some of these drawbacks that are presented on CNN.

Capsule Networks are a type of NN architecture, proposed by Geoffrey Hinton and his team in 2017 paper "Dynamic Routing Between Capsules" and caught attention there after in deep learning community. CapsNets are built from "capsules", which are groups of neurons that encode the pose and properties of an object. The idea is that, by representing objects as capsules, the model can better capture the hierarchical structure of objects and their relationships to each other. CNN uses pooling layer architecture, which results in loss of data. CNN requires massive amounts of data to learn, layes in CNN reduce spatial resolution resulting the output not being sensitive to small changes in the inputs. These issues are addressed by using Capsule architecture.

Now lets see what this caps network actually made of:

I will use Tensorflow with a famous dataset that many are already familiar with, MNIST.

import?tensorflow?as?tf
from?tensorflow.examples.tutorials.mnist?import?input_data
mnist?=?input_data.read_data_sets("/tmp/data/")
        
No alt text provided for this image

Step 1: Primary Capsules

I have created primary capsules of dimension 6 X 6 composed of 32 maps.

caps1_n_maps?=?3
caps1_n_caps?=?caps1_n_maps?*?6?*?6??#?1152?primary?capsules
caps1_n_dims?=?8

caps1_raw?=?tf.reshape(conv2,?[-1,?caps1_n_caps,?caps1_n_dims]
    ???????????????????????name="caps1_raw")        

Step 2 : Digit Capsules

Capsules layer contains 10 capsules (one for each digit) of 16 dimensions each

caps2_n_caps = 10
caps2_n_dims = 16
init_sigma = 0.1


W_init = tf.random_normal(
? ? shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims),
? ? stddev=init_sigma, dtype=tf.float32, name="W_init")
W = tf.Variable(W_init, name="W")


batch_size = tf.shape(X)[0]
W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1], name="W_tiled")


caps1_output_expanded = tf.expand_dims(caps1_output, -1,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="caps1_output_expanded")
caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="caps1_output_tile")
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1],
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="caps1_output_tiled")        

Step 3 : Mask

We must sent only output vector of the capsule that corresponds to the target digit.

mask_with_labels = tf.placeholder_with_default(False, shape=(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="mask_with_labels")


reconstruction_targets = tf.cond(mask_with_labels, # condition
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?lambda: y,? ? ? ? # if True
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?lambda: y_pred,? ?# if False
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="reconstruction_targets")        

Step 4 : Decoder

n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28


with tf.name_scope("decoder"):
? ? hidden1 = tf.layers.dense(decoder_input, n_hidden1,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? activation=tf.nn.relu,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? name="hidden1")
? ? hidden2 = tf.layers.dense(hidden1, n_hidden2,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? activation=tf.nn.relu,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? name="hidden2")
? ? decoder_output = tf.layers.dense(hidden2, n_output,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?activation=tf.nn.sigmoid,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?name="decoder_output")        

Step 5 : Training

This step is very standard to the other tensorflow training methods.

n_epochs = 10
batch_size = 50
restore_checkpoint = True


n_iterations_per_epoch = mnist.train.num_examples // batch_size
n_iterations_validation = mnist.validation.num_examples // batch_size
best_loss_val = np.infty
checkpoint_path = "./my_capsule_network"


with tf.Session() as sess:
? ? if restore_checkpoint and tf.train.checkpoint_exists(checkpoint_path):
? ? ? ? saver.restore(sess, checkpoint_path)
? ? else:
? ? ? ? init.run()


? ? for epoch in range(n_epochs):
? ? ? ? for iteration in range(1, n_iterations_per_epoch + 1):
? ? ? ? ? ? X_batch, y_batch = mnist.train.next_batch(batch_size)
? ? ? ? ? ? # Run the training operation and measure the loss:
? ? ? ? ? ? _, loss_train = sess.run(
? ? ? ? ? ? ? ? [training_op, loss],
? ? ? ? ? ? ? ? feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
? ? ? ? ? ? ? ? ? ? ? ? ? ?y: y_batch,
? ? ? ? ? ? ? ? ? ? ? ? ? ?mask_with_labels: True})
? ? ? ? ? ? print("\rIteration: {}/{} ({:.1f}%)? Loss: {:.5f}".format(
? ? ? ? ? ? ? ? ? ? ? iteration, n_iterations_per_epoch,
? ? ? ? ? ? ? ? ? ? ? iteration * 100 / n_iterations_per_epoch,
? ? ? ? ? ? ? ? ? ? ? loss_train),
? ? ? ? ? ? ? ? ? end="")


? ? ? ? # At the end of each epoch,
? ? ? ? # measure the validation loss and accuracy:
? ? ? ? loss_vals = []
? ? ? ? acc_vals = []
? ? ? ? for iteration in range(1, n_iterations_validation + 1):
? ? ? ? ? ? X_batch, y_batch = mnist.validation.next_batch(batch_size)
? ? ? ? ? ? loss_val, acc_val = sess.run(
? ? ? ? ? ? ? ? ? ? [loss, accuracy],
? ? ? ? ? ? ? ? ? ? feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?y: y_batch})
? ? ? ? ? ? loss_vals.append(loss_val)
? ? ? ? ? ? acc_vals.append(acc_val)
? ? ? ? ? ? print("\rEvaluating the model: {}/{} ({:.1f}%)".format(
? ? ? ? ? ? ? ? ? ? ? iteration, n_iterations_validation,
? ? ? ? ? ? ? ? ? ? ? iteration * 100 / n_iterations_validation),
? ? ? ? ? ? ? ? ? end=" " * 10)
? ? ? ? loss_val = np.mean(loss_vals)
? ? ? ? acc_val = np.mean(acc_vals)
? ? ? ? print("\rEpoch: {}? Val accuracy: {:.4f}%? Loss: {:.6f}{}".format(
? ? ? ? ? ? epoch + 1, acc_val * 100, loss_val,
? ? ? ? ? ? " (improved)" if loss_val < best_loss_val else ""))


? ? ? ? # And save the model if it improved:
? ? ? ? if loss_val < best_loss_val:
? ? ? ? ? ? save_path = saver.save(sess, checkpoint_path)
? ? ? ? ? ? best_loss_val = loss_val        

Step 5: Prediction

n_samples = 5


sample_images = mnist.test.images[:n_samples].reshape([-1, 28, 28, 1])


with tf.Session() as sess:
? ? saver.restore(sess, checkpoint_path)
? ? caps2_output_value, decoder_output_value, y_pred_value = sess.run(
? ? ? ? ? ? [caps2_output, decoder_output, y_pred],
? ? ? ? ? ? feed_dict={X: sample_images,
? ? ? ? ? ? ? ? ? ? ? ?y: np.array([], dtype=np.int64)})        

Output :

No alt text provided for this image

Not bad for a simple caps net with prediction, If you would like to learn more please refer to below github link.

Reference - "https://github.com/naturomics/CapsLayer"

要查看或添加评论,请登录

Eeswar C.的更多文章

  • In-Context Learning

    In-Context Learning

    Have you ever encountered instances where ChatGPT repeatedly provides similar responses to your queries, or where its…

    1 条评论
  • Retrieval Augumented Generation

    Retrieval Augumented Generation

    Anyone within the industry who has utilized ChatGPT for business purposes would likely have had the thought, "This is…

  • Diffusion Model - Gen AI

    Diffusion Model - Gen AI

    Diffusion models have gained attention for their ability to handle various tasks, particularly in the domains of image…

  • Anomaly Detection with VAE

    Anomaly Detection with VAE

    Anomaly detection is a machine learning technique used to identify patterns that are considered unusual or out of the…

  • Neural Network

    Neural Network

    In this article I am going back to the basics, Neural Networks! Most of the readers must have seen the picture above…

  • BERT - Who?

    BERT - Who?

    BERT - Bidirectional Encoder Representations from Transformers, isn’t that a tongue twister! 5 years ago, google…

  • How Does my Iphone know its me?

    How Does my Iphone know its me?

    Ever wondered how does iPhone know its you and never mistakes someone else for you when using Face Detection? Drum Roll…

    1 条评论
  • Natural Language Data Search

    Natural Language Data Search

    Remember how search was tedious a decade ago! Today you can search and ask questions in any search engine as you would…

  • Machine Learning & Data Privacy

    Machine Learning & Data Privacy

    Every person i know fears about how their personal data is at risk by all the AI/ML that is surrounding them, whether…

  • Business at center of Data Science

    Business at center of Data Science

    Any one who has participated in brainstroming & whiteboarding sessions would agree that, what data scientists think of…

社区洞察

其他会员也浏览了