在上篇文章中,我们介绍了rest framework框架的3个认证函数,perform_authentication(身份验证)、permissions(许可验证)、throttles(节流认证、限制访问数量)
,而该类方法,是通过APIView继承于View类方法,通过dispatch方法反射的操作,通过initialize_request方法封装了一个request(Request类)
,之后在调用initial方法来执行认证函数,而在此之前initial方法还有赋值的操作,而是通过什么来赋值,作用是什么?让我们来通过源码来观察吧。
initial函数方法如下:
def initial(self, request, *args, **kwargs): self.format_kwarg = self.get_format_suffix(**kwargs) neg = self.perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg version, scheme = self.determine_version(request, *args, **kwargs) request.version, request.versioning_scheme = version, scheme # 权限认证相关 self.perform_authentication(request) self.check_permissions(request) self.check_throttles(request)
我们先从version开始看,能发现version的中文为版本,且通过调用self.determine_version方法,返回的参数由接收,最后再将接收到的参数赋值给request中的version、versioning_scheme对象中。
determine_version函数如下:
def determine_version(self, request, *args, **kwargs): if self.versioning_class is None: return (None, None) scheme = self.versioning_class() return (scheme.determine_version(request, *args, **kwargs), scheme)
通过determine_version函数可以发现,参数需要传递request,且判断了self.versioning_class是否为空
,而根据我们前面源码分析判断,可以推断出,该变量是APIView方法中默认的全局配置,所以我们在自定义版本的时候就可以加上versioning_class或配置在全局中。通过代码可以发现,该方法没有进行遍历,且后面通过scheme = self.versioning_class()调用了函数
,也就是说我们在子类中定义的versioning_class不是列表,而是一个函数对象,通过scheme对象调用了determine_version方法,可以得知我们在自己定义类的时候必须传入determine_version函数
。
最后return返回了2个参数,第一个则为根据determine_version函数发送来的请求来获取到版本号
,而第二个返回参数则返回了versioning_class这个函数
。
所以我们继续返回到initial函数中,发现返回的两个参数赋值给了version, scheme
,而version, scheme又赋值给了request.version, request.versioning_scheme。
由此得出当我们子类自定义函数时,若想获取版本信息,可以通过request.version获取版本,而想获取自定义类方法的对象时可以通过request.versioning_scheme获取(子类没有就从父类找)。
urls如下:
from django.conf.urls import url from api import views urlpatterns = [ url(r'^api/v1/auth/$', views.AuthView.as_view()), url(r'^api/v1/order/$', views.OrderView.as_view()), url(r'^api/v1/users/$', views.UsersView.as_view()), ]
api/view如下:
from django.shortcuts import render, HttpResponse from rest_framework.views import APIView from django.http import JsonResponse class ParamVersion(object): def determine_version(self, request, *args, **kwargs): version = request.query_params.get('version') return version class UsersView(APIView): authentication_classes = [] permission_classes = [] throttle_classes = [] versioning_class = ParamVersion def get(self, request, *args, **kwargs): print(request.version, request.versioning_scheme) return HttpResponse('ok')
此时访问http://127.0.0.1:8000/api/v1/users/?version=v1打印:
v1 <api.views.ParamVersion object at 0x000001CF29E7DD08>
可以发现通过request是可以获取到版本和ParamVersion函数的对象的,且在Reuqest类中定义了一个query_params函数,该函数的作用就是返回原生request中的GET请求的参数
(具体还有很多返回请求的函数可以自行通过Request类中查看)。
BaseVersioning类是django内置的version版本类,该方法作为多个类方法的父类(子类继承
),用于定义默认版本参数,以及版本的选择。
BaseVersioning源码如下:
class BaseVersioning: default_version = api_settings.DEFAULT_VERSION allowed_versions = api_settings.ALLOWED_VERSIONS version_param = api_settings.VERSION_PARAM def determine_version(self, request, *args, **kwargs): msg = '{cls}.determine_version() must be implemented.' raise NotImplementedError(msg.format( cls=self.__class__.__name__ )) def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): return _reverse(viewname, args, kwargs, request, format, **extra) def is_allowed_version(self, version): if not self.allowed_versions: return True return ((version is not None and version == self.default_version) or (version in self.allowed_versions))
看到api_settings我们知道了可以通过settings设置全局,所以我们就可以通过全局的设置来改变默认参数、允许的版本、传入的值通过什么来接收。
settings.py如下:
REST_FRAMEWORK = { "DEFAULT_VERSION": 'v1', "ALLOWED_VERSIONS": ['v1', 'v2'], "VERSION_PARAM": 'version' }
BaseVersioning类还有个reverse方法,该方法就是通过获取路由后面传递参数的值,然后调用django的reverse方法将url返回给用户。
urls.py如下:
from django.conf.urls import url from api import views urlpatterns = [ url(r'^api/v1/auth/$', views.AuthView.as_view()), url(r'^api/v1/order/$', views.OrderView.as_view()), url(r'^api/v1/users/$', views.UsersView.as_view(), name='uuu'), ]
api/views.py如下:
from django.shortcuts import render, HttpResponse from rest_framework.views import APIView from rest_framework.versioning import BaseVersioning class ParamVersion(BaseVersioning): def determine_version(self, request, *args, **kwargs): version = request.query_params.get('version') return version class UsersView(APIView): authentication_classes = [] permission_classes = [] throttle_classes = [] versioning_class = ParamVersion def get(self, request, *args, **kwargs): u1 = request.versioning_scheme.reverse(viewname='uuu', request=request) print(request.version, request.versioning_scheme, u1) return HttpResponse('ok')
此时访问http://127.0.0.1:8000/api/v1/users/?version=v2打印:
v2 <api.views.ParamVersion object at 0x00000286D85413C8> http://127.0.0.1:8000/api/v1/users/
QueryParameterVersioning方法继承于BaseVersioning的类
,该方法通过以GET请求的方式获取版本,且还可以通过内置的reverse类来获取URL。
QueryParameterVersioning源码如下:
class QueryParameterVersioning(BaseVersioning): invalid_version_message = _('Invalid version in query parameter.') def determine_version(self, request, *args, **kwargs): version = request.query_params.get(self.version_param, self.default_version) if not self.is_allowed_version(version): raise exceptions.NotFound(self.invalid_version_message) return version def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): url = super().reverse( viewname, args, kwargs, request, format, **extra ) if request.version is not None: return replace_query_param(url, self.version_param, request.version) return url
可以发现QueryParameterVersioning中的determine_version方法通过request.query_params.get获取GET请求中的参数,通过我们定义的全局配置属性,self.version_param获取键值,还有默认属性
,之后通过is_allowed_version判断我们是否传了值,没传则显示默认值
。
而reverse方法则是调用了父类BaseVersioning的reverse方法。
is_allowed_version函数如下:
def is_allowed_version(self, version): if not self.allowed_versions: return True return ((version is not None and version == self.default_version) or (version in self.allowed_versions))
api/view如下:
from rest_framework.versioning import QueryParameterVersioning from django.shortcuts import render, HttpResponse class UsersView(APIView): authentication_classes = [] permission_classes = [] throttle_classes = [] versioning_class = QueryParameterVersioning def get(self, request, *args, **kwargs): u1 = request.versioning_scheme.reverse(viewname='uuu', request=request) print(request.version, request.versioning_scheme, u1) return HttpResponse('ok')
此时打印的数据和之前一样的。
URLPathVersioning方法继承了BaseVersioning的类,该方法封装了版本、URL路径,可以使用版本放在路由中,通过调用父类的reverse方法来获取URL。
URLPathVersioning源码如下:
class URLPathVersioning(BaseVersioning): """ urlpatterns = [ re_path(r'^(?P<version>[v1|v2]+)/users/$', users_list, name='users-list'), re_path(r'^(?P<version>[v1|v2]+)/users/(?P<pk>[0-9]+)/$', users_detail, name='users-detail') ] """ invalid_version_message = _('Invalid version in URL path.') def determine_version(self, request, *args, **kwargs): version = kwargs.get(self.version_param, self.default_version) if version is None: version = self.default_version if not self.is_allowed_version(version): raise exceptions.NotFound(self.invalid_version_message) return version def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): if request.version is not None: kwargs = {} if (kwargs is None) else kwargs kwargs[self.version_param] = request.version return super().reverse( viewname, args, kwargs, request, format, **extra )
从源码注释来看可以知道,该类方法通过kwargs的get方法获取通过正则来获取路由上的数据,且定义路由时的方式可以为(?P[v1|v2]+)
来定义版本,如果什么都没有就返回默认版本的信息。
而reverse方法则先判断request.version是不是存在,存在则赋值给kwargs,之后调用父类的BaseVersioning方法的reverse函数
,所以我们可以知道,在URLPathVersioning中也只需要传入request即可,内部会进行赋值操作的。
urls.py如下:
from django.conf.urls import url from api import views urlpatterns = [ url(r'^api/v1/auth/$', views.AuthView.as_view()), url(r'^api/v1/order/$', views.OrderView.as_view()), url(r'^api/(?P<version>[v1|v2]+)/users/$', views.UsersView.as_view(), name='uuu'), ]
app/views.py如下:
from rest_framework.versioning import URLPathVersioning from django.shortcuts import render, HttpResponse class UsersView(APIView): authentication_classes = [] permission_classes = [] throttle_classes = [] versioning_class = URLPathVersioning def get(self, request, *args, **kwargs): # reverse(viewname='uuu', request={'version':'v1'}) u1 = request.versioning_scheme.reverse(viewname='uuu', request=request) print(request.version, request.versioning_scheme, u1) return HttpResponse('ok')
此时访问http://127.0.0.1:8000/api/v1/users/打印:
v1 <rest_framework.versioning.URLPathVersioning object at 0x000001ACD1710488> http://127.0.0.1:8000/api/v1/users/
一般我们定义版本的时候都是放在路由中,所以为了方便,我们可以通过配置versioning_class来使得全局使用上。
settings.py如下:
REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ['api.utils.auth.Authtication', ], "UNAUTHENTICATED_USER": None, "UNAUTHENTICATED_TOKEN": None, "DEFAULT_PERMISSION_CLASSES": ['api.utils.permissions.MyPermission', ], "DEFAULT_THROTTLE_CLASSES": ['api.utils.throttle.UserThrottle'], "DEFAULT_THROTTLE_RATES": { "scope": '3/m', "user": '5/m', }, "DEFAULT_VERSION": 'v1', "ALLOWED_VERSIONS": ['v1', 'v2'], "VERSION_PARAM": 'version', "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning" }
当数据返回的时候,如果请求通过post请求返回,那么就需要遵循post请求数据的规范:
(request.post请求去request.body解析数据)
一般对于封装类似请求的东西,APIView都会封装在Request类中,之后在把原生request打包好,之后就可以通过打包好的函数调用,而解析器其实就封装在Request类中的data函数中
。
initialize_request函数如下:
def initialize_request(self, request, *args, **kwargs): """ Returns the initial request object. """ parser_context = self.get_parser_context(request) return Request( request, parsers=self.get_parsers(), authenticators=self.get_authenticators(), negotiator=self.get_content_negotiator(), parser_context=parser_context )
可以发现Request类中parsers这个对象调用了get_parsers()方法。get_parsers函数如下:
def get_parsers(self): return [parser() for parser in self.parser_classes]
可以发现通过调用self.parser_classes对象遍历成每一个函数,所以从这里我们就可以知道,要想在子类定义解析器,需写上self.parser_classes这个列表
。
Request/data函数如下:
@property def data(self): if not _hasattr(self, '_full_data'): self._load_data_and_files() return self._full_data
可以看到在data函数中添加了@property装饰器,所以在使用request.data的时候是无需传括号
的,且我们可以看到data函数中调用了_load_data_and_files函数。
_load_data_and_files函数源码如下:
def _load_data_and_files(self): if not _hasattr(self, '_data'): self._data, self._files = self._parse() if self._files: self._full_data = self._data.copy() self._full_data.update(self._files) else: self._full_data = self._data if is_form_media_type(self.content_type): self._request._post = self.POST self._request._files = self.FILES
此时可以看到_load_data_and_files方法又调用了self._parse(),且返回的参数给self._data, self._files。
self._parse()函数源码如下:
def _parse(self): media_type = self.content_type try: stream = self.stream except RawPostDataException: if not hasattr(self._request, '_post'): raise if self._supports_form_parsing(): return (self._request.POST, self._request.FILES) stream = None if stream is None or media_type is None: if media_type and is_form_media_type(media_type): empty_data = QueryDict('', encoding=self._request._encoding) else: empty_data = {} empty_files = MultiValueDict() return (empty_data, empty_files) parser = self.negotiator.select_parser(self, self.parsers) if not parser: raise exceptions.UnsupportedMediaType(media_type) try: parsed = parser.parse(stream, media_type, self.parser_context) except Exception: self._data = QueryDict('', encoding=self._request._encoding) self._files = MultiValueDict() self._full_data = self._data raise try: return (parsed.data, parsed.files) except AttributeError: empty_files = MultiValueDict() return (parsed, empty_files)
可以看到_parse函数获取到了self.content_type(即请求头content_type的参数)。
我们直接看 parser = self.negotiator.select_parser(self, self.parsers)
这里传入了self.parsers(解析器),通过self.negotiator对象的select_parser方法来解析,之后将值返还给parser。
DefaultContentNegotiation类/select_parser函数如下:
def select_parser(self, request, parsers): for parser in parsers: if media_type_matches(parser.media_type, request.content_type): return parser return None
可以发现此时通过循环parsers来获取到该方法中的media_type请求头,根据支持的请求头,返回该请求头的解析器。
之后我们继续往下走到parsed = parser.parse(stream, media_type, self.parser_context)
可以发现该方法是通过调用子类的parse函数执行的。
rest_framework有内置给我们的解析器,且它们都继承BaseParser类,而该类有一个parse函数,继承BaseParser类的解析器有很多,不过每一个源码流程都相似,这里就通过JSONParser的源码来了解一下解析器的整个流程。
JSONParser/parse函数如下:
def parse(self, stream, media_type=None, parser_context=None): parser_context = parser_context or {} encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) try: decoded_stream = codecs.getreader(encoding)(stream) parse_constant = json.strict_constant if self.strict else None return json.load(decoded_stream, parse_constant=parse_constant) except ValueError as exc: raise ParseError('JSON parse error - %s' % str(exc))
decoded_stream 接收到的就是request.body返回的参数,然后通过json.load的方式将json格式解析成了字典的形式。
urls.py如下:
from django.conf.urls import url from api import views urlpatterns = [ url(r'^api/v1/auth/$', views.AuthView.as_view()), url(r'^api/v1/order/$', views.OrderView.as_view()), url(r'^api/(?P<version>[v1|v2]+)/users/$', views.UsersView.as_view(), name='uuu'), url(r'^api/(?P<version>[v1|v2]+)/parser/$', views.ParserView.as_view(), name='ddd'), ]
api/views.py如下:
from rest_framework.parsers import JSONParser class ParserView(APIView): authentication_classes = [] permission_classes = [] throttle_classes = [] parser_classes = [JSONParser] def post(self, request, *args, **kwargs): print(request.data) return HttpResponse('ParserView')
此时访问http://127.0.0.1:8000/api/v1/parser/显示:
此时控制台打印:
{'name': 'sehun', 'age': 18, 'gender': '男'}
而要想每一个都能使用,我们继续将其放入全局配置中。
settings.py如下:
REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ['api.utils.auth.Authtication', ], "UNAUTHENTICATED_USER": None, "UNAUTHENTICATED_TOKEN": None, "DEFAULT_PERMISSION_CLASSES": ['api.utils.permissions.MyPermission', ], "DEFAULT_THROTTLE_CLASSES": ['api.utils.throttle.UserThrottle'], "DEFAULT_THROTTLE_RATES": { "scope": '3/m', "user": '5/m', }, "DEFAULT_VERSION": 'v1', "ALLOWED_VERSIONS": ['v1', 'v2'], "VERSION_PARAM": 'version', "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning", "rest_framework": ["rest_framework.parsers.JSONParser", ["rest_framework.parsers.FormParser"]], }
此时就会可以全局都配置上解析器了。