@@ -65,13 +65,13 @@ trait SimplifyMatch<'tcx> {
6565 _ => unreachable ! ( ) ,
6666 } ;
6767
68- if !self . can_simplify ( tcx, targets, param_env, bbs) {
68+ let discr_ty = discr. ty ( local_decls, tcx) ;
69+ if !self . can_simplify ( tcx, targets, param_env, bbs, discr_ty) {
6970 return false ;
7071 }
7172
7273 // Take ownership of items now that we know we can optimize.
7374 let discr = discr. clone ( ) ;
74- let discr_ty = discr. ty ( local_decls, tcx) ;
7575
7676 // Introduce a temporary for the discriminant value.
7777 let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
@@ -101,6 +101,7 @@ trait SimplifyMatch<'tcx> {
101101 targets : & SwitchTargets ,
102102 param_env : ParamEnv < ' tcx > ,
103103 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
104+ discr_ty : Ty < ' tcx > ,
104105 ) -> bool ;
105106
106107 fn new_stmts (
@@ -154,6 +155,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
154155 targets : & SwitchTargets ,
155156 param_env : ParamEnv < ' tcx > ,
156157 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
158+ _discr_ty : Ty < ' tcx > ,
157159 ) -> bool {
158160 if targets. iter ( ) . len ( ) != 1 {
159161 return false ;
@@ -265,7 +267,7 @@ struct SimplifyToExp {
265267enum CompareType < ' tcx , ' a > {
266268 Same ( & ' a StatementKind < ' tcx > ) ,
267269 Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
268- Discr ( & ' a Place < ' tcx > , Ty < ' tcx > ) ,
270+ Discr ( & ' a Place < ' tcx > , Ty < ' tcx > , bool ) ,
269271}
270272
271273enum TransfromType {
@@ -279,7 +281,7 @@ impl From<CompareType<'_, '_>> for TransfromType {
279281 match compare_type {
280282 CompareType :: Same ( _) => TransfromType :: Same ,
281283 CompareType :: Eq ( _, _, _) => TransfromType :: Eq ,
282- CompareType :: Discr ( _, _) => TransfromType :: Discr ,
284+ CompareType :: Discr ( _, _, _ ) => TransfromType :: Discr ,
283285 }
284286 }
285287}
@@ -330,6 +332,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
330332 targets : & SwitchTargets ,
331333 param_env : ParamEnv < ' tcx > ,
332334 bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
335+ discr_ty : Ty < ' tcx > ,
333336 ) -> bool {
334337 if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
335338 return false ;
@@ -352,6 +355,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
352355 return false ;
353356 }
354357
358+ let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
355359 let first_stmts = & bbs[ first_target] . statements ;
356360 let ( second_val, second_target) = iter. next ( ) . unwrap ( ) ;
357361 let second_stmts = & bbs[ second_target] . statements ;
@@ -376,12 +380,30 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
376380 ) {
377381 ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
378382 ( Some ( f) , Some ( s) )
379- if Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
380- && Some ( s) == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) =>
383+ if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
384+ && f. try_to_int ( f. size ( ) ) . unwrap ( )
385+ == ScalarInt :: try_from_uint ( first_val, discr_size)
386+ . unwrap ( )
387+ . try_to_int ( discr_size)
388+ . unwrap ( )
389+ && s. try_to_int ( s. size ( ) ) . unwrap ( )
390+ == ScalarInt :: try_from_uint ( second_val, discr_size)
391+ . unwrap ( )
392+ . try_to_int ( discr_size)
393+ . unwrap ( ) )
394+ || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
395+ && Some ( s)
396+ == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
381397 {
382- CompareType :: Discr ( lhs_f, f_c. const_ . ty ( ) )
398+ CompareType :: Discr (
399+ lhs_f,
400+ f_c. const_ . ty ( ) ,
401+ f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
402+ )
403+ }
404+ _ => {
405+ return false ;
383406 }
384- _ => return false ,
385407 }
386408 }
387409
@@ -406,15 +428,26 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
406428 && s_c. const_ . ty ( ) == f_ty
407429 && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val) => { }
408430 (
409- CompareType :: Discr ( lhs_f, f_ty) ,
431+ CompareType :: Discr ( lhs_f, f_ty, is_signed ) ,
410432 StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
411433 ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
412434 let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
413435 return false ;
414436 } ;
415- if Some ( f) != ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
416- return false ;
437+ if is_signed
438+ && s_c. const_ . ty ( ) . is_signed ( )
439+ && f. try_to_int ( f. size ( ) ) . unwrap ( )
440+ == ScalarInt :: try_from_uint ( other_val, discr_size)
441+ . unwrap ( )
442+ . try_to_int ( discr_size)
443+ . unwrap ( )
444+ {
445+ continue ;
446+ }
447+ if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
448+ continue ;
417449 }
450+ return false ;
418451 }
419452 _ => return false ,
420453 }
0 commit comments