负载均衡的目的是将请求按照某种策略分布到多台机器上,使得系统能够实现横向扩展。
现在我们来简单实现这些算法,首先我们要有一个服务提供者类。
/**
* 服务提供者
*/
@AllArgsConstructor
@Data
@Builder
@NoArgsConstructor
public class ProviderService {
private Class<?> serviceItf;
private transient Object serviceObject;
private transient Method serviceMethod;
private String serverIp;
private int serverPort;
private long timeout;
//该服务提供者权重
private int weight;
//服务端线程数
private int workerThreads;
//服务提供者唯一标识
private String appKey;
//服务分组组名
private String groupName;
public ProviderService copy() {
return ProviderService.builder()
.serviceItf(this.serviceItf)
.serviceObject(this.serviceObject)
.serviceMethod(this.serviceMethod)
.serverIp(this.serverIp)
.serverPort(this.serverPort)
.timeout(this.timeout)
.weight(this.weight)
.workerThreads(this.workerThreads)
.appKey(this.appKey)
.groupName(this.groupName)
.build();
}
@Override
public String toString() {
return serverIp + ":" + serverPort;
}
}
现在我们要在一个服务提供者列表中拿取其中的一个服务。对于不同的负载均衡算法,我们实现一个接口
/**
* 负载均衡策略
*/
public interface ClusterStrategy {
/**
* 拿取一个服务
* @param providerServices 服务提供者列表
* @return
*/
ProviderService select(List<ProviderService> providerServices);
}
第一个最简单的当然是随机算法
/**
* 随机算法
*/
public class RandomClusterStrategy implements ClusterStrategy {
@Override
public ProviderService select(List<ProviderService> providerServices) {
int maxLen = providerServices.size();
Random random = new Random();
int index = random.nextInt(maxLen);
return providerServices.get(index);
}
}
然后是加权随机算法
/**
* 加权随机算法
*/
public class WeightRandomClusterStrategy implements ClusterStrategy {
@Override
public ProviderService select(List<ProviderService> providerServices) {
List<ProviderService> providerServiceList = new ArrayList<>();
providerServices.stream().forEach(providerService -> {
int weight = providerService.getWeight();
for (int i = 0;i < weight;i++) {
providerServiceList.add(providerService.copy());
}
});
int maxLen = providerServiceList.size();
Random random = new Random();
int index = random.nextInt(maxLen);
return providerServices.get(index);
}
}
轮询算法
/**
* 轮询算法
*/
public class PollingClusterStrategy implements ClusterStrategy {
private AtomicInteger index = new AtomicInteger(0);
@Override
public ProviderService select(List<ProviderService> providerServices) {
if (index.get() > providerServices.size()) {
index.set(0);
}
return providerServices.get(index.getAndIncrement());
}
}
加权轮询算法
/**
* 加权轮询算法
*/
public class WeightPollingClusterStrategy implements ClusterStrategy {
private AtomicInteger index = new AtomicInteger(0);
@Override
public ProviderService select(List<ProviderService> providerServices) {
List<ProviderService> providerServiceList = new ArrayList<>();
providerServices.stream().forEach(providerService -> {
int weight = providerService.getWeight();
for (int i = 0;i < weight;i++) {
providerServiceList.add(providerService.copy());
}
});
if (index.get() > providerServiceList.size()) {
index.set(0);
}
return providerServices.get(index.getAndIncrement());
}
}
源地址hash算法
/**
* 源地址hash算法
*/
public class HashClusterStrategy implements ClusterStrategy {
@Override
public ProviderService select(List<ProviderService> providerServices) {
String ip = IpUtils.getHostIp();
int hashCode = ip.hashCode();
int size = providerServices.size();
return providerServices.get(hashCode % size);
}
}
其中IpUntils代码如下
@Slf4j
public class IpUtils {
public static String getHostIp() {
String ip = null;
try {
//枚举本机所有的网络接口
Enumeration<NetworkInterface> en = NetworkInterface
.getNetworkInterfaces();
while (en.hasMoreElements()) {
NetworkInterface intf = (NetworkInterface) en.nextElement();
//遍历所有Ip
Enumeration<InetAddress> enumIpAddr = intf.getInetAddresses();
while (enumIpAddr.hasMoreElements()) {
InetAddress inetAddress = (InetAddress) enumIpAddr
.nextElement();
//获取类似192.168的内网IP
if (!inetAddress.isLoopbackAddress() //isLoopbackAddress()是否是本机的IP地址(127开头的,一般指127.0.0.1)
&& !inetAddress.isLinkLocalAddress() //isLinkLocalAddress()是否是本地连接地址(任意开头)
&& inetAddress.isSiteLocalAddress()) { //isSiteLocalAddress()是否是地区本地地址(192.168段或其他内网IP)
ip = inetAddress.getHostAddress();
}
}
}
} catch (SocketException e) {
log.error("Fail to get IP address.", e);
}
return ip;
}
public static String getHostName() {
String hostName = null;
try {
Enumeration<NetworkInterface> en = NetworkInterface
.getNetworkInterfaces();
while (en.hasMoreElements()) {
NetworkInterface intf = (NetworkInterface) en.nextElement();
Enumeration<InetAddress> enumIpAddr = intf.getInetAddresses();
while (enumIpAddr.hasMoreElements()) {
InetAddress inetAddress = (InetAddress) enumIpAddr
.nextElement();
if (!inetAddress.isLoopbackAddress()
&& !inetAddress.isLinkLocalAddress()
&& inetAddress.isSiteLocalAddress()) {
hostName = inetAddress.getHostName();
}
}
}
} catch (SocketException e) {
log.error("Fail to get host name.", e);
}
return hostName;
}
}