csle_agents.agents.sarsa package
csle_agents.agents.sarsa.sarsa_agent module
- class csle_agents.agents.sarsa.sarsa_agent.SARSAAgent(simulation_env_config: SimulationEnvConfig, experiment_config: ExperimentConfig, training_job: Optional[TrainingJobConfig] = None, save_to_metastore: bool = True, env: Optional[BaseEnv] = None)[source]
- create_policy_from_q_table(num_states: int, num_actions: int, q_table: ndarray[Any, dtype[Any]]) ndarray[Any, dtype[Any]] [source]
Creates a tabular policy from a q table
- Parameters
num_states – the number of states
num_actions – the number of actions
q_table – the q_table
- Returns
the tabular policy
- eps_greedy(q_table: ndarray[Any, dtype[Any]], A: List[int], s: int, epsilon: float = 0.2) int [source]
Selects an action according to the epsilon-greedy strategy
- Parameters
q_table – the q table
A – the action space
s – the state
epsilon – the exploration epsilon
- Returns
the sampled action
- evaluate_policy(policy: ndarray[Any, dtype[Any]], eval_batch_size: int) float [source]
Evalutes a tabular policy
- Parameters
policy – the tabular policy to evaluate
eval_batch_size – the batch size
- Returns
- initialize_count_table(n_states: int = 256, n_actions: int = 5) ndarray[Any, dtype[Any]] [source]
Initializes the count table
- Parameters
n_states – the number of states in the MDP
n_actions – the number of actions in the MDP
- Returns
the initialized count table
- initialize_q_table(n_states: int = 256, n_actions: int = 5) ndarray[Any, dtype[Any]] [source]
Initializes the Q table
- Parameters
n_states – the number of states in the MDP
n_actions – the number of actions in the MDP
- Returns
the initialized Q table
- q_learning(exp_result: ExperimentResult, seed: int) ExperimentResult [source]
Runs the SARSA algorithm
- Parameters
exp_result – the experiment result object
seed – the random seed
- Returns
the updated experiment result
- sarsa_update(q_table: ndarray[Any, dtype[Any]], count_table: ndarray[Any, dtype[Any]], s: int, a: int, r: float, s_prime: int, gamma: float, a1: int) Tuple[ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]]] [source]
SARSA update
- Parameters
q_table – the Q-table
count_table – the count table (used for determining SA step sizes)
s – the sampled state
a – the exploration action
r – the reward
s_prime – the next sampled state
gamma – the discount factor
a1 – the next eaction
- Returns
the updated q table and updated count table
- step_size(n: int) float [source]
Calculates the SA step size
- Parameters
n – the iteration
- Returns
the step size
- train_sarsa(A: List[int], S: List[int], gamma: float = 0.8, N: int = 10000, epsilon: float = 0.2) Tuple[List[float], List[float], List[int], ndarray[Any, dtype[Any]], ndarray[Any, dtype[Any]]] [source]
Runs the Q learning algorithm
- Parameters
A – the action space
S – the state space
gamma – the discount factor
N – the number of iterations
epsilon – the exploration parameter
- Returns
the average returns, the running average returns, the initial state values, the q table, policy