Skip to content

Commit

Permalink
Swap out httplib for urllib3
Browse files Browse the repository at this point in the history
  • Loading branch information
AutomatedTester committed Jul 30, 2018
1 parent da3b0c5 commit 49a58eb
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 183 deletions.
11 changes: 2 additions & 9 deletions py/selenium/webdriver/firefox/webdriver.py
Expand Up @@ -16,18 +16,12 @@
# under the License.
import warnings

try:
import http.client as http_client
except ImportError:
import httplib as http_client

try:
basestring
except NameError: # Python 3.x
basestring = str

import shutil
import socket
import sys
from contextlib import contextmanager

Expand Down Expand Up @@ -200,9 +194,8 @@ def quit(self):
"""Quits the driver and close every associated window."""
try:
RemoteWebDriver.quit(self)
except (http_client.BadStatusLine, socket.error):
# Happens if Firefox shutsdown before we've read the response from
# the socket.
except Exception:
# We don't care about the message because something probably has gone wrong
pass

if self.w3c:
Expand Down
165 changes: 16 additions & 149 deletions py/selenium/webdriver/remote/remote_connection.py
Expand Up @@ -15,19 +15,17 @@
# specific language governing permissions and limitations
# under the License.

import base64
import logging
import platform
import socket
import string
import base64
import platform

import urllib3

try:
import http.client as httplib
from urllib import request as url_request
from urllib import parse
except ImportError: # above is available in py3+, below is py2.7
import httplib as httplib
import urllib2 as url_request
import urlparse as parse

from selenium.webdriver.common import utils as common_utils
Expand All @@ -39,98 +37,6 @@
LOGGER = logging.getLogger(__name__)


class Request(url_request.Request):
"""
Extends the url_request.Request to support all HTTP request types.
"""

def __init__(self, url, data=None, method=None):
"""
Initialise a new HTTP request.
:Args:
- url - String for the URL to send the request to.
- data - Data to send with the request.
"""
if method is None:
method = data is not None and 'POST' or 'GET'
elif method != 'POST' and method != 'PUT':
data = None
self._method = method
url_request.Request.__init__(self, url, data=data)

def get_method(self):
"""
Returns the HTTP method used by this request.
"""
return self._method


class Response(object):
"""
Represents an HTTP response.
"""

def __init__(self, fp, code, headers, url):
"""
Initialise a new Response.
:Args:
- fp - The response body file object.
- code - The HTTP status code returned by the server.
- headers - A dictionary of headers returned by the server.
- url - URL of the retrieved resource represented by this Response.
"""
self.fp = fp
self.read = fp.read
self.code = code
self.headers = headers
self.url = url

def close(self):
"""
Close the response body file object.
"""
self.read = None
self.fp = None

def info(self):
"""
Returns the response headers.
"""
return self.headers

def geturl(self):
"""
Returns the URL for the resource returned in this response.
"""
return self.url


class HttpErrorHandler(url_request.HTTPDefaultErrorHandler):
"""
A custom HTTP error handler.
Used to return Response objects instead of raising an HTTPError exception.
"""

def http_error_default(self, req, fp, code, msg, headers):
"""
Default HTTP error handler.
:Args:
- req - The original Request object.
- fp - The response body file object.
- code - The HTTP status code returned by the server.
- msg - The HTTP status message returned by the server.
- headers - The response headers.
:Returns:
A new Response object.
"""
return Response(fp, code, headers, req.get_full_url())


class RemoteConnection(object):
"""A connection with the Remote WebDriver server.
Expand Down Expand Up @@ -201,7 +107,6 @@ def __init__(self, remote_server_addr, keep_alive=False, resolve_ip=True):
# Attempt to resolve the hostname and get an IP address.
self.keep_alive = keep_alive
parsed_url = parse.urlparse(remote_server_addr)
addr = parsed_url.hostname
if parsed_url.hostname and resolve_ip:
port = parsed_url.port or None
if parsed_url.scheme == "https":
Expand All @@ -215,7 +120,6 @@ def __init__(self, remote_server_addr, keep_alive=False, resolve_ip=True):
port=port)
if ip:
netloc = ip
addr = netloc
if parsed_url.port:
netloc = common_utils.join_host_port(netloc,
parsed_url.port)
Expand All @@ -233,8 +137,7 @@ def __init__(self, remote_server_addr, keep_alive=False, resolve_ip=True):

self._url = remote_server_addr
if keep_alive:
self._conn = httplib.HTTPConnection(
str(addr), str(parsed_url.port), timeout=self._timeout)
self._conn = urllib3.PoolManager()

self._commands = {
Command.STATUS: ('GET', '/status'),
Expand Down Expand Up @@ -487,87 +390,51 @@ def _request(self, method, url, body=None):

parsed_url = parse.urlparse(url)
headers = self.get_remote_connection_headers(parsed_url, self.keep_alive)

resp = None
if self.keep_alive:
if body and method != 'POST' and method != 'PUT':
body = None
try:
self._conn.request(method, parsed_url.path, body, headers)
resp = self._conn.getresponse()
except (httplib.HTTPException, socket.error):
self._conn.close()
raise
resp = self._conn.request(method, url, body=body, headers=headers)

statuscode = resp.status
else:
password_manager = None
if parsed_url.username:
netloc = parsed_url.hostname
if parsed_url.port:
netloc += ":%s" % parsed_url.port
cleaned_url = parse.urlunparse((
parsed_url.scheme,
netloc,
parsed_url.path,
parsed_url.params,
parsed_url.query,
parsed_url.fragment))
password_manager = url_request.HTTPPasswordMgrWithDefaultRealm()
password_manager.add_password(None,
"%s://%s" % (parsed_url.scheme, netloc),
parsed_url.username,
parsed_url.password)
request = Request(cleaned_url, data=body.encode('utf-8'), method=method)
else:
request = Request(url, data=body.encode('utf-8'), method=method)

for key, val in headers.items():
request.add_header(key, val)
http = urllib3.PoolManager()
resp = http.request(method, url, body=body, headers=headers)

if password_manager:
opener = url_request.build_opener(url_request.HTTPRedirectHandler(),
HttpErrorHandler(),
url_request.HTTPBasicAuthHandler(password_manager))
else:
opener = url_request.build_opener(url_request.HTTPRedirectHandler(),
HttpErrorHandler())
resp = opener.open(request, timeout=self._timeout)
statuscode = resp.code
statuscode = resp.status
if not hasattr(resp, 'getheader'):
if hasattr(resp.headers, 'getheader'):
resp.getheader = lambda x: resp.headers.getheader(x)
elif hasattr(resp.headers, 'get'):
resp.getheader = lambda x: resp.headers.get(x)

data = resp.read()
data = resp.data.decode('UTF-8')
try:
if 300 <= statuscode < 304:
return self._request('GET', resp.getheader('location'))
body = data.decode('utf-8').replace('\x00', '').strip()
if 399 < statuscode <= 500:
return {'status': statuscode, 'value': body}
return {'status': statuscode, 'value': data}
content_type = []
if resp.getheader('Content-Type') is not None:
content_type = resp.getheader('Content-Type').split(';')
if not any([x.startswith('image/png') for x in content_type]):

try:
data = utils.load_json(body.strip())
data = utils.load_json(data.strip())
except ValueError:
if 199 < statuscode < 300:
status = ErrorCode.SUCCESS
else:
status = ErrorCode.UNKNOWN_ERROR
return {'status': status, 'value': body.strip()}
return {'status': status, 'value': data.strip()}

assert type(data) is dict, (
'Invalid server response body: %s' % body)
# Some of the drivers incorrectly return a response
# with no 'value' field when they should return null.
if 'value' not in data:
data['value'] = None
return data
else:
data = {'status': 0, 'value': body.strip()}
data = {'status': 0, 'value': data}
return data
finally:
LOGGER.debug("Finished Request")
Expand Down
1 change: 0 additions & 1 deletion py/selenium/webdriver/remote/utils.py
Expand Up @@ -21,7 +21,6 @@
import tempfile
import zipfile

from selenium.common.exceptions import NoSuchElementException

LOGGER = logging.getLogger(__name__)

Expand Down
24 changes: 0 additions & 24 deletions py/test/unit/selenium/webdriver/remote/test_remote_connection.py
Expand Up @@ -62,27 +62,3 @@ def close(self):

def getheader(self, *args, **kwargs):
pass


def test_remote_connection_adds_connection_headers_from_get_remote_connection_headers(mocker):
test_headers = {'FOO': 'bar', 'Content-Type': 'json'}
expected_request_headers = {'Foo': 'bar', 'Content-type': 'json'}

# Stub out the get_remote_connection_headers method to return something testable
mocker.patch(
'selenium.webdriver.remote.remote_connection.RemoteConnection.get_remote_connection_headers'
).return_value = test_headers

# Stub out response
try:
mock_open = mocker.patch('urllib.request.OpenerDirector.open')
except ImportError:
mock_open = mocker.patch('urllib2.OpenerDirector.open')

def assert_header_added(request, timeout):
assert request.headers == expected_request_headers
return MockResponse()

mock_open.side_effect = assert_header_added

RemoteConnection('http://remote', resolve_ip=False).execute('status', {})

0 comments on commit 49a58eb

Please sign in to comment.