Tests for providers callbacks
Also added SERVER_NAME in config_testing and pre-populated the keys of OAUTH_CREDENTIALS, since the implementation of providers is part of the application.
This commit is contained in:
parent
cecf81a07d
commit
41a82c44c5
@ -16,26 +16,36 @@ class OAuthUserResponse:
|
||||
email = attr.ib(validator=attr.validators.instance_of(str))
|
||||
|
||||
|
||||
class ProviderConfigurationMissing(ValueError):
|
||||
class OAuthError(Exception):
|
||||
"""Superclass of all exceptions raised by this module."""
|
||||
|
||||
|
||||
class ProviderConfigurationMissing(OAuthError):
|
||||
"""Raised when an OAuth provider is used but not configured."""
|
||||
|
||||
|
||||
class ProviderNotImplemented(ValueError):
|
||||
class ProviderNotImplemented(OAuthError):
|
||||
"""Raised when a provider is requested that does not exist."""
|
||||
|
||||
|
||||
class OAuthCodeNotProvided(OAuthError):
|
||||
"""Raised when the 'code' arg is not provided in the OAuth callback."""
|
||||
|
||||
|
||||
class OAuthSignIn(metaclass=abc.ABCMeta):
|
||||
_providers = None # initialized in get_provider()
|
||||
|
||||
def __init__(self, provider_name):
|
||||
self.provider_name = provider_name
|
||||
try:
|
||||
credentials = current_app.config['OAUTH_CREDENTIALS'][provider_name]
|
||||
except KeyError:
|
||||
credentials = current_app.config['OAUTH_CREDENTIALS'].get(provider_name)
|
||||
if not credentials:
|
||||
raise ProviderConfigurationMissing(f'Missing OAuth credentials for {provider_name}')
|
||||
self.consumer_id = credentials['id']
|
||||
self.consumer_secret = credentials['secret']
|
||||
|
||||
# Set in a subclass
|
||||
self.service: OAuth2Service = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def authorize(self) -> Response:
|
||||
"""Redirect to the correct authorization endpoint for the current provider.
|
||||
@ -58,6 +68,25 @@ class OAuthSignIn(metaclass=abc.ABCMeta):
|
||||
return url_for('users.oauth_callback', provider=self.provider_name,
|
||||
_external=True)
|
||||
|
||||
@staticmethod
|
||||
def auth_code_from_request() -> str:
|
||||
try:
|
||||
return request.args['code']
|
||||
except KeyError:
|
||||
raise OAuthCodeNotProvided('A code argument was not provided in the request')
|
||||
|
||||
@staticmethod
|
||||
def decode_json(payload):
|
||||
return json.loads(payload.decode('utf-8'))
|
||||
|
||||
def make_oauth_session(self):
|
||||
return self.service.get_auth_session(
|
||||
data={'code': self.auth_code_from_request(),
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.get_callback_url()},
|
||||
decoder=self.decode_json
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, provider_name) -> 'OAuthSignIn':
|
||||
if cls._providers is None:
|
||||
@ -96,20 +125,9 @@ class BlenderIdSignIn(OAuthSignIn):
|
||||
)
|
||||
|
||||
def callback(self):
|
||||
def decode_json(payload):
|
||||
return json.loads(payload.decode('utf-8'))
|
||||
|
||||
if 'code' not in request.args:
|
||||
return None, None, None
|
||||
oauth_session = self.service.get_auth_session(
|
||||
data={'code': request.args['code'],
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.get_callback_url()},
|
||||
decoder=decode_json
|
||||
)
|
||||
oauth_session = self.make_oauth_session()
|
||||
|
||||
# TODO handle exception for failed oauth or not authorized
|
||||
|
||||
session['blender_id_oauth_token'] = oauth_session.access_token
|
||||
me = oauth_session.get('user').json()
|
||||
return OAuthUserResponse(str(me['id']), me['email'])
|
||||
@ -135,17 +153,8 @@ class FacebookSignIn(OAuthSignIn):
|
||||
)
|
||||
|
||||
def callback(self):
|
||||
def decode_json(payload):
|
||||
return json.loads(payload.decode('utf-8'))
|
||||
oauth_session = self.make_oauth_session()
|
||||
|
||||
if 'code' not in request.args:
|
||||
return None, None, None
|
||||
oauth_session = self.service.get_auth_session(
|
||||
data={'code': request.args['code'],
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.get_callback_url()},
|
||||
decoder=decode_json
|
||||
)
|
||||
me = oauth_session.get('me?fields=id,email').json()
|
||||
# TODO handle case when user chooses not to disclose en email
|
||||
# see https://developers.facebook.com/docs/graph-api/reference/user/
|
||||
@ -172,16 +181,7 @@ class GoogleSignIn(OAuthSignIn):
|
||||
)
|
||||
|
||||
def callback(self):
|
||||
def decode_json(payload):
|
||||
return json.loads(payload.decode('utf-8'))
|
||||
oauth_session = self.make_oauth_session()
|
||||
|
||||
if 'code' not in request.args:
|
||||
return None, None, None
|
||||
oauth_session = self.service.get_auth_session(
|
||||
data={'code': request.args['code'],
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.get_callback_url()},
|
||||
decoder=decode_json
|
||||
)
|
||||
me = oauth_session.get('userinfo').json()
|
||||
return OAuthUserResponse(str(me['id']), me['email'])
|
||||
|
@ -102,7 +102,12 @@ BLENDER_ID_SUBCLIENT_ID = 'PILLAR'
|
||||
# 'base_url': 'http://blender_id:8000/'
|
||||
# }
|
||||
# }
|
||||
OAUTH_CREDENTIALS = {}
|
||||
# OAuth providers are defined in pillar.auth.oauth
|
||||
OAUTH_CREDENTIALS = {
|
||||
'blender-id': {},
|
||||
'facebook': {},
|
||||
'google': {},
|
||||
}
|
||||
|
||||
# See https://docs.python.org/2/library/logging.config.html#configuration-dictionary-schema
|
||||
LOGGING = {
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
BLENDER_ID_ENDPOINT = 'http://127.0.0.1:8001' # nonexistant server, no trailing slash!
|
||||
|
||||
SERVER_NAME = 'localhost:5000'
|
||||
|
||||
DEBUG = False
|
||||
TESTING = True
|
||||
|
||||
|
@ -11,7 +11,7 @@
|
||||
| Login using one of the following providers:
|
||||
|
||||
.login-providers
|
||||
| {% for login_provider, login_provider_conf in config['OAUTH_CREDENTIALS'].items() %}
|
||||
| {% for login_provider, login_provider_conf in config['OAUTH_CREDENTIALS'].items() if config['OAUTH_CREDENTIALS'].get(login_provider) %}
|
||||
|
||||
| {% set login_provider_name = login_provider | undertitle %}
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import responses
|
||||
from pillar.tests import AbstractPillarTest
|
||||
|
||||
|
||||
@ -9,9 +10,9 @@ class OAuthTests(AbstractPillarTest):
|
||||
def test_providers_init(self):
|
||||
from pillar.auth.oauth import OAuthSignIn, BlenderIdSignIn
|
||||
|
||||
blender_id_oauth_provider = OAuthSignIn.get_provider('blender-id')
|
||||
self.assertIsInstance(blender_id_oauth_provider, BlenderIdSignIn)
|
||||
self.assertEqual(blender_id_oauth_provider.service.base_url, 'http://blender_id:8000/api/')
|
||||
oauth_provider = OAuthSignIn.get_provider('blender-id')
|
||||
self.assertIsInstance(oauth_provider, BlenderIdSignIn)
|
||||
self.assertEqual(oauth_provider.service.base_url, 'http://blender_id:8000/api/')
|
||||
|
||||
def test_provider_not_implemented(self):
|
||||
from pillar.auth.oauth import OAuthSignIn, ProviderNotImplemented
|
||||
@ -29,3 +30,46 @@ class OAuthTests(AbstractPillarTest):
|
||||
del self.app.config['OAUTH_CREDENTIALS']['blender-id']
|
||||
with self.assertRaises(ProviderConfigurationMissing):
|
||||
OAuthSignIn.get_provider('blender-id')
|
||||
|
||||
def test_provider_authorize(self):
|
||||
from pillar.auth.oauth import OAuthSignIn
|
||||
from urllib.parse import urlparse, parse_qsl
|
||||
oauth_provider = OAuthSignIn.get_provider('blender-id')
|
||||
r = oauth_provider.authorize()
|
||||
self.assertEqual(r.status_code, 302)
|
||||
url_parts = list(urlparse(r.location))
|
||||
# Get the query arguments as a dict
|
||||
query = dict(parse_qsl(url_parts[4]))
|
||||
self.assertEqual(query['client_id'], oauth_provider.service.client_id)
|
||||
|
||||
@responses.activate
|
||||
def test_provider_callback_happy(self):
|
||||
from pillar.auth.oauth import OAuthSignIn
|
||||
|
||||
responses.add(responses.POST, 'http://blender_id:8000/oauth/token',
|
||||
json={'access_token': 'successful-token'},
|
||||
status=200)
|
||||
|
||||
responses.add(responses.GET, 'http://blender_id:8000/api/user',
|
||||
json={'id': '7',
|
||||
'email': 'harry@blender.org'},
|
||||
status=200)
|
||||
|
||||
oauth_provider = OAuthSignIn.get_provider('blender-id')
|
||||
|
||||
with self.app.test_request_context('/oauth/blender-id/authorized?code=123'):
|
||||
# We override the call to blender-id
|
||||
cb = oauth_provider.callback()
|
||||
self.assertEqual(cb.id, '7')
|
||||
|
||||
@responses.activate
|
||||
def test_provider_callback_missing_code(self):
|
||||
from pillar.auth.oauth import OAuthSignIn, OAuthCodeNotProvided
|
||||
|
||||
oauth_provider = OAuthSignIn.get_provider('blender-id')
|
||||
|
||||
# Check exception when the 'code' argument is not returned
|
||||
with self.assertRaises(OAuthCodeNotProvided):
|
||||
with self.app.test_request_context('/oauth/blender-id/authorized'):
|
||||
oauth_provider.callback()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user