1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
15+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
1616
1717from selenium import webdriver
1818from selenium .common .exceptions import (
2121 UnknownMethodException ,
2222 WebDriverException ,
2323)
24- from selenium .webdriver .common .by import By
2524from selenium .webdriver .remote .command import Command as RemoteCommand
2625from selenium .webdriver .remote .remote_connection import RemoteConnection
2726from typing_extensions import Self
2827
2928from appium .common .logger import logger
3029from appium .options .common .base import AppiumOptions
31- from appium .webdriver .common .appiumby import AppiumBy
3230
3331from .appium_connection import AppiumConnection
3432from .client_config import AppiumClientConfig
@@ -241,7 +239,7 @@ class WebDriver(
241239 def __init__ (
242240 self ,
243241 command_executor : Union [str , AppiumConnection ] = 'http://127.0.0.1:4723' ,
244- extensions : Optional [List ['WebDriver' ]] = None ,
242+ extensions : Optional [List [Type [ 'ExtensionBase' ] ]] = None ,
245243 options : Union [AppiumOptions , List [AppiumOptions ], None ] = None ,
246244 client_config : Optional [AppiumClientConfig ] = None ,
247245 ):
@@ -250,28 +248,22 @@ def __init__(
250248 )
251249 super ().__init__ (
252250 command_executor = command_executor ,
253- options = options ,
251+ options = options , # type: ignore[arg-type]
254252 locator_converter = AppiumLocatorConverter (),
255253 web_element_cls = MobileWebElement ,
256254 client_config = client_config ,
257255 )
258256
257+ # to explicitly set type after the initialization
258+ self .command_executor : RemoteConnection
259+
259260 self ._add_commands ()
260261
261262 self .error_handler = MobileErrorHandler ()
262263
263264 if client_config and client_config .direct_connection :
264265 self ._update_command_executor (keep_alive = client_config .keep_alive )
265266
266- # add new method to the `find_by_*` pantheon
267- By .IOS_PREDICATE = AppiumBy .IOS_PREDICATE
268- By .IOS_CLASS_CHAIN = AppiumBy .IOS_CLASS_CHAIN
269- By .ANDROID_UIAUTOMATOR = AppiumBy .ANDROID_UIAUTOMATOR
270- By .ANDROID_VIEWTAG = AppiumBy .ANDROID_VIEWTAG
271- By .ACCESSIBILITY_ID = AppiumBy .ACCESSIBILITY_ID
272- By .IMAGE = AppiumBy .IMAGE
273- By .CUSTOM = AppiumBy .CUSTOM
274-
275267 self ._absent_extensions : Set [str ] = set ()
276268
277269 self ._extensions = extensions or []
@@ -286,6 +278,14 @@ def __init__(
286278 method , url_cmd = instance .add_command ()
287279 self .command_executor .add_command (method_name , method .upper (), url_cmd )
288280
281+ if TYPE_CHECKING :
282+
283+ def find_element (self , by : str , value : Union [str , Dict , None ] = None ) -> 'MobileWebElement' : # type: ignore[override]
284+ ...
285+
286+ def find_elements (self , by : str , value : Union [str , Dict , None ] = None ) -> List ['MobileWebElement' ]: # type: ignore[override]
287+ ...
288+
289289 def delete_extensions (self ) -> None :
290290 """Delete extensions added in the class with 'setattr'"""
291291 for extension in self ._extensions :
0 commit comments