<template>
  <ion-card class="local-serving">
    <ion-card-header>
      <ion-card-title>
        {{ props.modelTitle }}
      </ion-card-title>
    </ion-card-header>
    <ion-card-content>
      <ion-loading :is-open="loadingModelAndWasmArtefactsInProgress"
        :message="(modelAndWasmArtefactsSaved) ? LOCAL_SERVING_MODULE_STRINGS.INITIALIZING_POPUP : LOCAL_SERVING_MODULE_STRINGS.DOWNLOAD_POPU"
        spinner="circles" />
      <ion-grid v-if="servingInProgress">
        <ion-row>
          <ion-col>
            <ion-spinner name="circles"/>
          </ion-col>
        </ion-row>
        <ion-loading :is-open="servingInProgress" :message=LOCAL_SERVING_MODULE_STRINGS.APPLYMODEL_POPUP spinner="circles" />
      </ion-grid>
      <ion-grid v-if="loadingModelAndWasmArtefactsInProgress">
        <p>{{ LOCAL_SERVING_MODULE_STRINGS.INITIALIZING }}</p>
      </ion-grid>
      <ion-grid v-if="!modelAndWasmArtefactsSaved && !loadingModelAndWasmArtefactsInProgress">
        <p>{{ LOCAL_SERVING_MODULE_STRINGS.DOWNLOAD }}</p>
      </ion-grid>
      <ion-grid
        v-if="predictionInfoList.length === 0 && !servingInProgress && modelAndWasmArtefactsSaved && !loadingModelAndWasmArtefactsInProgress && environmentInitialized && inferenceError != ''">
        <p>{{ inferenceError }}</p>
      </ion-grid>
      <ion-grid v-if="predictionInfoList.length === 0 && environmentError != ''">
        <p>{{ environmentError }}</p>
      </ion-grid>
      <ion-grid
        v-if="predictionInfoList.length === 0 && !servingInProgress && modelAndWasmArtefactsSaved && !loadingModelAndWasmArtefactsInProgress && environmentInitialized">
        <p>{{ LOCAL_SERVING_MODULE_STRINGS.NEED_DATA }}</p>
      </ion-grid>
      <ion-grid v-if="predictionInfoList.length !== 0 && !servingInProgress">
        <p>Model Prediction:</p>
        <ion-row class="prediction-row">
          <ion-col size="auto" v-for="(prediction, index) in predictionInfoList" :key="index" class="prediction-col"
            :style="{ transform: `scale(${props.predictionDisplayScaling / 100})`, transformOrigin: 'center center', margin: `${5 * 2 * props.predictionDisplayScaling / 100}px` }">
            <div class="prediction-cell">
              <div style="color: white">{{ prediction.name.replace(/_/g, ' ') }}</div>
              <div v-if="prediction.unit !== undefined">{{ prediction.unit }}</div>
              <div :style="{ color: prediction.color }" class="prediction-result">{{ prediction.score }}</div>
            </div>
          </ion-col>
        </ion-row>
      </ion-grid>
    </ion-card-content>
    <ion-grid v-if="!modelAndWasmArtefactsSaved && !loadingModelAndWasmArtefactsInProgress">
      <ion-button fill="clear" @click="downloadEnvironment">Download</ion-button>
    </ion-grid>
  </ion-card>
</template>

<style scoped>
/*
.prediction-cell {
  display: flex;
  flex-direction: column;
  align-items: center;
  justify-content: center;
  padding: 10px;
  border: 1px solid #ccc;
  // Adds a border to each cell
  min-height: 5px;
  // Adjust height as necessary
  text-align: center;
  // Centers text within items
}*/
.prediction-row {
  display: flex;
  justify-content: center;
  /* Center the cells horizontally */
  flex-wrap: wrap;
  /* Allow cells to wrap on smaller screens */
}

.prediction-col {
  display: flex;
  justify-content: center;
  /* Center the content within each column */
  align-items: center;
  padding: 10px;
  /* Add some padding for spacing */
  margin: 10px;
  /* Add some margin between the columns */
}

.prediction-cell {
  text-align: center;
  /* Center the text in the cells */
  padding: 10px;
  /* Add padding inside the cells */
  background-color: #000000;
  /* Optional: Add a background color */
  border-radius: 5px;
  /* Optional: Add some rounding to the cells */
  box-sizing: border-box;
}
</style>

<script setup>
import { IonLoading, IonGrid, IonRow, IonCol, IonCard, IonCardHeader, IonSpinner, IonCardTitle, IonCardContent } from "@ionic/vue";
import { ref, onMounted, onBeforeUnmount } from "vue";
import { events } from "@/utils/events";
import * as ort from '@compolytics/onnxruntime-web';

import { useFileStore } from "@/store/file";
import { downloadFile } from "@/supabase/file";
import { useAppStateStore } from "@/store/app";
import { LOCAL_SERVING_MODULE_STRINGS } from '@/const/strings';
import { v4 as uuid4 } from "uuid";

const fileStore = useFileStore();

// List of Predictions
const predictionInfoList = ref([]);
// Serving in progress indicator flag
const servingInProgress = ref(false);
// ort session
const session = ref(null);
// All files required for local inference are loaded flag
const modelAndWasmArtefactsSaved = ref(true)
// Loading files required for local inference in progress indicator flag
const loadingModelAndWasmArtefactsInProgress = ref(true)
// ort.wasmPaths and session initialized
const environmentInitialized = ref(false)
// string to be shown if error occurs during inference
const inferenceError = ref("")
// string to be shown if error occurs during environment downloading or initialization
const environmentError = ref("")

// Define the props for the component which are default values
const props = defineProps({

  // What message to listen for
  modelTitle: {
    type: String,
    default: "Model Serving",
  },

  // Which fields to display from the response of the model
  predictionFieldName: {
    type: String,
    default: "",
  },

  // What message to listen for
  dataMessage: {
    type: String,
    default: "data",
  },

  // Supabase bucket name
  bucketName: {
    type: String,
    default: "cicada-example",
  },

  // model.onnx filename in supabase bucket
  model: {
    type: String,
    default: "",
  },

  // What to display from the response of the model
  predictionDisplayName: {
    type: String,
    default: "",
  },

  predictionMetaName: {
    type: String,
    default: "",
  },

  taskFieldName: {
    type: String,
    default: "",
  },

  // What units to show for the prediction
  predictionDisplayUnits: {
    type: String,
    default: "",
  },

  // Simple scaling factor for the display in percent
  predictionDisplayScaling: {
    type: Number,
    default() {
      return 100;
    }
  }

});

// name of the local .onnx model file for this app
const appModelName = useAppStateStore().id + "_" + props.model;

// Function to find a full filename in the assets folder
function findFullWasmName(substring) {
  // Use Webpack's require.context to load the list of all wasm files in the assets/wasm folder
  const assetsContext = require.context('@/assets/wasm/', false, /\.wasm$/);
  const allFiles = assetsContext.keys();

  // Construct the regular expression dynamically using the expected name
  const wasmFilePattern = new RegExp(`${substring}-\\d+\\.\\d+\\.\\d+\\.wasm$`);

  // Search for a file that is named as substring + version + ".wasm"
  let matchingFile = allFiles.find((file) => file.match(wasmFilePattern));

  // crop ./ in the beginning
  matchingFile = matchingFile.substring(2);

  // Return the full filename
  return matchingFile;
}

// Mappings of initial URLs to wasm files. .wasm files will be accessed with other links.
// The names of .wasm files contain the version that changes, therefore necessary files need to be found
const initialWasmUrls = {
  'ort-wasm.wasm': require("@/assets/wasm/" + findFullWasmName("ort-wasm")),
  // 'ort-wasm-threaded.wasm': require("@/assets/wasm/" + findFullWasmName("ort-wasm-threaded")),
  'ort-wasm-simd.wasm': require("@/assets/wasm/" + findFullWasmName("ort-wasm-simd")),
  // 'ort-wasm-simd-threaded.wasm': require("@/assets/wasm/" + findFullWasmName("ort-wasm-simd-threaded"))
};

function blobToBase64(blob) {
  return new Promise((resolve, reject) => {
    const reader = new FileReader();
    reader.onloadend = () => {
      const base64data = reader.result.split(',')[1]; // remove the prefix "data:application/wasm;base64,"
      resolve(base64data);
    };
    reader.onerror = reject;
    reader.readAsDataURL(blob);
  });
}

// function used to download files, encode them to base64 string and save locally
async function downloadAndSave(localfileName, url, type, downloadFromCloud = false) {
  console.log(`Downloading and saving ${localfileName}`)
  let response;
  if (downloadFromCloud) {
    response = await downloadFile(props.bucketName, props.model);
  }
  else
    response = await fetch(url);
  const buffer = await response.arrayBuffer();
  const blob = new Blob([buffer], { type: type });
  const base64Data = await blobToBase64(blob);

  // saving .wasm if not loading from cloud -> make available for everybody
  if (!downloadFromCloud) {
    const access = {
      user: undefined,
      app: undefined,
    };
    await fileStore.saveFile(localfileName, base64Data, access);
  }
  else {
    await fileStore.saveFile(localfileName, base64Data);
  }
}

// function to get URL to local file, used to initialize wasmPaths. Needed because direct assignment of ort.env.wasmPaths to the ressources did not work.
async function getURLtoLocalFile(fileName, type) {
  const base64string = await fileStore.getFile(fileName);

  const byteCharacters = atob(base64string);
  const byteNumbers = new Array(byteCharacters.length);
  for (let i = 0; i < byteCharacters.length; i++) {
    byteNumbers[i] = byteCharacters.charCodeAt(i);
  }
  const byteArray = new Uint8Array(byteNumbers);
  const blob = new Blob([byteArray], { type: type });
  return URL.createObjectURL(blob);
}

// function to read local file as ByteArray, used to read .onnx model
async function getByteArrayFromLocalFile(fileName) {
  const base64string = await fileStore.getFile(fileName);

  const byteCharacters = atob(base64string);
  const byteNumbers = new Array(byteCharacters.length);
  for (let i = 0; i < byteCharacters.length; i++) {
    byteNumbers[i] = byteCharacters.charCodeAt(i);
  }
  const byteArray = new Uint8Array(byteNumbers);
  return byteArray;
}

// checks if all required .wasm files and .onnx model are available locally
async function areAllFilesSaved() {
  // get current list of available files
  const localFiles = await fileStore.listFiles();
  for (const fileNameKey of Object.keys(initialWasmUrls)) {
    if (!localFiles.includes(initialWasmUrls[fileNameKey])) {
      return false;
    }
  }
  if (!localFiles.includes(appModelName)) {
    return false;
  }
  return true;
}

// function to load all required .wasm and .onnx files
async function downloadEnvironment() {
  try {
    loadingModelAndWasmArtefactsInProgress.value = true
    // delete all the local data of the app (wasm artefacts and ONNX model)
    const localFiles = await fileStore.listFiles();
    // get list of required .wasm files in order to preserve them if present
    const requiredWasmFiles = Object.keys(initialWasmUrls).map((fileName) => findFullWasmName(fileName.substring(0, fileName.length - 5)));
    if (localFiles) {
      for (const fileToDelete of localFiles) {
        // .wasm files must be deleted only if their names differ from required ones
        if (fileToDelete.endsWith(".wasm") && !requiredWasmFiles.includes(fileToDelete)) {
          console.log(`Delete ${fileToDelete}`)
          await fileStore.deleteFile(fileToDelete);
        }
      }
    }
  } catch (error) {
    environmentError.value = LOCAL_SERVING_MODULE_STRINGS.WEBASSEMBLY_DELETE_ERROR;
    loadingModelAndWasmArtefactsInProgress.value = false;
    return;
  }

  try {
    // downloading and saving .wasm artefacts
    for (const [onnxruntimeFileKey, url] of Object.entries(initialWasmUrls)) {
      const fileExistsLocally = await fileStore.fileExists(findFullWasmName(
        // get filename except .wasm extension
        onnxruntimeFileKey.substring(0, onnxruntimeFileKey.length - 5)
      ));

      if (!fileExistsLocally) {
        await downloadAndSave(
          // get filename except .wasm extension
          findFullWasmName(onnxruntimeFileKey.substring(0, onnxruntimeFileKey.length - 5)),
          url, 'application/wasm', false);
      }
    }
  } catch (error) {
    environmentError.value = LOCAL_SERVING_MODULE_STRINGS.WEBASSEMBLY_SAVE_ERROR;
    loadingModelAndWasmArtefactsInProgress.value = false;
    return;
  }

  try {
    // downloading and saving .onnx model
    const fileExistsLocally = await fileStore.fileExists(appModelName);

    if (!fileExistsLocally) {
      await downloadAndSave(appModelName, props.model, 'application/octet-stream', true);
    }
  } catch (error) {
    environmentError.value = LOCAL_SERVING_MODULE_STRINGS.ONNX_SAVE_ERROR;
    loadingModelAndWasmArtefactsInProgress.value = false;
    return;
  }

  // initialize environment
  await initEnvironment();

  environmentError.value = "";
  modelAndWasmArtefactsSaved.value = true;
  loadingModelAndWasmArtefactsInProgress.value = false;
}

// function to initialize onnxruntime-web environment
async function initEnvironment() {

  // set loading flag to true
  loadingModelAndWasmArtefactsInProgress.value = true
  //initializing ort.env.wasm.wasmPaths
  const wasmPaths = {};

  try {
    for (const [onnxFileName, url] of Object.entries(initialWasmUrls)) {
      wasmPaths[onnxFileName] = await getURLtoLocalFile(findFullWasmName(
        // get filename except .wasm extension
        onnxFileName.substring(0, onnxFileName.length - 5)
      ), 'application/wasm');
    }

    ort.env.wasm.wasmPaths = wasmPaths;
  } catch (error) {
    environmentError.value = LOCAL_SERVING_MODULE_STRINGS.INITIALIZATION_ERROR_WEBASSEMBLY;
    loadingModelAndWasmArtefactsInProgress.value = false;
    return;
  }

  try {
    //initializing session
    const model = await getByteArrayFromLocalFile(appModelName);
    ort.env.debug = true;
    ort.env.logLevel = 'verbose';
    ort.env.wasm.numThreads = 1;
    session.value = await ort.InferenceSession.create(model);
  } catch (error) {
    environmentError.value = LOCAL_SERVING_MODULE_STRINGS.INITIALIZATION_ERROR_ONNX;
    loadingModelAndWasmArtefactsInProgress.value = false;
    return;
  }

  environmentError.value = "";
  loadingModelAndWasmArtefactsInProgress.value = false;
  environmentInitialized.value = true;
}

onMounted(async () => {
  // Setup event listener for data messages 
  events.on(props.dataMessage, onData);
  await new Promise(r => setTimeout(r, 500));

  // a flag to indicate whether reading or downloading operations with .wasm and .onnx files were successfull
  let localFilesOperationsSuccess = false;
  while (!localFilesOperationsSuccess) {
    try {
      await areAllFilesSaved();
      localFilesOperationsSuccess = true;
    }
    catch (err) {
      await new Promise(r => setTimeout(r, 100));
    }
  }
  if (await areAllFilesSaved()) {
    await initEnvironment();
  }
  else {
    downloadEnvironment();
  }
});

onBeforeUnmount(() => {
  events.off(props.dataMessage, onData);
});

function processPredictions(props, predicted_values) {

  const predictionList = []; // Assuming this is the initial setup

  // Get a table of field names to display names
  let renameTable = {};
  try {
    renameTable = JSON.parse(props.predictionDisplayName);
  } catch (error) {
    renameTable = {};
  }

  // Get a table of field names to unit names
  let unitTable = {};
  try {
    unitTable = JSON.parse(props.predictionDisplayUnits);
  } catch (error) {
    unitTable = {};
  }

  // Get a table of field names to meta field names which contain label and color for classification tasks
  let metaTable = {};
  try {
    metaTable = JSON.parse(props.predictionMetaName);
  } catch (error) {
    metaTable = {};
  }

  // Ensure predictionFieldName is available and a string
  if (!props || typeof props.predictionFieldName !== 'string') {
    return []; // or handle the error appropriately
  }

  // Assume we have a regression task
  let taskType = "regression";
  if (!props || typeof props.taskFieldName === 'string') {
    // if taskfield is set, we take it
    if (predicted_values && predicted_values[props.taskFieldName])
      taskType = predicted_values[props.taskFieldName].data;
  }

  // Split field names and trim whitespace
  const predFieldList = props.predictionFieldName.split(',').map(name => name.trim());

  // Process each field name and create predictions
  predFieldList.forEach(fieldName => {

    console.log("Process Field: " + fieldName + " with task type: " + taskType);
    // Check if the field name is in the scorePayload
    if (predicted_values && predicted_values[fieldName]) {
      // Extract the numeric prediction from the scoring payload
      const number = parseFloat(predicted_values[fieldName].data);
      // If the field name is in the rename table, we use the display name
      const displayFieldName = renameTable[fieldName] ? renameTable[fieldName] : fieldName;
      // If the field name is in the unit table, we use the unit name and generate a proper formating
      const displayUnitName = unitTable[fieldName] ? "[" + unitTable[fieldName] + "]" : undefined;
      // Create the prediction structure with all display information
      const prediction = {
        name: displayFieldName,
        score: number.toFixed(2),
        unit: displayUnitName,
        color: "#ffffff" // Default color
      };

      // If we have classification task, we replace the score with the class name and get the color
      // asscoiated with the class      
      if (taskType == "classification") {
        // Try parsing predicted_values[metaTable[fieldName]].data as JSON
        let predicted_values_json = {};
        try {
          predicted_values_json = JSON.parse(predicted_values[metaTable[fieldName]].data);
        } catch (error) {
          predicted_values_json = {};
        }
        if (metaTable[fieldName] && predicted_values[metaTable[fieldName]] && predicted_values_json['label']) {
          prediction.score = predicted_values_json['label'][0]
          prediction.color = predicted_values_json['color'][0]
          // If the color string is to short, fill with 0
          if (prediction.color.length < 7) {
            prediction.color = prediction.color + "0".repeat(7 - prediction.color.length);
          }
        }
      }

      // Add the prediction to the list
      predictionList.push(prediction);
    }
  });
  return predictionList;
}

function convertResponse(model_response) {
  for (const key in model_response) {
    if (Object.prototype.hasOwnProperty.call(model_response, key)) {
      // Get data from ONNX tensor
      model_response[key] = model_response[key].data;
      // Check if the field is a Float32Array
      if (model_response[key] instanceof Float32Array) {
        // Convert to a standard array
        model_response[key] = Array.from(model_response[key]);
      }
    }
  }
  return model_response;
}

function parseResponseForImage(response) {
  // Initialize a new dictionary
  const result = {};
  // Iterate over the response fields
  for (const key in response) {
    if (Object.prototype.hasOwnProperty.call(response, key)) {
      let value = response[key];
      // Check if value is an array
      if (Array.isArray(value)) {
        value = value[0];
      }
      // Check if the value is a Data URI for an image
      if (typeof value === 'string' && value.startsWith('data:image/')) {
        // Add the Data URI to the "values" field in the new dictionary
        result.values = value;
        result.type = "image";
        result.source = "ServingModule";
        result.shape = [];
        result.config = {};
        result.meta = {};
        result.uuid = uuid4();
        result.timestamp = Date.now();
        break; // Stop after finding the first Data URI
      }
    }
  }

  return result;
}

async function setServingInProgress(value) {
  servingInProgress.value = value;
  // Wait for gui to update
  await new Promise(r => setTimeout(r, 250));
}

async function onData(data) {

  if (!environmentInitialized.value) {
    return;
  }

  // Set serving flag as working
  await setServingInProgress(true);
  // Clear the prediction list
  predictionInfoList.value = [];
  
  // Hack, we need to remove field calibration if it is any empty list
  // -----------------------------------------------------------------
  // This should be done in the model, but older version are deployed with not ignoring calibration when empty
  if (data.calibration && data.calibration.length == 0) {
    delete data.calibration;
  }

  // Lets turn this into a json string
  const inputData = JSON.stringify(data)

  try {

    const inputTensor = new ort.Tensor('string', [inputData], [1]);
    const model_output = await session.value.run({
      input: inputTensor,
    });

    predictionInfoList.value = processPredictions(props, model_output);

    data.prediction = {};
    data.prediction.values = convertResponse(model_output);

    inferenceError.value = '';
    // We emit the prediction data
    events.emit("prediction", data);
    // Parse the response for an image
    const image_data = parseResponseForImage(data.prediction.values);
    // if we have non empty data, send it
    if (Object.keys(image_data).length !== 0) {
      events.emit("image", image_data);
    }

  } catch (e) {
    inferenceError.value = LOCAL_SERVING_MODULE_STRINGS.INFERENCE_ERROR;
  }
  finally {
    await setServingInProgress(false);
  }
}

</script>

<style></style>