Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/bevy_app/src/app_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl AppBuilder {
self.add_system_to_stage(stage::UPDATE, system)
}

pub fn on_state_enter<T: Clone + Resource, S: System<In = (), Out = ()>>(
pub fn on_state_enter<T: Resource, S: System<In = (), Out = ()>>(
&mut self,
stage: &str,
state: T,
Expand All @@ -140,7 +140,7 @@ impl AppBuilder {
})
}

pub fn on_state_update<T: Clone + Resource, S: System<In = (), Out = ()>>(
pub fn on_state_update<T: Resource, S: System<In = (), Out = ()>>(
&mut self,
stage: &str,
state: T,
Expand All @@ -151,7 +151,7 @@ impl AppBuilder {
})
}

pub fn on_state_exit<T: Clone + Resource, S: System<In = (), Out = ()>>(
pub fn on_state_exit<T: Resource, S: System<In = (), Out = ()>>(
&mut self,
stage: &str,
state: T,
Expand Down
69 changes: 31 additions & 38 deletions crates/bevy_ecs/src/schedule/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{Resource, Resources, Stage, System, SystemStage, World};
use bevy_utils::HashMap;
use std::{mem::Discriminant, ops::Deref};
use std::{
mem::{self, Discriminant},
ops::Deref,
};
use thiserror::Error;

pub(crate) struct StateStages {
Expand All @@ -21,12 +24,14 @@ impl Default for StateStages {

pub struct StateStage<T> {
stages: HashMap<Discriminant<T>, StateStages>,
current_stage: Option<Discriminant<T>>,
}

impl<T> Default for StateStage<T> {
fn default() -> Self {
Self {
stages: Default::default(),
current_stage: None,
}
}
}
Expand Down Expand Up @@ -142,14 +147,12 @@ impl<T> StateStage<T> {
}

fn state_stages(&mut self, state: T) -> &mut StateStages {
self.stages
.entry(std::mem::discriminant(&state))
.or_default()
self.stages.entry(mem::discriminant(&state)).or_default()
}
}

#[allow(clippy::mem_discriminant_non_enum)]
impl<T: Resource + Clone> Stage for StateStage<T> {
impl<T: Resource> Stage for StateStage<T> {
fn initialize(&mut self, world: &mut World, resources: &mut Resources) {
for state_stages in self.stages.values_mut() {
state_stages.enter.initialize(world, resources);
Expand All @@ -160,33 +163,25 @@ impl<T: Resource + Clone> Stage for StateStage<T> {

fn run(&mut self, world: &mut World, resources: &mut Resources) {
let current_stage = loop {
let (next_stage, current_stage) = {
let next = {
let mut state = resources
.get_mut::<State<T>>()
.expect("Missing state resource");
let result = (
state.next.as_ref().map(|next| std::mem::discriminant(next)),
std::mem::discriminant(&state.current),
);

state.apply_next();

result
state.previous = state.apply_next().or_else(|| state.previous.take());
mem::discriminant(&state.current)
};

// if next_stage is Some, we just applied a new state
if let Some(next_stage) = next_stage {
if next_stage != current_stage {
if let Some(current_state_stages) = self.stages.get_mut(&current_stage) {
current_state_stages.exit.run(world, resources);
}
if self.current_stage == Some(next) {
break next;
} else {
if let Some(current_state_stages) =
self.current_stage.and_then(|it| self.stages.get_mut(&it))
{
current_state_stages.exit.run(world, resources);
}

if let Some(next_state_stages) = self.stages.get_mut(&next_stage) {
self.current_stage = Some(next);
if let Some(next_state_stages) = self.stages.get_mut(&next) {
next_state_stages.enter.run(world, resources);
}
} else {
break current_stage;
}
};

Expand All @@ -204,20 +199,19 @@ pub enum StateError {
}

#[derive(Debug)]
pub struct State<T: Clone> {
pub struct State<T> {
previous: Option<T>,
current: T,
next: Option<T>,
}

#[allow(clippy::mem_discriminant_non_enum)]
impl<T: Clone> State<T> {
impl<T> State<T> {
pub fn new(state: T) -> Self {
Self {
current: state.clone(),
current: state,
next: None,
previous: None,
// add value to queue so that we "enter" the state
next: Some(state),
}
}

Expand All @@ -235,7 +229,7 @@ impl<T: Clone> State<T> {

/// Queue a state change. This will fail if there is already a state in the queue, or if the given `state` matches the current state
pub fn set_next(&mut self, state: T) -> Result<(), StateError> {
if std::mem::discriminant(&self.current) == std::mem::discriminant(&state) {
if mem::discriminant(&self.current) == mem::discriminant(&state) {
return Err(StateError::AlreadyInState);
}

Expand All @@ -249,25 +243,24 @@ impl<T: Clone> State<T> {

/// Same as [Self::queue], but there is already a next state, it will be overwritten instead of failing
pub fn overwrite_next(&mut self, state: T) -> Result<(), StateError> {
if std::mem::discriminant(&self.current) == std::mem::discriminant(&state) {
if mem::discriminant(&self.current) == mem::discriminant(&state) {
return Err(StateError::AlreadyInState);
}

self.next = Some(state);
Ok(())
}

fn apply_next(&mut self) {
fn apply_next(&mut self) -> Option<T> {
if let Some(next) = self.next.take() {
let previous = std::mem::replace(&mut self.current, next);
if std::mem::discriminant(&previous) != std::mem::discriminant(&self.current) {
self.previous = Some(previous)
}
Some(std::mem::replace(&mut self.current, next))
} else {
None
}
}
}

impl<T: Clone> Deref for State<T> {
impl<T> Deref for State<T> {
type Target = T;

fn deref(&self) -> &Self::Target {
Expand Down
7 changes: 6 additions & 1 deletion examples/ecs/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ fn main() {

const STAGE: &str = "app_state";

#[derive(Clone)]
enum AppState {
Menu,
InGame,
Expand Down Expand Up @@ -101,6 +100,7 @@ fn setup_game(
commands: &mut Commands,
asset_server: Res<AssetServer>,
mut materials: ResMut<Assets<ColorMaterial>>,
state: Res<State<AppState>>,
) {
let texture_handle = asset_server.load("branding/icon.png");
commands
Expand All @@ -109,6 +109,11 @@ fn setup_game(
material: materials.add(texture_handle.into()),
..Default::default()
});
match state.previous() {
Some(AppState::Menu) => println!("Called setup_game from leaving the menu"),
Some(AppState::InGame) => unreachable!("Called setup_game from leaving InGame"),
None => unreachable!("Called setup_game as the first state"),
}
}

const SPEED: f32 = 100.0;
Expand Down