diff --git a/requests_html.py b/requests_html.py index 92b8b40..71ce376 100644 --- a/requests_html.py +++ b/requests_html.py @@ -1,7 +1,7 @@ import asyncio from urllib.parse import urlparse, urlunparse from concurrent.futures._base import TimeoutError -from typing import Set +from typing import Set, Union, List import pyppeteer import requests @@ -21,6 +21,10 @@ DEFAULT_URL = 'https://example.org/' useragent = UserAgent() +# Typing. +_Find = Union[List['Element'], 'Element'] +_XPath = Union[List[str], List['Element'], str, 'Element'] +_HTML = Union[str, bytes] class BaseParser: @@ -33,12 +37,17 @@ class BaseParser: """ - def __init__(self, *, element, default_encoding: str = None, html: str = None, url: str) -> None: + def __init__(self, *, element, default_encoding: str = None, html: _HTML = None, url: str) -> None: self.element = element self.url = url self.skip_anchors = True self.default_encoding = default_encoding self._encoding = None + + # Encode incoming unicode HTML into bytes. + if isinstance(html, str): + html = html.encode(DEFAULT_ENCODING) + self._html = html @property @@ -100,7 +109,7 @@ class BaseParser: """The full text content (including links) of the :class:`Element ` or :class:`HTML `..""" return self.lxml.text_content() - def find(self, selector: str, first: bool = False, _encoding: str = None): + def find(self, selector: str, first: bool = False, _encoding: str = None) -> _Find: """Given a CSS Selector, returns a list of :class:`Element ` objects. :param selector: CSS Selector to use. @@ -131,7 +140,7 @@ class BaseParser: else: return elements - def xpath(self, selector: str, first: bool = False, _encoding: str = None): + def xpath(self, selector: str, first: bool = False, _encoding: str = None) -> _XPath: """Given an XPath selector, returns a list of :class:`Element ` objects. @@ -154,7 +163,7 @@ class BaseParser: if not isinstance(selection, etree._ElementUnicodeResult): element = Element(element=selection, url=self.url, default_encoding=_encoding or self.encoding) else: - element = selection + element = str(selection) c.append(element) if first: @@ -280,7 +289,7 @@ class HTML(BaseParser): :param default_encoding: Which encoding to default to. """ - def __init__(self, *, url=DEFAULT_URL, html, default_encoding=DEFAULT_ENCODING) -> None: + def __init__(self, *, url: str = DEFAULT_URL, html: _HTML, default_encoding: str =DEFAULT_ENCODING) -> None: # Convert incoming unicode HTML into bytes. if isinstance(html, str): @@ -382,9 +391,9 @@ class HTMLResponse(requests.Response): intelligent ``.html`` property added. """ - def __init__(self, *args, **kwargs) -> None: - super(HTMLResponse, self).__init__(*args, **kwargs) - self._html = None + def __init__(self) -> None: + super(HTMLResponse, self).__init__() + self._html = None # type: HTML @property def html(self) -> HTML: