本篇文章是对以前学习MVVM,RxSwift,以及机器学习的一些总结和实践。其中涉及的知识点:
关于这些技术的理论知识点,以后有时间我会一项一项做学习笔记,详细编码可以参看demo。
本项目是一个拥有登录、同意、指引、人脸识别4大模块的简单人脸识别项目,麻雀虽小,五脏俱全。特点,使用MVVM架构,层次较为分明,数据层和逻辑层,展示层基本都完成分层,可以自行进行剥离模块化,同时数据层加入假数据模块可以在开发时候自由加入模仿数据。(时间精力有限没有进行模块化开发,这个后面有时间会针对大型项目的模块化再进行实践)
因为篇幅有限,简单的部分请移步看demo。
- UserSessionRepository.swift (整个用户数据协议接口)
- RemoteAPI.swift (网络请求协议接口,具体实现可以自由切换假数据和真实数据)
- UserSessionDataStore.swift (本地数据持久化协议接口,具体实现可以自由切换假数据和真实数据)
具体实现: APPUserSessionRepository.swift作为整个app数据接入口 继承UserSessionRepository.swift并实现接口,根据项目需要实现网络拉取数据(交由由实现RemoteAPI协议的功能模块Remote完成),本地数据持久化(交由实现UserSessionDataStore协议的功能模块Persistence完成)。
class APPUserSessionRepository: UserSessionRepository { let dataStore: UserSessionDataStore let remoteAPI: RemoteAPI init(dataStore: UserSessionDataStore, remoteAPI: RemoteAPI) { self.dataStore = dataStore self.remoteAPI = remoteAPI } ... ... 复制代码
机器学习基本理论部分如感兴趣请自行入门:
这里只描述大致流程(没有模型的情况下,有模型请从4开始部署到移动端):
- 1、明确机器学习目标,例如OCR、人脸识别、物品识别等,然后根据目的选择算法(或称网络结构)简单的如决策树,随机森林,支持向量机等等;复杂点的大神提供的网络结构FaceNet,MTCNN等,此项目是识别人脸选择FaceNet;
- 2、使用别人的网络结构,一般都有对应的训练方法和建议参数,如果你不添加新需求不修改网络结构就按照对方给的资料准备数据进行训练。(数据准备和预处理跟据情况自己搜集,基本都是用的开源数据,个人实在是没发找那么多);
- 3、训练完成后测试结果,根据实际效果调参数(利用GridSearch等),调整训练数据集和测试数据集;模型稳定后进行quant 缩减模型大小;
开始部署到移动端
- 4、使用转换工具(Google官方有提供)将训练出来的模型.pb 转为 .tflite ;
- 5、把训练的tflite 模型加入项目,导入对应机器学习框架;Swift使用"TensorFlowLiteSwift",OC使用"TensorFlowLite";其实2者本质都是C++版的TensorFlowLite,Swift版只是Google加了桥接便于Swift调用。
- 6 、按照模型要求输入对应数据(一般是一张图片的像素点),模型正常识别输出后会是一组或几组数据(根你的模型输出一致)。拿到输出数据后就需要进行后处理,转换成我们需要的数据模型。
下面我们看下具体核心实现代码:
- 1、模型初始化(输入数据参数和你训练模型输入的参数是对应的);
let threadCount: Int // 线程使用条数 let batchSize = 1 // 数据分为多少批次喂给模型,此处为1次全部输入 let inputChannels = 3 // 图片像素对应的3个通道R、G、B let inputWidth = 160 // 图片宽度像素值 let inputHeight = 160 // 图片高度像素值 private var interpreter: Interpreter // 模型解释器 private let alphaComponent = (baseOffset: 4, moduloRemainder: 0) //FileInfo 是自定义的关于模型的数据结构,具体请参看完整demo init?(modelFileInfo: FileInfo,threadCount: Int = 1) { let modelFilename = modelFileInfo.name guard let modelPath = Bundle.main.path(forResource: modelFilename, ofType: modelFileInfo.extension) else { print("Failed to load the model file with name: \(modelFilename).\(modelFileInfo.extension)") return nil } self.threadCount = threadCount var options = InterpreterOptions() options.threadCount = threadCount do { interpreter = try Interpreter(modelPath:modelPath, options: options) try interpreter.allocateTensors() }catch let error { print("Failed to create the interpreter with error: \(error.localizedDescription)") return nil } super.init() } ... ... 复制代码
- 2、把数据读入模型,获取返回数据并进行后处理(这里后处理用的C++,也可以用Swift,但是当时为了跨平台故选择了C++);
//this model just return only one biggest face func runModel(onFrame pixelBuffer: CVPixelBuffer) -> Result?{ let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) assert(sourcePixelFormat == kCVPixelFormatType_32ARGB || sourcePixelFormat == kCVPixelFormatType_32BGRA || sourcePixelFormat == kCVPixelFormatType_32RGBA) let imageChannels = 4 assert(imageChannels >= inputChannels) // 根据输入图片的像素值,把宽高转换为模型需要的宽高值,此处必须保证喂入数据和模型要求数据一致 guard let thumbnailPixelBuffer = resizePixelBuffer(pixelBuffer, width: inputWidth, height: inputHeight) else { return nil } let interval: TimeInterval //根据模型,这里可以有很多输出 let outputTensor0: Tensor //输出数据1 let outputTensor1: Tensor //输出数据2 // let outputTensor2: Tensor do { let inputTensor = try interpreter.input(at: 0) guard let rgbData = rgbDataFromBuffer( thumbnailPixelBuffer, byteCount: batchSize * inputWidth * inputHeight * inputChannels, isModelQuantized: inputTensor.dataType == .uInt8 ) else { print("Failed to convert the image buffer to RGB data.") return nil } try interpreter.copy(rgbData, toInputAt: 0) let startDate = Date() try interpreter.invoke() interval = Date().timeIntervalSince(startDate) * 1000 outputTensor0 = try interpreter.output(at: 0) outputTensor1 = try interpreter.output(at: 1) }catch let error { print("Failed to invoke the interpreter with error: \(error.localizedDescription)") return nil } var ouputresults0: [Float] var ouputresults1: [Float] ouputresults0 = [Float32](unsafeData: outputTensor0.data) ?? [] ouputresults1 = [Float32](unsafeData: outputTensor1.data) ?? [] var faceBoxs = [Float32](repeating: 0, count: 4)//[Float]() //对应后处理 C++ 编写的后处理 getDetection(&ouputresults0, &ouputresults1, Int32(inputWidth), Int32(inputHeight), 0.7, &faceBoxs) //结果 只会返回一个人脸 faceBoxs 4个点 对应 左上。右下 //筛选不合格的数据 if faceBoxs[0] > 0.0 && faceBoxs[1] > 0.0 && faceBoxs[2] > 0.0 && faceBoxs[3] > 0.0 { //将合格数据转换成我们需要的数据结构供程序使用 let resultFace = FaceBox(x_min: faceBoxs[0], y_min: faceBoxs[1], x_max: faceBoxs[2], y_max: faceBoxs[3]) let result = Result(inferenceTime: interval, faceboxes: [resultFace]) return result }else { return nil } } 复制代码
1、AppDependencyContainer,项目整体配置模块。
let sharedUserSessionRepository: UserSessionRepository //数据层 let sharedMainViewModel: MainViewModel //逻辑层 init() { func makeUserSessionRepository() -> UserSessionRepository { let dataStore = makeUserSessionDataStore() let remoteAPI = makeRemoteAPI() return APPUserSessionRepository(dataStore: dataStore, remoteAPI: remoteAPI) } // 用户数据切换,用户数据也使用假数据 func makeUserSessionDataStore() -> UserSessionDataStore { return FakeUserSessionDataStore() // return FileUserSessionDataStore() } //真实数据接口因为安全原因已经被删除,如果需要你可以直接套用 func makeRemoteAPI() -> RemoteAPI { //切换开发使用的加数据 return FakeRemoteAPI() //切换开发使用的加数据 // return OfficalRemoteAPI() //切换真实数据 } func makeMainViewModel() -> MainViewModel { return MainViewModel() } self.sharedMainViewModel = makeMainViewModel() self.sharedUserSessionRepository = makeUserSessionRepository() } 复制代码2、MainViewController,壳子,用于页面跳转,对应逻辑数据层MainViewModel。代码就不贴了,基本的基于RxSwift实现的MVVM结构。
本项目是学习RxSwift后的第一个实践项目,同时把POP和机器学习的内容也加进去了。学习的知识点得到了实践,但是项目架构依然不够清晰,模块基本没有封装。目前正在组件化学习,后续会使用组件化工程的思想再进行实践,同时里面一些知识点例如机器学习,RxSwift 的使用还需要进一步总结。