@@ -44,6 +44,7 @@ export default class ModalPortal extends Component {
4444 closeTimeoutMS : PropTypes . number ,
4545 shouldFocusAfterRender : PropTypes . bool ,
4646 shouldCloseOnOverlayClick : PropTypes . bool ,
47+ shouldReturnFocusAfterClose : PropTypes . bool ,
4748 role : PropTypes . string ,
4849 contentLabel : PropTypes . string ,
4950 aria : PropTypes . object ,
@@ -137,8 +138,23 @@ export default class ModalPortal extends Component {
137138 afterClose = ( ) => {
138139 // Remove body class
139140 bodyClassList . remove ( this . props . bodyOpenClassName ) ;
140- focusManager . returnFocus ( ) ;
141- focusManager . teardownScopedFocus ( ) ;
141+
142+ if ( this . shouldReturnFocus ( ) ) {
143+ focusManager . returnFocus ( ) ;
144+ focusManager . teardownScopedFocus ( ) ;
145+ }
146+ } ;
147+
148+ shouldReturnFocus = ( ) => {
149+ // Don't restore focus to the element that had focus prior to
150+ // the modal's display if:
151+ // 1. Focus was never shifted to the modal in the first place
152+ // (shouldFocusAfterRender = false)
153+ // 2. Explicit direction to not restore focus
154+ return (
155+ this . props . shouldFocusAfterRender ||
156+ this . props . shouldReturnFocusAfterClose
157+ ) ;
142158 } ;
143159
144160 open = ( ) => {
@@ -147,8 +163,11 @@ export default class ModalPortal extends Component {
147163 clearTimeout ( this . closeTimer ) ;
148164 this . setState ( { beforeClose : false } ) ;
149165 } else {
150- focusManager . setupScopedFocus ( this . node ) ;
151- focusManager . markForFocusLater ( ) ;
166+ if ( this . shouldReturnFocus ( ) ) {
167+ focusManager . setupScopedFocus ( this . node ) ;
168+ focusManager . markForFocusLater ( ) ;
169+ }
170+
152171 this . setState ( { isOpen : true } , ( ) => {
153172 this . setState ( { afterOpen : true } ) ;
154173
0 commit comments