Androidアプリケーションの構築
オンデバイストレーニング: Androidアプリケーションの構築
Section titled “オンデバイストレーニング: Androidアプリケーションの構築”このチュートリアルでは、ONNX Runtimeのオンデバイストレーニングソリューションを組み込んだAndroidアプリケーションを構築する方法を探ります。オンデバイストレーニングとは、クラウドサービスや外部サーバーに依存せず、エッジデバイス上で直接機械学習モデルをトレーニングするプロセスを指します。
このチュートリアルの最後でアプリケーションは次のようになります:
オンデバイストレーニング技術を使用して、簡単な画像分類モデルをトレーニングできるAndroidアプリを作成する手順を案内します。このチュートリアルでは、転移学習技術を紹介し、1つのタスクでトレーニングしたモデルが別の関連するタスクでパフォーマンスを向上させる方法を示します。学習プロセスをゼロから開始するのではなく、転移学習により、事前トレーニングされたモデルが学習した知識や特徴を新しいタスクに転移させることができます。
このチュートリアルでは、MobileNetV2モデルを活用します。このモデルは、ImageNet(1,000クラスを持つ)などの大規模画像データセットでトレーニングされています。このモデルを使用して、カスタムデータを4つのクラスのうちの1つに分類します。MobileNetV2の初期層は特徴抽出器として機能し、さまざまなタスクに適用可能な汎用的な視覚特徴をキャプチャします。分類器の最後の層のみが現在のタスクに対してトレーニングされます。
このチュートリアルでは、以下のデータを学習します:
- 事前パッケージされた動物データセットを使用して、動物を4つのカテゴリのいずれかに分類する
- カスタムの有名人データセットを使用して、有名人を4つのカテゴリのいずれかに分類する
- TOCプレースホルダー
このチュートリアルに従うには、JavaまたはKotlinを使用したAndroidアプリ開発の基本的な理解が必要です。C++とニューラルネットワークや画像分類などの機械学習概念についての知識も役立ちます。
- トレーニングアーティファクトを準備するためのPython開発環境
- Android Studio 4.1+
- Android SDK 29+
- Android NDK r21+
- カメラ付きのAndroidデバイス(開発者モードでUSBデバッグが有効)
注 Androidアプリケーション全体も
onnxruntime-training-examplesGitHubリポジトリで利用できます。
オフラインフェーズ - トレーニングアーティファクトの構築
Section titled “オフラインフェーズ - トレーニングアーティファクトの構築”-
事前トレーニングされたPyTorchモデルから開始し、ONNXにエクスポートします。
MobileNetV2モデルはimagenetデータセットで事前トレーニングされており、1000カテゴリのデータを持っています。画像分類タスクでは、画像を4クラスに分類するだけでよいため、モデルの最後の層を1,000ではなく4つのロジットを出力するように変更します。PyTorchモデルをONNXにエクスポートする方法の詳細についてはこちらを参照してください。
import torchimport torchvisionmodel = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)# オリジナルモデルはimagenetでトレーニングされており、1000クラスを持っています。# 画像分類シナリオでは、4つのカテゴリに分類する必要があります。# モデルの最後の層を4つの出力を持つように変更します。model.classifier[1] = torch.nn.Linear(1280, 4)# モデルをONNXにエクスポートします。model_name = "mobilenetv2"torch.onnx.export(model, torch.randn(1, 3, 224, 224),f"training_artifacts/{model_name}.onnx",input_names=["input"], output_names=["output"],dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}) -
トレーニング可能なパラメータとトレーニング不可能なパラメータを定義する
import onnx# onnxモデルを読み込みます。onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")# 勾配を計算する必要があるパラメータ(トレーニング可能なパラメータ)と# 必要ないパラメータ(凍結/トレーニング不可能なパラメータ)を定義します。requires_grad = ["classifier.1.weight", "classifier.1.bias"]frozen_params = [param.namefor param in onnx_model.graph.initializerif param.name not in requires_grad] -
このチュートリアルでは、
CrossEntropyLoss損失とAdamWオプティマイザーを使用します。アーティファクト生成の詳細についてはこちらを参照してください。from onnxruntime.training import artifacts# トレーニングアーティファクトを生成します。artifacts.generate_artifacts(onnx_model,requires_grad=requires_grad,frozen_params=frozen_params,loss=artifacts.LossType.CrossEntropyLoss,optimizer=artifacts.OptimType.AdamW,artifact_directory="training_artifacts")これでオフラインフェーズが終了しました。これらのアーティファクトは、トレーニングのためにAndroidデバイスにデプロイする準備ができています。
トレーニングフェーズ - Androidアプリケーション開発
Section titled “トレーニングフェーズ - Androidアプリケーション開発”-
a. Android Studioを開き、
New Projectをクリックします。
b.
Native C++->Nextをクリックします。New Projectの詳細を次のように入力します:- 名前 -
ORT Personalize - パッケージ名 -
com.example.ortpersonalize - 言語 -
Kotlin
Nextをクリックします。
c.
C++17ツールチェーンを選択 ->Finish
d. これで完了です!Android Studioプロジェクトがセットアップされました。現在、Android Studioエディターにボイラープレートコードが表示されているはずです。
- 名前 -
-
a. Android Studioプロジェクトのcppディレクトリの下に
libとinclude\onnxruntimeという2つの新しいフォルダーを作成します。
b. Maven Centralにアクセスします。
Versions->Browse->でonnxruntime-training-androidアーカイブパッケージ(aarファイル)をダウンロードします。c.
aar拡張子をzipに変更します。onnxruntime-training-android-1.15.0.aarをonnxruntime-training-android-1.15.0.zipに変更します。d. zipファイルの内容を抽出します。
e.
jni\arm64-v8aフォルダーからlibonnxruntime.so共有ライブラリを、新しく作成したlibフォルダーのAndroidプロジェクトにコピーします。f.
headersフォルダーの内容を新しく作成したinclude\onnxruntimeフォルダーにコピーします。g.
native-lib.cppファイルでトレーニングcxxヘッダーファイルを含めます。#include "onnxruntime_training_cxx_api.h"h.
build.gradle (Module)ファイルにabiFiltersを追加してarm64-v8aを選択します。この設定はbuild.gradleのdefaultConfigの下に追加する必要があります:ndk {abiFilters 'arm64-v8a'}build.gradleファイルのdefaultConfigセクションは次のようになります:defaultConfig {applicationId "com.example.ortpersonalize"minSdk 29targetSdk 33versionCode 1versionName "1.0"testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"externalNativeBuild {cmake {cppFlags '-std=c++17'}}ndk {abiFilters 'arm64-v8a'}}i.
cmakeが共有ライブラリを見つけてビルドできるように、CMakeLists.txtにonnxruntime共有ライブラリを追加します。これを行うには、CMakeLists.txtでortpersonalizeライブラリが追加された後に次の行を追加します:Terminal window add_library(onnxruntime SHARED IMPORTED)set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)上記の2行の直後に、ONNX Runtimeヘッダーファイルがどこにあるかを
CMakeに知らせます:Terminal window target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)Android C++プロジェクトを
onnxruntimeライブラリに対してリンクします:Terminal window target_link_libraries( # Specifies the target library.ortpersonalize# Links the target library to the log library# included in the NDK.${log-lib}onnxruntime)CMakeLists.txtファイルは次のようになります:project("ortpersonalize")add_library( # Sets the name of the library.ortpersonalize# Sets the library as a shared library.SHARED# Provides a relative path to your source file(s).native-lib.cpputils.cppinference.cpptrain.cpp)add_library(onnxruntime SHARED IMPORTED)set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)find_library( # Sets the name of the path variable.log-lib# Specifies the name of the NDK library that# you want CMake to locate.log)target_link_libraries( # Specifies the target library.ortpersonalize# Links the target library to the log library# included in the NDK.${log-lib}onnxruntime)j. アプリケーションをビルドし、アプリがONNX Runtimeヘッダーを正常にインクルードし、共有onnxruntimeライブラリに対してリンクできることを確認して成功するのを待ちます。
-
事前ビルドトレーニングアーティファクトとデータセットのパッケージング
a. Android Studioプロジェクトの
appの下の左ペインから新しいassetsフォルダーを作成します(右クリック app -> New -> Folder -> Assets Folderを選択し、mainの下に配置)。b. ステップ2で生成されたトレーニングアーティファクトをこのフォルダーにコピーします。
c. 次に、
onnxruntime-training-examplesリポジトリにアクセスし、マシンにデータセット(images.zip)をダウンロードして抽出します。このデータセットは、KaggleでCorrado Alessioによって作成された元のanimals-10データセットから変更されました。d. ダウンロードした
imagesフォルダーをAndroid Studioのassets/imagesディレクトリにコピーします。プロジェクトの左ペインは次のようになります:

-
ONNX Runtimeとのインターフェース - C++コード
a. アプリケーションから呼び出される4つの関数をC++で実装します:
createSession: アプリケーション起動時に呼び出されます。新しいCheckpointStateとTrainingSessionオブジェクトを作成します。releaseSession: アプリケーションが閉じようとしているときに呼び出されます。この関数は、開始時に割り当てられたリソースを解放します。performTraining: UIのTrainボタンをクリックしたときに呼び出されます。performInference: UIのInferボタンをクリックしたときに呼び出されます。
b. セッションを作成
この関数は、アプリケーションが起動されたときに呼び出されます。この関数は、トレーニングアーティファクトアセットを使用して、C++ CheckpointStateとTrainingSessionオブジェクトを作成します。これらのオブジェクトは、デバイス上でモデルをトレーニングするために使用されます。
createSessionへの引数は:checkpoint_path: チェックポイントアーティファクトへのキャッシュされたパス。train_model_path: トレーニングモデルアーティファクトへのキャッシュされたパス。eval_model_path: evalモデルアーティファクトへのキャッシュされたパス。optimizer_model_path: オプティマイザーモデルアーティファクトへのキャッシュされたパス。cache_dir_path: Androidデバイス上のキャッシュディレクトリへのパス。キャッシュディレクトリは、C++コードからトレーニングアーティファクトにアクセスするための方法として使用されます。
この関数は、
session_cacheオブジェクトへのポインタを表すlongを返します。このlongは、SessionCacheにアクセスする必要があるときはいつでもSessionCacheにキャストできます。extern "C" JNIEXPORT jlong JNICALLJava_com_example_ortpersonalize_MainActivity_createSession(JNIEnv *env, jobject /* this */,jstring checkpoint_path, jstring train_model_path, jstring eval_model_path,jstring optimizer_model_path, jstring cache_dir_path){std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>(utils::JString2String(env, checkpoint_path),utils::JString2String(env, train_model_path),utils::JString2String(env, eval_model_path),utils::JString2String(env, optimizer_model_path),utils::JString2String(env, cache_dir_path));return reinterpret_cast<long>(session_cache.release());}上記の関数本体から見られるように、この関数は
SessionCacheクラスのオブジェクトへのユニークポインタを作成します。SessionCacheの定義は以下に示されています。struct SessionCache {ArtifactPaths artifact_paths;Ort::Env ort_env;Ort::SessionOptions session_options;Ort::CheckpointState checkpoint_state;Ort::TrainingSession training_session;Ort::Session* inference_session;SessionCache(const std::string &checkpoint_path, const std::string &training_model_path,const std::string &eval_model_path, const std::string &optimizer_model_path,const std::string& cache_dir_path) :artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path),ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(),checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())),training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(),artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()),inference_session(nullptr) {}};ArtifactPathsの定義は:struct ArtifactPaths {std::string checkpoint_path;std::string training_model_path;std::string eval_model_path;std::string optimizer_model_path;std::string cache_dir_path;std::string inference_model_path;ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path,const std::string &eval_model_path, const std::string &optimizer_model_path,const std::string& cache_dir_path) :checkpoint_path(checkpoint_path), training_model_path(training_model_path),eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path),cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {}};c. セッションを解放
この関数は、アプリケーションがシャットダウンしようとしているときに呼び出されます。主要に起動時に作成されたCheckpointStateとTrainingSessionを解放します。
releaseSessionへの引数は:session:SessionCacheオブジェクトを表すlong。
extern "C" JNIEXPORT void JNICALLJava_com_example_ortpersonalize_MainActivity_releaseSession(JNIEnv *env, jobject /* this */,jlong session) {auto *session_cache = reinterpret_cast<SessionCache *>(session);delete session_cache->inference_session;delete session_cache;}d. トレーニングを実行
この関数は、トレーニングする必要がある各バッチに対して呼び出されます。トレーニングループはアプリケーション側でKotlinで記述されており、トレーニングループ内で各バッチに対して
performTraining関数が呼び出されます。performTrainingへの引数は:session:SessionCacheオブジェクトを表すlong。batch: トレーニングのために渡されるfloat配列としての入力画像。labels: トレーニングのために提供される入力画像に関連付けられたint配列としてのラベル。batch_size: 各TrainStepで処理する画像の数。channels: 画像のチャネル数。この例では、常に値3で呼び出されます。frame_rows: 画像の行数。この例では、常に値224で呼び出されます。frame_cols: 画像の列数。この例では、常に値224で呼び出されます。
この関数は、このバッチのトレーニング損失を表す
floatを返します。extern "C"JNIEXPORT float JNICALLJava_com_example_ortpersonalize_MainActivity_performTraining(JNIEnv *env, jobject /* this */,jlong session, jfloatArray batch, jintArray labels, jint batch_size,jint channels, jint frame_rows, jint frame_cols) {auto* session_cache = reinterpret_cast<SessionCache *>(session);if (session_cache->inference_session) {// train_stepでモデルパラメータを更新するため、// 推論セッションを無効化します。// 次の推論セッション呼び出しでは、推論セッションを再作成する必要があります。delete session_cache->inference_session;session_cache->inference_session = nullptr;}// この入力バッチを使用してモデルパラメータを更新します。return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr),env->GetIntArrayElements(labels, nullptr), batch_size,channels, frame_rows, frame_cols);}上記の関数は
train_step関数を活用します。train_step関数の定義は以下の通りです:namespace training {float train_step(SessionCache* session_cache, float *batches, int32_t *labels,int64_t batch_size, int64_t image_channels, int64_t image_rows,int64_t image_cols) {const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});const std::vector<int64_t> labels_shape({batch_size});Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);std::vector<Ort::Value> user_inputs; // {inputs, labels}// 入力をバッチ化user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches,batch_size * image_channels * image_rows * image_cols * sizeof(float),input_shape.data(), input_shape.size(),ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));// ラベルをバッチ化user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels,batch_size * sizeof(int32_t),labels_shape.data(), labels_shape.size(),ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32));// トレーニングステップを実行し、前進 + 損失 + 後退を実行します。float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>());// 上記で計算された勾配の方向にステップを踏んでモデルパラメータを更新します。session_cache->training_session.OptimizerStep();// パラメータが更新されたら勾配をリセットします。// 次回の入力ラウンドで新しい勾配を計算できます。session_cache->training_session.LazyResetGrad();return loss;}} // namespace traininge. 推論を実行
この関数は、ユーザーが推論を実行したいときに呼び出されます。
performInferenceへの引数は:session:SessionCacheオブジェクトを表すlong。image_buffer: トレーニングのために渡されるfloat配列としての入力画像。batch_size: 各推論で処理する画像の数。この例では、常に値1で呼び出されます。image_channels: 画像のチャネル数。この例では、常に値3で呼び出されます。image_rows: 画像の行数。この例では、常に値224で呼び出されます。image_cols: 画像の列数。この例では、常に値224で呼び出されます。classes: 4つのカスタムクラスを表す文字列のリスト。
この関数は、提供された4つのカスタムクラスのうちの1つを表す
stringを返します。これはモデルの予測です。extern "C"JNIEXPORT jstring JNICALLJava_com_example_ortpersonalize_MainActivity_performInference(JNIEnv *env, jobject /* this */,jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows,jint image_cols, jobjectArray classes) {std::vector<std::string> classes_str;for (int i = 0; i < env->GetArrayLength(classes); ++i) {// 現在の文字列要素にアクセスjstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i));classes_str.push_back(utils::JString2String(env, elem));}auto* session_cache = reinterpret_cast<SessionCache *>(session);if (!session_cache->inference_session) {// 推論セッションが存在しないので、新しいものを作成します。session_cache->training_session.ExportModelForInferencing(session_cache->artifact_paths.inference_model_path.c_str(), {"output"});session_cache->inference_session = std::make_unique<Ort::Session>(session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(),session_cache->session_options).release();}auto prediction = inference::classify(session_cache, env->GetFloatArrayElements(image_buffer, nullptr),batch_size, image_channels, image_rows, image_cols, classes_str);return env->NewStringUTF(prediction.first.c_str());}上記の関数は
classifyを呼び出します。classifyの定義は:namespace inference {std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data,int64_t batch_size, int64_t image_channels,int64_t image_rows, int64_t image_cols,const std::vector<std::string>& classes) {std::vector<const char *> input_names = {"input"};size_t input_count = 1;std::vector<const char *> output_names = {"output"};size_t output_count = 1;std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);std::vector<Ort::Value> input_values; // {input images}input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data,batch_size * image_channels * image_rows * image_cols * sizeof(float),input_shape.data(), input_shape.size(),ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));std::vector<Ort::Value> output_values;output_values.emplace_back(nullptr);// ロジットを取得session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(),input_count, output_names.data(), output_values.data(), output_count);float *output = output_values.front().GetTensorMutableData<float>();// softmaxを実行し、各クラスの確率を取得std::vector<float> probabilities = Softmax(output, classes.size());size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end()));return {classes[best_index], probabilities[best_index]};}} // namespace inferenceclassify関数は
Softmaxという別の関数を呼び出します。Softmaxの定義は:std::vector<float> Softmax(float *logits, size_t num_logits) {std::vector<float> probabilities(num_logits, 0);float sum = 0;for (size_t i = 0; i < num_logits; ++i) {probabilities[i] = exp(logits[i]);sum += probabilities[i];}if (sum != 0.0f) {for (size_t i = 0; i < num_logits; ++i) {probabilities[i] /= sum;}}return probabilities;} -
a.
MobileNetV2モデルは、提供される入力画像が3 x 224 x 224のサイズであること- 平均
(0.485, 0.456, 0.406)が減算され、標準偏差(0.229, 0.224, 0.225)で除算された正規化された画像であること
この前処理は、Android提供のライブラリを使用してJava/Kotlinで行われます。
app/src/main/java/com/example/ortpersonalizeディレクトリの下にImageProcessingUtil.ktという新しいファイルを作成します。このファイルに画像のトリミング、リサイズ、正規化のユーティリティメソッドを追加します。b. 画像のトリミングとリサイズ。
fun processBitmap(bitmap: Bitmap) : Bitmap {// この関数は、指定されたビットマップを処理します// - より長い次元に沿ってトリミングして正方形ビットマップを取得// 幅が高さより大きい場合// ___+_________________+___// | + + |// | + + |// | + + + |// | + + |// |__+_________________+__|// <-------- width -------->// <----- height ---->// <--> cropped <-->//// 高さが幅より大きい場合// _________________________ ʌ ʌ// | | | cropped// |+++++++++++++++++++++++| | ʌ v// | | | |// | | | |// | + | height width// | | | |// | | | |// |+++++++++++++++++++++++| | v ʌ// | | | cropped// |_______________________| v v//////// - トリミングされた正方形画像をmobilenetv2モデルが必要とするサイズ(3 x 224 x 224)にリサイズlateinit var bitmapCropped: Bitmapif (bitmap.getWidth() >= bitmap.getHeight()) {// 高さが幅より小さいため、長さが高さである正方形をトリミング// 幅次元に沿ってトリミングが発生val width: Int = bitmap.getHeight()val height: Int = bitmap.getHeight()// トリミング画像の左側は、幅の両側で中心の等分部分を含むように// (bitmap.getWidth() / 2 - bitmap.getHeight() / 2)から始まる必要があります// 高さ次元に沿ってトリミングしないため、上側は0から始まりますval x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2val y: Int = 0bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)} else {// 幅が高さより小さいため、長さが幅である正方形をトリミング// 高さ次元に沿ってトリミングが発生val width: Int = bitmap.getWidth()val height: Int = bitmap.getWidth()// 幅次元に沿ってトリミングしないため、左側は0から始まります// 上側は、中心の両側で高さの等分部分を含むように// (bitmap.getHeight() / 2 - bitmap.getWidth() / 2)から始まる必要がありますval x: Int = 0val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)}// 画像をmobilenetv2モデルが必要とするチャネルx幅x高さにリサイズval width: Int = 224val height: Int = 224val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false)return bitmapResized}c. 画像の正規化。
fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) {// この関数は画像ピクセルを反復し、画像に以下を実行します// - ピクセル値を0から1の間に正規化// - 平均(0.485, 0.456, 0.406)をピクセル値から減算(movilenetv2モデル構成から派生)// 標準偏差(0.229, 0.224, 0.225)でピクセル値を除算(movilenetv2モデル構成から派生)// 値は指定されたオフセットから始まる指定されたバッファに書き込まれます。// 値は次のように書き込まれます// |____|____________________|__________________| <--- buffer// ʌ <--- offset// ʌ <--- offset + width * height * channels// |____|rrrrrr|_____________|__________________| <--- red channel read in column major order// |____|______|gggggg|______|__________________| <--- green channel read in column major order// |____|______|______|bbbbbb|__________________| <--- blue channel read in column major orderval width: Int = bitmap.getWidth()val height: Int = bitmap.getHeight()val stride: Int = width * heightfor (x in 0 until width) {for (y in 0 until height) {val color: Int = bitmap.getPixel(y, x)val index = offset + (x * height + y)// 平均を減算し、標準偏差で除算// movilenetv2モデルの値を使用buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f)buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f)buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f)}}}d. UriからBitmapを取得
fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap {// この関数は、指定されたuriの画像ファイルを読み取り、ビットマップにデコードしますval source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri)return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true)} -
a. このチュートリアルでは、次のユーザーインターフェース要素を使用します:
- TrainボタンとInferボタン
- クラスボタン
- ステータスメッセージテキスト
- 画像表示
- プログレスダイアログ
b. このチュートリアルは、グラフィカルユーザーインターフェースの作成方法を示すことを意図していません。このため、GitHubで利用可能なファイルを使用します。
c.
strings.xmlからすべての文字列定義をAndroid Studioのローカルstrings.xmlにコピーします。d.
activity_main.xmlの内容をAndroid Studioのローカルactivity_main.xmlにコピーします。e.
layoutフォルダの下にdialog.xmlという新しいファイルを作成します。dialog.xmlの内容をAndroid Studioのローカルdialog.xmlにコピーします。f. このセクションの残りの変更は、MainActivity.ktファイルで行う必要があります。
g. アプリケーションの起動
アプリケーションが起動すると、
onCreate関数が呼び出されます。この関数は、セッションキャッシュとユーザーインターフェースハンドラーをセットアップします。MainActivity.ktファイルのonCreate関数を参照してください。h. カスタムクラスボタンハンドラー - クラスボタンを使用して、ユーザーがトレーニング用のカスタム画像を選択できるようにします。これを行うリスナーを追加する必要があります。これらのリスナーは正確に行います。
MainActivity.ktのこれらのボタンハンドラーを参照してください:- onClassAClickedListener
- onClassBClickedListener
- onClassXClickedListener
- onClassYClickedListener
i. カスタムクラスラベルをパーソナライズ
デフォルトでは、カスタムクラスラベルは
[A, B, X, Y]です。しかし、明確にするために、これらのラベルをユーザーが名前変更できるようにします。これはMainActivity.ktで定義されたロングクリックリスナーによって実現されます:- onClassALongClickedListener
- onClassBLongClickedListener
- onClassXLongClickedListener
- onClassYLongClickedListener
j. カスタムクラスの切り替え
カスタムクラススイッチがオフになると、事前パッケージされた動物データセットが実行されます。オンになると、トレーニング用に独自のデータセットを提供することが期待されます。この遷移を処理するために、
MainActivity.ktのonCustomClassSettingChangedListenerスイッチハンドラーが実装されます。k. トレーニングハンドラー
各クラスに少なくとも1つの画像がある場合、
Trainボタンを有効にできます。Trainボタンがクリックされると、トレーニングが選択された画像で開始されます。トレーニングハンドラーは以下の責任を負います:- トレーニング画像を1つのコンテナに収集
- 画像の順序をシャッフル
- 画像のトリミングとリサイズ
- 画像の正規化
- 画像のバッチ化
- トレーニングループの実行(ループ内でC++の
performTraining関数を呼び出し)
MainActivity.ktで定義されたonTrainButtonClickedListener関数がこれを行います。l. 推論ハンドラー
トレーニングが完了したら、ユーザーは任意の画像を推論するために
Inferボタンをクリックできます。推論ハンドラーは以下の責任を負います- 推論画像の収集
- 画像のトリミングとリサイズ
- 画像の正規化
- C++の
performInference関数の呼び出し - ユーザーインターフェースへの推論出力のレポート
これは
MainActivity.ktのonInferenceButtonClickedListener関数によって実現されます。m. 上記のすべてのアクティビティのハンドラー
推論またはカスタムクラスの画像が選択されると、それらを処理する必要があります。
MainActivity.ktで定義されたonActivityResult関数がそれを行います。n. 最後のこと。カメラを使用するために、
AndroidManifest.xmlファイルに以下を追加します:<uses-permission android:name="android.permission.CAMERA" /><uses-feature android:name="android.hardware.camera" />
トレーニングフェーズ - デバイスでのアプリケーション実行
Section titled “トレーニングフェーズ - デバイスでのアプリケーション実行”-
a. Androidデバイスをマシンに接続し、デバイスでアプリケーションを実行しましょう。
b. デバイスでのアプリケーション起動は次のようになります:
-
a. デバイスでアプリケーションを起動して、事前ロード動物デバイスでのトレーニングを開始しましょう。
b. 下部の
Custom classesスイッチを切り替えます。c. クラスラベルが
Dog、Cat、Elephant、Cowに変更されます。d.
Trainingを実行し、トレーニング完了時にプログレスダイアログが消えるのを待ちます。e. ライブラリから任意の動物画像を使用して推論を実行します。
上記の画像からわかるように、モデルは正しく
Cowを予測しました。 -
a. ウェブからTom Cruise、Leonardo DiCaprio、Ryan Reynolds、Brad Pittの画像をダウンロードします。
b. アプリを閉じて再起動して、新しいセッションを確実に起動してください。
c. アプリケーションが起動した後、ロングクリックを使用して4つのクラスをそれぞれ
Tom、Leo、Ryan、Bradに名前変更します。d. 各クラスのボタンをクリックし、その有名人に関連する画像を選択します。各カテゴリあたり10〜15枚の画像を使用できます。
e.
Trainボタンを押して、提供されたデータから学習するようにアプリケーションに指示します。f. トレーニングが完了したら、まだ見ていない画像を提供するために
Inferボタンを押します。g. それだけです!うまくいけば、アプリケーションは画像を正しく分類したはずです。
おめでとうございます!ONNX Runtimeを使用してデバイス上で画像を分類することを学習するAndroidアプリケーションを正常に構築しました。アプリケーションはGitHubのonnxruntime-training-examplesでも利用できます。