Run Machine Learning models in your browser with TensorFlow.js

By Omri Grossman

Image for post
Image for post

TensorFlow.js (or, in short, tfjs) is a library that lets you create, train, and use trained Machine Learning models in Javascript!

The main focus is to let Javascript Developers enter the Machine Learning & Deep Learning world by creating cool and intelligent web applications that can run on any major browser or Node.js servers using Javascript.

Introduction

TensorFlow.js can run almost everywhere, all major browsers, servers, mobile phones, and even IoT devices. This demonstrates how huge potential this library has. TensorFlow.js backend can run on the device GPU through WebGLAPI, which gives Javascript code to run on the GPU, which means that TensorFlow.js can have excellent performance even though it runs on the browser.

After reading this post, you will:

  • Learn about TensorFlow.js and the ways you can use it.
  • Know how to load Machine Learning models into your Javascript project and start using it.
  • Gain the skills to create such a project by yourself
  • And finally, gain more knowledge about Machine Learning.

So, how does it work?

There are several options that we can choose from:

1. Run existing models:

TensorFlow.js provided

us few attractive pre-trained models that we can import into our project, provide input, and use the output to our requirements, here you can explore the models they are providing for us: TensorFlow.js Models, and they keep adding more models as time goes by.

In addition to that, you can find many attractive pre-trained models developed by the TensorFlow.js community all across the web.

2. Retrain existing models:

This option allows us to improve an existing model for our specific use-case. We can achieve that by using a method called: Transfer Learning.

Transfer learning is the improvement of learning in a new task by transferring knowledge from a related task that has already been learned.

For instance, in the real-world, the balancing logic learned while riding a bicycle can be transferred to learn driving other two-wheeled vehicles. Similarly, in machine learning, transfer learning can be used to transfer the algorithmic logic from one ML model to the other.

3. Develop ML with JavaScript:

The third option will be used for situations where the developer wants to create a new Machine Learning model from scratch, using TensorFlow.js API, just like the regular TensorFlow version.

Now let’s get our hands dirty and do some Machine Learning with Javascript

In this article, our primary focus will be on adding and running a pre-trained Machine Learning model to a Javascript project. You will see how easy it is to install, load, and run predictions on the machine learning model.

So let’s get started!

I built an application that demonstrates the use of a pre-trained image tag classification model created by the Tensorflow.js team. The model is called MobileNet, and you can find more information about it here: MobileNet

The demo application is built with React.js, and Ant Design for the UI components.

React is an open-source, front end, JavaScript library for building user interfaces or UI components.

Let’s walk through the main parts of the application together:

First, dependencies

After we set up our React application, we will need to install tfjs and the image classification model — mobilenet, by running the commands below:

$ npm i @tensorflow-models/mobilenet$ npm i @tensorflow/tfjs

Now, after we installed the packages, we can import them to our `App.js` file:

import "@tensorflow/tfjs";import * as mobileNet from "@tensorflow-models/mobilenet

We imported the image classification model and the TensorFlow.js engine, which runs the machine learning models in the background every time we invoke the model.

Next up, we need to load the model into our component for future use. Please note that the model.load() function is an asynchronous function, so we need to wait for it to be completed.

const [model, setModel] = useState(null);useEffect(() => {const loadModel = async () => {const model = await mobileNet.load();setModel(model);};loadModel();}, []);The mobileNet model has a method called classify, after we loaded the model we can call this methodmodel.classify(img: tf.Tensor3D | ImageData | HTMLImageElement |HTMLCanvasElement | HTMLVideoElement,topk?: number)

This method accepts 2 arguments:

  1. img: A Tensor or an image element to make a classification on.
  2. topk: How many of the top probabilities to return. Defaults to 3.

In the next step we want to read the user input image and load the uploaded file into a canvas element of type HTMLCanvasElement

const onImageChange = async ({ target }) => {// Load the image into a canvas elementconst canvas = canvasRef.current;const ctx = canvas.getContext(“2d”);drawImageOnCanvas(target, canvas, ctx);}

After the image is loaded into the canvas we can run the classification method.

const onImageChange = async ({ target }) => {// Load the image into a canvas elementconst canvas = canvasRef.current;const ctx = canvas.getContext(“2d”);drawImageOnCanvas(target, canvas, ctx);// Classify the imageconst predictions = await model.classify(canvas, 5);// Set the results to the componenet’s statesetPredictions(predictions);};

The output of the model.classify method is an array of classified labels and their prediction score. The output looks like this:

[{className: “tiger, Panthera tigris”,probability: 0.6370824575424194}, {className: “tiger cat”,probability: 0.3609316051006317}, {className: “jaguar, panther, Panthera onca, Felis onca”,probability: 0.0009806138696148992}]

Once we saved the predictions array in our component we can loop over the array and render them to the screen:

<div className=”tags-container”>{predictions.map(({ className, probability }) =>probability.toFixed(3) > 0 && (<Tag className=”tag” key={className} color=”geekblue”>{className.split(“,”)[0]} {probability.toFixed(3)}</Tag>))}</div>

(I’ve chosen to only render meaningful probabilities that are above 0.000)

So that’s it, we have a living Machine Learning model in our browser, congratulations!

Please visit those links for:

You can upload your own images, get predictions and can even be more creative and try to add new features :)

Conclusion

There is no doubt that the use of machine learning is continuously increasing. With Javascript development becoming even more popular, the TensorFlow.js community will grow and get more powerful. I think we will see more and more production-grade applications running TensorFlow.js in the browser or Node.js servers for simple, light-weight tasks that Machine Learning models can solve.

After you’ve all seen how fast and easy it is to integrate TensorFlow.js to a Javascript application, I invite you all to try it by yourself and create some cool projects and share them with everyone.

Alibaba Tech

First hand and in-depth information about Alibaba’s latest technology → Facebook: “Alibaba Tech”. Twitter: “AlibabaTech”.

First-hand & in-depth information about Alibaba's tech innovation in Artificial Intelligence, Big Data & Computer Engineering. Follow us on Facebook!

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store