csle_agents.agents.reinforce package
Submodules
csle_agents.agents.reinforce.reinforce_agent module
- class csle_agents.agents.reinforce.reinforce_agent.ReinforceAgent(simulation_env_config: csle_common.dao.simulation_config.simulation_env_config.SimulationEnvConfig, emulation_env_config: Union[None, csle_common.dao.emulation_config.emulation_env_config.EmulationEnvConfig], experiment_config: csle_common.dao.training.experiment_config.ExperimentConfig, env: Optional[csle_common.dao.simulation_config.base_env.BaseEnv] = None, training_job: Optional[csle_common.dao.jobs.training_job_config.TrainingJobConfig] = None, save_to_metastore: bool = True)[source]
Bases:
csle_agents.agents.base.base_agent.BaseAgent
Reinforce Agent
- static compute_avg_metrics(metrics: Dict[str, List[Union[float, int]]]) Dict[str, Union[float, int]] [source]
Computes the average metrics of a dict with aggregated metrics
- Parameters
metrics – the dict with the aggregated metrics
- Returns
the average metrics
- reinforce(exp_result: csle_common.dao.training.experiment_result.ExperimentResult, seed: int, training_job: csle_common.dao.jobs.training_job_config.TrainingJobConfig, random_seeds: List[int]) csle_common.dao.training.experiment_result.ExperimentResult [source]
Runs the random search algorithm
- Parameters
exp_result – the experiment result object to store the result
seed – the seed
training_job – the training job config
random_seeds – list of seeds
- Returns
the updated experiment result and the trained policy
- static round_vec(vec) List[float] [source]
Rounds a vector to 3 decimals
- Parameters
vec – the vector to round
- Returns
the rounded vector
- train() csle_common.dao.training.experiment_execution.ExperimentExecution [source]
Performs the policy training for the given random seeds using reinforce
- Returns
the training metrics and the trained policies
- training_step(saved_rewards: List[List[float]], saved_log_probs: List[List[torch.Tensor]], policy_network: csle_common.models.fnn_w_softmax.FNNwithSoftmax, optimizer: torch.optim.optimizer.Optimizer, gamma: float) torch.Tensor [source]
Performs a training step of the REINFORCE algorithm
- Parameters
saved_rewards – list of rewards encountered in the latest episode trajectory
saved_log_probs – list of log-action probabilities (log p(a|s)) encountered in the latest episode trajectory
policy_network – the policy network
optimizer – the optimizer for updating the weights
gamma – the discount factor
- Returns
loss
- static update_metrics(metrics: Dict[str, List[Union[float, int]]], info: Dict[str, Union[float, int]]) Dict[str, List[Union[float, int]]] [source]
Update a dict with aggregated metrics using new information from the environment
- Parameters
metrics – the dict with the aggregated metrics
info – the new information
- Returns
the updated dict