import * as ort from "onnxruntime-web";
import { IVector3 } from "@ar_template/component/interface";

let session: ort.InferenceSession | null = null;

async function init_onnx(type: "landmark" | "seg" = "landmark"): Promise<boolean> {
    try {let modelPath = "/static/model/re_optimized_mbo_bisenetV10_landmark_pose106_fused_model_HWC.onnx";
        if(type === "seg") {
            modelPath = "/static/model/re_optimized_mbo_bisenetV10_seg_pose106_fused_model_HWC.onnx";
        }

        session = await ort.InferenceSession.create(modelPath, {
            executionProviders: ["wasm"], enableMemPattern: true, enableProfiling: true,
            interOpNumThreads: 4, graphOptimizationLevel: "extended"
        });

        return true;
    } catch (e) {
        return false;
    }
}

async function predict(data: ort.Tensor): Promise<boolean | ort.InferenceSession.ReturnType> {
    const feeds: Record<string, ort.Tensor> = {};
    if (!session) {
        return false;
    }

    feeds[session.inputNames[0]] = data;

    return await session.run(feeds);
}

function get_landmark_point(landmark: ort.Tensor, index: number): IVector3 {
    return {
        x: landmark.data[index * 3] as number,
        y: landmark.data[(index * 3) + 1] as number,
        z: landmark.data[(index * 3) + 2] as number
    };
}

export {
    init_onnx,
    predict,
    get_landmark_point,
};
