diff --git a/requests_html.py b/requests_html.py
index c7f6e11..6630a2a 100644
--- a/requests_html.py
+++ b/requests_html.py
@@ -606,12 +606,23 @@ class HTMLSession(requests.Session):
class AsyncHTMLSession(requests.Session):
""" """
- def __init__(self, *args, **kwargs):
+ def __init__(self, mock_browser: bool = True, *args, **kwargs):
""" Create loop and thread pool. """
+ super().__init__(*args, **kwargs)
+
+ if mock_browser:
+ self.headers['User-Agent'] = user_agent()
+
+ self.hooks["response"].append(self.response_hook)
+
self.loop = asyncio.get_event_loop()
self.thread_pool = ThreadPoolExecutor()
- super().__init__(*args, **kwargs)
+ @staticmethod
+ def response_hook(response, **kwargs) -> HTMLResponse:
+ """ Change response enconding and replace it by a HTMLResponse. """
+ response.encoding = DEFAULT_ENCODING
+ return HTMLResponse._from_response(response)
def request(self, *args, **kwargs):
""" Partial original request func and run it in a thread. """
diff --git a/tests/test_requests_html.py b/tests/test_requests_html.py
index 3b70215..5f35aa7 100644
--- a/tests/test_requests_html.py
+++ b/tests/test_requests_html.py
@@ -1,4 +1,5 @@
import os
+from functools import partial
import pytest
from requests_html import HTMLSession, AsyncHTMLSession, HTML
@@ -24,7 +25,7 @@ def async_get(event_loop):
path = os.path.sep.join((os.path.dirname(os.path.abspath(__file__)), 'python.html'))
url = 'file://{}'.format(path)
- return async_session.get(url)
+ return partial(async_session.get, url)
@pytest.mark.ok
@@ -36,7 +37,7 @@ def test_file_get():
@pytest.mark.ok
@pytest.mark.asyncio
async def test_async_file_get(async_get):
- r = await async_get
+ r = await async_get()
assert r.status_code == 200
@@ -72,6 +73,7 @@ def test_containing():
for e in python:
assert 'python' in e.full_text.lower()
+
@pytest.mark.ok
def test_attrs():
r = get()
@@ -90,6 +92,16 @@ def test_links():
assert len(about.absolute_links) == 6
+@pytest.mark.ok
+@pytest.mark.asyncio
+async def test_async_links(async_get):
+ r = await async_get()
+ about = r.html.find('#about', first=True)
+
+ assert len(about.links) == 6
+ assert len(about.absolute_links) == 6
+
+
@pytest.mark.ok
def test_search():
r = get()