Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong shape of dense_input #508

Open
snufkinnikfuns opened this issue Mar 2, 2024 · 4 comments
Open

Wrong shape of dense_input #508

snufkinnikfuns opened this issue Mar 2, 2024 · 4 comments
Assignees

Comments

@snufkinnikfuns
Copy link

Tried to run rl with new gym wrapper code and it gives out the following error

expected dense_input to have shape (1, 10) but got array with shape (1, 2)

this is the code `import asyncio
from poke_env.ps_client import AccountConfiguration, ShowdownServerConfiguration
import numpy as np
from gymnasium.spaces import Box, Space
from rl.agents.dqn import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import EpsGreedyQPolicy, LinearAnnealedPolicy
from keras.layers import Dense, Flatten
from keras.models import Sequential
from keras.optimizers import Adam
from poke_env.data import GenData
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player import (
Gen9EnvSinglePlayer,
MaxBasePowerPlayer,
ObsType,
RandomPlayer
)
from keras.internal import enable_unsafe_deserialization

enable_unsafe_deserialization()

class AdvancedDQNPlayer(Gen9EnvSinglePlayer):
def calc_reward(self, last_battle, current_battle) -> float:
return self.reward_computing_helper(
current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
)

def embed_battle(self, battle: AbstractBattle) -> ObsType:
    # -1 indicates that the move does not have a base power
    # or is not available
    moves_base_power = -np.ones(4)
    moves_dmg_multiplier = np.ones(4)
    for i, move in enumerate(battle.available_moves):
        moves_base_power[i] = (
            move.base_power / 100
        )  # Simple rescaling to facilitate learning
        if move.type:
            moves_dmg_multiplier[i] = move.type.damage_multiplier(
                battle.opponent_active_pokemon.type_1,
                battle.opponent_active_pokemon.type_2,
                GenData.load_type_chart(GenData, 9)
            )

    # We count how many pokemons have fainted in each team
    fainted_mon_team = len([mon for mon in battle.team.values() if mon.fainted]) / 6
    fainted_mon_opponent = (
        len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
    )

    # Final vector with 10 components
    final_vector = np.concatenate(
        [
            moves_base_power,
            moves_dmg_multiplier,
            [fainted_mon_team, fainted_mon_opponent],
        ]
    )
    return np.float32(final_vector)

def describe_embedding(self) -> Space:
    low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
    high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
    return Box(
        np.array(low, dtype=np.float32),
        np.array(high, dtype=np.float32),
        dtype=np.float32,
    )

async def main():
opponent = RandomPlayer(battle_format="gen9randombattle")
env = AdvancedDQNPlayer(
battle_format="gen9randombattle", opponent=opponent, start_challenging=True
)
env2 = AdvancedDQNPlayer(
battle_format="gen9randombattle", opponent=opponent, start_challenging=True
)

n_action = env.action_space.n
input_shape = (1,) + env.observation_space.shape

# Create model
model = Sequential()
model.add(Dense(128, activation="elu", input_shape=input_shape))
model.add(Flatten())
model.add(Dense(64, activation="elu"))
model.add(Dense(n_action, activation="linear"))


# Defining the DQN
memory = SequentialMemory(limit=10000, window_length=1)

policy = LinearAnnealedPolicy(
    EpsGreedyQPolicy(),
    attr="eps",
    value_max=1.0,
    value_min=0.05,
    value_test=0.0,
    nb_steps=10000,
)

dqn = DQNAgent(
    model=model,
    nb_actions=n_action,
    policy=policy,
    memory=memory,
    nb_steps_warmup=1000,
    gamma=0.5,
    target_model_update=1,
    delta_clip=0.01,
    enable_double_dqn=True,
)
dqn.compile(Adam(learning_rate=0.00025), metrics=["mae"])

dqn.fit(env, nb_steps=10000)
env.close()

# Evaluating the model
print("Results against random player:")
dqn.test(env2, nb_episodes=100, verbose=False, visualize=False)
print(
    f"DQN Evaluation: {env2.n_won_battles} victories out of {env2.n_finished_battles} episodes"
)
second_opponent = MaxBasePowerPlayer(battle_format="gen9randombattle")
env2.reset_env(restart=True, opponent=second_opponent)
print("Results against max base power player:")
dqn.test(env2, nb_episodes=100, verbose=False, visualize=False)
print(
    f"DQN Evaluation: {env2.n_won_battles} victories out of {env2.n_finished_battles} episodes"
)
env2.reset_env(restart=False)

model.save("advanced_dqn_model.h5")

if name == "main":
asyncio.get_event_loop().run_until_complete(main())`

@hsahovic hsahovic self-assigned this Mar 3, 2024
@Ank-22
Copy link

Ank-22 commented Apr 14, 2024

@hsahovic did you find any solution/tempory fix please let me know since I have to complete my college project in a week

@hsahovic
Copy link
Owner

Hi @Ank-22 and @snufkinnikfuns,
keras-rl has stopped development a while ago and is hard to get running these days. There are combinations of package versions where this example works - you can refer to older issue to find them - but i would recommend using a maintained framework.
Here's an example using stable-baselines3:

import numpy as np
from stable_baselines3 import A2C
from gymnasium.spaces import Box
from poke_env.data import GenData

from poke_env.player import Gen9EnvSinglePlayer, RandomPlayer


# We define our RL player
# It needs a state embedder and a reward computer, hence these two methods
class SimpleRLPlayer(Gen9EnvSinglePlayer):
    def embed_battle(self, battle):
        # -1 indicates that the move does not have a base power
        # or is not available
        moves_base_power = -np.ones(4)
        moves_dmg_multiplier = np.ones(4)
        for i, move in enumerate(battle.available_moves):
            moves_base_power[i] = (
                move.base_power / 100
            )  # Simple rescaling to facilitate learning
            if move.type:
                moves_dmg_multiplier[i] = move.type.damage_multiplier(
                    battle.opponent_active_pokemon.type_1,
                    battle.opponent_active_pokemon.type_2,
                    type_chart=GEN_9_DATA.type_chart

                )

        # We count how many pokemons have not fainted in each team
        remaining_mon_team = (
            len([mon for mon in battle.team.values() if mon.fainted]) / 6
        )
        remaining_mon_opponent = (
            len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
        )

        # Final vector with 10 components
        return np.concatenate(
            [
                moves_base_power,
                moves_dmg_multiplier,
                [remaining_mon_team, remaining_mon_opponent],
            ]
        )

    def calc_reward(self, last_state, current_state) -> float:
        return self.reward_computing_helper(
            current_state, fainted_value=2, hp_value=1, victory_value=30
        )
    
    def describe_embedding(self):
        low = [-1, -1, -1, -1, 0, 0, 0, 0, 0, 0]
        high = [3, 3, 3, 3, 4, 4, 4, 4, 1, 1]
        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32,
        )


class MaxDamagePlayer(RandomPlayer):
    def choose_move(self, battle):
        # If the player can attack, it will
        if battle.available_moves:
            # Finds the best move among available ones
            best_move = max(battle.available_moves, key=lambda move: move.base_power)
            return self.create_order(best_move)

        # If no attack is available, a random switch will be made
        else:
            return self.choose_random_move(battle)


NB_TRAINING_STEPS = 10000
NB_EVALUATION_EPISODES = 100

np.random.seed(0)


model_store = {}

# This is the function that will be used to train the a2c
def a2c_training(player, nb_steps):
    model = A2C("MlpPolicy", player, verbose=1)
    model.learn(total_timesteps=10_000)
    model_store[player] = model
    


def a2c_evaluation(player, nb_episodes):
    # Reset battle statistics
    model = model_store[player]
    player.reset_battles()
    model.test(player, nb_episodes=nb_episodes, visualize=False, verbose=False)

    print(
        "A2C Evaluation: %d victories out of %d episodes"
        % (player.n_won_battles, nb_episodes)
    )


NB_TRAINING_STEPS = 20_000
TEST_EPISODES = 100
GEN_9_DATA = GenData.from_gen(9)

if __name__ == "__main__":
    opponent = RandomPlayer()
    env_player = SimpleRLPlayer(opponent=opponent)
    second_opponent = MaxDamagePlayer()

    model = A2C("MlpPolicy", env_player, verbose=1)
    model.learn(total_timesteps=NB_TRAINING_STEPS)

    obs, reward, done, _, info = env_player.step(0)
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, _, info = env_player.step(action)

    finished_episodes = 0

    env_player.reset_battles()
    obs, _ = env_player.reset()
    while True:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, _, info = env_player.step(action)

        if done:
            finished_episodes += 1
            if finished_episodes >= TEST_EPISODES:
                break
            obs, _ = env_player.reset()

    print("Won", env_player.n_won_battles, "battles against", env_player._opponent)

    finished_episodes = 0
    env_player._opponent = second_opponent

    env_player.reset_battles()
    obs, _ = env_player.reset()
    while True:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, _, info = env_player.step(action)

        if done:
            finished_episodes += 1
            obs, _ = env_player.reset()
            if finished_episodes >= TEST_EPISODES:
                break

    print("Won", env_player.n_won_battles, "battles against", env_player._opponent)

For what it's worth, updating the docs and examples to use such a framework is on my to do list.

@Ank-22
Copy link

Ank-22 commented Apr 14, 2024

Thanks, for the Update.

I am also exploring things around, exploring more of the Poke-env, and I am looking forward to contributing to this repo!

@Ank-22
Copy link

Ank-22 commented Apr 15, 2024

@hsahovic

When running the above scripts, after a few battle the bots stops and don't make a move. In the terminal I get this error

`2024-04-15 17:04:58,195 - RandomPlayer 1 - ERROR - Unhandled exception raised while handling message:

battle-gen9randombattle-11218
|player|p2|RandomPlayer 1|1|
|teamsize|p1|6
|teamsize|p2|6
|gen|9
|tier|[Gen 9] Random Battle
|rule|Species Clause: Limit one of each Pokémon
|rule|HP Percentage Mod: HP is shown in percentages
|rule|Sleep Clause Mod: Limit one foe put to sleep
|rule|Illusion Level Mod: Illusion disguises the Pokémon's true level
|
|t:|1713215098
|start
|switch|p1a: Archaludon|Archaludon, L78, F|100/100
|switch|p2a: Iron Valiant|Iron Valiant, L79|247/247
|turn|1
Traceback (most recent call last):
File "/home/ank22/.local/lib/python3.10/site-packages/poke_env/ps_client/ps_client.py", line 138, in _handle_message
await self._handle_battle_message(split_messages) # type: ignore
File "/home/ank22/.local/lib/python3.10/site-packages/poke_env/player/player.py", line 353, in _handle_battle_message
battle.parse_message(split_message)
File "/home/ank22/.local/lib/python3.10/site-packages/poke_env/environment/abstract_battle.py", line 387, in parse_message
self.switch(pokemon, details, hp_status)
File "/home/ank22/.local/lib/python3.10/site-packages/poke_env/environment/battle.py", line 148, in switch
pokemon = self.get_pokemon(pokemon_str, details=details)
File "/home/ank22/.local/lib/python3.10/site-packages/poke_env/environment/abstract_battle.py", line 226, in get_pokemon
team[identifier] = Pokemon(details=details, gen=self._data.gen)
File "/home/ank22/.local/lib/python3.10/site-packages/poke_env/environment/pokemon.py", line 110, in init
self._update_from_details(details)
File "/home/ank22/.local/lib/python3.10/site-packages/poke_env/environment/pokemon.py", line 425, in _update_from_details
self._update_from_pokedex(species)
File "/home/ank22/.local/lib/python3.10/site-packages/poke_env/environment/pokemon.py", line 359, in _update_from_pokedex
dex_entry = self._data.pokedex[species]
KeyError: 'archaludon'
`

Here is the screenshot of the showdown

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants