Getting Started with TensorFlow.js

Getting Started with TensorFlow.js

TensorFlow.js is an open source WebGL-accelerated JavaScript library for machine intelligence. It brings highly performant machine learning building blocks to your fingertips, allowing you to train neural networks in a browser or run pre-trained models in inference mode. See Getting Started for a guide on installing/configuring TensorFlow.js.

TensorFlow.js provides low-level building blocks for machine learning as well as a high-level, Keras-inspired API for constructing neural networks. Let's take a look at some of the core components of the library.

With TensorFlow.js, you can not only run machine-learned models in the browser to perform inference, you can also train them. In this super-simple tutorial, I’ll show you a basic ‘Hello World’ example that will teach you the scaffolding to get you up and running.

Let’s start with the simplest Web Page imaginable:

<html>
<head></head>
<body></body>
</html>

Once you have that, the first thing you’ll need to do is add a reference to TensorFlow.js, so that you can use the TensorFlow APIs. The JS file is available on a CDN for your convenience:

<html>
<head>
 <!-- Load TensorFlow.js -->
 <!-- Get latest version at https://github.com/tensorflow/tfjs -->
 <script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"> </script>

Right now I’m using version 0.11.2, but be sure to check GitHub for the most recent version.

Now that TensorFlow.js is loaded, let’s do something interesting with it.

Consider a straight line with the formula Y=2X-1. This will give you a set of points like (-1, -3), (0, -1), (1, 1), (2, 3), (3, 5) and (4, 7). While we know that the formula gives us the Y value for a given X, it’s a nice exercise in training a model for a computer that is not explicitly programmed with this formula, to see if it can infer values of Y for given values of X when trained on this data.

So how would this work?

Well, first of all, we can create a super-simple neural network to do the inference. As there’s only 1 input value, and 1 output value, it can be a single node. In JavaScript, I can then create a tf.sequential, and add my layer definition to it. It can’t get any more basic than this:

const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));

To finish defining my model, I compile it, specifying my loss type and optimizer. I’ll pick the most basic loss type — the meanSquaredError, and my optimizer will be a standard stochastic gradient descent (aka ‘sgd’):

model.compile({
   loss: 'meanSquaredError',
   optimizer: 'sgd'
  });

To train the model, I’ll need a tensor with my input (i.e. ‘X’) values, and another with my output (i.e. ‘Y’) values. With TensorFlow, I also need to define the shape of that given tensor:

const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);

So, my Xs are the values -1,0,1,2,3 and 4, defined as a 6×1 tensor. My Ys are -3, -1, 1, 3, 5, 7 in the same shape. Note that the nth Y entry is the value for the nth X entry when we say that Y=2X-1.

To train the model we use the fit method. To this we pass the set of X and Y values, along with a number of epochs (loops through the data) in which we will train it. Note that this is asynchronous, so we should await the return value before proceeding, so all this code needs to be in an async function (more on that later):

await model.fit(xs, ys, {epochs: 500});

Once that’s done, the model is trained, so we can predict a value for a new X. So, for example, if we wanted to figure out the Y for X=10 and write it on the page in a <div>, the code would look like this:

document.getElementById('output_field').innerText =
   model.predict(tf.tensor2d([10], [1, 1]));

Note that the input is a tensor, where we specify that it’s a 1×1 tensor containing the value 10.

The result is written on the page in the div, and should look something like this:

Wait, you might ask — why isn’t it 19? It’s pretty close, but it’s not 19! That’s because the algorithm has never been given the formula — it simply learns based on the data it was given. With more relevant data any ML model will give greater accuracy, but this one isn’t bad considering it only had 6 pieces of data to learn from!

For your convenience, here’s the entire code for the page, including the declaraion of all this code as an async function called ‘learnLinear’:

<html>
 <head>
 <!-- Load TensorFlow.js -->
 <!-- Get latest version at https://github.com/tensorflow/tfjs -->
 <script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]">   
 </script>
 </head>
 <body>
   <div id="output_field"></div>
 </body>
 <script>
 async function learnLinear(){
  const model = tf.sequential();
  model.add(tf.layers.dense({units: 1, inputShape: [1]}));
  model.compile({
   loss: 'meanSquaredError',
   optimizer: 'sgd'
  });
  
  const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
  const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
  
  await model.fit(xs, ys, {epochs: 500});
  
  document.getElementById('output_field').innerText =
   model.predict(tf.tensor2d([10], [1, 1]));
 }
 learnLinear();
 </script>
<html>

And that’s all it takes to create a very simple Machine Learned model with TensorFlow.js that executes in your browser. From here you have the foundation to go forward with more advanced concepts.

Have fun with it!

要查看或添加评论,请登录

Vrijraj Singh的更多文章

  • Google’s Women Techmakers Empowers Women Globally

    Google’s Women Techmakers Empowers Women Globally

    Driven by the belief that a diversity of perspectives leads to better decision-making and more relevant products…

  • Image classification using TensorFlow.js

    Image classification using TensorFlow.js

    For Image Classification like label detection we're going to use MobileNet Model which is define in TensorFlow.js.

  • Pose Detection in the Browser: PoseNet Model Using TensorFlow.js

    Pose Detection in the Browser: PoseNet Model Using TensorFlow.js

    TensorFlow.js, an open-source library you can use to define, train, and run machine learning models entirely in the…

  • How to Become an Excellent Interaction Designer?

    How to Become an Excellent Interaction Designer?

    In recent years, with the rise of the Internet industry, designer quickly became a popular profession. The thriving…

  • IoT: The Internet of (Every)thing?

    IoT: The Internet of (Every)thing?

    Gone are the days when I had to get up out of my bed to turn the lights on. Now every morning my lights slowly turn on…

  • What is Udacity Nanodegrees Program?

    What is Udacity Nanodegrees Program?

    Udacity offers some great courses in their Nanodegree programs. If you don’t have the money, and/or just don’t want to…

    1 条评论
  • What is AMP and Who Actually Needs It?

    What is AMP and Who Actually Needs It?

    Accelerated Mobile Pages (AMP) is set to roll out within Google’s mobile search results in February 2016. Here we…

  • What is Microsoft Azure & Why Use It?

    What is Microsoft Azure & Why Use It?

    An In-Depth Look at the Cloud Services Platform from Microsoft Today plenty of businesses still have real concerns…

  • Google Developers :: Solve For India

    Google Developers :: Solve For India

    Google Developers :: Solve for India initiative aims to nurture developers, tech entrepreneurs and reach, support…

  • Google IO Extended 2017, GDG-Jalandhar

    Google IO Extended 2017, GDG-Jalandhar

    Google I/O is an annual developer festival held at the outdoor Shoreline Amphitheatre. This year's festival built lots…

社区洞察

其他会员也浏览了