Învățare automată pentru dezvoltatorii front-end cu Tensorflow.js
Publicat: 2022-03-10Învățarea automată pare adesea că aparține domeniului cercetătorilor de date și al dezvoltatorilor Python. Cu toate acestea, în ultimii câțiva ani, cadrele open-source au fost create pentru a le face mai accesibile în diferite limbaje de programare, inclusiv JavaScript. În acest articol, vom folosi Tensorflow.js pentru a explora diferitele posibilități de utilizare a învățării automate în browser prin câteva exemple de proiecte.
Ce este învățarea automată?
Înainte de a începe să pătrundem în cod, să vorbim pe scurt despre ce este învățarea automată, precum și despre câteva concepte și terminologie de bază.
Definiție
O definiție comună este aceea că este capacitatea computerelor de a învăța din date fără a fi programate în mod explicit.
Dacă o comparăm cu programarea tradițională, înseamnă că lăsăm computerele să identifice modele în date și să genereze predicții fără a fi nevoie să îi spunem exact ce să caute.
Să luăm exemplul detectării fraudelor. Nu există criterii stabilite pentru a ști ce face o tranzacție frauduloasă sau nu; fraudele pot fi executate în orice țară, în orice cont, vizând orice client, în orice moment și așa mai departe. Ar fi aproape imposibil să urmăriți toate acestea manual.
Cu toate acestea, folosind datele anterioare despre cheltuielile frauduloase adunate de-a lungul anilor, putem antrena un algoritm de învățare automată pentru a înțelege tiparele din aceste date pentru a genera un model căruia i se poate da orice tranzacție nouă și pentru a prezice probabilitatea ca aceasta să fie fraudă sau nu, fără spunându-i exact ce să caute.
Concepte de baza
Pentru a înțelege următoarele exemple de cod, trebuie să acoperim mai întâi câțiva termeni comuni.
Model
Când antrenați un algoritm de învățare automată cu un set de date, modelul este rezultatul acestui proces de instruire. Este un pic ca o funcție care preia date noi ca intrare și produce o predicție ca ieșire.
Etichete și caracteristici
Etichetele și caracteristicile se referă la datele pe care le furnizați unui algoritm în procesul de antrenament.
O etichetă reprezintă modul în care ați clasifica fiecare intrare din setul de date și cum ați eticheta-o. De exemplu, dacă setul nostru de date era un fișier CSV care descrie diferite animale, etichetele noastre ar putea fi cuvinte precum „pisica”, „câine” sau „șarpe” (în funcție de ceea ce reprezintă fiecare animal).
Pe de altă parte, caracteristicile sunt caracteristicile fiecărei intrări din setul dvs. de date. Pentru animalele noastre, de exemplu, ar putea fi lucruri precum „muștați, miaunături”, „jucăușe, latră”, „reptile, rampante” și așa mai departe.
Folosind acest lucru, un algoritm de învățare automată va putea găsi o corelație între caracteristicile și eticheta lor pe care o va folosi pentru predicții viitoare.
Rețele neuronale
Rețelele neuronale sunt un set de algoritmi de învățare automată care încearcă să imite modul în care funcționează creierul folosind straturi de neuroni artificiali.
Nu trebuie să detaliem modul în care funcționează în acest articol, dar dacă doriți să aflați mai multe, iată un videoclip foarte bun:
Acum că am definit câțiva termeni folosiți în mod obișnuit în învățarea automată, să vorbim despre ce se poate face folosind JavaScript și cadrul Tensorflow.js.
Caracteristici
În prezent sunt disponibile trei funcții:
- Folosind un model pre-antrenat,
- Transferați învățarea,
- Definirea, rularea și utilizarea propriului model.
Să începem cu cel mai simplu.
1. Utilizarea unui model pre-antrenat
În funcție de problema pe care încercați să o rezolvați, ar putea exista un model deja antrenat cu un anumit set de date și pentru un anumit scop pe care îl puteți utiliza și importa în codul dvs.
De exemplu, să presupunem că construim un site web pentru a prezice dacă o imagine este o imagine a unei pisici. Un model popular de clasificare a imaginilor se numește MobileNet și este disponibil ca model pre-antrenat cu Tensorflow.js.
Codul pentru aceasta ar arăta cam așa:
<html lang="en"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta http-equiv="X-UA-Compatible" content="ie=edge"> <title>Cat detection</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"> </script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/[email protected]"> </script> </head> <body> <img alt="cat laying down" src="cat.jpeg"/> <script> const img = document.getElementById('image'); const predictImage = async () => { console.log("Model loading..."); const model = await mobilenet.load(); console.log("Model is loaded!") const predictions = await model.classify(img); console.log('Predictions: ', predictions); } predictImage(); </script> </body> </html>
Începem prin a importa Tensorflow.js și modelul MobileNet în capul HTML-ului nostru:
<script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.0.1/tf.js"> </script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/[email protected]"> </script>
Apoi, în interiorul corpului, avem un element de imagine care va fi folosit pentru predicții:
<img alt="cat laying down" src="cat.jpeg"/>
Și în sfârșit, în interiorul etichetei de script
, avem codul JavaScript care încarcă modelul MobileNet pre-antrenat și clasifică imaginea găsită în eticheta de image
. Returnează o serie de 3 predicții care sunt ordonate după scorul de probabilitate (primul element fiind cea mai bună predicție).
const predictImage = async () => { console.log("Model loading..."); const model = await mobilenet.load(); console.log("Model is loaded!") const predictions = await model.classify(img); console.log('Predictions: ', predictions); } predictImage();
Si asta e! Acesta este modul în care puteți utiliza un model pre-antrenat în browser cu Tensorflow.js!
Notă : Dacă doriți să aruncați o privire la ce altceva poate clasifica modelul MobileNet, puteți găsi o listă cu diferitele clase disponibile pe Github.
Un lucru important de știut este că încărcarea unui model pre-antrenat în browser poate dura ceva timp (uneori până la 10 secunde), așa că probabil că veți dori să preîncărcați sau să vă adaptați interfața, astfel încât utilizatorii să nu fie afectați.
Dacă preferați să utilizați Tensorflow.js ca modul NPM, puteți face acest lucru importând modulul astfel:
import * as mobilenet from '@tensorflow-models/mobilenet';
Simțiți-vă liber să vă jucați cu acest exemplu pe CodeSandbox.
Acum că am văzut cum să folosim un model pre-antrenat, să ne uităm la a doua caracteristică disponibilă: transferul de învățare.
2. Transfer de învățare
Învățarea prin transfer este capacitatea de a combina un model pre-antrenat cu date personalizate de antrenament. Acest lucru înseamnă că puteți profita de funcționalitatea unui model și puteți adăuga propriile mostre fără a fi nevoie să creați totul de la zero.
De exemplu, un algoritm a fost antrenat cu mii de imagini pentru a crea un model de clasificare a imaginilor și, în loc să vă creați propriul dvs., învățarea prin transfer vă permite să combinați noi mostre de imagini personalizate cu modelul pre-antrenat pentru a crea un nou clasificator de imagini. Această caracteristică face să fie foarte rapid și ușor să existe un clasificator mai personalizat.
Pentru a oferi un exemplu despre cum ar arăta în cod, să reproporăm exemplul nostru anterior și să-l modificăm astfel încât să putem clasifica imagini noi.
Notă : rezultatul final este experimentul de mai jos pe care îl puteți încerca live aici.
Mai jos sunt câteva exemple de cod ale celei mai importante părți a acestei configurații, dar dacă trebuie să aruncați o privire la întregul cod, îl puteți găsi pe acest CodeSandbox.
Mai trebuie să începem prin a importa Tensorflow.js și MobileNet, dar de data aceasta trebuie să adăugăm și un clasificator KNN (k-nearest neighbor):
<!-- Load TensorFlow.js --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script> <!-- Load MobileNet --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script> <!-- Load KNN Classifier --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
Motivul pentru care avem nevoie de un clasificator este pentru că (în loc să folosim doar modulul MobileNet) adăugăm mostre personalizate pe care nu le-am văzut niciodată până acum, astfel încât clasificatorul KNN ne va permite să combinăm totul împreună și să rulăm predicții asupra datelor combinate.
Apoi, putem înlocui imaginea pisicii cu o etichetă video
pentru a folosi imagini din fluxul camerei.
<video autoplay width="227" height="227"></video>
În cele din urmă, va trebui să adăugăm câteva butoane pe pagină pe care le vom folosi ca etichete pentru a înregistra câteva mostre video și a începe predicțiile.
<section> <button class="button">Left</button> <button class="button">Right</button> <button class="test-predictions">Test</button> </section>
Acum, să trecem la fișierul JavaScript de unde vom începe prin a configura câteva variabile importante:
// Number of classes to classify const NUM_CLASSES = 2; // Labels for our classes const classes = ["Left", "Right"]; // Webcam Image size. Must be 227. const IMAGE_SIZE = 227; // K value for KNN const TOPK = 10; const video = document.getElementById("webcam");
În acest exemplu particular, dorim să putem clasifica intrarea camerei web între capul înclinat la stânga sau la dreapta, așa că avem nevoie de două clase etichetate left
și right
.
Dimensiunea imaginii setată la 227 este dimensiunea elementului video în pixeli. Pe baza exemplelor Tensorflow.js, această valoare trebuie setată la 227 pentru a se potrivi cu formatul datelor cu care a fost antrenat modelul MobileNet. Pentru ca acesta să poată clasifica noile noastre date, acestea din urmă trebuie să se potrivească cu același format.
Dacă într-adevăr aveți nevoie să fie mai mare, este posibil, dar va trebui să transformați și să redimensionați datele înainte de a le furniza clasificatorului KNN.
Apoi, setăm valoarea lui K la 10. Valoarea K din algoritmul KNN este importantă deoarece reprezintă numărul de instanțe pe care le luăm în considerare atunci când determinăm clasa noii noastre intrări.
În acest caz, valoarea lui 10 înseamnă că, atunci când predicăm eticheta pentru unele date noi, ne vom uita la cei mai apropiați 10 vecini din datele de antrenament pentru a determina cum să clasificăm noua noastră intrare.
În sfârșit, obținem elementul video
. Pentru logică, să începem prin a încărca modelul și clasificatorul:
async load() { const knn = knnClassifier.create(); const mobilenetModule = await mobilenet.load(); console.log("model loaded"); }
Apoi, să accesăm fluxul video:
navigator.mediaDevices .getUserMedia({ video: true, audio: false }) .then(stream => { video.srcObject = stream; video.width = IMAGE_SIZE; video.height = IMAGE_SIZE; });
În continuare, să setăm câteva evenimente butoane pentru a înregistra datele noastre eșantion:
setupButtonEvents() { for (let i = 0; i < NUM_CLASSES; i++) { let button = document.getElementsByClassName("button")[i]; button.onmousedown = () => { this.training = i; this.recordSamples = true; }; button.onmouseup = () => (this.training = -1); } }
Să scriem funcția noastră care va lua mostre de imagini ale camerei web, le va reformata și le va combina cu modulul MobileNet:
// Get image data from video element const image = tf.browser.fromPixels(video); let logits; // 'conv_preds' is the logits activation of MobileNet. const infer = () => this.mobilenetModule.infer(image, "conv_preds"); // Train class if one of the buttons is held down if (this.training != -1) { logits = infer(); // Add current image to classifier this.knn.addExample(logits, this.training); }
Și, în cele din urmă, odată ce am adunat câteva imagini webcam, ne putem testa predicțiile cu următorul cod:
logits = infer(); const res = await this.knn.predictClass(logits, TOPK); const prediction = classes[res.classIndex];
Și, în sfârșit, puteți elimina datele camerei web, deoarece nu mai avem nevoie de ele:
// Dispose image when done image.dispose(); if (logits != null) { logits.dispose(); }
Încă o dată, dacă doriți să aruncați o privire la codul complet, îl puteți găsi în CodeSandbox menționat mai devreme.
3. Antrenarea unui model în browser
Ultima caracteristică este de a defini, antrena și rula un model în întregime în browser. Pentru a ilustra acest lucru, vom construi exemplul clasic de recunoaștere a Irisilor.
Pentru aceasta, vom crea o rețea neuronală care poate clasifica Irisurile în trei categorii: Setosa, Virginica și Versicolor, pe baza unui set de date open-source.
Înainte de a începe, iată un link către demonstrația live și aici este CodeSandbox dacă doriți să jucați cu codul complet.
La baza fiecărui proiect de învățare automată se află un set de date. Unul dintre primii pasi pe care trebuie să-l facem este împărțirea acestui set de date într-un set de antrenament și un set de testare.
Motivul pentru aceasta este că vom folosi setul nostru de antrenament pentru a ne antrena algoritmul și setul nostru de testare pentru a verifica acuratețea predicțiilor noastre, pentru a valida dacă modelul nostru este gata de utilizat sau trebuie modificat.
Notă : Pentru a fi mai ușor, am împărțit deja setul de antrenament și setul de testare în două fișiere JSON pe care le puteți găsi în CodeSanbox.
Setul de antrenament conține 130 de articole și setul de testare 14. Dacă vă uitați la cum arată aceste date, veți vedea ceva de genul acesta:
{ "sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2, "species": "setosa" }
Ceea ce putem vedea sunt patru caracteristici diferite pentru lungimea și lățimea sepalului și petalei, precum și o etichetă pentru specie.
Pentru a putea folosi acest lucru cu Tensorflow.js, trebuie să modelăm aceste date într-un format pe care cadrul îl va înțelege, în acest caz, pentru datele de antrenament, va fi [130, 4]
pentru 130 de mostre cu patru caracteristici per fiecare. iris.
import * as trainingSet from "training.json"; import * as testSet from "testing.json"; const trainingData = tf.tensor2d( trainingSet.map(item => [ item.sepal_length, item.sepal_width, item.petal_length, item.petal_width ]), [130, 4] ); const testData = tf.tensor2d( testSet.map(item => [ item.sepal_length, item.sepal_width, item.petal_length, item.petal_width ]), [14, 4] );
În continuare, trebuie să ne modelăm și datele de ieșire:
const output = tf.tensor2d(trainingSet.map(item => [ item.species === 'setosa' ? 1 : 0, item.species === 'virginica' ? 1 : 0, item.species === 'versicolor' ? 1 : 0 ]), [130,3])
Apoi, odată ce datele noastre sunt gata, putem trece la crearea modelului:
const model = tf.sequential(); model.add(tf.layers.dense( { inputShape: 4, activation: 'sigmoid', units: 10 } )); model.add(tf.layers.dense( { inputShape: 10, units: 3, activation: 'softmax' } ));
În exemplul de cod de mai sus, începem prin a instanția un model secvențial, adăugăm un strat de intrare și ieșire.
Parametrii pe care îi puteți vedea folosiți în interior ( inputShape
, activation
, and units
) nu fac obiectul acestei postări, deoarece pot varia în funcție de modelul pe care îl creați, de tipul de date utilizate și așa mai departe.
Odată ce modelul nostru este gata, îl putem antrena cu datele noastre:
async function train_data(){ for(let i=0;i<15;i++){ const res = await model.fit(trainingData, outputData,{epochs: 40}); } } async function main() { await train_data(); model.predict(testSet).print(); }
Dacă acest lucru funcționează bine, puteți începe să înlocuiți datele de testare cu intrări personalizate de utilizator.
Odată ce apelăm funcția noastră principală, rezultatul predicției va arăta ca una dintre aceste trei opțiuni:
[1,0,0] // Setosa [0,1,0] // Virginica [0,0,1] // Versicolor
Predicția returnează o matrice de trei numere reprezentând probabilitatea ca datele să aparțină uneia dintre cele trei clase. Numărul cel mai apropiat de 1 este cea mai mare predicție.
De exemplu, dacă rezultatul clasificării este [0.0002, 0.9494, 0.0503]
, al doilea element al matricei este cel mai înalt, astfel încât modelul a prezis că noua intrare este probabil să fie Virginica.
Și atât pentru o rețea neuronală simplă în Tensorflow.js!
Am vorbit doar despre un set mic de date de Iris, dar dacă doriți să treceți la seturi de date mai mari sau să lucrați cu imagini, pașii vor fi aceiași:
- Colectarea datelor;
- Împărțirea între setul de antrenament și setul de testare;
- Reformatarea datelor astfel încât Tensorflow.js să le poată înțelege;
- Alegerea algoritmului;
- Potrivirea datelor;
- Prezice.
Dacă doriți să salvați modelul creat pentru a-l putea încărca într-o altă aplicație și a prezice date noi, puteți face acest lucru cu următoarea linie:
await model.save('file:///path/to/my-model'); // in Node.js
Notă : Pentru mai multe opțiuni despre cum să salvați un model, aruncați o privire la această resursă.
Limite
Asta e! Tocmai am acoperit cele trei caracteristici principale disponibile în prezent folosind Tensorflow.js!
Înainte de a termina, cred că este important să menționăm pe scurt câteva dintre limitele utilizării învățării automate în front-end.
1. Performanță
Importarea unui model pre-antrenat dintr-o sursă externă poate avea un impact asupra performanței aplicației dvs. Unele modele de detectare a obiectelor, de exemplu, au mai mult de 10 MB, ceea ce va încetini în mod semnificativ site-ul dvs. Asigurați-vă că vă gândiți la experiența dvs. de utilizator și optimizați încărcarea activelor pentru a vă îmbunătăți performanța percepută.
2. Calitatea datelor de intrare
Dacă construiți un model de la zero, va trebui să vă adunați propriile date sau să găsiți un set de date open-source.
Înainte de a efectua orice fel de procesare a datelor sau de a încerca diferiți algoritmi, asigurați-vă că verificați calitatea datelor de intrare. De exemplu, dacă încercați să construiți un model de analiză a sentimentelor pentru a recunoaște emoțiile în fragmente de text, asigurați-vă că datele pe care le utilizați pentru a vă antrena modelul sunt exacte și diverse. Dacă calitatea datelor utilizate este scăzută, rezultatul antrenamentului dvs. va fi inutil.
3. Răspundere
Utilizarea unui model open-source pre-antrenat poate fi foarte rapidă și fără efort. Totuși, înseamnă, de asemenea, că nu știi întotdeauna cum a fost generat, din ce a fost format setul de date sau chiar ce algoritm a fost folosit. Unele modele sunt numite „cutii negre”, ceea ce înseamnă că nu știi cu adevărat cum au prezis o anumită ieșire.
În funcție de ceea ce încercați să construiți, aceasta poate fi o problemă. De exemplu, dacă utilizați un model de învățare automată pentru a ajuta la detectarea probabilității ca cineva să aibă cancer pe baza imaginilor scanate, în caz de negativ fals (modelul a prezis că o persoană nu a avut cancer atunci când a avut de fapt), există ar putea fi o răspundere legală reală și ar trebui să puteți explica de ce modelul a făcut o anumită predicție.
rezumat
În concluzie, utilizarea JavaScript și a cadrelor precum Tensorflow.js este o modalitate excelentă de a începe și de a afla mai multe despre învățarea automată. Chiar dacă o aplicație pregătită pentru producție ar trebui probabil să fie construită într-un limbaj precum Python, JavaScript îl face cu adevărat accesibil pentru dezvoltatori să se joace cu diferitele caracteristici și să înțeleagă mai bine conceptele fundamentale înainte de a trece mai departe și de a investi timp în învățarea altora. limba.
În acest tutorial, am acoperit doar ceea ce a fost posibil folosind Tensorflow.js, cu toate acestea, ecosistemul altor biblioteci și instrumente este în creștere. Sunt disponibile, de asemenea, cadre mai specificate, permițându-vă să explorați folosind învățarea automată cu alte domenii, cum ar fi muzica cu Magenta.js, sau să preziceți navigarea utilizatorului pe un site web folosind guess.js!
Pe măsură ce instrumentele devin mai performante, posibilitățile de a construi aplicații activate pentru învățarea automată în JavaScript vor fi probabil din ce în ce mai interesante și acum este un moment bun pentru a afla mai multe despre acestea, deoarece comunitatea depune eforturi pentru a le face accesibil.
Resurse suplimentare
Dacă sunteți interesat să aflați mai multe, iată câteva resurse:
Alte cadre și instrumente
- ml5.js
- ml.js
- creier.js
- Keras.js
- PoseNet
- Loc de joacă Tensorflow
Exemple, modele și seturi de date
- Modele Tensorflow.js
- Exemple Tensorflow.js
- Seturi de date
Inspirație
- Mașină de învățat
- Experimente AI
- AIJS.roci
- Creabilitate
Multumesc pentru lectura!