Source code for csle_rest_api.resources.recovery_ai.routes

"""
Routes and sub-resources for the /recovery-ai resource
"""
from __future__ import annotations

from typing import Tuple, Dict, Any, Generator
from flask import Blueprint, Response, request, stream_with_context, jsonify
from peft import PeftModel, PeftConfig
import json
import random
import torch
import csle_common.constants.constants as constants
import csle_rest_api.constants.constants as api_constants
import csle_rest_api.util.rest_api_util as rest_api_util
from csle_common.logging.log import Logger
from csle_rest_api.resources.recovery_ai.prompts import Prompts
from csle_rest_api.resources.recovery_ai.recovery_ai_util import RecoveryAIUtil
from csle_common.metastore.metastore_facade import MetastoreFacade

recovery_ai_bp = Blueprint(
    api_constants.MGMT_WEBAPP.RECOVERY_AI_RESOURCE, __name__,
    url_prefix=f"{constants.COMMANDS.SLASH_DELIM}{api_constants.MGMT_WEBAPP.RECOVERY_AI_RESOURCE}")

try:
    config = MetastoreFacade.get_config()
    if config.recovery_ai:
        tokenizer, llm = RecoveryAIUtil.load_llm()
        output_dir = config.recovery_ai_output_dir
        peft_config = PeftConfig.from_pretrained(output_dir)
        assert isinstance(llm, torch.nn.Module)
        model = PeftModel.from_pretrained(llm, output_dir)
        model.eval()
        with open(config.recovery_ai_examples_path, "r") as f:
            examples = json.load(f)
except Exception as e:
    Logger.__call__().get_logger().warning(f"There was an exception loading LLM for RecoveryAI. "
                                           f"Exception: {str(e)}, {repr(e)}")


[docs]@recovery_ai_bp.route("", methods=[api_constants.MGMT_WEBAPP.HTTP_REST_POST]) def recovery_ai() -> Tuple[Response, int]: """ SSE endpoint that streams tokens of the recovery plan """ authorized = rest_api_util.check_if_user_is_authorized(request=request, requires_admin=False) if authorized is not None: return authorized body: Dict[str, Any] = request.get_json(silent=True) or {} system: str = body.get(api_constants.MGMT_WEBAPP.RECOVERY_AI_JSON_SYSTEM_DESCRIPTION, "") logs: str = body.get(api_constants.MGMT_WEBAPP.RECOVERY_AI_JSON_LOGS, "") rag: bool = body.get(api_constants.MGMT_WEBAPP.RECOVERY_AI_RAG, False) lookahead_horizon: int = body.get(api_constants.MGMT_WEBAPP.RECOVERY_AI_LOOKAHEAD_HORIZON, 1) rollout_horizon: int = body.get(api_constants.MGMT_WEBAPP.RECOVERY_AI_ROLLOUT_HORIZON, 1) optimization_steps: int = body.get(api_constants.MGMT_WEBAPP.RECOVERY_AI_OPTIMIZATION_STEPS, 1) temperature: float = body.get(api_constants.MGMT_WEBAPP.RECOVERY_AI_TEMPERATURE, 1.0) def llm_stream(logs: str, system: str) -> Generator[str, None, None]: if rag: rag_output = yield from RecoveryAIUtil.gemini_rag(logs=logs) if rag_output is None: raise ValueError("There was an exception generating Rag output") identifiers, contexts = rag_output logs = RecoveryAIUtil.enrich_logs(logs=logs, identifiers=identifiers, contexts=contexts) incident_input = RecoveryAIUtil.create_incident_prompt(system, logs, Prompts.INCIDENT_PROMPT_TEMPLATE) incident_classification = yield from RecoveryAIUtil.incident_classification( incident_input=incident_input, model=model, tokenizer=tokenizer, system=system, logs=logs, num_optimization_steps=optimization_steps, temperature=temperature, lookahead_horizon=lookahead_horizon, rollout_horizon=rollout_horizon) if incident_classification is None: raise ValueError("There was an exception classifying the incident") if incident_classification[Prompts.INCIDENT] == Prompts.INCIDENT_YES: incident_str = RecoveryAIUtil.format_incident_description(incident_classification) yield from RecoveryAIUtil.recovery_loop(model=model, tokenizer=tokenizer, system=system, logs=logs, incident_str=incident_str, action_prompt_template=Prompts.ACTION_PROMPT_TEMPLATE, state_prompt_template=Prompts.STATE_PROMPT_TEMPLATE, temperature=temperature, lookahead_horizon=lookahead_horizon, rollout_horizon=rollout_horizon, num_optimization_steps=optimization_steps) response = Response( stream_with_context(llm_stream(logs=logs, system=system)), status=constants.HTTPS.OK_STATUS_CODE, content_type=api_constants.MGMT_WEBAPP.EVENT_STREAM_CONTENT_TYPE, ) return response, constants.HTTPS.OK_STATUS_CODE
[docs]@recovery_ai_bp.route("/example", methods=[api_constants.MGMT_WEBAPP.HTTP_REST_GET]) def example() -> Tuple[Response, int]: """ The /recovery-ai/example resource. :return: An example system description and log """ requires_admin = False if request.method == api_constants.MGMT_WEBAPP.HTTP_REST_DELETE: requires_admin = True authorized = rest_api_util.check_if_user_is_authorized(request=request, requires_admin=requires_admin) if authorized is not None: return authorized example = random.choice(examples) example_dict = {api_constants.MGMT_WEBAPP.RECOVERY_AI_JSON_SYSTEM_DESCRIPTION: example["System"], api_constants.MGMT_WEBAPP.RECOVERY_AI_JSON_LOGS: "\n".join(example["Logs"])} response = jsonify(example_dict) response.headers.add(api_constants.MGMT_WEBAPP.ACCESS_CONTROL_ALLOW_ORIGIN_HEADER, "*") return response, constants.HTTPS.OK_STATUS_CODE