redis存储ip段问题
问题描述
有个需求是根据 ip 来获取对应的用户,从而让 ip 地址匹配的用户进行自动登录等操作。想到的解决方式是将 ip -> userId 的映射关系存储到 redis 中,然后根据 ip 获取对应的 userId。
ip 工具类如下:
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class IpUtils {
public static List<String> parseIpRange(String ipRange) {
List<String> ips = new ArrayList<>();
String[] parts = ipRange.split(";");
for (String part : parts) {
if (part.contains("-")) {
String[] range = part.split("-");
String startIp = range[0].trim();
String endIp = range[1].trim();
List<String> rangeIps = generateRangeIps(startIp, endIp);
ips.addAll(rangeIps);
} else {
ips.add(part.trim());
}
}
return ips;
}
private static List<String> generateRangeIps(String startIp, String endIp) {
List<String> ips = new ArrayList<>();
String[] startParts = startIp.split("\\.");
String[] endParts = endIp.split("\\.");
Integer[] intStartParts = Arrays.stream(startParts).map(Integer::parseInt).toArray(Integer[]::new);
Integer[] intEndParts = Arrays.stream(endParts).map(Integer::parseInt).toArray(Integer[]::new);
do {
ips.add(String.join(".", Arrays.stream(intStartParts).map(String::valueOf).toArray(String[]::new)));
plusIp(intStartParts);
} while (compareIp(intStartParts, intEndParts) <= 0);
for (int i = 0; i < 4; i++) {
if (!intStartParts[i].equals(intEndParts[i])) {
break;
}
}
return ips;
}
private static int compareIp(Integer[] ip1, Integer[] ip2) {
for (int i = 0; i < 4; i++) {
if (!ip1[i].equals(ip2[i])) {
return ip1[i] - ip2[i];
}
}
return 0;
}
private static void plusIp(Integer[] ip) {
ip[3]++;
for (int i = 3; i > 0; i--) {
if (ip[i] > 255) {
ip[i] = 0;
ip[i - 1]++;
}
}
}
}
jedis 工具类如下:
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Pipeline;
import javax.annotation.PostConstruct;
import java.util.List;
import java.util.Map;
@Component
public class JedisUtils {
private final JedisPool jedisPool;
private static JedisUtils self;
// 过期时间为2小时
public static final int EXPIRE_SECONDS = 7200;
@Autowired
public JedisUtils(JedisPool jedisPool) {
this.jedisPool = jedisPool;
}
@PostConstruct
public void init() {
self = this;
}
/**
* 批量删除redis中的key
*
* @param keys key列表
*/
public static void batchDeleteKey(List<String> keys) {
try (Jedis jedis = self.jedisPool.getResource()) {
Pipeline pipeline = jedis.pipelined();
for (String key : keys) {
pipeline.del(key);
}
pipeline.sync();
}
}
/**
* 批量插入redis中的key-value
*
* @param keyValues key-value列表
*/
public static void batchInsertKeyValues(Map<String, String> keyValues) {
try (Jedis jedis = self.jedisPool.getResource()) {
Pipeline pipeline = jedis.pipelined();
for (Map.Entry<String, String> entry : keyValues.entrySet()) {
pipeline.setex(entry.getKey(), EXPIRE_SECONDS, entry.getValue());
}
pipeline.sync();
}
}
/**
* 使用lua脚本清空所有数据并批量插入新数据,确保原子性
*
* @param keyValues key-value列表
*/
public static void flushAllAndBatchInsertKeyValues(Map<String, String> keyValues) {
try (Jedis jedis = self.jedisPool.getResource()) {
StringBuilder scriptBuilder = new StringBuilder();
scriptBuilder.append("redis.call('FLUSHDB');");
for (Map.Entry<String, String> entry : keyValues.entrySet()) {
scriptBuilder.append("redis.call('SETEX', '")
.append(entry.getKey()).append("', ")
.append(EXPIRE_SECONDS).append(", '")
.append(entry.getValue()).append("');");
}
jedis.eval(scriptBuilder.toString());
}
}
}
使用示例:
public class Main {
public static void main(String[] args) {
String ipRange = "192.168.2.2-192.168.3.4;192.168.3.3";
String userId = "123"
List<String> ips = IpUtils.parseIpRange(ipRange);
Map<String, String> keyValues = new HashMap<>();
for (String ip : ips) {
keyValues.put(ip, userId);
}
JedisUtils.flushAllAndBatchInsertKeyValues(keyValues);
}
}
但上述代码存在一个问题,如果 ip 段很大,比如 10.0.0.1-192.168.0.1,那么生成的 ip 列表会非常大,约(192-10) * 256^3 数量级达到了上亿,这样的数据量无论是生成还是存储都是非常大的开销。因此需要优化存储方式。
优化方案
由原来的存储每个 ip 改为存储 ip 段,startIp -> endIp:userId,这样可以减少存储的数据量。为了提高性能,使用有序集合 zset 来存储 startIp,同时 ip 采用整数形式存储,通过或运算将 ip 转换为整数,这样可以方便比较大小。
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Pipeline;
import javax.annotation.PostConstruct;
import java.util.List;
import java.util.Map;
import java.util.Set;
@Component
public class IpRangeStore {
private final JedisPool jedisPool;
private static IpRangeStore self;
// 过期时间为2小时
public static final int EXPIRE_SECONDS = 7200;
@Autowired
public IpRangeStore(JedisPool jedisPool) {
this.jedisPool = jedisPool;
}
@PostConstruct
public void init() {
self = this;
}
/**
* 批量添加ip段
*
* @param ipUserIdMap ip段-用户id映射
* @param deleteOrigin 是否删除原数据
*/
public static void batchAddIpRanges(Map<String, String> ipUserIdMap, boolean deleteOrigin) {
try (Jedis jedis = self.jedisPool.getResource()) {
Pipeline pipeline = jedis.pipelined();
if (deleteOrigin) {
pipeline.del("ip_ranges");
}
for (Map.Entry<String, String> entry : ipUserIdMap.entrySet()) {
String ipString = entry.getKey();
if (ipString.contains("-")) {
String[] ipRange = ipString.split("-");
long start = ipToLong(ipRange[0]);
long end = ipToLong(ipRange[1]);
pipeline.zadd("ip_ranges", start, end + ":" + entry.getValue());
} else {
long ip = ipToLong(ipString);
pipeline.zadd("ip_ranges", ip, ip + ":" + entry.getValue());
}
}
pipeline.expire("ip_ranges", EXPIRE_SECONDS);
pipeline.sync();
}
}
/**
* 批量删除指定IP
*
* @param ips 要删除的IP列表
*/
public static void batchDeleteIpRanges(List<String> ips) {
try (Jedis jedis = self.jedisPool.getResource()) {
Pipeline pipeline = jedis.pipelined();
for (String ip : ips) {
pipeline.zrem("ip_ranges", ip);
}
pipeline.sync();
}
}
public static Integer findUserIdByIp(String ipAddress) {
try (Jedis jedis = self.jedisPool.getResource()) {
String userId = findUserIdByIp(ipAddress, jedis);
return userId == null ? null : Integer.parseInt(userId);
}
}
private static long ipToLong(String ipAddress) {
String[] ipParts = ipAddress.split("\\.");
long ip = 0;
for (String ipPart : ipParts) {
ip <<= 8;
ip |= Integer.parseInt(ipPart);
}
return ip;
}
private static String findUserIdByIp(String ipAddress, Jedis jedis) {
long ip = ipToLong(ipAddress);
// 查询ip_ranges中离ip最近且小于等于ip的ip段
Set<String> results = jedis.zrevrangeByScore("ip_ranges", ip, 0, 0, 1);
if (!results.isEmpty()) {
String result = results.iterator().next();
String[] parts = result.split(":");
if (Long.parseLong(parts[0]) >= ip) {
return parts[1];
}
}
return null;
}
}