问题描述

有个需求是根据 ip 来获取对应的用户,从而让 ip 地址匹配的用户进行自动登录等操作。想到的解决方式是将 ip -> userId 的映射关系存储到 redis 中,然后根据 ip 获取对应的 userId。

ip 工具类如下:


import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class IpUtils {

    public static List<String> parseIpRange(String ipRange) {
        List<String> ips = new ArrayList<>();
        String[] parts = ipRange.split(";");
        for (String part : parts) {
            if (part.contains("-")) {
                String[] range = part.split("-");
                String startIp = range[0].trim();
                String endIp = range[1].trim();
                List<String> rangeIps = generateRangeIps(startIp, endIp);
                ips.addAll(rangeIps);
            } else {
                ips.add(part.trim());
            }
        }
        return ips;
    }

    private static List<String> generateRangeIps(String startIp, String endIp) {
        List<String> ips = new ArrayList<>();
        String[] startParts = startIp.split("\\.");
        String[] endParts = endIp.split("\\.");
        Integer[] intStartParts = Arrays.stream(startParts).map(Integer::parseInt).toArray(Integer[]::new);
        Integer[] intEndParts = Arrays.stream(endParts).map(Integer::parseInt).toArray(Integer[]::new);

        do {
            ips.add(String.join(".", Arrays.stream(intStartParts).map(String::valueOf).toArray(String[]::new)));
            plusIp(intStartParts);
        } while (compareIp(intStartParts, intEndParts) <= 0);

        for (int i = 0; i < 4; i++) {
            if (!intStartParts[i].equals(intEndParts[i])) {
                break;
            }
        }

        return ips;
    }

    private static int compareIp(Integer[] ip1, Integer[] ip2) {
        for (int i = 0; i < 4; i++) {
            if (!ip1[i].equals(ip2[i])) {
                return ip1[i] - ip2[i];
            }
        }
        return 0;
    }

    private static void plusIp(Integer[] ip) {
        ip[3]++;
        for (int i = 3; i > 0; i--) {
            if (ip[i] > 255) {
                ip[i] = 0;
                ip[i - 1]++;
            }
        }
    }

}

jedis 工具类如下:

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Pipeline;

import javax.annotation.PostConstruct;
import java.util.List;
import java.util.Map;

@Component
public class JedisUtils {

    private final JedisPool jedisPool;

    private static JedisUtils self;

    // 过期时间为2小时
    public static final int EXPIRE_SECONDS = 7200;

    @Autowired
    public JedisUtils(JedisPool jedisPool) {
        this.jedisPool = jedisPool;
    }

    @PostConstruct
    public void init() {
        self = this;
    }

    /**
     * 批量删除redis中的key
     *
     * @param keys key列表
     */
    public static void batchDeleteKey(List<String> keys) {
        try (Jedis jedis = self.jedisPool.getResource()) {
            Pipeline pipeline = jedis.pipelined();
            for (String key : keys) {
                pipeline.del(key);
            }
            pipeline.sync();
        }
    }

    /**
     * 批量插入redis中的key-value
     *
     * @param keyValues key-value列表
     */
    public static void batchInsertKeyValues(Map<String, String> keyValues) {
        try (Jedis jedis = self.jedisPool.getResource()) {
            Pipeline pipeline = jedis.pipelined();
            for (Map.Entry<String, String> entry : keyValues.entrySet()) {
                pipeline.setex(entry.getKey(), EXPIRE_SECONDS, entry.getValue());
            }
            pipeline.sync();
        }
    }

    /**
     * 使用lua脚本清空所有数据并批量插入新数据,确保原子性
     *
     * @param keyValues key-value列表
     */
    public static void flushAllAndBatchInsertKeyValues(Map<String, String> keyValues) {
        try (Jedis jedis = self.jedisPool.getResource()) {
            StringBuilder scriptBuilder = new StringBuilder();
            scriptBuilder.append("redis.call('FLUSHDB');");
            for (Map.Entry<String, String> entry : keyValues.entrySet()) {
                scriptBuilder.append("redis.call('SETEX', '")
                        .append(entry.getKey()).append("', ")
                        .append(EXPIRE_SECONDS).append(", '")
                        .append(entry.getValue()).append("');");
            }
            jedis.eval(scriptBuilder.toString());
        }
    }

}

使用示例:

public class Main {

    public static void main(String[] args) {
        String ipRange = "192.168.2.2-192.168.3.4;192.168.3.3";
        String userId = "123"
        List<String> ips = IpUtils.parseIpRange(ipRange);
        Map<String, String> keyValues = new HashMap<>();
        for (String ip : ips) {
            keyValues.put(ip, userId);
        }
        JedisUtils.flushAllAndBatchInsertKeyValues(keyValues);
    }

}

但上述代码存在一个问题,如果 ip 段很大,比如 10.0.0.1-192.168.0.1,那么生成的 ip 列表会非常大,约(192-10) * 256^3 数量级达到了上亿,这样的数据量无论是生成还是存储都是非常大的开销。因此需要优化存储方式。

优化方案

由原来的存储每个 ip 改为存储 ip 段,startIp -> endIp:userId,这样可以减少存储的数据量。为了提高性能,使用有序集合 zset 来存储 startIp,同时 ip 采用整数形式存储,通过或运算将 ip 转换为整数,这样可以方便比较大小。

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.Pipeline;

import javax.annotation.PostConstruct;
import java.util.List;
import java.util.Map;
import java.util.Set;

@Component
public class IpRangeStore {

    private final JedisPool jedisPool;

    private static IpRangeStore self;

    // 过期时间为2小时
    public static final int EXPIRE_SECONDS = 7200;

    @Autowired
    public IpRangeStore(JedisPool jedisPool) {
        this.jedisPool = jedisPool;
    }

    @PostConstruct
    public void init() {
        self = this;
    }

    /**
     * 批量添加ip段
     *
     * @param ipUserIdMap  ip段-用户id映射
     * @param deleteOrigin 是否删除原数据
     */
    public static void batchAddIpRanges(Map<String, String> ipUserIdMap, boolean deleteOrigin) {
        try (Jedis jedis = self.jedisPool.getResource()) {
            Pipeline pipeline = jedis.pipelined();
            if (deleteOrigin) {
                pipeline.del("ip_ranges");
            }
            for (Map.Entry<String, String> entry : ipUserIdMap.entrySet()) {
                String ipString = entry.getKey();
                if (ipString.contains("-")) {
                    String[] ipRange = ipString.split("-");
                    long start = ipToLong(ipRange[0]);
                    long end = ipToLong(ipRange[1]);
                    pipeline.zadd("ip_ranges", start, end + ":" + entry.getValue());
                } else {
                    long ip = ipToLong(ipString);
                    pipeline.zadd("ip_ranges", ip, ip + ":" + entry.getValue());
                }
            }
            pipeline.expire("ip_ranges", EXPIRE_SECONDS);
            pipeline.sync();
        }
    }

    /**
     * 批量删除指定IP
     *
     * @param ips 要删除的IP列表
     */
    public static void batchDeleteIpRanges(List<String> ips) {
        try (Jedis jedis = self.jedisPool.getResource()) {
            Pipeline pipeline = jedis.pipelined();
            for (String ip : ips) {
                pipeline.zrem("ip_ranges", ip);
            }
            pipeline.sync();
        }
    }

    public static Integer findUserIdByIp(String ipAddress) {
        try (Jedis jedis = self.jedisPool.getResource()) {
            String userId = findUserIdByIp(ipAddress, jedis);
            return userId == null ? null : Integer.parseInt(userId);
        }
    }

    private static long ipToLong(String ipAddress) {
        String[] ipParts = ipAddress.split("\\.");
        long ip = 0;
        for (String ipPart : ipParts) {
            ip <<= 8;
            ip |= Integer.parseInt(ipPart);
        }
        return ip;
    }

    private static String findUserIdByIp(String ipAddress, Jedis jedis) {
        long ip = ipToLong(ipAddress);
        // 查询ip_ranges中离ip最近且小于等于ip的ip段
        Set<String> results = jedis.zrevrangeByScore("ip_ranges", ip, 0, 0, 1);
        if (!results.isEmpty()) {
            String result = results.iterator().next();
            String[] parts = result.split(":");
            if (Long.parseLong(parts[0]) >= ip) {
                return parts[1];
            }
        }
        return null;
    }

}