Maschinelles Lernen für Front-End-Entwickler mit Tensorflow.js
Veröffentlicht: 2022-03-10Maschinelles Lernen fühlt sich oft so an, als ob es in das Reich von Datenwissenschaftlern und Python-Entwicklern gehört. In den letzten Jahren wurden jedoch Open-Source-Frameworks entwickelt, um es in verschiedenen Programmiersprachen, einschließlich JavaScript, zugänglicher zu machen. In diesem Artikel werden wir anhand von Tensorflow.js anhand einiger Beispielprojekte die verschiedenen Möglichkeiten des Einsatzes von maschinellem Lernen im Browser untersuchen.
Was ist maschinelles Lernen?
Bevor wir in Code eintauchen, lassen Sie uns kurz darüber sprechen, was maschinelles Lernen ist, sowie einige Kernkonzepte und Terminologie.
Definition
Eine gängige Definition ist, dass es die Fähigkeit von Computern ist, aus Daten zu lernen, ohne explizit programmiert zu werden.
Wenn wir es mit traditioneller Programmierung vergleichen, bedeutet dies, dass wir Computer Muster in Daten erkennen und Vorhersagen erstellen lassen, ohne dass wir ihnen genau sagen müssen, wonach sie suchen sollen.
Nehmen wir das Beispiel Betrugserkennung. Es gibt keine festgelegten Kriterien, um zu wissen, was eine Transaktion betrügerisch macht oder nicht; Betrügereien können in jedem Land, auf jedem Konto, gegen jeden Kunden, zu jeder Zeit usw. ausgeführt werden. Es wäre so gut wie unmöglich, all dies manuell zu verfolgen.
Anhand früherer Daten über betrügerische Ausgaben, die im Laufe der Jahre gesammelt wurden, können wir jedoch einen maschinellen Lernalgorithmus trainieren, um Muster in diesen Daten zu verstehen, um ein Modell zu erstellen, das jeder neuen Transaktion gegeben werden kann und die Wahrscheinlichkeit vorhersagt, ob es sich um Betrug handelt oder nicht, ohne ihm genau sagen, wonach er suchen soll.
Kernkonzepte
Um die folgenden Codebeispiele zu verstehen, müssen wir zunächst einige allgemeine Begriffe behandeln.
Modell
Wenn Sie einen maschinellen Lernalgorithmus mit einem Datensatz trainieren, ist das Modell das Ergebnis dieses Trainingsprozesses. Es ist ein bisschen wie eine Funktion, die neue Daten als Eingabe nimmt und eine Vorhersage als Ausgabe erzeugt.
Etiketten und Funktionen
Labels und Features beziehen sich auf die Daten, die Sie einem Algorithmus im Trainingsprozess zuführen.
Eine Bezeichnung stellt dar, wie Sie jeden Eintrag in Ihrem Datensatz klassifizieren und wie Sie ihn beschriften würden. Wenn unser Datensatz beispielsweise eine CSV-Datei wäre, die verschiedene Tiere beschreibt, könnten unsere Bezeichnungen Wörter wie „Katze“, „Hund“ oder „Schlange“ sein (je nachdem, was jedes Tier darstellt).
Merkmale hingegen sind die Merkmale jedes Eintrags in Ihrem Datensatz. Für unser Tierbeispiel könnten das Dinge sein wie „Schnurrhaare, Miauen“, „verspielt, bellt“, „Reptilien, wild“ und so weiter.
Auf diese Weise kann ein maschineller Lernalgorithmus eine gewisse Korrelation zwischen Merkmalen und ihrer Bezeichnung finden, die er für zukünftige Vorhersagen verwendet.
Neuronale Netze
Neuronale Netze sind eine Reihe von Algorithmen für maschinelles Lernen, die versuchen, die Funktionsweise des Gehirns nachzuahmen, indem sie Schichten künstlicher Neuronen verwenden.
Wir müssen in diesem Artikel nicht näher darauf eingehen, wie sie funktionieren, aber wenn Sie mehr erfahren möchten, finden Sie hier ein wirklich gutes Video:
Nachdem wir nun einige Begriffe definiert haben, die häufig beim maschinellen Lernen verwendet werden, wollen wir darüber sprechen, was mit JavaScript und dem Tensorflow.js-Framework getan werden kann.
Merkmale
Derzeit sind drei Funktionen verfügbar:
- Mit einem vortrainierten Modell,
- Lernen übertragen,
- Definieren, Ausführen und Verwenden Ihres eigenen Modells.
Beginnen wir mit dem einfachsten.
1. Verwenden eines vortrainierten Modells
Abhängig von dem Problem, das Sie zu lösen versuchen, gibt es möglicherweise ein Modell, das bereits mit einem bestimmten Datensatz und für einen bestimmten Zweck trainiert wurde, das Sie nutzen und in Ihren Code importieren können.
Nehmen wir zum Beispiel an, wir erstellen eine Website, um vorherzusagen, ob ein Bild ein Bild einer Katze ist. Ein beliebtes Bildklassifizierungsmodell heißt MobileNet und ist als vortrainiertes Modell mit Tensorflow.js verfügbar.
Der Code dafür würde in etwa so aussehen:
<html lang="en"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta http-equiv="X-UA-Compatible" content="ie=edge"> <title>Cat detection</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"> </script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/[email protected]"> </script> </head> <body> <img alt="cat laying down" src="cat.jpeg"/> <script> const img = document.getElementById('image'); const predictImage = async () => { console.log("Model loading..."); const model = await mobilenet.load(); console.log("Model is loaded!") const predictions = await model.classify(img); console.log('Predictions: ', predictions); } predictImage(); </script> </body> </html>
Wir beginnen mit dem Import von Tensorflow.js und dem MobileNet-Modell in den Kopf unseres HTML:
<script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.0.1/tf.js"> </script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/[email protected]"> </script>
Dann haben wir im Körper ein Bildelement, das für Vorhersagen verwendet wird:
<img alt="cat laying down" src="cat.jpeg"/>
Und schließlich haben wir im script
-Tag den JavaScript-Code, der das vortrainierte MobileNet-Modell lädt und das im image
Tag gefundene Bild klassifiziert. Es gibt ein Array von 3 Vorhersagen zurück, die nach Wahrscheinlichkeitswert geordnet sind (das erste Element ist die beste Vorhersage).
const predictImage = async () => { console.log("Model loading..."); const model = await mobilenet.load(); console.log("Model is loaded!") const predictions = await model.classify(img); console.log('Predictions: ', predictions); } predictImage();
Und das ist es! So können Sie mit Tensorflow.js ein vortrainiertes Modell im Browser verwenden!
Hinweis : Wenn Sie sich ansehen möchten, was das MobileNet-Modell sonst noch klassifizieren kann, finden Sie eine Liste der verschiedenen verfügbaren Klassen auf Github.
Es ist wichtig zu wissen, dass das Laden eines vorab trainierten Modells im Browser einige Zeit in Anspruch nehmen kann (manchmal bis zu 10 Sekunden), sodass Sie Ihre Benutzeroberfläche wahrscheinlich vorab laden oder anpassen möchten, damit die Benutzer nicht beeinträchtigt werden.
Wenn Sie Tensorflow.js lieber als NPM-Modul verwenden möchten, können Sie dies tun, indem Sie das Modul auf diese Weise importieren:
import * as mobilenet from '@tensorflow-models/mobilenet';
Fühlen Sie sich frei, mit diesem Beispiel auf CodeSandbox herumzuspielen.
Nachdem wir nun gesehen haben, wie ein vortrainiertes Modell verwendet wird, wollen wir uns die zweite verfügbare Funktion ansehen: Lerntransfer.
2. Lernen übertragen
Transfer Learning ist die Fähigkeit, ein vortrainiertes Modell mit benutzerdefinierten Trainingsdaten zu kombinieren. Das bedeutet, dass Sie die Funktionalität eines Modells nutzen und Ihre eigenen Beispiele hinzufügen können, ohne alles von Grund auf neu erstellen zu müssen.
Beispielsweise wurde ein Algorithmus mit Tausenden von Bildern trainiert, um ein Bildklassifizierungsmodell zu erstellen, und anstatt Ihr eigenes zu erstellen, können Sie mit Transfer Learning neue benutzerdefinierte Bildbeispiele mit dem vortrainierten Modell kombinieren, um einen neuen Bildklassifizierer zu erstellen. Diese Funktion macht es wirklich schnell und einfach, einen individuelleren Klassifikator zu haben.
Um ein Beispiel dafür zu geben, wie dies im Code aussehen würde, lassen Sie uns unser vorheriges Beispiel umfunktionieren und es so ändern, dass wir neue Bilder klassifizieren können.
Hinweis : Das Endergebnis ist das folgende Experiment, das Sie hier live ausprobieren können.
Nachfolgend finden Sie einige Codebeispiele des wichtigsten Teils dieses Setups, aber wenn Sie sich den gesamten Code ansehen müssen, finden Sie ihn in dieser CodeSandbox.
Wir müssen immer noch mit dem Import von Tensorflow.js und MobileNet beginnen, aber dieses Mal müssen wir auch einen KNN-Klassifikator (k-nächster Nachbar) hinzufügen:
<!-- Load TensorFlow.js --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script> <!-- Load MobileNet --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script> <!-- Load KNN Classifier --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
Der Grund, warum wir einen Klassifikator benötigen, ist, dass wir (anstatt nur das MobileNet-Modul zu verwenden) benutzerdefinierte Beispiele hinzufügen, die es noch nie zuvor gesehen hat, sodass der KNN-Klassifikator es uns ermöglicht, alles miteinander zu kombinieren und Vorhersagen für die kombinierten Daten auszuführen.
Dann können wir das Bild der Katze durch ein video
Tag ersetzen, um Bilder aus dem Kamera-Feed zu verwenden.
<video autoplay width="227" height="227"></video>
Schließlich müssen wir auf der Seite einige Schaltflächen hinzufügen, die wir als Beschriftungen verwenden, um einige Videobeispiele aufzunehmen und die Vorhersagen zu starten.
<section> <button class="button">Left</button> <button class="button">Right</button> <button class="test-predictions">Test</button> </section>
Kommen wir nun zur JavaScript-Datei, in der wir mit der Einrichtung einiger wichtiger Variablen beginnen:
// Number of classes to classify const NUM_CLASSES = 2; // Labels for our classes const classes = ["Left", "Right"]; // Webcam Image size. Must be 227. const IMAGE_SIZE = 227; // K value for KNN const TOPK = 10; const video = document.getElementById("webcam");
In diesem speziellen Beispiel möchten wir in der Lage sein, die Webcam-Eingabe zwischen unserem nach links oder rechts geneigten Kopf zu klassifizieren, also brauchen wir zwei Klassen mit den Bezeichnungen left
und right
.
Die auf 227 eingestellte Bildgröße ist die Größe des Videoelements in Pixel. Basierend auf den Tensorflow.js-Beispielen muss dieser Wert auf 227 festgelegt werden, um dem Format der Daten zu entsprechen, mit denen das MobileNet-Modell trainiert wurde. Damit es unsere neuen Daten klassifizieren kann, müssen diese in dasselbe Format passen.
Wenn Sie wirklich eine größere Größe benötigen, ist dies möglich, aber Sie müssen die Daten transformieren und ihre Größe ändern, bevor Sie sie dem KNN-Klassifikator zuführen.
Dann setzen wir den Wert von K auf 10. Der K-Wert im KNN-Algorithmus ist wichtig, weil er die Anzahl der Instanzen darstellt, die wir berücksichtigen, wenn wir die Klasse unserer neuen Eingabe bestimmen.
In diesem Fall bedeutet der Wert 10, dass wir bei der Vorhersage des Labels für einige neue Daten die 10 nächsten Nachbarn aus den Trainingsdaten betrachten, um zu bestimmen, wie unsere neue Eingabe zu klassifizieren ist.
Schließlich erhalten wir das video
. Beginnen wir für die Logik mit dem Laden des Modells und des Klassifikators:
async load() { const knn = knnClassifier.create(); const mobilenetModule = await mobilenet.load(); console.log("model loaded"); }
Lassen Sie uns dann auf den Video-Feed zugreifen:
navigator.mediaDevices .getUserMedia({ video: true, audio: false }) .then(stream => { video.srcObject = stream; video.width = IMAGE_SIZE; video.height = IMAGE_SIZE; });
Anschließend richten wir einige Schaltflächenereignisse ein, um unsere Beispieldaten aufzuzeichnen:
setupButtonEvents() { for (let i = 0; i < NUM_CLASSES; i++) { let button = document.getElementsByClassName("button")[i]; button.onmousedown = () => { this.training = i; this.recordSamples = true; }; button.onmouseup = () => (this.training = -1); } }
Lassen Sie uns unsere Funktion schreiben, die die Webcam-Bildbeispiele nimmt, sie neu formatiert und sie mit dem MobileNet-Modul kombiniert:
// Get image data from video element const image = tf.browser.fromPixels(video); let logits; // 'conv_preds' is the logits activation of MobileNet. const infer = () => this.mobilenetModule.infer(image, "conv_preds"); // Train class if one of the buttons is held down if (this.training != -1) { logits = infer(); // Add current image to classifier this.knn.addExample(logits, this.training); }
Und schließlich, nachdem wir einige Webcam-Bilder gesammelt haben, können wir unsere Vorhersagen mit dem folgenden Code testen:
logits = infer(); const res = await this.knn.predictClass(logits, TOPK); const prediction = classes[res.classIndex];
Und schließlich können Sie die Webcam-Daten entsorgen, da wir sie nicht mehr benötigen:
// Dispose image when done image.dispose(); if (logits != null) { logits.dispose(); }
Wenn Sie sich den vollständigen Code noch einmal ansehen möchten, finden Sie ihn in der zuvor erwähnten CodeSandbox.
3. Trainieren eines Modells im Browser
Die letzte Funktion besteht darin, ein Modell vollständig im Browser zu definieren, zu trainieren und auszuführen. Um dies zu veranschaulichen, bauen wir das klassische Beispiel zum Erkennen von Schwertlilien.
Dazu erstellen wir ein neuronales Netzwerk, das Schwertlilien in drei Kategorien klassifizieren kann: Setosa, Virginica und Versicolor, basierend auf einem Open-Source-Datensatz.
Bevor wir beginnen, hier ist ein Link zur Live-Demo und hier ist die CodeSandbox, wenn Sie mit dem vollständigen Code herumspielen möchten.
Der Kern jedes maschinellen Lernprojekts ist ein Datensatz. Einer der ersten Schritte, die wir unternehmen müssen, besteht darin, diesen Datensatz in einen Trainingssatz und einen Testsatz aufzuteilen.
Der Grund dafür ist, dass wir unser Trainingsset verwenden werden, um unseren Algorithmus zu trainieren, und unser Testset, um die Genauigkeit unserer Vorhersagen zu überprüfen, um zu validieren, ob unser Modell einsatzbereit ist oder optimiert werden muss.
Hinweis : Um es einfacher zu machen, habe ich das Trainingsset und das Testset bereits in zwei JSON-Dateien aufgeteilt, die Sie in der CodeSanbox finden.
Das Trainingsset enthält 130 Elemente und das Testset 14. Wenn Sie sich ansehen, wie diese Daten aussehen, sehen Sie ungefähr Folgendes:
{ "sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2, "species": "setosa" }
Was wir sehen können, sind vier verschiedene Merkmale für die Länge und Breite des Kelch- und Blütenblatts sowie eine Bezeichnung für die Art.
Um dies mit Tensorflow.js verwenden zu können, müssen wir diese Daten in ein Format umwandeln, das das Framework versteht. In diesem Fall sind dies für die Trainingsdaten [130, 4]
für 130 Samples mit vier Features pro Iris.
import * as trainingSet from "training.json"; import * as testSet from "testing.json"; const trainingData = tf.tensor2d( trainingSet.map(item => [ item.sepal_length, item.sepal_width, item.petal_length, item.petal_width ]), [130, 4] ); const testData = tf.tensor2d( testSet.map(item => [ item.sepal_length, item.sepal_width, item.petal_length, item.petal_width ]), [14, 4] );
Als nächstes müssen wir auch unsere Ausgabedaten formen:
const output = tf.tensor2d(trainingSet.map(item => [ item.species === 'setosa' ? 1 : 0, item.species === 'virginica' ? 1 : 0, item.species === 'versicolor' ? 1 : 0 ]), [130,3])
Sobald unsere Daten fertig sind, können wir mit der Erstellung des Modells fortfahren:
const model = tf.sequential(); model.add(tf.layers.dense( { inputShape: 4, activation: 'sigmoid', units: 10 } )); model.add(tf.layers.dense( { inputShape: 10, units: 3, activation: 'softmax' } ));
Im obigen Codebeispiel beginnen wir mit der Instanziierung eines sequentiellen Modells und fügen eine Eingabe- und Ausgabeschicht hinzu.
Die darin verwendeten Parameter ( inputShape
, activation
und units
) sind nicht Gegenstand dieses Beitrags, da sie je nach erstelltem Modell, verwendetem Datentyp usw. variieren können.
Sobald unser Modell fertig ist, können wir es mit unseren Daten trainieren:
async function train_data(){ for(let i=0;i<15;i++){ const res = await model.fit(trainingData, outputData,{epochs: 40}); } } async function main() { await train_data(); model.predict(testSet).print(); }
Wenn dies gut funktioniert, können Sie damit beginnen, die Testdaten durch benutzerdefinierte Benutzereingaben zu ersetzen.
Sobald wir unsere Hauptfunktion aufgerufen haben, sieht die Ausgabe der Vorhersage wie eine dieser drei Optionen aus:
[1,0,0] // Setosa [0,1,0] // Virginica [0,0,1] // Versicolor
Die Vorhersage gibt ein Array aus drei Zahlen zurück, die die Wahrscheinlichkeit darstellen, dass die Daten zu einer der drei Klassen gehören. Die Zahl, die 1 am nächsten liegt, ist die höchste Vorhersage.
Wenn die Ausgabe der Klassifizierung beispielsweise [0.0002, 0.9494, 0.0503]
ist, ist das zweite Element des Arrays das höchste, sodass das Modell vorhersagte, dass die neue Eingabe wahrscheinlich ein Virginica sein wird.
Und das war's für ein einfaches neuronales Netzwerk in Tensorflow.js!
Wir haben nur über einen kleinen Datensatz von Schwertlilien gesprochen, aber wenn Sie zu größeren Datensätzen übergehen oder mit Bildern arbeiten möchten, sind die Schritte die gleichen:
- Sammeln der Daten;
- Aufteilung zwischen Trainings- und Testset;
- Neuformatierung der Daten, damit Tensorflow.js sie verstehen kann;
- Auswahl Ihres Algorithmus;
- Anpassen der Daten;
- Vorhersagen.
Wenn Sie das erstellte Modell speichern möchten, um es in einer anderen Anwendung laden und neue Daten vorhersagen zu können, können Sie dies mit der folgenden Zeile tun:
await model.save('file:///path/to/my-model'); // in Node.js
Hinweis : Weitere Optionen zum Speichern eines Modells finden Sie in dieser Ressource.
Grenzen
Das ist es! Wir haben gerade die drei Hauptfunktionen behandelt, die derzeit mit Tensorflow.js verfügbar sind!
Bevor wir zum Schluss kommen, ist es meiner Meinung nach wichtig, kurz einige der Grenzen der Verwendung von maschinellem Lernen im Frontend zu erwähnen.
1. Leistung
Das Importieren eines vorab trainierten Modells aus einer externen Quelle kann sich auf die Leistung Ihrer Anwendung auswirken. Einige Objekterkennungsmodelle sind beispielsweise mehr als 10 MB groß, was Ihre Website erheblich verlangsamen wird. Denken Sie an Ihre Benutzererfahrung und optimieren Sie das Laden Ihrer Assets, um Ihre wahrgenommene Leistung zu verbessern.
2. Qualität der Eingabedaten
Wenn Sie ein Modell von Grund auf neu erstellen, müssen Sie Ihre eigenen Daten sammeln oder einen Open-Source-Datensatz finden.
Bevor Sie irgendeine Art von Datenverarbeitung durchführen oder verschiedene Algorithmen ausprobieren, überprüfen Sie unbedingt die Qualität Ihrer Eingabedaten. Wenn Sie beispielsweise versuchen, ein Stimmungsanalysemodell zu erstellen, um Emotionen in Textteilen zu erkennen, stellen Sie sicher, dass die Daten, die Sie zum Trainieren Ihres Modells verwenden, genau und vielfältig sind. Wenn die Qualität der verwendeten Daten niedrig ist, ist der Output Ihres Trainings nutzlos.
3. Haftung
Die Verwendung eines vortrainierten Open-Source-Modells kann sehr schnell und mühelos sein. Es bedeutet jedoch auch, dass Sie nicht immer wissen, wie es generiert wurde, woraus der Datensatz bestand oder sogar welcher Algorithmus verwendet wurde. Einige Modelle werden „Black Boxes“ genannt, was bedeutet, dass Sie nicht wirklich wissen, wie sie eine bestimmte Ausgabe vorhergesagt haben.
Je nachdem, was Sie zu bauen versuchen, kann dies ein Problem sein. Wenn Sie beispielsweise ein maschinelles Lernmodell verwenden, um die Wahrscheinlichkeit zu erkennen, dass jemand Krebs hat, basierend auf Scanbildern, im Falle eines falschen Negativs (das Modell sagte voraus, dass eine Person keinen Krebs hatte, als sie es tatsächlich hatte), dort könnte eine echte gesetzliche Haftung sein, und Sie müssten erklären können, warum das Modell eine bestimmte Vorhersage getroffen hat.
Zusammenfassung
Zusammenfassend lässt sich sagen, dass die Verwendung von JavaScript und Frameworks wie Tensorflow.js eine großartige Möglichkeit ist, loszulegen und mehr über maschinelles Lernen zu erfahren. Auch wenn eine produktionsreife Anwendung wahrscheinlich in einer Sprache wie Python erstellt werden sollte, macht JavaScript es Entwicklern wirklich zugänglich, mit den verschiedenen Funktionen herumzuspielen und ein besseres Verständnis der grundlegenden Konzepte zu erlangen, bevor sie schließlich weitermachen und Zeit in das Erlernen einer anderen investieren Sprache.
In diesem Tutorial haben wir nur behandelt, was mit Tensorflow.js möglich war, aber das Ökosystem anderer Bibliotheken und Tools wächst. Es sind auch spezifischere Frameworks verfügbar, mit denen Sie die Verwendung von maschinellem Lernen mit anderen Domänen wie Musik mit Magenta.js oder die Vorhersage der Benutzernavigation auf einer Website mit rate.js erkunden können!
Da Tools immer leistungsfähiger werden, werden die Möglichkeiten zum Erstellen maschinell lernender Anwendungen in JavaScript wahrscheinlich immer spannender, und jetzt ist ein guter Zeitpunkt, mehr darüber zu erfahren, da die Community Anstrengungen unternimmt, um es zugänglich zu machen.
Weitere Ressourcen
Wenn Sie daran interessiert sind, mehr zu erfahren, finden Sie hier einige Ressourcen:
Andere Frameworks und Tools
- ml5.js
- ml.js
- brain.js
- Keras.js
- PoseNet
- Tensorflow-Spielplatz
Beispiele, Modelle und Datensätze
- Tensorflow.js-Modelle
- Tensorflow.js-Beispiele
- Datensätze
Inspiration
- Lehrbare Maschine
- KI-Experimente
- AIJS.rockt
- Erstellbarkeit
Danke fürs Lesen!