Source code for csle_agents.job_controllers.training_job_manager

import subprocess
import csle_common.constants.constants as constants
from csle_common.dao.jobs.training_job_config import TrainingJobConfig
from csle_common.dao.training.agent_type import AgentType
from csle_common.metastore.metastore_facade import MetastoreFacade
from csle_agents.agents.t_spsa.t_spsa_agent import TSPSAAgent
from csle_agents.agents.ppo.ppo_agent import PPOAgent


[docs]class TrainingJobManager: """ Class that manages training jobs in CSLE """
[docs] @staticmethod def run_training_job(job_config: TrainingJobConfig) -> None: """ Runs a given training job :param job_config: the configuration of the job :return: None """ emulation_env_config = None simulation_env_config = None if job_config.emulation_env_name is not None: emulation_env_config = MetastoreFacade.get_emulation_by_name(name=job_config.emulation_env_name) if job_config.simulation_env_name is not None: simulation_env_config = MetastoreFacade.get_simulation_by_name(name=job_config.simulation_env_name) if job_config.experiment_config.agent_type == AgentType.T_SPSA: tspsa_agent = TSPSAAgent(emulation_env_config=emulation_env_config, simulation_env_config=simulation_env_config, experiment_config=job_config.experiment_config, training_job=job_config) tspsa_agent.train() elif job_config.experiment_config.agent_type == AgentType.PPO: ppo_agent = PPOAgent(emulation_env_config=emulation_env_config, simulation_env_config=simulation_env_config, experiment_config=job_config.experiment_config, training_job=job_config) experiment_execution = ppo_agent.train() for policy in experiment_execution.result.policies.values(): MetastoreFacade.save_ppo_policy(ppo_policy=policy) else: raise ValueError(f"Agent type: {job_config.experiment_config.agent_type} not recognized")
[docs] @staticmethod def start_training_job_in_background(training_job: TrainingJobConfig) -> None: """ Starts a training job with a given configuration in the background :param training_job: the job configuration :return: None """ cmd = constants.COMMANDS.START_TRAINING_JOB.format(training_job.id) p = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, shell=True) p.communicate()