/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.confignode.manager;

import java.util.List;
import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp;
import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.client.exception.ClientManagerException;
import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.model.ModelStatus;
import org.apache.iotdb.commons.model.ModelType;
import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import org.apache.iotdb.confignode.exception.NoAvailableAINodeException;
import org.apache.iotdb.confignode.manager.ConfigManager;
import org.apache.iotdb.confignode.persistence.ModelInfo;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo;
import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
import org.apache.iotdb.consensus.exception.ConsensusException;
import org.apache.iotdb.db.protocol.client.ainode.AINodeClient;
import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager;
import org.apache.iotdb.rpc.TSStatusCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelManager {
    private static final Logger LOGGER = LoggerFactory.getLogger(ModelManager.class);
    private final ConfigManager configManager;
    private final ModelInfo modelInfo;

    public ModelManager(ConfigManager configManager, ModelInfo modelInfo) {
        this.configManager = configManager;
        this.modelInfo = modelInfo;
    }

    public TSStatus createModel(TCreateModelReq req) {
        if (this.modelInfo.contain(req.modelName)) {
            return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()).setMessage(String.format("Model name %s already exists", req.modelName));
        }
        try {
            if (req.uri.isEmpty()) {
                return this.configManager.getConsensusManager().write(new CreateModelPlan(req.modelName));
            }
            return this.configManager.getProcedureManager().createModel(req.modelName, req.uri);
        }
        catch (ConsensusException e) {
            LOGGER.warn("Unexpected error happened while getting model: ", (Throwable)e);
            TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
            res.setMessage(e.getMessage());
            return res;
        }
    }

    public TSStatus dropModel(TDropModelReq req) {
        if (this.modelInfo.checkModelType(req.getModelId()) != ModelType.USER_DEFINED) {
            return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()).setMessage(String.format("Built-in model %s can't be removed", req.modelId));
        }
        if (!this.modelInfo.contain(req.modelId)) {
            return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()).setMessage(String.format("Model name %s doesn't exists", req.modelId));
        }
        return this.configManager.getProcedureManager().dropModel(req.getModelId());
    }

    public TSStatus loadModel(TLoadModelReq req) {
        TSStatus tSStatus;
        block8: {
            AINodeClient client = this.getAINodeClient();
            try {
                TLoadModelReq loadModelReq = new TLoadModelReq(req.existingModelId, req.deviceIdList);
                tSStatus = client.loadModel(loadModelReq);
                if (client == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (client != null) {
                        try {
                            client.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    LOGGER.warn("Failed to load model due to", (Throwable)e);
                    return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()).setMessage(e.getMessage());
                }
            }
            client.close();
        }
        return tSStatus;
    }

    public TSStatus unloadModel(TUnloadModelReq req) {
        TSStatus tSStatus;
        block8: {
            AINodeClient client = this.getAINodeClient();
            try {
                TUnloadModelReq unloadModelReq = new TUnloadModelReq(req.modelId, req.deviceIdList);
                tSStatus = client.unloadModel(unloadModelReq);
                if (client == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (client != null) {
                        try {
                            client.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    LOGGER.warn("Failed to unload model due to", (Throwable)e);
                    return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()).setMessage(e.getMessage());
                }
            }
            client.close();
        }
        return tSStatus;
    }

    public TShowModelsResp showModel(TShowModelsReq req) {
        TShowModelsResp tShowModelsResp;
        block9: {
            AINodeClient client = this.getAINodeClient();
            try {
                TShowModelsReq showModelsReq = new TShowModelsReq();
                if (req.isSetModelId()) {
                    showModelsReq.setModelId(req.getModelId());
                }
                TShowModelsResp resp = client.showModels(showModelsReq);
                TShowModelsResp res = new TShowModelsResp().setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
                res.setModelIdList(resp.getModelIdList());
                res.setModelTypeMap(resp.getModelTypeMap());
                res.setCategoryMap(resp.getCategoryMap());
                res.setStateMap(resp.getStateMap());
                tShowModelsResp = res;
                if (client == null) break block9;
            }
            catch (Throwable throwable) {
                try {
                    if (client != null) {
                        try {
                            client.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    LOGGER.warn("Failed to show models due to", (Throwable)e);
                    return new TShowModelsResp().setStatus(new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()).setMessage(e.getMessage()));
                }
            }
            client.close();
        }
        return tShowModelsResp;
    }

    public TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req) {
        TShowLoadedModelsResp tShowLoadedModelsResp;
        block8: {
            AINodeClient client = this.getAINodeClient();
            try {
                TShowLoadedModelsReq showModelsReq = new TShowLoadedModelsReq().setDeviceIdList(req.getDeviceIdList());
                TShowLoadedModelsResp resp = client.showLoadedModels(showModelsReq);
                TShowLoadedModelsResp res = new TShowLoadedModelsResp().setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
                res.setDeviceLoadedModelsMap(resp.getDeviceLoadedModelsMap());
                tShowLoadedModelsResp = res;
                if (client == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (client != null) {
                        try {
                            client.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    LOGGER.warn("Failed to show loaded models due to", (Throwable)e);
                    return new TShowLoadedModelsResp().setStatus(new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()).setMessage(e.getMessage()));
                }
            }
            client.close();
        }
        return tShowLoadedModelsResp;
    }

    public TShowAIDevicesResp showAIDevices() {
        TShowAIDevicesResp tShowAIDevicesResp;
        block8: {
            AINodeClient client = this.getAINodeClient();
            try {
                TShowAIDevicesResp resp = client.showAIDevices();
                TShowAIDevicesResp res = new TShowAIDevicesResp().setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
                res.setDeviceIdList(resp.getDeviceIdList());
                tShowAIDevicesResp = res;
                if (client == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (client != null) {
                        try {
                            client.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    LOGGER.warn("Failed to show AI devices due to", (Throwable)e);
                    return new TShowAIDevicesResp().setStatus(new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()).setMessage(e.getMessage()));
                }
            }
            client.close();
        }
        return tShowAIDevicesResp;
    }

    public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
        return new TGetModelInfoResp().setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())).setAiNodeAddress(this.configManager.getNodeManager().getRegisteredAINodes().get(0).getLocation().getInternalEndPoint());
    }

    public TSStatus updateModelInfo(TUpdateModelInfoReq req) {
        if (!this.modelInfo.contain(req.getModelId())) {
            return new TSStatus(TSStatusCode.MODEL_NOT_FOUND_ERROR.getStatusCode()).setMessage(String.format("Model %s doesn't exists", req.getModelId()));
        }
        try {
            ModelInformation modelInformation = new ModelInformation(ModelType.USER_DEFINED, req.getModelId());
            modelInformation.updateStatus(ModelStatus.values()[req.getModelStatus()]);
            modelInformation.setAttribute(req.getAttributes());
            modelInformation.setInputColumnSize(1);
            if (req.isSetOutputLength()) {
                modelInformation.setOutputLength(req.getOutputLength());
            }
            if (req.isSetInputLength()) {
                modelInformation.setInputLength(req.getInputLength());
            }
            UpdateModelInfoPlan updateModelInfoPlan = new UpdateModelInfoPlan(req.getModelId(), modelInformation);
            if (req.isSetAiNodeIds()) {
                updateModelInfoPlan.setNodeIds(req.getAiNodeIds());
            }
            return this.configManager.getConsensusManager().write(updateModelInfoPlan);
        }
        catch (ConsensusException e) {
            LOGGER.warn("Unexpected error happened while updating model info: ", (Throwable)e);
            TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
            res.setMessage(e.getMessage());
            return res;
        }
    }

    private AINodeClient getAINodeClient() throws NoAvailableAINodeException, ClientManagerException {
        List<TAINodeInfo> aiNodeInfo = this.configManager.getNodeManager().getRegisteredAINodeInfoList();
        if (aiNodeInfo.isEmpty()) {
            throw new NoAvailableAINodeException();
        }
        TEndPoint targetAINodeEndPoint = new TEndPoint(aiNodeInfo.get(0).getInternalAddress(), aiNodeInfo.get(0).getInternalPort());
        try {
            return AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public List<Integer> getModelDistributions(String modelName) {
        return this.modelInfo.getNodeIds(modelName);
    }
}

