-
-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong shape of dense_input #508
Comments
@hsahovic did you find any solution/tempory fix please let me know since I have to complete my college project in a week |
Hi @Ank-22 and @snufkinnikfuns, import numpy as np
from stable_baselines3 import A2C
from gymnasium.spaces import Box
from poke_env.data import GenData
from poke_env.player import Gen9EnvSinglePlayer, RandomPlayer
# We define our RL player
# It needs a state embedder and a reward computer, hence these two methods
class SimpleRLPlayer(Gen9EnvSinglePlayer):
def embed_battle(self, battle):
# -1 indicates that the move does not have a base power
# or is not available
moves_base_power = -np.ones(4)
moves_dmg_multiplier = np.ones(4)
for i, move in enumerate(battle.available_moves):
moves_base_power[i] = (
move.base_power / 100
) # Simple rescaling to facilitate learning
if move.type:
moves_dmg_multiplier[i] = move.type.damage_multiplier(
battle.opponent_active_pokemon.type_1,
battle.opponent_active_pokemon.type_2,
type_chart=GEN_9_DATA.type_chart
)
# We count how many pokemons have not fainted in each team
remaining_mon_team = (
len([mon for mon in battle.team.values() if mon.fainted]) / 6
)
remaining_mon_opponent = (
len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
)
# Final vector with 10 components
return np.concatenate(
[
moves_base_power,
moves_dmg_multiplier,
[remaining_mon_team, remaining_mon_opponent],
]
)
def calc_reward(self, last_state, current_state) -> float:
return self.reward_computing_helper(
current_state, fainted_value=2, hp_value=1, victory_value=30
)
def describe_embedding(self):
low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
return Box(
np.array(low, dtype=np.float32),
np.array(high, dtype=np.float32),
dtype=np.float32,
)
class MaxDamagePlayer(RandomPlayer):
def choose_move(self, battle):
# If the player can attack, it will
if battle.available_moves:
# Finds the best move among available ones
best_move = max(battle.available_moves, key=lambda move: move.base_power)
return self.create_order(best_move)
# If no attack is available, a random switch will be made
else:
return self.choose_random_move(battle)
NB_TRAINING_STEPS = 10000
NB_EVALUATION_EPISODES = 100
np.random.seed(0)
model_store = {}
# This is the function that will be used to train the a2c
def a2c_training(player, nb_steps):
model = A2C("MlpPolicy", player, verbose=1)
model.learn(total_timesteps=10_000)
model_store[player] = model
def a2c_evaluation(player, nb_episodes):
# Reset battle statistics
model = model_store[player]
player.reset_battles()
model.test(player, nb_episodes=nb_episodes, visualize=False, verbose=False)
print(
"A2C Evaluation: %d victories out of %d episodes"
% (player.n_won_battles, nb_episodes)
)
NB_TRAINING_STEPS = 20_000
TEST_EPISODES = 100
GEN_9_DATA = GenData.from_gen(9)
if __name__ == "__main__":
opponent = RandomPlayer()
env_player = SimpleRLPlayer(opponent=opponent)
second_opponent = MaxDamagePlayer()
model = A2C("MlpPolicy", env_player, verbose=1)
model.learn(total_timesteps=NB_TRAINING_STEPS)
obs, reward, done, _, info = env_player.step(0)
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, _, info = env_player.step(action)
finished_episodes = 0
env_player.reset_battles()
obs, _ = env_player.reset()
while True:
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, _, info = env_player.step(action)
if done:
finished_episodes += 1
if finished_episodes >= TEST_EPISODES:
break
obs, _ = env_player.reset()
print("Won", env_player.n_won_battles, "battles against", env_player._opponent)
finished_episodes = 0
env_player._opponent = second_opponent
env_player.reset_battles()
obs, _ = env_player.reset()
while True:
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, _, info = env_player.step(action)
if done:
finished_episodes += 1
obs, _ = env_player.reset()
if finished_episodes >= TEST_EPISODES:
break
print("Won", env_player.n_won_battles, "battles against", env_player._opponent) For what it's worth, updating the docs and examples to use such a framework is on my to do list. |
Thanks, for the Update. I am also exploring things around, exploring more of the Poke-env, and I am looking forward to contributing to this repo! |
When running the above scripts, after a few battle the bots stops and don't make a move. In the terminal I get this error `2024-04-15 17:04:58,195 - RandomPlayer 1 - ERROR - Unhandled exception raised while handling message:
Here is the screenshot of the showdown |
Tried to run rl with new gym wrapper code and it gives out the following error
expected dense_input to have shape (1, 10) but got array with shape (1, 2)
this is the code `import asyncio
from poke_env.ps_client import AccountConfiguration, ShowdownServerConfiguration
import numpy as np
from gymnasium.spaces import Box, Space
from rl.agents.dqn import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import EpsGreedyQPolicy, LinearAnnealedPolicy
from keras.layers import Dense, Flatten
from keras.models import Sequential
from keras.optimizers import Adam
from poke_env.data import GenData
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player import (
Gen9EnvSinglePlayer,
MaxBasePowerPlayer,
ObsType,
RandomPlayer
)
from keras.internal import enable_unsafe_deserialization
enable_unsafe_deserialization()
class AdvancedDQNPlayer(Gen9EnvSinglePlayer):
def calc_reward(self, last_battle, current_battle) -> float:
return self.reward_computing_helper(
current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
)
async def main():
opponent = RandomPlayer(battle_format="gen9randombattle")
env = AdvancedDQNPlayer(
battle_format="gen9randombattle", opponent=opponent, start_challenging=True
)
env2 = AdvancedDQNPlayer(
battle_format="gen9randombattle", opponent=opponent, start_challenging=True
)
if name == "main":
asyncio.get_event_loop().run_until_complete(main())`
The text was updated successfully, but these errors were encountered: