コンテンツにスキップ

Androidアプリケーションの構築

オンデバイストレーニング: Androidアプリケーションの構築

Section titled “オンデバイストレーニング: Androidアプリケーションの構築”

このチュートリアルでは、ONNX Runtimeのオンデバイストレーニングソリューションを組み込んだAndroidアプリケーションを構築する方法を探ります。オンデバイストレーニングとは、クラウドサービスや外部サーバーに依存せず、エッジデバイス上で直接機械学習モデルをトレーニングするプロセスを指します。

このチュートリアルの最後でアプリケーションは次のようになります:

トム・クルーズが中央にいる画像分類アプリ

オンデバイストレーニング技術を使用して、簡単な画像分類モデルをトレーニングできるAndroidアプリを作成する手順を案内します。このチュートリアルでは、転移学習技術を紹介し、1つのタスクでトレーニングしたモデルが別の関連するタスクでパフォーマンスを向上させる方法を示します。学習プロセスをゼロから開始するのではなく、転移学習により、事前トレーニングされたモデルが学習した知識や特徴を新しいタスクに転移させることができます。

このチュートリアルでは、MobileNetV2モデルを活用します。このモデルは、ImageNet(1,000クラスを持つ)などの大規模画像データセットでトレーニングされています。このモデルを使用して、カスタムデータを4つのクラスのうちの1つに分類します。MobileNetV2の初期層は特徴抽出器として機能し、さまざまなタスクに適用可能な汎用的な視覚特徴をキャプチャします。分類器の最後の層のみが現在のタスクに対してトレーニングされます。

このチュートリアルでは、以下のデータを学習します:

  • 事前パッケージされた動物データセットを使用して、動物を4つのカテゴリのいずれかに分類する
  • カスタムの有名人データセットを使用して、有名人を4つのカテゴリのいずれかに分類する
  • TOCプレースホルダー

このチュートリアルに従うには、JavaまたはKotlinを使用したAndroidアプリ開発の基本的な理解が必要です。C++とニューラルネットワークや画像分類などの機械学習概念についての知識も役立ちます。

  • トレーニングアーティファクトを準備するためのPython開発環境
  • Android Studio 4.1+
  • Android SDK 29+
  • Android NDK r21+
  • カメラ付きのAndroidデバイス(開発者モードでUSBデバッグが有効)

Androidアプリケーション全体もonnxruntime-training-examples GitHubリポジトリで利用できます。

オフラインフェーズ - トレーニングアーティファクトの構築

Section titled “オフラインフェーズ - トレーニングアーティファクトの構築”
  1. モデルをONNXにエクスポートする

    事前トレーニングされたPyTorchモデルから開始し、ONNXにエクスポートします。MobileNetV2モデルはimagenetデータセットで事前トレーニングされており、1000カテゴリのデータを持っています。画像分類タスクでは、画像を4クラスに分類するだけでよいため、モデルの最後の層を1,000ではなく4つのロジットを出力するように変更します。

    PyTorchモデルをONNXにエクスポートする方法の詳細についてはこちらを参照してください。

    import torch
    import torchvision
    model = torchvision.models.mobilenet_v2(
    weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
    # オリジナルモデルはimagenetでトレーニングされており、1000クラスを持っています。
    # 画像分類シナリオでは、4つのカテゴリに分類する必要があります。
    # モデルの最後の層を4つの出力を持つように変更します。
    model.classifier[1] = torch.nn.Linear(1280, 4)
    # モデルをONNXにエクスポートします。
    model_name = "mobilenetv2"
    torch.onnx.export(model, torch.randn(1, 3, 224, 224),
    f"training_artifacts/{model_name}.onnx",
    input_names=["input"], output_names=["output"],
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
  2. トレーニング可能なパラメータとトレーニング不可能なパラメータを定義する

    import onnx
    # onnxモデルを読み込みます。
    onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
    # 勾配を計算する必要があるパラメータ(トレーニング可能なパラメータ)と
    # 必要ないパラメータ(凍結/トレーニング不可能なパラメータ)を定義します。
    requires_grad = ["classifier.1.weight", "classifier.1.bias"]
    frozen_params = [
    param.name
    for param in onnx_model.graph.initializer
    if param.name not in requires_grad
    ]
  3. トレーニングアーティファクトを生成する

    このチュートリアルでは、CrossEntropyLoss損失とAdamWオプティマイザーを使用します。アーティファクト生成の詳細についてはこちらを参照してください。

    from onnxruntime.training import artifacts
    # トレーニングアーティファクトを生成します。
    artifacts.generate_artifacts(
    onnx_model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    loss=artifacts.LossType.CrossEntropyLoss,
    optimizer=artifacts.OptimType.AdamW,
    artifact_directory="training_artifacts"
    )

    これでオフラインフェーズが終了しました。これらのアーティファクトは、トレーニングのためにAndroidデバイスにデプロイする準備ができています。

トレーニングフェーズ - Androidアプリケーション開発

Section titled “トレーニングフェーズ - Androidアプリケーション開発”
  1. Android Studioでのプロジェクトセットアップ

    a. Android Studioを開き、New Projectをクリックします。 Android Studioセットアップ - 新規プロジェクト

    b. Native C++ -> Nextをクリックします。New Projectの詳細を次のように入力します:

    • 名前 - ORT Personalize
    • パッケージ名 - com.example.ortpersonalize
    • 言語 - Kotlin

    Nextをクリックします。

    Android Studioセットアップ - プロジェクト名

    c. C++17ツールチェーンを選択 -> Finish

    Android Studioセットアップ - プロジェクトC++ツールチェーン

    d. これで完了です!Android Studioプロジェクトがセットアップされました。現在、Android Studioエディターにボイラープレートコードが表示されているはずです。

  2. ONNX Runtime依存関係の追加

    a. Android Studioプロジェクトのcppディレクトリの下にlibinclude\onnxruntimeという2つの新しいフォルダーを作成します。

    libとincludeフォルダー

    b. Maven Centralにアクセスします。Versions->Browse->でonnxruntime-training-androidアーカイブパッケージ(aarファイル)をダウンロードします。

    c. aar拡張子をzipに変更します。onnxruntime-training-android-1.15.0.aaronnxruntime-training-android-1.15.0.zipに変更します。

    d. zipファイルの内容を抽出します。

    e. jni\arm64-v8aフォルダーからlibonnxruntime.so共有ライブラリを、新しく作成したlibフォルダーのAndroidプロジェクトにコピーします。

    f. headersフォルダーの内容を新しく作成したinclude\onnxruntimeフォルダーにコピーします。

    g. native-lib.cppファイルでトレーニングcxxヘッダーファイルを含めます。

    #include "onnxruntime_training_cxx_api.h"

    h. build.gradle (Module)ファイルにabiFiltersを追加してarm64-v8aを選択します。この設定はbuild.gradledefaultConfigの下に追加する必要があります:

    ndk {
    abiFilters 'arm64-v8a'
    }

    build.gradleファイルのdefaultConfigセクションは次のようになります:

    defaultConfig {
    applicationId "com.example.ortpersonalize"
    minSdk 29
    targetSdk 33
    versionCode 1
    versionName "1.0"
    testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
    externalNativeBuild {
    cmake {
    cppFlags '-std=c++17'
    }
    }
    ndk {
    abiFilters 'arm64-v8a'
    }
    }

    i. cmakeが共有ライブラリを見つけてビルドできるように、CMakeLists.txtonnxruntime共有ライブラリを追加します。これを行うには、CMakeLists.txtortpersonalizeライブラリが追加された後に次の行を追加します:

    Terminal window
    add_library(onnxruntime SHARED IMPORTED)
    set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)

    上記の2行の直後に、ONNX RuntimeヘッダーファイルがどこにあるかをCMakeに知らせます:

    Terminal window
    target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)

    Android C++プロジェクトをonnxruntimeライブラリに対してリンクします:

    Terminal window
    target_link_libraries( # Specifies the target library.
    ortpersonalize
    # Links the target library to the log library
    # included in the NDK.
    ${log-lib}
    onnxruntime)

    CMakeLists.txtファイルは次のようになります:

    project("ortpersonalize")
    add_library( # Sets the name of the library.
    ortpersonalize
    # Sets the library as a shared library.
    SHARED
    # Provides a relative path to your source file(s).
    native-lib.cpp
    utils.cpp
    inference.cpp
    train.cpp)
    add_library(onnxruntime SHARED IMPORTED)
    set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    find_library( # Sets the name of the path variable.
    log-lib
    # Specifies the name of the NDK library that
    # you want CMake to locate.
    log)
    target_link_libraries( # Specifies the target library.
    ortpersonalize
    # Links the target library to the log library
    # included in the NDK.
    ${log-lib}
    onnxruntime)

    j. アプリケーションをビルドし、アプリがONNX Runtimeヘッダーを正常にインクルードし、共有onnxruntimeライブラリに対してリンクできることを確認して成功するのを待ちます。

  3. 事前ビルドトレーニングアーティファクトとデータセットのパッケージング

    a. Android Studioプロジェクトのappの下の左ペインから新しいassetsフォルダーを作成します(右クリック app -> New -> Folder -> Assets Folderを選択し、mainの下に配置)。

    b. ステップ2で生成されたトレーニングアーティファクトをこのフォルダーにコピーします。

    c. 次に、onnxruntime-training-examplesリポジトリにアクセスし、マシンにデータセット(images.zip)をダウンロードして抽出します。このデータセットは、KaggleでCorrado Alessioによって作成された元のanimals-10データセットから変更されました。

    d. ダウンロードしたimagesフォルダーをAndroid Studioのassets/imagesディレクトリにコピーします。

    プロジェクトの左ペインは次のようになります:

    プロジェクトアセット

  4. ONNX Runtimeとのインターフェース - C++コード

    a. アプリケーションから呼び出される4つの関数をC++で実装します:

    • createSession: アプリケーション起動時に呼び出されます。新しいCheckpointStateTrainingSessionオブジェクトを作成します。
    • releaseSession: アプリケーションが閉じようとしているときに呼び出されます。この関数は、開始時に割り当てられたリソースを解放します。
    • performTraining: UIのTrainボタンをクリックしたときに呼び出されます。
    • performInference: UIのInferボタンをクリックしたときに呼び出されます。

    b. セッションを作成

    この関数は、アプリケーションが起動されたときに呼び出されます。この関数は、トレーニングアーティファクトアセットを使用して、C++ CheckpointStateとTrainingSessionオブジェクトを作成します。これらのオブジェクトは、デバイス上でモデルをトレーニングするために使用されます。

    createSessionへの引数は:

    • checkpoint_path: チェックポイントアーティファクトへのキャッシュされたパス。
    • train_model_path: トレーニングモデルアーティファクトへのキャッシュされたパス。
    • eval_model_path: evalモデルアーティファクトへのキャッシュされたパス。
    • optimizer_model_path: オプティマイザーモデルアーティファクトへのキャッシュされたパス。
    • cache_dir_path: Androidデバイス上のキャッシュディレクトリへのパス。キャッシュディレクトリは、C++コードからトレーニングアーティファクトにアクセスするための方法として使用されます。

    この関数は、session_cacheオブジェクトへのポインタを表すlongを返します。このlongは、SessionCacheにアクセスする必要があるときはいつでもSessionCacheにキャストできます。

    extern "C" JNIEXPORT jlong JNICALL
    Java_com_example_ortpersonalize_MainActivity_createSession(
    JNIEnv *env, jobject /* this */,
    jstring checkpoint_path, jstring train_model_path, jstring eval_model_path,
    jstring optimizer_model_path, jstring cache_dir_path)
    {
    std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>(
    utils::JString2String(env, checkpoint_path),
    utils::JString2String(env, train_model_path),
    utils::JString2String(env, eval_model_path),
    utils::JString2String(env, optimizer_model_path),
    utils::JString2String(env, cache_dir_path));
    return reinterpret_cast<long>(session_cache.release());
    }

    上記の関数本体から見られるように、この関数はSessionCacheクラスのオブジェクトへのユニークポインタを作成します。SessionCacheの定義は以下に示されています。

    struct SessionCache {
    ArtifactPaths artifact_paths;
    Ort::Env ort_env;
    Ort::SessionOptions session_options;
    Ort::CheckpointState checkpoint_state;
    Ort::TrainingSession training_session;
    Ort::Session* inference_session;
    SessionCache(const std::string &checkpoint_path, const std::string &training_model_path,
    const std::string &eval_model_path, const std::string &optimizer_model_path,
    const std::string& cache_dir_path) :
    artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path),
    ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(),
    checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())),
    training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(),
    artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()),
    inference_session(nullptr) {}
    };

    ArtifactPathsの定義は:

    struct ArtifactPaths {
    std::string checkpoint_path;
    std::string training_model_path;
    std::string eval_model_path;
    std::string optimizer_model_path;
    std::string cache_dir_path;
    std::string inference_model_path;
    ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path,
    const std::string &eval_model_path, const std::string &optimizer_model_path,
    const std::string& cache_dir_path) :
    checkpoint_path(checkpoint_path), training_model_path(training_model_path),
    eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path),
    cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {}
    };

    c. セッションを解放

    この関数は、アプリケーションがシャットダウンしようとしているときに呼び出されます。主要に起動時に作成されたCheckpointStateとTrainingSessionを解放します。

    releaseSessionへの引数は:

    • session: SessionCacheオブジェクトを表すlong
    extern "C" JNIEXPORT void JNICALL
    Java_com_example_ortpersonalize_MainActivity_releaseSession(
    JNIEnv *env, jobject /* this */,
    jlong session) {
    auto *session_cache = reinterpret_cast<SessionCache *>(session);
    delete session_cache->inference_session;
    delete session_cache;
    }

    d. トレーニングを実行

    この関数は、トレーニングする必要がある各バッチに対して呼び出されます。トレーニングループはアプリケーション側でKotlinで記述されており、トレーニングループ内で各バッチに対してperformTraining関数が呼び出されます。

    performTrainingへの引数は:

    • session: SessionCacheオブジェクトを表すlong
    • batch: トレーニングのために渡されるfloat配列としての入力画像。
    • labels: トレーニングのために提供される入力画像に関連付けられたint配列としてのラベル。
    • batch_size: 各TrainStepで処理する画像の数。
    • channels: 画像のチャネル数。この例では、常に値3で呼び出されます。
    • frame_rows: 画像の行数。この例では、常に値224で呼び出されます。
    • frame_cols: 画像の列数。この例では、常に値224で呼び出されます。

    この関数は、このバッチのトレーニング損失を表すfloatを返します。

    extern "C"
    JNIEXPORT float JNICALL
    Java_com_example_ortpersonalize_MainActivity_performTraining(
    JNIEnv *env, jobject /* this */,
    jlong session, jfloatArray batch, jintArray labels, jint batch_size,
    jint channels, jint frame_rows, jint frame_cols) {
    auto* session_cache = reinterpret_cast<SessionCache *>(session);
    if (session_cache->inference_session) {
    // train_stepでモデルパラメータを更新するため、
    // 推論セッションを無効化します。
    // 次の推論セッション呼び出しでは、推論セッションを再作成する必要があります。
    delete session_cache->inference_session;
    session_cache->inference_session = nullptr;
    }
    // この入力バッチを使用してモデルパラメータを更新します。
    return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr),
    env->GetIntArrayElements(labels, nullptr), batch_size,
    channels, frame_rows, frame_cols);
    }

    上記の関数はtrain_step関数を活用します。train_step関数の定義は以下の通りです:

    namespace training {
    float train_step(SessionCache* session_cache, float *batches, int32_t *labels,
    int64_t batch_size, int64_t image_channels, int64_t image_rows,
    int64_t image_cols) {
    const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
    const std::vector<int64_t> labels_shape({batch_size});
    Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    std::vector<Ort::Value> user_inputs; // {inputs, labels}
    // 入力をバッチ化
    user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches,
    batch_size * image_channels * image_rows * image_cols * sizeof(float),
    input_shape.data(), input_shape.size(),
    ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    // ラベルをバッチ化
    user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels,
    batch_size * sizeof(int32_t),
    labels_shape.data(), labels_shape.size(),
    ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32));
    // トレーニングステップを実行し、前進 + 損失 + 後退を実行します。
    float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>());
    // 上記で計算された勾配の方向にステップを踏んでモデルパラメータを更新します。
    session_cache->training_session.OptimizerStep();
    // パラメータが更新されたら勾配をリセットします。
    // 次回の入力ラウンドで新しい勾配を計算できます。
    session_cache->training_session.LazyResetGrad();
    return loss;
    }
    } // namespace training

    e. 推論を実行

    この関数は、ユーザーが推論を実行したいときに呼び出されます。

    performInferenceへの引数は:

    • session: SessionCacheオブジェクトを表すlong
    • image_buffer: トレーニングのために渡されるfloat配列としての入力画像。
    • batch_size: 各推論で処理する画像の数。この例では、常に値1で呼び出されます。
    • image_channels: 画像のチャネル数。この例では、常に値3で呼び出されます。
    • image_rows: 画像の行数。この例では、常に値224で呼び出されます。
    • image_cols: 画像の列数。この例では、常に値224で呼び出されます。
    • classes: 4つのカスタムクラスを表す文字列のリスト。

    この関数は、提供された4つのカスタムクラスのうちの1つを表すstringを返します。これはモデルの予測です。

    extern "C"
    JNIEXPORT jstring JNICALL
    Java_com_example_ortpersonalize_MainActivity_performInference(
    JNIEnv *env, jobject /* this */,
    jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows,
    jint image_cols, jobjectArray classes) {
    std::vector<std::string> classes_str;
    for (int i = 0; i < env->GetArrayLength(classes); ++i) {
    // 現在の文字列要素にアクセス
    jstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i));
    classes_str.push_back(utils::JString2String(env, elem));
    }
    auto* session_cache = reinterpret_cast<SessionCache *>(session);
    if (!session_cache->inference_session) {
    // 推論セッションが存在しないので、新しいものを作成します。
    session_cache->training_session.ExportModelForInferencing(
    session_cache->artifact_paths.inference_model_path.c_str(), {"output"});
    session_cache->inference_session = std::make_unique<Ort::Session>(
    session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(),
    session_cache->session_options).release();
    }
    auto prediction = inference::classify(
    session_cache, env->GetFloatArrayElements(image_buffer, nullptr),
    batch_size, image_channels, image_rows, image_cols, classes_str);
    return env->NewStringUTF(prediction.first.c_str());
    }

    上記の関数はclassifyを呼び出します。classifyの定義は:

    namespace inference {
    std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data,
    int64_t batch_size, int64_t image_channels,
    int64_t image_rows, int64_t image_cols,
    const std::vector<std::string>& classes) {
    std::vector<const char *> input_names = {"input"};
    size_t input_count = 1;
    std::vector<const char *> output_names = {"output"};
    size_t output_count = 1;
    std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
    Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    std::vector<Ort::Value> input_values; // {input images}
    input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data,
    batch_size * image_channels * image_rows * image_cols * sizeof(float),
    input_shape.data(), input_shape.size(),
    ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    std::vector<Ort::Value> output_values;
    output_values.emplace_back(nullptr);
    // ロジットを取得
    session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(),
    input_count, output_names.data(), output_values.data(), output_count);
    float *output = output_values.front().GetTensorMutableData<float>();
    // softmaxを実行し、各クラスの確率を取得
    std::vector<float> probabilities = Softmax(output, classes.size());
    size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end()));
    return {classes[best_index], probabilities[best_index]};
    }
    } // namespace inference

    classify関数はSoftmaxという別の関数を呼び出します。Softmaxの定義は:

    std::vector<float> Softmax(float *logits, size_t num_logits) {
    std::vector<float> probabilities(num_logits, 0);
    float sum = 0;
    for (size_t i = 0; i < num_logits; ++i) {
    probabilities[i] = exp(logits[i]);
    sum += probabilities[i];
    }
    if (sum != 0.0f) {
    for (size_t i = 0; i < num_logits; ++i) {
    probabilities[i] /= sum;
    }
    }
    return probabilities;
    }
  5. 画像前処理

    a. MobileNetV2モデルは、提供される入力画像が

    • 3 x 224 x 224のサイズであること
    • 平均(0.485, 0.456, 0.406)が減算され、標準偏差(0.229, 0.224, 0.225)で除算された正規化された画像であること

    この前処理は、Android提供のライブラリを使用してJava/Kotlinで行われます。

    app/src/main/java/com/example/ortpersonalizeディレクトリの下にImageProcessingUtil.ktという新しいファイルを作成します。このファイルに画像のトリミング、リサイズ、正規化のユーティリティメソッドを追加します。

    b. 画像のトリミングとリサイズ。

    fun processBitmap(bitmap: Bitmap) : Bitmap {
    // この関数は、指定されたビットマップを処理します
    // - より長い次元に沿ってトリミングして正方形ビットマップを取得
    // 幅が高さより大きい場合
    // ___+_________________+___
    // | + + |
    // | + + |
    // | + + + |
    // | + + |
    // |__+_________________+__|
    // <-------- width -------->
    // <----- height ---->
    // <--> cropped <-->
    //
    // 高さが幅より大きい場合
    // _________________________ ʌ ʌ
    // | | | cropped
    // |+++++++++++++++++++++++| | ʌ v
    // | | | |
    // | | | |
    // | + | height width
    // | | | |
    // | | | |
    // |+++++++++++++++++++++++| | v ʌ
    // | | | cropped
    // |_______________________| v v
    //
    //
    //
    // - トリミングされた正方形画像をmobilenetv2モデルが必要とするサイズ(3 x 224 x 224)にリサイズ
    lateinit var bitmapCropped: Bitmap
    if (bitmap.getWidth() >= bitmap.getHeight()) {
    // 高さが幅より小さいため、長さが高さである正方形をトリミング
    // 幅次元に沿ってトリミングが発生
    val width: Int = bitmap.getHeight()
    val height: Int = bitmap.getHeight()
    // トリミング画像の左側は、幅の両側で中心の等分部分を含むように
    // (bitmap.getWidth() / 2 - bitmap.getHeight() / 2)から始まる必要があります
    // 高さ次元に沿ってトリミングしないため、上側は0から始まります
    val x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2
    val y: Int = 0
    bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
    } else {
    // 幅が高さより小さいため、長さが幅である正方形をトリミング
    // 高さ次元に沿ってトリミングが発生
    val width: Int = bitmap.getWidth()
    val height: Int = bitmap.getWidth()
    // 幅次元に沿ってトリミングしないため、左側は0から始まります
    // 上側は、中心の両側で高さの等分部分を含むように
    // (bitmap.getHeight() / 2 - bitmap.getWidth() / 2)から始まる必要があります
    val x: Int = 0
    val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2
    bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
    }
    // 画像をmobilenetv2モデルが必要とするチャネルx幅x高さにリサイズ
    val width: Int = 224
    val height: Int = 224
    val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false)
    return bitmapResized
    }

    c. 画像の正規化。

    fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) {
    // この関数は画像ピクセルを反復し、画像に以下を実行します
    // - ピクセル値を0から1の間に正規化
    // - 平均(0.485, 0.456, 0.406)をピクセル値から減算(movilenetv2モデル構成から派生)
    // 標準偏差(0.229, 0.224, 0.225)でピクセル値を除算(movilenetv2モデル構成から派生)
    // 値は指定されたオフセットから始まる指定されたバッファに書き込まれます。
    // 値は次のように書き込まれます
    // |____|____________________|__________________| <--- buffer
    // ʌ <--- offset
    // ʌ <--- offset + width * height * channels
    // |____|rrrrrr|_____________|__________________| <--- red channel read in column major order
    // |____|______|gggggg|______|__________________| <--- green channel read in column major order
    // |____|______|______|bbbbbb|__________________| <--- blue channel read in column major order
    val width: Int = bitmap.getWidth()
    val height: Int = bitmap.getHeight()
    val stride: Int = width * height
    for (x in 0 until width) {
    for (y in 0 until height) {
    val color: Int = bitmap.getPixel(y, x)
    val index = offset + (x * height + y)
    // 平均を減算し、標準偏差で除算
    // movilenetv2モデルの値を使用
    buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f)
    buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f)
    buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f)
    }
    }
    }

    d. UriからBitmapを取得

    fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap {
    // この関数は、指定されたuriの画像ファイルを読み取り、ビットマップにデコードします
    val source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri)
    return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true)
    }
  6. アプリケーションのフロントエンド

    a. このチュートリアルでは、次のユーザーインターフェース要素を使用します:

    • TrainボタンとInferボタン
    • クラスボタン
    • ステータスメッセージテキスト
    • 画像表示
    • プログレスダイアログ

    b. このチュートリアルは、グラフィカルユーザーインターフェースの作成方法を示すことを意図していません。このため、GitHubで利用可能なファイルを使用します。

    c. strings.xmlからすべての文字列定義をAndroid Studioのローカルstrings.xmlにコピーします。

    d. activity_main.xmlの内容をAndroid Studioのローカルactivity_main.xmlにコピーします。

    e. layoutフォルダの下にdialog.xmlという新しいファイルを作成します。dialog.xmlの内容をAndroid Studioのローカルdialog.xmlにコピーします。

    f. このセクションの残りの変更は、MainActivity.ktファイルで行う必要があります。

    g. アプリケーションの起動

    アプリケーションが起動すると、onCreate関数が呼び出されます。この関数は、セッションキャッシュとユーザーインターフェースハンドラーをセットアップします。

    MainActivity.ktファイルのonCreate関数を参照してください。

    h. カスタムクラスボタンハンドラー - クラスボタンを使用して、ユーザーがトレーニング用のカスタム画像を選択できるようにします。これを行うリスナーを追加する必要があります。これらのリスナーは正確に行います。

    MainActivity.ktのこれらのボタンハンドラーを参照してください:

    • onClassAClickedListener
    • onClassBClickedListener
    • onClassXClickedListener
    • onClassYClickedListener

    i. カスタムクラスラベルをパーソナライズ

    デフォルトでは、カスタムクラスラベルは[A, B, X, Y]です。しかし、明確にするために、これらのラベルをユーザーが名前変更できるようにします。これはMainActivity.ktで定義されたロングクリックリスナーによって実現されます:

    • onClassALongClickedListener
    • onClassBLongClickedListener
    • onClassXLongClickedListener
    • onClassYLongClickedListener

    j. カスタムクラスの切り替え

    カスタムクラススイッチがオフになると、事前パッケージされた動物データセットが実行されます。オンになると、トレーニング用に独自のデータセットを提供することが期待されます。この遷移を処理するために、MainActivity.ktonCustomClassSettingChangedListenerスイッチハンドラーが実装されます。

    k. トレーニングハンドラー

    各クラスに少なくとも1つの画像がある場合、Trainボタンを有効にできます。Trainボタンがクリックされると、トレーニングが選択された画像で開始されます。トレーニングハンドラーは以下の責任を負います:

    • トレーニング画像を1つのコンテナに収集
    • 画像の順序をシャッフル
    • 画像のトリミングとリサイズ
    • 画像の正規化
    • 画像のバッチ化
    • トレーニングループの実行(ループ内でC++のperformTraining関数を呼び出し)

    MainActivity.ktで定義されたonTrainButtonClickedListener関数がこれを行います。

    l. 推論ハンドラー

    トレーニングが完了したら、ユーザーは任意の画像を推論するためにInferボタンをクリックできます。推論ハンドラーは以下の責任を負います

    • 推論画像の収集
    • 画像のトリミングとリサイズ
    • 画像の正規化
    • C++のperformInference関数の呼び出し
    • ユーザーインターフェースへの推論出力のレポート

    これはMainActivity.ktonInferenceButtonClickedListener関数によって実現されます。

    m. 上記のすべてのアクティビティのハンドラー

    推論またはカスタムクラスの画像が選択されると、それらを処理する必要があります。MainActivity.ktで定義されたonActivityResult関数がそれを行います。

    n. 最後のこと。カメラを使用するために、AndroidManifest.xmlファイルに以下を追加します:

    <uses-permission android:name="android.permission.CAMERA" />
    <uses-feature android:name="android.hardware.camera" />

トレーニングフェーズ - デバイスでのアプリケーション実行

Section titled “トレーニングフェーズ - デバイスでのアプリケーション実行”
  1. デバイスでのアプリケーション実行

    a. Androidデバイスをマシンに接続し、デバイスでアプリケーションを実行しましょう。

    b. デバイスでのアプリケーション起動は次のようになります:

    Barebones ORT Personalize app
  2. 事前ロードデータセットでのトレーニング - 動物

    a. デバイスでアプリケーションを起動して、事前ロード動物デバイスでのトレーニングを開始しましょう。

    b. 下部のCustom classesスイッチを切り替えます。

    c. クラスラベルがDogCatElephantCowに変更されます。

    d. Trainingを実行し、トレーニング完了時にプログレスダイアログが消えるのを待ちます。

    e. ライブラリから任意の動物画像を使用して推論を実行します。

    牛の画像を持つORT Personalizeアプリ

    上記の画像からわかるように、モデルは正しくCowを予測しました。

  3. カスタムデータセットでのトレーニング - 有名人の場合

    a. ウェブからTom Cruise、Leonardo DiCaprio、Ryan Reynolds、Brad Pittの画像をダウンロードします。

    b. アプリを閉じて再起動して、新しいセッションを確実に起動してください。

    c. アプリケーションが起動した後、ロングクリックを使用して4つのクラスをそれぞれTomLeoRyanBradに名前変更します。

    d. 各クラスのボタンをクリックし、その有名人に関連する画像を選択します。各カテゴリあたり10〜15枚の画像を使用できます。

    e. Trainボタンを押して、提供されたデータから学習するようにアプリケーションに指示します。

    f. トレーニングが完了したら、まだ見ていない画像を提供するためにInferボタンを押します。

    g. それだけです!うまくいけば、アプリケーションは画像を正しく分類したはずです。

    トム・クルーズが中央にいる画像分類アプリ

おめでとうございます!ONNX Runtimeを使用してデバイス上で画像を分類することを学習するAndroidアプリケーションを正常に構築しました。アプリケーションはGitHubのonnxruntime-training-examplesでも利用できます。