在前面的内容中,我们已经由浅入深的理解了Netty的基础知识和实现原理,相信大家已经对Netty有了一个较为全面的理解。那么接下来,我们通过一个手写RPC通信的实战案例来带大家了解Netty的实际应用。
为什么要选择RPC来作为实战呢?因为Netty本身就是解决通信问题,而在实际应用中,RPC协议框架是我们接触得最多的一种,所以这个实战能让大家了解到Netty实际应用之外,还能理解RPC的底层原理。
RPC全称为(Remote Procedure Call),是一种通过网络从远程计算机程序上请求服务,而不需要了解底层网络技术的协议,简单理解就是让开发者能够像调用本地服务一样调用远程服务。
既然是协议,那么它必然有协议的规范,如图6-1所示。
为了达到“让开发者能够像调用本地服务那样调用远程服务”的目的,RPC协议需像图6-1那样实现远程交互。
凡是满足RPC协议的框架,我们成为RPC框架,在实际开发中,我们可以使用开源且相对成熟的RPC框架解决微服务架构下的远程通信问题,常见的rpc框架:
基于上文中对于RPC协议的理解,如果我们自己去实现,需要考虑哪些技术呢? 其实基于图6-1的整个流程应该有一个大概的理解。
理解了RPC协议后,我们基于Netty来实现一个RPC通信框架。
代码详见附件 netty-rpc-example
需要引入的jar包:
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter</artifactId> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> </dependency> <dependency> <groupId>com.alibaba</groupId> <artifactId>fastjson</artifactId> <version>1.2.72</version> </dependency> <dependency> <groupId>io.netty</groupId> <artifactId>netty-all</artifactId> </dependency>
模块依赖关系:
provider依赖 netty-rpc-protocol和netty-rpc-api
cosumer依赖 netty-rpc-protocol和netty-rpc-api
public interface IUserService { String saveUser(String name); }
@Service @Slf4j public class UserServiceImpl implements IUserService { @Override public String saveUser(String name) { log.info("begin saveUser:"+name); return "Save User Success!"; } }
注意,在当前步骤中,描述了case的部分,暂时先不用加,后续再加上
@ComponentScan(basePackages = {"com.example.spring","com.example.service"}) //case1(后续再加上) @SpringBootApplication public class NettyRpcProviderMain { public static void main(String[] args) throws Exception { SpringApplication.run(NettyRpcProviderMain.class, args); new NettyServer("127.0.0.1",8080).startNettyServer(); //case2(后续再加上) } }
开始写通信协议模块,这个模块主要做几个事情
之前我们讲过自定义消息协议,我们在这里可以按照下面这个协议格式来定义好。
/* +----------------------------------------------+ | 魔数 2byte | 序列化算法 1byte | 请求类型 1byte | +----------------------------------------------+ | 消息 ID 8byte | 数据长度 4byte | +----------------------------------------------+ */
@AllArgsConstructor @Data public class Header implements Serializable { /* +----------------------------------------------+ | 魔数 2byte | 序列化算法 1byte | 请求类型 1byte | +----------------------------------------------+ | 消息 ID 8byte | 数据长度 4byte | +----------------------------------------------+ */ private short magic; //魔数-用来验证报文的身份(2个字节) private byte serialType; //序列化类型(1个字节) private byte reqType; //操作类型(1个字节) private long requestId; //请求id(8个字节) private int length; //数据长度(4个字节) }
@Data public class RpcRequest implements Serializable { private String className; private String methodName; private Object[] params; private Class<?>[] parameterTypes; }
@Data public class RpcResponse implements Serializable { private Object data; private String msg; }
@Data public class RpcProtocol<T> implements Serializable { private Header header; private T content; }
上述消息协议定义中,涉及到几个枚举相关的类,定义如下
消息类型
public enum ReqType { REQUEST((byte)1), RESPONSE((byte)2), HEARTBEAT((byte)3); private byte code; private ReqType(byte code) { this.code=code; } public byte code(){ return this.code; } public static ReqType findByCode(int code) { for (ReqType msgType : ReqType.values()) { if (msgType.code() == code) { return msgType; } } return null; } }
序列化类型
public enum SerialType { JSON_SERIAL((byte)0), JAVA_SERIAL((byte)1); private byte code; SerialType(byte code) { this.code=code; } public byte code(){ return this.code; } }
public class RpcConstant { //header部分的总字节数 public final static int HEAD_TOTAL_LEN=16; //魔数 public final static short MAGIC=0xca; }
这里演示两种,一种是JSON方式,另一种是Java原生的方式
public interface ISerializer { <T> byte[] serialize(T obj); <T> T deserialize(byte[] data,Class<T> clazz); byte getType(); }
public class JavaSerializer implements ISerializer{ @Override public <T> byte[] serialize(T obj) { ByteArrayOutputStream byteArrayOutputStream= new ByteArrayOutputStream(); try { ObjectOutputStream outputStream= new ObjectOutputStream(byteArrayOutputStream); outputStream.writeObject(obj); return byteArrayOutputStream.toByteArray(); } catch (IOException e) { e.printStackTrace(); } return new byte[0]; } @Override public <T> T deserialize(byte[] data, Class<T> clazz) { ByteArrayInputStream byteArrayInputStream=new ByteArrayInputStream(data); try { ObjectInputStream objectInputStream= new ObjectInputStream(byteArrayInputStream); return (T) objectInputStream.readObject(); } catch (IOException e) { e.printStackTrace(); } catch (ClassNotFoundException e) { e.printStackTrace(); } return null; } @Override public byte getType() { return SerialType.JAVA_SERIAL.code(); } }
public class JsonSerializer implements ISerializer{ @Override public <T> byte[] serialize(T obj) { return JSON.toJSONString(obj).getBytes(); } @Override public <T> T deserialize(byte[] data, Class<T> clazz) { return JSON.parseObject(new String(data),clazz); } @Override public byte getType() { return SerialType.JSON_SERIAL.code(); } }
实现对序列化机制的管理
public class SerializerManager { private final static ConcurrentHashMap<Byte, ISerializer> serializers=new ConcurrentHashMap<Byte, ISerializer>(); static { ISerializer jsonSerializer=new JsonSerializer(); ISerializer javaSerializer=new JavaSerializer(); serializers.put(jsonSerializer.getType(),jsonSerializer); serializers.put(javaSerializer.getType(),javaSerializer); } public static ISerializer getSerializer(byte key){ ISerializer serializer=serializers.get(key); if(serializer==null){ return new JavaSerializer(); } return serializer; } }
由于自定义了消息协议,所以 需要自己实现编码和解码,代码如下
@Slf4j public class RpcDecoder extends ByteToMessageDecoder { /* +----------------------------------------------+ | 魔数 2byte | 序列化算法 1byte | 请求类型 1byte | +----------------------------------------------+ | 消息 ID 8byte | 数据长度 4byte | +----------------------------------------------+ */ @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception { log.info("==========begin RpcDecoder =============="); if(in.readableBytes()< RpcConstant.HEAD_TOTAL_LEN){ //消息长度不够,不需要解析 return; } in.markReaderIndex();//标记一个读取数据的索引,后续用来重置。 short magic=in.readShort(); //读取magic if(magic!=RpcConstant.MAGIC){ throw new IllegalArgumentException("Illegal request parameter 'magic',"+magic); } byte serialType=in.readByte(); //读取序列化算法类型 byte reqType=in.readByte(); //请求类型 long requestId=in.readLong(); //请求消息id int dataLength=in.readInt(); //请求数据长度 //可读区域的字节数小于实际数据长度 if(in.readableBytes()<dataLength){ in.resetReaderIndex(); return; } //读取消息内容 byte[] content=new byte[dataLength]; in.readBytes(content); //构建header头信息 Header header=new Header(magic,serialType,reqType,requestId,dataLength); ISerializer serializer=SerializerManager.getSerializer(serialType); ReqType rt=ReqType.findByCode(reqType); switch(rt){ case REQUEST: RpcRequest request=serializer.deserialize(content, RpcRequest.class); RpcProtocol<RpcRequest> reqProtocol=new RpcProtocol<>(); reqProtocol.setHeader(header); reqProtocol.setContent(request); out.add(reqProtocol); break; case RESPONSE: RpcResponse response=serializer.deserialize(content,RpcResponse.class); RpcProtocol<RpcResponse> resProtocol=new RpcProtocol<>(); resProtocol.setHeader(header); resProtocol.setContent(response); out.add(resProtocol); break; case HEARTBEAT: break; default: break; } } }
@Slf4j public class RpcEncoder extends MessageToByteEncoder<RpcProtocol<Object>> { /* +----------------------------------------------+ | 魔数 2byte | 序列化算法 1byte | 请求类型 1byte | +----------------------------------------------+ | 消息 ID 8byte | 数据长度 4byte | +----------------------------------------------+ */ @Override protected void encode(ChannelHandlerContext ctx, RpcProtocol<Object> msg, ByteBuf out) throws Exception { log.info("=============begin RpcEncoder============"); Header header=msg.getHeader(); out.writeShort(header.getMagic()); //写入魔数 out.writeByte(header.getSerialType()); //写入序列化类型 out.writeByte(header.getReqType());//写入请求类型 out.writeLong(header.getRequestId()); //写入请求id ISerializer serializer= SerializerManager.getSerializer(header.getSerialType()); byte[] data=serializer.serialize(msg.getContent()); //序列化 header.setLength(data.length); out.writeInt(data.length); //写入消息长度 out.writeBytes(data); } }
实现NettyServer构建。
@Slf4j public class NettyServer{ private String serverAddress; //地址 private int serverPort; //端口 public NettyServer(String serverAddress, int serverPort) { this.serverAddress = serverAddress; this.serverPort = serverPort; } public void startNettyServer() throws Exception { log.info("begin start Netty Server"); EventLoopGroup bossGroup=new NioEventLoopGroup(); EventLoopGroup workGroup=new NioEventLoopGroup(); try { ServerBootstrap bootstrap = new ServerBootstrap(); bootstrap.group(bossGroup, workGroup) .channel(NioServerSocketChannel.class) .childHandler(new RpcServerInitializer()); ChannelFuture channelFuture = bootstrap.bind(this.serverAddress, this.serverPort).sync(); log.info("Server started Success on Port:{}", this.serverPort); channelFuture.channel().closeFuture().sync(); }catch (Exception e){ log.error("Rpc Server Exception",e); }finally { workGroup.shutdownGracefully(); bossGroup.shutdownGracefully(); } } }
public class RpcServerInitializer extends ChannelInitializer<SocketChannel> { @Override protected void initChannel(SocketChannel ch) throws Exception { ch.pipeline() .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE,12,4,0,0)) .addLast(new RpcDecoder()) .addLast(new RpcEncoder()) .addLast(new RpcServerHandler()); } }
public class RpcServerHandler extends SimpleChannelInboundHandler<RpcProtocol<RpcRequest>> { @Override protected void channelRead0(ChannelHandlerContext ctx, RpcProtocol<RpcRequest> msg) throws Exception { RpcProtocol resProtocol=new RpcProtocol<>(); Header header=msg.getHeader(); header.setReqType(ReqType.RESPONSE.code()); Object result=invoke(msg.getContent()); resProtocol.setHeader(header); RpcResponse response=new RpcResponse(); response.setData(result); response.setMsg("success"); resProtocol.setContent(response); ctx.writeAndFlush(resProtocol); } private Object invoke(RpcRequest request){ try { Class<?> clazz=Class.forName(request.getClassName()); Object bean= SpringBeansManager.getBean(clazz); //获取实例对象(CASE) Method declaredMethod=clazz.getDeclaredMethod(request.getMethodName(),request.getParameterTypes()); return declaredMethod.invoke(bean,request.getParams()); } catch (ClassNotFoundException | NoSuchMethodException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } catch (InvocationTargetException e) { e.printStackTrace(); } return null; } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { super.exceptionCaught(ctx, cause); } }
@Component public class SpringBeansManager implements ApplicationContextAware { private static ApplicationContext applicationContext; @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { SpringBeansManager.applicationContext=applicationContext; } public static <T> T getBean(Class<T> clazz){ return applicationContext.getBean(clazz); } }
需要注意,这个类的构建好之后,需要在netty-rpc-provider模块的main方法中增加compone-scan进行扫描
@ComponentScan(basePackages = {"com.example.spring","com.example.service"}) //修改这里 @SpringBootApplication public class NettyRpcProviderMain { public static void main(String[] args) throws Exception { SpringApplication.run(NettyRpcProviderMain.class, args); new NettyServer("127.0.0.1",8080).startNettyServer(); // 修改这里 } }
接下来开始实现消费端
public class RpcClientProxy { public <T> T clientProxy(final Class<T> interfaceCls,final String host,final int port){ return (T) Proxy.newProxyInstance (interfaceCls.getClassLoader(), new Class<?>[]{interfaceCls}, new RpcInvokerProxy(host,port)); } }
@Slf4j public class RpcInvokerProxy implements InvocationHandler { private String serviceAddress; private int servicePort; public RpcInvokerProxy(String serviceAddress, int servicePort) { this.serviceAddress = serviceAddress; this.servicePort = servicePort; } @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { log.info("begin invoke target server"); //组装参数 RpcProtocol<RpcRequest> protocol=new RpcProtocol<>(); long requestId= RequestHolder.REQUEST_ID.incrementAndGet(); Header header=new Header(RpcConstant.MAGIC, SerialType.JSON_SERIAL.code(), ReqType.REQUEST.code(),requestId,0); protocol.setHeader(header); RpcRequest request=new RpcRequest(); request.setClassName(method.getDeclaringClass().getName()); request.setMethodName(method.getName()); request.setParameterTypes(method.getParameterTypes()); request.setParams(args); protocol.setContent(request); //发送请求 NettyClient nettyClient=new NettyClient(serviceAddress,servicePort); //构建异步数据处理 RpcFuture<RpcResponse> future=new RpcFuture<>(new DefaultPromise<>(new DefaultEventLoop())); RequestHolder.REQUEST_MAP.put(requestId,future); nettyClient.sendRequest(protocol); return future.getPromise().get().getData(); } }
在netty-rpc-protocol这个模块的protocol包路径下,创建NettyClient
@Slf4j public class NettyClient { private final Bootstrap bootstrap; private final EventLoopGroup eventLoopGroup=new NioEventLoopGroup(); private String serviceAddress; private int servicePort; public NettyClient(String serviceAddress,int servicePort){ log.info("begin init NettyClient"); bootstrap=new Bootstrap(); bootstrap.group(eventLoopGroup) .channel(NioSocketChannel.class) .handler(new RpcClientInitializer()); this.serviceAddress=serviceAddress; this.servicePort=servicePort; } public void sendRequest(RpcProtocol<RpcRequest> protocol) throws InterruptedException { ChannelFuture future=bootstrap.connect(this.serviceAddress,this.servicePort).sync(); future.addListener(listener->{ if(future.isSuccess()){ log.info("connect rpc server {} success.",this.serviceAddress); }else{ log.error("connect rpc server {} failed .",this.serviceAddress); future.cause().printStackTrace(); eventLoopGroup.shutdownGracefully(); } }); log.info("begin transfer data"); future.channel().writeAndFlush(protocol); } }
@Slf4j public class RpcClientInitializer extends ChannelInitializer<SocketChannel> { @Override protected void initChannel(SocketChannel ch) throws Exception { log.info("begin initChannel"); ch.pipeline() .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE,12,4,0,0)) .addLast(new LoggingHandler()) .addLast(new RpcEncoder()) .addLast(new RpcDecoder()) .addLast(new RpcClientHandler()); } }
需要注意,Netty的通信过程是基于入站出站分离的,所以在获取结果时,我们需要借助一个Future对象来完成。
@Slf4j public class RpcClientHandler extends SimpleChannelInboundHandler<RpcProtocol<RpcResponse>> { @Override protected void channelRead0(ChannelHandlerContext ctx, RpcProtocol<RpcResponse> msg) throws Exception { log.info("receive rpc server result"); long requestId=msg.getHeader().getRequestId(); RpcFuture<RpcResponse> future=RequestHolder.REQUEST_MAP.remove(requestId); future.getPromise().setSuccess(msg.getContent()); //返回结果 } }
在netty-rpc-protocol模块中添加rpcFuture实现
@Data public class RpcFuture<T> { //Promise是可写的 Future, Future自身并没有写操作相关的接口, // Netty通过 Promise对 Future进行扩展,用于设置IO操作的结果 private Promise<T> promise; public RpcFuture(Promise<T> promise) { this.promise = promise; } }
保存requestid和future的对应结果
public class RequestHolder { public static final AtomicLong REQUEST_ID=new AtomicLong(); public static final Map<Long,RpcFuture> REQUEST_MAP=new ConcurrentHashMap<>(); }
版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自
Mic带你学架构
!
如果本篇文章对您有帮助,还请帮忙点个关注和赞,您的坚持是我不断创作的动力。