学习手写RPC框架

学习手写RPC框架

最近学习了dubbo 想自己试着写一个简单的rpc的框架

  • 使用socket做数据传输层
  • 使用JDK默认的序列化
  • 使用JDK默认的代理类

新建项目 maven 项目

image-20210317105002704

注册中心

注册中心主要有2个功能

  • 服务注册

    能提供服务的服务器要去服务器注册,并把自己的访问地址和注册的接口发送给注册中心让它知道。

  • 获取服务器

    调用方向注册中心获取对应接口的调用的服务地址和端口

RpcRegistrySocket.java 开启一个socket监听服务器

/**
* @author yzy
*/
public class RpcRegistrySocket {

ExecutorService executorService = new ScheduledThreadPoolExecutor(1, r -> new Thread(r, "mThread"));

public void start(int port){
try {
ServerSocket serverSocket = new ServerSocket(port);
while (true){
final Socket socket = serverSocket.accept();
executorService.execute(new RpcRegistryHandle(socket));
}
} catch (IOException e) {
e.printStackTrace();
}
}
}

具体的逻辑代码放到了 RpcRegistryHandle 里面处理客户端的每个请求

/**
* @author yzy
*/
public class RpcRegistryHandle implements Runnable {
private Socket socket;

public RpcRegistryHandle(Socket socket) {
this.socket = socket;
}

@Override
public void run() {
// 获取客户端发来的消息
ObjectInputStream inputStream = null;
ObjectOutputStream outputStream = null;
try {
inputStream = new ObjectInputStream(socket.getInputStream());
outputStream = new ObjectOutputStream(socket.getOutputStream());

SocketRpcRegistryEntity reg = (SocketRpcRegistryEntity) StreamUtils.getObject(inputStream);

if (reg == null) {
throw new RuntimeException("发送信息体有问题");
}
if (reg.getType() == 0) {
// 注册
RegistryHandle.add(reg);
StreamUtils.putObject(outputStream, reg.getHost() + ":" + reg.getPort() + "服务器注册成功");
System.out.println(reg.getHost() + ":" + reg.getPort() + "服务器注册成功");
} else if (reg.getType() == 1) {
// 获取服务器
SocketRpcServerEntity entity = RegistryHandle.get(reg.getServiceClassName());
StreamUtils.putObject(outputStream, entity);
System.out.println("获取服务器成功");
}


} catch (Exception e) {
e.printStackTrace();
} finally {
StreamUtils.close(inputStream);
StreamUtils.close(outputStream);
}

}
}

SocketRpcRegistryEntity 是封装的消息对象

/**
* @author yzy
*/
public class SocketRpcRegistryEntity implements Serializable {

/**
* 类型 0 注册 1获取
*/
private int type;
/**
* 服务器地址
*/
private String host;
/**
* 服务器端口
*/
private int port;
/**
* 请求服务器的服务名字(接口名字)
*/
private String serviceClassName;
/**
* 注册的服务名字(接口名字)
*/
private String[] classNames;
... get set
}

RegistryHandle.java 是服务器保存接口和请求服务器信息的地方

/**
* @author yzy
*/
public class RegistryHandle {

public static ConcurrentMap<String, SocketRpcServerEntity> map = new ConcurrentHashMap<>();

public static void add(SocketRpcRegistryEntity reg){
for (String className : reg.getClassNames()) {
map.put(className,new SocketRpcServerEntity(className,reg.getHost(),reg.getPort()));
}
}

public static SocketRpcServerEntity get(String className){
if (map.containsKey(className)){
return map.get(className);
}
return null;
}
}

启动方法

/**
* @author yzy
*/
public class Registry {
/**
* 注册服务器端口
*/
private static int regPort = 999;

public static void main(String[] args) {
System.out.println("注册中心启动成功");
new RpcRegistrySocket().start(regPort);
}

}

image-20210317134813435

服务提供方

服务提供方要做2件事

  • 向注册中心注册对应的接口调用
  • 等待客户端的发送的请求来执行对应的逻辑

RpcServerSocket.java 开启一个socket服务

/**
* @author yzy
*/
public class RpcServerSocket {

ExecutorService executorService = new ScheduledThreadPoolExecutor(1, r -> new Thread(r, "mThread"));

/**
* 启动服务
*
* @param port 监听端口
*/
public void start(int port) {
try {
ServerSocket serverSocket = new ServerSocket(port);
while (true) {
// bio 阻塞获取数据
final Socket socket = serverSocket.accept();
// 放到线程里处理
executorService.execute(new RpcServerHandle(socket));
}
} catch (IOException e) {
e.printStackTrace();
}
}
}

主要是接受调用方的请求逻辑在RpcServerHandle.java

/**
* @author yzy
*/
public class RpcServerHandle implements Runnable {
private Socket socket;

public RpcServerHandle(Socket socket) {
this.socket = socket;
}

@Override
public void run() {
// 主要这里处理逻辑

ObjectInputStream inputStream = null;
ObjectOutputStream outputStream = null;
try {
inputStream = new ObjectInputStream(socket.getInputStream());
outputStream = new ObjectOutputStream(socket.getOutputStream());
// 1.获取客户端发送来的请求
SocketRpcRequestEntity rpcRequest = (SocketRpcRequestEntity) StreamUtils.getObject(inputStream);
if (rpcRequest == null) {
throw new RuntimeException("请求参数有问题");
}
// 2.通过类名获取本地实现类
Class clazz = null;
if (ClassHandle.map.containsKey(rpcRequest.getClassName())) {
clazz = ClassHandle.map.get(rpcRequest.getClassName());
}
if (clazz == null) {
throw new RuntimeException("没有找到对应的实现类");
}
// 3.执行获得返回值
Method method = clazz.getMethod(rpcRequest.getMethodName(), rpcRequest.getTypes());
Object obj = method.invoke(clazz.newInstance(), rpcRequest.getArgs());

// 4.重新发送给客户端
StreamUtils.putObject(outputStream, obj);
} catch (Exception e) {
e.printStackTrace();
} finally {
StreamUtils.close(inputStream);
StreamUtils.close(outputStream);
}
}
}

ClassHandle是存储本地接口名字和实现类的关系的类

/**
* @author yzy
*/
public class ClassHandle {
public static ConcurrentMap<String,Class> map = new ConcurrentHashMap<>();
}

启动服务

  1. 去服务中心注册.
  2. 本地接口实现类映射的关系注册
  3. 开启服务监听
/**
* @author yzy
*/
public class Server {
/**
* 注册服务器地址
*/
private static String regHost = "localhost";
/**
* 注册服务器端口
*/
private static int regPort = 999;

/**
* 本地服务端口
*/
private static String localHost = "localhost";

/**
* 本地服务端口
*/
private static int localPort = 888;

public static void main(String[] args) {
// 向服务器注册
SocketRpcRegistryEntity reg = new SocketRpcRegistryEntity();
reg.setClassNames(new String[]{"cn.com.yangzhenyu.service.UserService"});
reg.setHost(localHost);
reg.setPort(localPort);
reg.setType(0);

ObjectOutputStream outputStream = null;
ObjectInputStream inputStream = null;
try {
Socket socket = new Socket(regHost, regPort);
outputStream = new ObjectOutputStream(socket.getOutputStream());
inputStream = new ObjectInputStream(socket.getInputStream());
StreamUtils.putObject(outputStream, reg);

String msg = (String) StreamUtils.getObject(inputStream);
System.out.println(msg);

} catch (Exception e) {
e.printStackTrace();
} finally {
StreamUtils.close(outputStream);
StreamUtils.close(outputStream);
}

// 本地注册映射关系
ClassHandle.map.put("cn.com.yangzhenyu.service.UserService", UserServiceImpl.class);

// 启动本地服务监听
System.out.println("启动成功");
new RpcServerSocket().start(localPort);
}
}

默认接口

/**
* @author yzy
*/
public interface UserService {

/**
* 获取用户通过ID
* @param id 用户ID
* @return 返回用户
*/
UserEntity getUserById(int id);
}

默认实现类

/**
* @author yzy
*/
public class UserServiceImpl implements UserService {
@Override
public UserEntity getUserById(int id) {
return new UserEntity(id, "yzy");
}
}

客户端请求的封装SocketRpcRequestEntity.java

/**
* 发送RPC请求的消息封装
* @author yzy
*/
public class SocketRpcRequestEntity implements Serializable {

/**
* 类名
*/
private String className;
/**
* 方法名
*/
private String methodName;
/**
* 方法参数
*/
private Object[] args;
/**
* 方法参数类型
*/
private Class[] types;
... get set
}

启动

image-20210317135815053

程序调用方

调用方(客户端)要做的事?

  1. 要使用一个代理类去实现接口的实现
  2. 代理类要去注册中心根据接口名字获取对应服务的地址
  3. 访问对应的服务地址把参数序列化封装传过去
  4. 等待服务端返回序列化后的数据
  5. 把对应的数据反序列化回来

主程序Main

/**
* @author yzy
*/
public class Main {
public static void main(String[] args) {
long start = System.currentTimeMillis();
RpcProxy rpcProxy = new RpcProxy();

UserService userService = (UserService) rpcProxy.call(UserService.class, "localhost", 999);
UserEntity user = userService.getUserById(1);
System.out.println(user.toString());
System.out.println("时间:" + (System.currentTimeMillis() - start));
}
}

代理类RpcProxy.java

/**
* @author yzy
*/
public class RpcProxy<T> {
public T call(Class<T> clazz, String host, int port) {
return (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class[]{clazz}, new RpcHandler(clazz.getName(), host, port));
}
}

这里使用的JDK的动态代理 只能针对接口 具体的处理逻辑转移到了 RpcHandler上

/**
* @author yzy
*/
public class RpcHandler implements InvocationHandler {

private String host;
private String className;
private int port;

public <T> RpcHandler(String className, String host, int port) {
this.host = host;
this.port = port;
this.className = className;
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {

// 从注册中心获取请求服务器
SocketRpcRegistryEntity reg = new SocketRpcRegistryEntity();
reg.setType(1);
reg.setServiceClassName(className);

Socket socket = new Socket(host, port);

ObjectOutputStream outputStream = null;
ObjectInputStream inputStream = null;

ObjectOutputStream rpcOutputStream = null;
ObjectInputStream rpcInputStream = null;
try {
outputStream = new ObjectOutputStream(socket.getOutputStream());
StreamUtils.putObject(outputStream, reg);

inputStream = new ObjectInputStream(socket.getInputStream());
SocketRpcServerEntity rpcServer = (SocketRpcServerEntity) StreamUtils.getObject(inputStream);
if (rpcServer == null) {
throw new RuntimeException("没有获取到对应服务器");
}

// 远程调用服务区
Socket socketRpc = new Socket(rpcServer.getHost(), rpcServer.getPort());
rpcOutputStream = new ObjectOutputStream(socketRpc.getOutputStream());
SocketRpcRequestEntity entity = new SocketRpcRequestEntity();
entity.setArgs(args);
entity.setClassName(className);
entity.setMethodName(method.getName());
entity.setTypes(method.getParameterTypes());

StreamUtils.putObject(rpcOutputStream, entity);

rpcInputStream = new ObjectInputStream(socketRpc.getInputStream());

return StreamUtils.getObject(rpcInputStream);

} finally {
StreamUtils.close(inputStream);
StreamUtils.close(outputStream);
StreamUtils.close(rpcInputStream);
StreamUtils.close(rpcOutputStream);
}
}
}

使用socket连接发送数据接收数据

运行效果

image-20210317140519039

image-20210317140431418

简单的rpc框架就这么实现了.

源码地址

https://github.com/yttrium2016/yzy-rpc

文章作者: 杨振宇
文章链接: https://www.yangzhenyu.com.cn/35693/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 杨振宇 个人经验
支付宝打赏
微信打赏