diff --git a/requests_html.py b/requests_html.py
index 92b8b40..71ce376 100644
--- a/requests_html.py
+++ b/requests_html.py
@@ -1,7 +1,7 @@
import asyncio
from urllib.parse import urlparse, urlunparse
from concurrent.futures._base import TimeoutError
-from typing import Set
+from typing import Set, Union, List
import pyppeteer
import requests
@@ -21,6 +21,10 @@ DEFAULT_URL = 'https://example.org/'
useragent = UserAgent()
+# Typing.
+_Find = Union[List['Element'], 'Element']
+_XPath = Union[List[str], List['Element'], str, 'Element']
+_HTML = Union[str, bytes]
class BaseParser:
@@ -33,12 +37,17 @@ class BaseParser:
"""
- def __init__(self, *, element, default_encoding: str = None, html: str = None, url: str) -> None:
+ def __init__(self, *, element, default_encoding: str = None, html: _HTML = None, url: str) -> None:
self.element = element
self.url = url
self.skip_anchors = True
self.default_encoding = default_encoding
self._encoding = None
+
+ # Encode incoming unicode HTML into bytes.
+ if isinstance(html, str):
+ html = html.encode(DEFAULT_ENCODING)
+
self._html = html
@property
@@ -100,7 +109,7 @@ class BaseParser:
"""The full text content (including links) of the :class:`Element ` or :class:`HTML `.."""
return self.lxml.text_content()
- def find(self, selector: str, first: bool = False, _encoding: str = None):
+ def find(self, selector: str, first: bool = False, _encoding: str = None) -> _Find:
"""Given a CSS Selector, returns a list of :class:`Element ` objects.
:param selector: CSS Selector to use.
@@ -131,7 +140,7 @@ class BaseParser:
else:
return elements
- def xpath(self, selector: str, first: bool = False, _encoding: str = None):
+ def xpath(self, selector: str, first: bool = False, _encoding: str = None) -> _XPath:
"""Given an XPath selector, returns a list of
:class:`Element ` objects.
@@ -154,7 +163,7 @@ class BaseParser:
if not isinstance(selection, etree._ElementUnicodeResult):
element = Element(element=selection, url=self.url, default_encoding=_encoding or self.encoding)
else:
- element = selection
+ element = str(selection)
c.append(element)
if first:
@@ -280,7 +289,7 @@ class HTML(BaseParser):
:param default_encoding: Which encoding to default to.
"""
- def __init__(self, *, url=DEFAULT_URL, html, default_encoding=DEFAULT_ENCODING) -> None:
+ def __init__(self, *, url: str = DEFAULT_URL, html: _HTML, default_encoding: str =DEFAULT_ENCODING) -> None:
# Convert incoming unicode HTML into bytes.
if isinstance(html, str):
@@ -382,9 +391,9 @@ class HTMLResponse(requests.Response):
intelligent ``.html`` property added.
"""
- def __init__(self, *args, **kwargs) -> None:
- super(HTMLResponse, self).__init__(*args, **kwargs)
- self._html = None
+ def __init__(self) -> None:
+ super(HTMLResponse, self).__init__()
+ self._html = None # type: HTML
@property
def html(self) -> HTML: