背景:看了博主一抹浅笑的rest_framework认证模板,发现登录视图函数是基于APIView类封装。
优化:使用ModelViewSet类通过重写create方法编写登录函数。
环境:既然接触到rest_framework的使用,相信已经搭建好相关环境了。
编写模型类
# models.py from django.db import models class User(models.Model): username = models.CharField(verbose_name='用户名称',unique=True,max_length=16) password = models.CharField(verbose_name='登陆密码',max_length=16) class Token(models.Model): username = models.CharField(verbose_name='用户名称',unique=True,max_length=16) token = models.CharField(verbose_name='验证密钥',max_length=32)
生成迁移文件
python manage.py makemigrations
迁移数据模型
python manage.py migrate
查看ModelViewSet类源码
''' class ModelViewSet(mixins.CreateModelMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, GenericViewSet): """ A viewset that provides default `create()`, `retrieve()`, `update()`, `partial_update()`, `destroy()` and `list()` actions. """ pass '''
最终目的是往Token模型对应的表添加数据,所以得选择CreateModelMixin模型的源码查看。
''' class CreateModelMixin: """ Create a model instance. """ def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) def perform_create(self, serializer): serializer.save() def get_success_headers(self, data): try: return {'Location': str(data[api_settings.URL_FIELD_NAME])} except (TypeError, KeyError): return {} '''
查看得知,CreateModelMixin类下的create方法调用了serializer类的save方法创建数据。继续查看save方法。
通过serializers.ModelSerializer定位到serializers.py文件,搜索'def save('定位到以下内容。
''' def save(self, **kwargs): assert hasattr(self, '_errors'), ( 'You must call `.is_valid()` before calling `.save()`.' ) assert not self.errors, ( 'You cannot call `.save()` on a serializer with invalid data.' ) # Guard against incorrect use of `serializer.save(commit=False)` assert 'commit' not in kwargs, ( "'commit' is not a valid keyword argument to the 'save()' method. " "If you need to access data before committing to the database then " "inspect 'serializer.validated_data' instead. " "You can also pass additional keyword arguments to 'save()' if you " "need to set extra attributes on the saved model instance. " "For example: 'serializer.save(owner=request.user)'.'" ) assert not hasattr(self, '_data'), ( "You cannot call `.save()` after accessing `serializer.data`." "If you need to access data before committing to the database then " "inspect 'serializer.validated_data' instead. " ) validated_data = {**self.validated_data, **kwargs} if self.instance is not None: self.instance = self.update(self.instance, validated_data) assert self.instance is not None, ( '`update()` did not return an object instance.' ) else: self.instance = self.create(validated_data) assert self.instance is not None, ( '`create()` did not return an object instance.' ) '''
看最后这个if……else……语句中的self.instance = self.create(validated_data)。
说明这里调用了create方法,返回一个模型对象。于是查看ModelSerializer类的create方法。
''' def create(self, validated_data): """ We have a bit of extra checking around this in order to provide descriptive messages when something goes wrong, but this method is essentially just: return ExampleModel.objects.create(**validated_data) If there are many to many fields present on the instance then they cannot be set until the model is instantiated, in which case the implementation is like so: example_relationship = validated_data.pop('example_relationship') instance = ExampleModel.objects.create(**validated_data) instance.example_relationship = example_relationship return instance The default implementation also does not handle nested relationships. If you want to support writable nested relationships you'll need to write an explicit `.create()` method. """ raise_errors_on_nested_writes('create', self, validated_data) ModelClass = self.Meta.model # Remove many-to-many relationships from validated_data. # They are not valid arguments to the default `.create()` method, # as they require that the instance has already been saved. info = model_meta.get_field_info(ModelClass) many_to_many = {} for field_name, relation_info in info.relations.items(): if relation_info.to_many and (field_name in validated_data): many_to_many[field_name] = validated_data.pop(field_name) try: instance = ModelClass._default_manager.create(**validated_data) except TypeError: tb = traceback.format_exc() msg = ( 'Got a `TypeError` when calling `%s.%s.create()`. ' 'This may be because you have a writable field on the ' 'serializer class that is not a valid argument to ' '`%s.%s.create()`. You may need to make the field ' 'read-only, or override the %s.create() method to handle ' 'this correctly.\nOriginal exception was:\n %s' % ( ModelClass.__name__, ModelClass._default_manager.name, ModelClass.__name__, ModelClass._default_manager.name, self.__class__.__name__, tb ) ) raise TypeError(msg) # Save many-to-many relationships after the instance is created. if many_to_many: for field_name, value in many_to_many.items(): field = getattr(instance, field_name) field.set(value) return instance '''
这逻辑我是没看懂,但是通过print、type、dir函数可以确定
接收对象validated_data是一个字典,
返回对象instance是一个模型对象。
于是可以把源码cv过来,简单测试是否能够通。
import time import hashlib from rest_framework import status from rest_framework import serializers from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet from myapp import models as myapp_models class TokenSerializer(serializers.ModelSerializer): class Meta: model = myapp_models.Token fields = '__all__' def create(self,validated_data): ###################################### query_obj = myapp_models.Token.objects.update_or_create( username=validated_data['username'], defaults={"username":validated_data['username'],"token":validated_data['token']})[0] print(query_obj) return query_obj #------------------------------------# class LoginView(ModelViewSet): queryset = myapp_models.Token.objects.all() serializer_class = TokenSerializer def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
TokenSerializer
1.获取username和password。
2.验证username、password匹配性。
3.匹配错误:更新或创建模型中username对应的token为空字符串,返回模型对象。
4.匹配正确:通过md5加密生成token,更新或创建模型中username对应的token为密钥。
ModelViewSet
1.根据username查询token值。
2.将username、token值设置到session会话。
import time import hashlib from rest_framework import status from rest_framework import serializers from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet from myapp import models as myapp_models class TokenSerializer(serializers.ModelSerializer): class Meta: model = myapp_models.Token fields = '__all__' def create(self,validated_data): ###################################### user_obj = myapp_models.User.objects.filter( username=validated_data['username'], password=validated_data['token']) user_dict = validated_data user_dict['token'] = '' if not user_obj.exists(): query_obj = myapp_models.Token.objects.update_or_create( username=user_dict['username'], defaults={"username":user_dict['username'],"token":user_dict['token']})[0] return query_obj validated_data['token'] = hashlib.md5( ''.format(time.time(),''.join(validated_data.values())).encode()).hexdigest() query_obj = myapp_models.Token.objects.update_or_create( username=validated_data['username'], defaults={"username":validated_data['username'],"token":validated_data['token']})[0] print(query_obj) return query_obj #------------------------------------# class LoginView(ModelViewSet): queryset = myapp_models.Token.objects.all() serializer_class = TokenSerializer def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) ###################################### token_obj = myapp_models.Token.objects.filter( username=request.POST.get('username')).first() if token_obj.token == '': request.session['username'] = token_obj.username request.session['token'] = token_obj.token return Response('检查输入的账户和密码') request.session['username'] = token_obj.username request.session['token'] = token_obj.token #------------------------------------# return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
1.从session中获取username,token。
2.判断username,token是否不存在、或token是否为空字符串。
3.判断正确:抛出异常。
4.判断错误:范围username和模型对象组成的元组。
from rest_framework import exceptions from rest_framework.authentication import BaseAuthentication from myapp import models as myapp_models class Authentication(BaseAuthentication): def authenticate(self,request): ###################################### username = request._request.session.get('username','') token = request._request.session.get('token','') token_obj = myapp_models.Token.objects.filter( username=username,token=token) if not token_obj.exists or token_obj.first().token == '': raise exceptions.AuthenticationFailed('认证失败') return (token_obj.first().username,token_obj.first()) #------------------------------------#
path('login/',myapp_views.LoginView.as_view({ 'post':'create'}),name='login')