Uczenie maszynowe dla programistów front-end z Tensorflow.js

Opublikowany: 2022-03-10
Szybkie podsumowanie ↬ Korzystanie z JavaScript i frameworków, takich jak Tensorflow.js, to świetny sposób na rozpoczęcie pracy i dowiedzenie się więcej o uczeniu maszynowym. W tym artykule Charlie Gerard omawia trzy główne funkcje dostępne obecnie przy użyciu Tensorflow.js i rzuca światło na ograniczenia korzystania z uczenia maszynowego w interfejsie użytkownika.

Często wydaje się, że uczenie maszynowe należy do dziedziny naukowców zajmujących się danymi i programistów Pythona. Jednak w ciągu ostatnich kilku lat stworzono frameworki typu open source, aby uczynić je bardziej dostępnymi w różnych językach programowania, w tym w JavaScript. W tym artykule wykorzystamy Tensorflow.js do zbadania różnych możliwości wykorzystania uczenia maszynowego w przeglądarce poprzez kilka przykładowych projektów.

Co to jest uczenie maszynowe?

Zanim zaczniemy zagłębiać się w kod, porozmawiajmy krótko o tym, czym jest uczenie maszynowe, a także o kilku podstawowych pojęciach i terminologii.

Definicja

Powszechną definicją jest to, że jest to zdolność komputerów do uczenia się na podstawie danych bez wyraźnego programowania.

Jeśli porównamy to z tradycyjnym programowaniem, oznacza to, że pozwalamy komputerom identyfikować wzorce w danych i generować prognozy bez konieczności dokładnego informowania ich, czego szukać.

Weźmy przykład wykrywania oszustw. Nie ma ustalonych kryteriów, aby wiedzieć, co sprawia, że ​​transakcja jest oszukańcza, a co nie; oszustwa mogą być dokonywane w dowolnym kraju, na dowolnym koncie, celując w dowolnego klienta, w dowolnym czasie i tak dalej. Ręczne śledzenie tego wszystkiego byłoby prawie niemożliwe.

Jednak korzystając z wcześniejszych danych dotyczących nieuczciwych wydatków zebranych przez lata, możemy wytrenować algorytm uczenia maszynowego, aby zrozumieć wzorce w tych danych, aby wygenerować model, który może otrzymać każdą nową transakcję i przewidzieć prawdopodobieństwo, że jest to oszustwo lub nie, bez mówiąc dokładnie, czego szukać.

Więcej po skoku! Kontynuuj czytanie poniżej ↓

Podstawowe pojęcia

Aby zrozumieć poniższe przykłady kodu, musimy najpierw omówić kilka typowych terminów.

Model

Podczas uczenia algorytmu uczenia maszynowego za pomocą zestawu danych model jest wynikiem tego procesu uczenia. Przypomina to trochę funkcję, która pobiera nowe dane jako dane wejściowe i generuje prognozę jako dane wyjściowe.

Etykiety i funkcje

Etykiety i funkcje odnoszą się do danych, które podajesz algorytmowi w procesie uczenia.

Etykieta reprezentuje sposób, w jaki sklasyfikowałbyś każdy wpis w zestawie danych i jak go oznaczyłeś. Na przykład, jeśli nasz zbiór danych był plikiem CSV opisującym różne zwierzęta, naszymi etykietami mogą być słowa takie jak „kot”, „pies” lub „wąż” (w zależności od tego, co reprezentuje każde zwierzę).

Z drugiej strony funkcje to cechy charakterystyczne każdego wpisu w zestawie danych. W przypadku naszych zwierząt mogą to być takie rzeczy jak „wąsy, miauczy”, „zabawny, szczeka”, „gad, szalejący” i tak dalej.

Korzystając z tego, algorytm uczenia maszynowego będzie w stanie znaleźć pewną korelację między cechami a ich etykietą, której użyje do przyszłych prognoz.

Sieci neuronowe

Sieci neuronowe to zestaw algorytmów uczenia maszynowego, które próbują naśladować sposób działania mózgu za pomocą warstw sztucznych neuronów.

Nie musimy szczegółowo omawiać ich działania w tym artykule, ale jeśli chcesz dowiedzieć się więcej, oto naprawdę dobry film:

Teraz, gdy zdefiniowaliśmy kilka terminów powszechnie używanych w uczeniu maszynowym, porozmawiajmy o tym, co można zrobić za pomocą JavaScript i frameworka Tensorflow.js.

Cechy

Obecnie dostępne są trzy funkcje:

  1. Używając wstępnie wytrenowanego modelu,
  2. Transfer nauki,
  3. Definiowanie, uruchamianie i używanie własnego modelu.

Zacznijmy od najprostszego.

1. Korzystanie z przeszkolonego modelu

W zależności od problemu, który próbujesz rozwiązać, może istnieć model już wytrenowany z określonym zestawem danych i do określonego celu, który możesz wykorzystać i zaimportować w swoim kodzie.

Załóżmy na przykład, że budujemy witrynę internetową, aby przewidzieć, czy obraz przedstawia kota. Popularny model klasyfikacji obrazów nazywa się MobileNet i jest dostępny jako wstępnie wytrenowany model z Tensorflow.js.

Kod do tego wyglądałby mniej więcej tak:

 <html lang="en"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta http-equiv="X-UA-Compatible" content="ie=edge"> <title>Cat detection</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"> </script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/[email protected]"> </script> </head> <body> <img alt="cat laying down" src="cat.jpeg"/> <script> const img = document.getElementById('image'); const predictImage = async () => { console.log("Model loading..."); const model = await mobilenet.load(); console.log("Model is loaded!") const predictions = await model.classify(img); console.log('Predictions: ', predictions); } predictImage(); </script> </body> </html>

Zaczynamy od zaimportowania Tensorflow.js i modelu MobileNet w nagłówku naszego HTML:

 <script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.0.1/tf.js"> </script> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/[email protected]"> </script>

Następnie wewnątrz ciała mamy element obrazu, który posłuży do prognozowania:

 <img alt="cat laying down" src="cat.jpeg"/>

I wreszcie, wewnątrz tagu script , mamy kod JavaScript, który ładuje wstępnie wytrenowany model MobileNet i klasyfikuje obraz znaleziony w tagu image . Zwraca tablicę 3 prognoz, które są uporządkowane według wyniku prawdopodobieństwa (pierwszy element to najlepsza prognoza).

 const predictImage = async () => { console.log("Model loading..."); const model = await mobilenet.load(); console.log("Model is loaded!") const predictions = await model.classify(img); console.log('Predictions: ', predictions); } predictImage();

I to wszystko! W ten sposób możesz użyć wstępnie wytrenowanego modelu w przeglądarce z Tensorflow.js!

Uwaga : jeśli chcesz zobaczyć, co jeszcze może klasyfikować model MobileNet, możesz znaleźć listę różnych klas dostępnych na Github.

Ważną rzeczą, o której należy wiedzieć, jest to, że ładowanie wstępnie wytrenowanego modelu w przeglądarce może zająć trochę czasu (czasami do 10 s), więc prawdopodobnie będziesz chciał wstępnie załadować lub dostosować swój interfejs, aby nie miało to wpływu na użytkowników.

Jeśli wolisz używać Tensorflow.js jako modułu NPM, możesz to zrobić, importując moduł w ten sposób:

 import * as mobilenet from '@tensorflow-models/mobilenet';

Zapraszam do zabawy z tym przykładem na CodeSandbox.

Teraz, gdy widzieliśmy, jak używać wstępnie wytrenowanego modelu, przyjrzyjmy się drugiej dostępnej funkcji: uczeniu transferu.

2. Transfer nauki

Transfer uczenia się to możliwość połączenia wstępnie wytrenowanego modelu z niestandardowymi danymi treningowymi. Oznacza to, że możesz wykorzystać funkcjonalność modelu i dodawać własne próbki bez konieczności tworzenia wszystkiego od zera.

Na przykład algorytm został przeszkolony z tysiącami obrazów w celu utworzenia modelu klasyfikacji obrazów, a zamiast tworzenia własnego, uczenie transferu umożliwia łączenie nowych próbek niestandardowych obrazów ze wstępnie wytrenowanym modelem w celu utworzenia nowego klasyfikatora obrazu. Ta funkcja sprawia, że ​​uzyskanie bardziej spersonalizowanego klasyfikatora jest naprawdę szybkie i łatwe.

Aby dać przykład tego, jak to wyglądałoby w kodzie, zmodyfikujmy nasz poprzedni przykład i zmodyfikujmy go tak, abyśmy mogli klasyfikować nowe obrazy.

Uwaga : Efektem końcowym jest poniższy eksperyment, który możesz wypróbować na żywo tutaj.

(Demo na żywo) (Duży podgląd)

Poniżej znajduje się kilka przykładów kodu z najważniejszej części tej konfiguracji, ale jeśli chcesz rzucić okiem na cały kod, możesz go znaleźć na tej CodeSandbox.

Nadal musimy zacząć od zaimportowania Tensorflow.js i MobileNet, ale tym razem musimy również dodać klasyfikator KNN (k-nearest near):

 <!-- Load TensorFlow.js --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script> <!-- Load MobileNet --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script> <!-- Load KNN Classifier --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>

Powodem, dla którego potrzebujemy klasyfikatora, jest to, że (zamiast tylko korzystania z modułu MobileNet) dodajemy niestandardowe próbki, których nigdy wcześniej nie widziałem, więc klasyfikator KNN pozwoli nam połączyć wszystko razem i uruchomić prognozy na połączonych danych.

Następnie możemy zastąpić obraz kota tagiem video , aby wykorzystać obrazy z kamery.

 <video autoplay width="227" height="227"></video>

Na koniec musimy dodać kilka przycisków na stronie, których użyjemy jako etykiet, aby nagrać kilka próbek wideo i rozpocząć prognozy.

 <section> <button class="button">Left</button> <button class="button">Right</button> <button class="test-predictions">Test</button> </section>

Przejdźmy teraz do pliku JavaScript, w którym zaczniemy od ustawienia kilku ważnych zmiennych:

 // Number of classes to classify const NUM_CLASSES = 2; // Labels for our classes const classes = ["Left", "Right"]; // Webcam Image size. Must be 227. const IMAGE_SIZE = 227; // K value for KNN const TOPK = 10; const video = document.getElementById("webcam");

W tym konkretnym przykładzie chcemy móc sklasyfikować wejście z kamery internetowej między przechyloną głową w lewo lub w prawo, więc potrzebujemy dwóch klas oznaczonych jako left i right .

Rozmiar obrazu ustawiony na 227 to rozmiar elementu wideo w pikselach. Na podstawie przykładów Tensorflow.js ta wartość musi być ustawiona na 227, aby odpowiadała formatowi danych, z którymi trenowano model MobileNet. Aby móc klasyfikować nasze nowe dane, te ostatnie muszą mieć ten sam format.

Jeśli naprawdę potrzebujesz, aby był większy, jest to możliwe, ale będziesz musiał przekształcić i zmienić rozmiar danych przed przesłaniem ich do klasyfikatora KNN.

Następnie ustawiamy wartość K na 10. Wartość K w algorytmie KNN jest ważna, ponieważ reprezentuje liczbę instancji, które bierzemy pod uwagę przy określaniu klasy naszego nowego wejścia.

W tym przypadku wartość 10 oznacza, że ​​podczas przewidywania etykiety dla niektórych nowych danych przyjrzymy się 10 najbliższym sąsiadom z danych uczących, aby określić, jak sklasyfikować nasze nowe dane wejściowe.

Wreszcie otrzymujemy element video . Dla logiki zacznijmy od załadowania modelu i klasyfikatora:

 async load() { const knn = knnClassifier.create(); const mobilenetModule = await mobilenet.load(); console.log("model loaded"); }

Następnie przejdźmy do kanału wideo:

 navigator.mediaDevices .getUserMedia({ video: true, audio: false }) .then(stream => { video.srcObject = stream; video.width = IMAGE_SIZE; video.height = IMAGE_SIZE; });

Następnie skonfigurujmy kilka zdarzeń przycisków, aby zarejestrować nasze przykładowe dane:

 setupButtonEvents() { for (let i = 0; i < NUM_CLASSES; i++) { let button = document.getElementsByClassName("button")[i]; button.onmousedown = () => { this.training = i; this.recordSamples = true; }; button.onmouseup = () => (this.training = -1); } }

Napiszmy naszą funkcję, która pobierze próbki obrazów z kamery internetowej, przeformatuje je i połączy z modułem MobileNet:

 // Get image data from video element const image = tf.browser.fromPixels(video); let logits; // 'conv_preds' is the logits activation of MobileNet. const infer = () => this.mobilenetModule.infer(image, "conv_preds"); // Train class if one of the buttons is held down if (this.training != -1) { logits = infer(); // Add current image to classifier this.knn.addExample(logits, this.training); }

I na koniec, gdy już zebraliśmy kilka obrazów z kamery internetowej, możemy przetestować nasze przewidywania za pomocą następującego kodu:

 logits = infer(); const res = await this.knn.predictClass(logits, TOPK); const prediction = classes[res.classIndex];

I wreszcie, możesz pozbyć się danych z kamery internetowej, ponieważ już ich nie potrzebujemy:

 // Dispose image when done image.dispose(); if (logits != null) { logits.dispose(); }

Jeszcze raz, jeśli chcesz rzucić okiem na cały kod, możesz go znaleźć we wspomnianej wcześniej CodeSandbox.

3. Szkolenie modelu w przeglądarce

Ostatnią funkcją jest zdefiniowanie, przeszkolenie i uruchomienie modelu w całości w przeglądarce. Aby to zilustrować, zbudujemy klasyczny przykład rozpoznawania Irysów.

W tym celu stworzymy sieć neuronową, która może klasyfikować irysy w trzech kategoriach: Setosa, Virginica i Versicolor, w oparciu o zbiór danych o otwartym kodzie źródłowym.

Zanim zaczniemy, oto link do demonstracji na żywo, a oto CodeSandbox, jeśli chcesz poeksperymentować z pełnym kodem.

U podstaw każdego projektu uczenia maszynowego leży zbiór danych. Jednym z pierwszych kroków, które musimy wykonać, jest podzielenie tego zestawu danych na zestaw treningowy i zestaw testowy.

Powodem tego jest to, że zamierzamy użyć naszego zestawu treningowego do trenowania naszego algorytmu, a naszego zestawu testowego do sprawdzenia dokładności naszych przewidywań, aby sprawdzić, czy nasz model jest gotowy do użycia, czy też musi zostać poprawiony.

Uwaga : Aby było łatwiej, podzieliłem już zestaw szkoleniowy i zestaw testowy na dwa pliki JSON, które można znaleźć w CodeSanbox.

Zestaw uczący zawiera 130 pozycji, a zestaw testowy 14. Jeśli spojrzysz na to, jak wyglądają te dane, zobaczysz coś takiego:

 { "sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2, "species": "setosa" }

Możemy zobaczyć cztery różne cechy długości i szerokości działki i płatka, a także etykietę gatunku.

Aby móc używać tego z Tensorflow.js, musimy ukształtować te dane w formacie zrozumiałym dla frameworka, w tym przypadku dla danych treningowych będzie to [130, 4] dla 130 próbek z czterema funkcjami na irys.

 import * as trainingSet from "training.json"; import * as testSet from "testing.json"; const trainingData = tf.tensor2d( trainingSet.map(item => [ item.sepal_length, item.sepal_width, item.petal_length, item.petal_width ]), [130, 4] ); const testData = tf.tensor2d( testSet.map(item => [ item.sepal_length, item.sepal_width, item.petal_length, item.petal_width ]), [14, 4] );

Następnie musimy również ukształtować nasze dane wyjściowe:

 const output = tf.tensor2d(trainingSet.map(item => [ item.species === 'setosa' ? 1 : 0, item.species === 'virginica' ? 1 : 0, item.species === 'versicolor' ? 1 : 0 ]), [130,3])

Następnie, gdy nasze dane są gotowe, możemy przejść do tworzenia modelu:

 const model = tf.sequential(); model.add(tf.layers.dense( { inputShape: 4, activation: 'sigmoid', units: 10 } )); model.add(tf.layers.dense( { inputShape: 10, units: 3, activation: 'softmax' } ));

W powyższym przykładzie kodu zaczynamy od utworzenia instancji modelu sekwencyjnego, dodając warstwę wejściową i wyjściową.

Parametry, które możesz zobaczyć w środku ( inputShape , activation i units ) są poza zakresem tego postu, ponieważ mogą się różnić w zależności od tworzonego modelu, typu używanych danych i tak dalej.

Gdy nasz model jest gotowy, możemy go trenować za pomocą naszych danych:

 async function train_data(){ for(let i=0;i<15;i++){ const res = await model.fit(trainingData, outputData,{epochs: 40}); } } async function main() { await train_data(); model.predict(testSet).print(); }

Jeśli to działa dobrze, możesz zacząć zastępować dane testowe niestandardowymi danymi wejściowymi użytkownika.

Gdy wywołamy naszą główną funkcję, wynik prognozy będzie wyglądał jak jedna z tych trzech opcji:

 [1,0,0] // Setosa [0,1,0] // Virginica [0,0,1] // Versicolor

Przewidywanie zwraca tablicę trzech liczb reprezentujących prawdopodobieństwo danych należących do jednej z trzech klas. Liczba najbliższa 1 to najwyższa prognoza.

Na przykład, jeśli wynik klasyfikacji to [0.0002, 0.9494, 0.0503] , drugi element tablicy jest najwyższy, więc model przewidział, że nowym wejściem prawdopodobnie będzie Virginica.

I to tyle w przypadku prostej sieci neuronowej w Tensorflow.js!

Rozmawialiśmy tylko o małym zestawie danych Irysów, ale jeśli chcesz przejść do większych zestawów danych lub pracować z obrazami, kroki będą takie same:

  • Zbieranie danych;
  • Podział na zestaw szkoleniowy i testowy;
  • Ponowne formatowanie danych, aby Tensorflow.js mógł je zrozumieć;
  • Wybór algorytmu;
  • Dopasowanie danych;
  • Przewidywanie.

Jeśli chcesz zapisać utworzony model, aby móc go załadować w innej aplikacji i przewidzieć nowe dane, możesz to zrobić za pomocą następującego wiersza:

 await model.save('file:///path/to/my-model'); // in Node.js

Uwaga : aby uzyskać więcej opcji zapisywania modelu, zajrzyj do tego zasobu.

Limity

Otóż ​​to! Właśnie omówiliśmy trzy główne funkcje dostępne obecnie za pomocą Tensorflow.js!

Myślę, że zanim skończymy, warto krótko wspomnieć o niektórych ograniczeniach korzystania z uczenia maszynowego we frontendzie.

1. Wydajność

Importowanie wstępnie wytrenowanego modelu ze źródła zewnętrznego może mieć wpływ na wydajność aplikacji. Na przykład niektóre modele wykrywania obiektów mają więcej niż 10 MB, co znacznie spowolni Twoją witrynę. Zastanów się nad wrażeniami użytkownika i zoptymalizuj ładowanie zasobów, aby poprawić postrzeganą skuteczność.

2. Jakość danych wejściowych

Jeśli zbudujesz model od podstaw, będziesz musiał zebrać własne dane lub znaleźć jakiś zbiór danych o otwartym kodzie źródłowym.

Zanim zaczniesz przetwarzać dane lub spróbujesz różnych algorytmów, sprawdź jakość danych wejściowych. Na przykład, jeśli próbujesz zbudować model analizy sentymentu do rozpoznawania emocji w fragmentach tekstu, upewnij się, że dane używane do trenowania modelu są dokładne i zróżnicowane. Jeśli jakość wykorzystywanych danych jest niska, wyniki treningu będą bezużyteczne.

3. Odpowiedzialność

Korzystanie z wstępnie wytrenowanego modelu typu open source może być bardzo szybkie i bezproblemowe. Oznacza to jednak również, że nie zawsze wiesz, w jaki sposób został on wygenerowany, z czego został utworzony zbiór danych, a nawet jakiego algorytmu użyto. Niektóre modele nazywane są „czarnymi skrzynkami”, co oznacza, że ​​tak naprawdę nie wiesz, w jaki sposób przewidziały określone wyniki.

W zależności od tego, co próbujesz zbudować, może to stanowić problem. Na przykład, jeśli używasz modelu uczenia maszynowego, aby pomóc wykryć prawdopodobieństwo zachorowania na raka na podstawie obrazów skanu, w przypadku wyniku fałszywie negatywnego (model przewidywał, że dana osoba nie miała raka, kiedy faktycznie miała), tam może być realną odpowiedzialnością prawną i musiałbyś być w stanie wyjaśnić, dlaczego model dokonał określonej prognozy.

Streszczenie

Podsumowując, korzystanie z JavaScript i frameworków, takich jak Tensorflow.js, to świetny sposób na rozpoczęcie pracy i dowiedzenie się więcej o uczeniu maszynowym. Mimo że aplikacja gotowa do produkcji powinna być prawdopodobnie zbudowana w języku takim jak Python, JavaScript sprawia, że ​​jest ona naprawdę dostępna dla programistów, którzy mogą bawić się różnymi funkcjami i lepiej zrozumieć podstawowe pojęcia, zanim w końcu przejdą dalej i zainwestują czas w naukę innego język.

W tym samouczku omówiliśmy tylko to, co było możliwe przy użyciu Tensorflow.js, jednak ekosystem innych bibliotek i narzędzi rośnie. Dostępne są również bardziej szczegółowe frameworki, które pozwalają eksplorować korzystanie z uczenia maszynowego z innymi domenami, takimi jak muzyka z Magenta.js lub przewidywanie nawigacji użytkownika w witrynie za pomocą guess.js!

W miarę jak narzędzia stają się coraz wydajniejsze, możliwości tworzenia aplikacji obsługujących uczenie maszynowe w języku JavaScript będą prawdopodobnie coraz bardziej ekscytujące, a teraz jest dobry czas, aby dowiedzieć się więcej na ten temat, ponieważ społeczność stara się, aby były dostępne.

Dalsze zasoby

Jeśli chcesz dowiedzieć się więcej, oto kilka zasobów:

Inne struktury i narzędzia

  • ml5.js
  • ml.js
  • mózg.js
  • Keras.js
  • PoseNet
  • Plac zabaw Tensorflow

Przykłady, modele i zbiory danych

  • Modele Tensorflow.js
  • Przykłady Tensorflow.js
  • Zbiory danych

Inspiracja

  • Maszyna z możliwością uczenia
  • Eksperymenty AI
  • AIJS.rocks
  • Kreatywność

Dziękuje za przeczytanie!