Ground-Zerro / HydraRoute Public
Code Issues Pull requests Actions Releases View on GitHub ↗
15.3 KB c
#include "../include/ipset_nl.h"
#include "../include/log.h"
#include "../include/util.h"
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <unistd.h>
#include <errno.h>
#include <sys/socket.h>
#include <linux/netlink.h>
#include <linux/netfilter/nfnetlink.h>

static void ipset_add_to_cache(ipset_manager_t *mgr, const char *name);

static void put_u32_network(uint8_t *buf, uint32_t val) {
    buf[0] = (val >> 24) & 0xFF;
    buf[1] = (val >> 16) & 0xFF;
    buf[2] = (val >>  8) & 0xFF;
    buf[3] =  val        & 0xFF;
}

static int nla_put_u8(uint8_t *buf, uint16_t type, uint8_t val) {
    uint16_t nla_len = NLA_HDRLEN + 1;
    memcpy(buf, &nla_len, 2);
    memcpy(buf + 2, &type, 2);
    buf[NLA_HDRLEN] = val;
    return NLA_ALIGN(nla_len);
}

static int nla_put_string(uint8_t *buf, uint16_t type, const char *str) {
    int slen = strlen(str) + 1;
    uint16_t nla_len = NLA_HDRLEN + slen;
    memcpy(buf, &nla_len, 2);
    memcpy(buf + 2, &type, 2);
    memcpy(buf + NLA_HDRLEN, str, slen);
    return NLA_ALIGN(nla_len);
}

static int nla_put_raw(uint8_t *buf, uint16_t type, const uint8_t *data, int data_len) {
    uint16_t nla_len = NLA_HDRLEN + data_len;
    memcpy(buf, &nla_len, 2);
    memcpy(buf + 2, &type, 2);
    memcpy(buf + NLA_HDRLEN, data, data_len);
    return NLA_ALIGN(nla_len);
}

static int nl_send_recv_ack(ipset_manager_t *mgr, uint8_t *buf, int len) {
    if (send(mgr->fd, buf, len, 0) < 0) {
        LOG_ERROR("netlink send: %s", strerror(errno));
        return -1;
    }

    uint8_t resp[512];
    int n = recv(mgr->fd, resp, sizeof(resp), 0);
    if (n < 0) {
        LOG_ERROR("netlink recv: %s", strerror(errno));
        return -1;
    }

    struct nlmsghdr *nh = (struct nlmsghdr *)resp;
    if (nh->nlmsg_type == NLMSG_ERROR) {
        struct nlmsgerr *err = (struct nlmsgerr *)((uint8_t *)nh + NLMSG_HDRLEN);
        if (err->error != 0) {
            return -err->error;
        }
    }
    return 0;
}

int ipset_manager_init(ipset_manager_t *mgr) {
    memset(mgr, 0, sizeof(*mgr));
    mgr->fd = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, NETLINK_NETFILTER);
    if (mgr->fd < 0) {
        LOG_ERROR("netlink socket: %s", strerror(errno));
        return -1;
    }

    struct sockaddr_nl sa;
    memset(&sa, 0, sizeof(sa));
    sa.nl_family = AF_NETLINK;
    sa.nl_pid = 0;

    if (bind(mgr->fd, (struct sockaddr *)&sa, sizeof(sa)) < 0) {
        LOG_ERROR("netlink bind: %s", strerror(errno));
        close(mgr->fd);
        mgr->fd = -1;
        return -1;
    }

    mgr->seq = 1;
    mgr->pid = getpid();
    return 0;
}

void ipset_manager_close(ipset_manager_t *mgr) {
    if (mgr->fd >= 0) close(mgr->fd);
    mgr->fd = -1;
}

static int ipset_query_revision(ipset_manager_t *mgr, const char *type, int family) {
    uint8_t buf[256];
    memset(buf, 0, sizeof(buf));

    struct nlmsghdr *nlh = (struct nlmsghdr *)buf;
    nlh->nlmsg_type = (NFNL_SUBSYS_IPSET << 8) | IPSET_CMD_TYPE;
    nlh->nlmsg_flags = NLM_F_REQUEST;
    nlh->nlmsg_seq = mgr->seq++;
    nlh->nlmsg_pid = mgr->pid;

    uint8_t nf_family = (family == AF_INET6) ? 10 : 2;
    uint8_t *nfgen = buf + NLMSG_HDRLEN;
    nfgen[0] = nf_family;

    int offset = NLMSG_HDRLEN + 4;
    offset += nla_put_u8(buf + offset, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
    offset += nla_put_string(buf + offset, IPSET_ATTR_TYPENAME, type);
    offset += nla_put_u8(buf + offset, IPSET_ATTR_FAMILY, nf_family);
    nlh->nlmsg_len = offset;

    if (send(mgr->fd, buf, offset, 0) < 0)
        return 0;

    uint8_t resp[512];
    int n = recv(mgr->fd, resp, sizeof(resp), 0);
    if (n < (int)(NLMSG_HDRLEN + 4))
        return 0;

    struct nlmsghdr *rnh = (struct nlmsghdr *)resp;
    if (rnh->nlmsg_type == NLMSG_ERROR)
        return 0;

    uint8_t *attrs = resp + NLMSG_HDRLEN + 4;
    int attrs_len = n - NLMSG_HDRLEN - 4;
    int pos = 0;
    while (pos + NLA_HDRLEN <= attrs_len) {
        uint16_t nla_len, nla_type;
        memcpy(&nla_len, attrs + pos, 2);
        memcpy(&nla_type, attrs + pos + 2, 2);
        if ((nla_type & ~(NLA_F_NESTED | NLA_F_NET_BYTEORDER)) == IPSET_ATTR_REVISION
                && nla_len >= (uint16_t)(NLA_HDRLEN + 1))
            return attrs[pos + NLA_HDRLEN];
        int next = NLA_ALIGN(nla_len);
        if (next <= 0 || pos + next > attrs_len)
            break;
        pos += next;
    }
    return 0;
}

static int build_create_msg(uint8_t *buf, int buf_size, ipset_manager_t *mgr,
                            const char *name, const char *type, int family,
                            uint32_t timeout, uint32_t maxelem, uint8_t revision) {
    memset(buf, 0, buf_size);

    struct nlmsghdr *nlh = (struct nlmsghdr *)buf;
    nlh->nlmsg_type = (NFNL_SUBSYS_IPSET << 8) | IPSET_CMD_CREATE;
    nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL;
    nlh->nlmsg_seq = mgr->seq++;
    nlh->nlmsg_pid = mgr->pid;

    uint8_t nf_family = (family == AF_INET6) ? 10 : 2;
    uint8_t *nfgen = buf + NLMSG_HDRLEN;
    nfgen[0] = nf_family;

    int offset = NLMSG_HDRLEN + 4;
    offset += nla_put_u8(buf + offset, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
    offset += nla_put_string(buf + offset, IPSET_ATTR_SETNAME, name);
    offset += nla_put_string(buf + offset, IPSET_ATTR_TYPENAME, type);
    offset += nla_put_u8(buf + offset, IPSET_ATTR_REVISION, revision);
    offset += nla_put_u8(buf + offset, IPSET_ATTR_FAMILY, nf_family);

    if (timeout > 0 || maxelem > 0) {
        uint8_t data_buf[32];
        int data_len = 0;
        if (timeout > 0) {
            uint8_t timeout_bytes[4];
            put_u32_network(timeout_bytes, timeout);
            data_len += nla_put_raw(data_buf + data_len,
                                    IPSET_ATTR_TIMEOUT | NLA_F_NET_BYTEORDER,
                                    timeout_bytes, 4);
        }
        if (maxelem > 0) {
            uint8_t maxelem_bytes[4];
            put_u32_network(maxelem_bytes, maxelem);
            data_len += nla_put_raw(data_buf + data_len,
                                    IPSET_ATTR_MAXELEM | NLA_F_NET_BYTEORDER,
                                    maxelem_bytes, 4);
        }
        offset += nla_put_raw(buf + offset, IPSET_ATTR_DATA | NLA_F_NESTED,
                                 data_buf, data_len);
    }

    nlh->nlmsg_len = offset;
    return offset;
}


int ipset_create(ipset_manager_t *mgr, const char *name, const char *type, int family, uint32_t timeout, uint32_t maxelem) {
    uint8_t revision = (uint8_t)ipset_query_revision(mgr, type, family);

    LOG_DEBUG("Netlink CREATE: set=%s type=%s family=%d timeout=%u revision=%u",
              name, type, family, timeout, revision);

    uint8_t buf[512];
    int msg_len = build_create_msg(buf, sizeof(buf), mgr, name, type, family,
                                   timeout, maxelem, revision);
    int ret = nl_send_recv_ack(mgr, buf, msg_len);

    if (ret == 17) {
        LOG_DEBUG("Set %s already exists", name);
        ipset_add_to_cache(mgr, name);
        return 0;
    }
    if (ret != 0) {
        LOG_ERROR("Netlink CREATE error for %s: errno=%d", name, ret);
        return ret;
    }

    LOG_DEBUG("Set %s created", name);
    ipset_add_to_cache(mgr, name);
    return 0;
}

int ipset_flush(ipset_manager_t *mgr, const char *name) {
    uint8_t buf[256];
    memset(buf, 0, sizeof(buf));

    struct nlmsghdr *nlh = (struct nlmsghdr *)buf;
    nlh->nlmsg_type = (NFNL_SUBSYS_IPSET << 8) | IPSET_CMD_FLUSH;
    nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
    nlh->nlmsg_seq = mgr->seq++;
    nlh->nlmsg_pid = mgr->pid;

    uint8_t *nfgen = buf + NLMSG_HDRLEN;
    nfgen[0] = 2;

    int offset = NLMSG_HDRLEN + 4;
    offset += nla_put_u8(buf + offset, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
    offset += nla_put_string(buf + offset, IPSET_ATTR_SETNAME, name);

    nlh->nlmsg_len = offset;

    LOG_DEBUG("Netlink FLUSH: set=%s", name);

    int ret = nl_send_recv_ack(mgr, buf, offset);
    if (ret != 0) {
        LOG_DEBUG("Netlink FLUSH error: errno=%d", ret);
        return ret;
    }

    LOG_DEBUG("Netlink success: set %s flushed", name);
    return 0;
}

static int build_ipset_add_msg(uint8_t *buf, int buf_size, ipset_manager_t *mgr,
                               const char *set_name_nul, int set_name_len,
                               const parsed_cidr_t *entry,
                               int has_timeout, const uint8_t *timeout_bytes,
                               uint16_t extra_flags) {
    memset(buf, 0, buf_size > 256 ? 256 : buf_size);

    struct nlmsghdr *nlh = (struct nlmsghdr *)buf;
    nlh->nlmsg_type = (NFNL_SUBSYS_IPSET << 8) | IPSET_CMD_ADD;
    nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | extra_flags;
    nlh->nlmsg_seq = mgr->seq++;
    nlh->nlmsg_pid = mgr->pid;

    uint8_t *nfgen = buf + NLMSG_HDRLEN;
    nfgen[0] = 2;

    int offset = NLMSG_HDRLEN + 4;
    offset += nla_put_u8(buf + offset, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);

    {
        uint16_t nla_len_val = NLA_HDRLEN + set_name_len;
        uint16_t nla_type_val = IPSET_ATTR_SETNAME;
        memcpy(buf + offset, &nla_len_val, 2);
        memcpy(buf + offset + 2, &nla_type_val, 2);
        memcpy(buf + offset + NLA_HDRLEN, set_name_nul, set_name_len);
        offset += NLA_ALIGN(nla_len_val);
    }

    uint8_t data_buf[64];
    int data_len = 0;

    {
        uint8_t ip_buf[32];
        int ip_len = 0;
        if (entry->family == AF_INET) {
            uint16_t attr_type = IPSET_ATTR_IPADDR_IPV4 | NLA_F_NET_BYTEORDER;
            uint16_t attr_len = NLA_HDRLEN + 4;
            memcpy(ip_buf + ip_len, &attr_len, 2);
            memcpy(ip_buf + ip_len + 2, &attr_type, 2);
            memcpy(ip_buf + ip_len + NLA_HDRLEN, entry->ip, 4);
            ip_len += NLA_ALIGN(attr_len);
        } else {
            uint16_t attr_type = IPSET_ATTR_IPADDR_IPV6 | NLA_F_NET_BYTEORDER;
            uint16_t attr_len = NLA_HDRLEN + 16;
            memcpy(ip_buf + ip_len, &attr_len, 2);
            memcpy(ip_buf + ip_len + 2, &attr_type, 2);
            memcpy(ip_buf + ip_len + NLA_HDRLEN, entry->ip, 16);
            ip_len += NLA_ALIGN(attr_len);
        }

        data_len += nla_put_raw(data_buf + data_len, IPSET_ATTR_IP | NLA_F_NESTED,
                                   ip_buf, ip_len);
    }

    data_len += nla_put_u8(data_buf + data_len, IPSET_ATTR_CIDR, entry->prefix);

    if (has_timeout) {
        data_len += nla_put_raw(data_buf + data_len,
                                IPSET_ATTR_TIMEOUT | NLA_F_NET_BYTEORDER,
                                timeout_bytes, 4);
    }

    offset += nla_put_raw(buf + offset, IPSET_ATTR_DATA | NLA_F_NESTED,
                             data_buf, data_len);

    nlh->nlmsg_len = offset;
    return offset;
}

static int is_service_ip(const uint8_t *ip, int family) {
    if (family == AF_INET) {
        return (ip[0] == 0 || ip[0] == 127);
    }
    static const uint8_t zeros[16] = {0};
    static const uint8_t loopback[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1};
    return (memcmp(ip, zeros, 16) == 0 || memcmp(ip, loopback, 16) == 0);
}

int ipset_add_batch(ipset_manager_t *mgr, const char *set_name,
                    const parsed_cidr_t *entries, int count,
                    int with_timeout, int *new_count, int *new_indices) {
    if (count == 0) return 0;

    uint32_t idx = fnv1a_hash(set_name, strlen(set_name)) & 0xFF;
    int has_timeout = mgr->set_has_timeout[idx];
    uint32_t timeout_val = mgr->timeout_value[idx];

    uint8_t timeout_bytes[4] = {0};
    if (has_timeout && with_timeout) {
        put_u32_network(timeout_bytes, timeout_val);
    }

    uint16_t excl_flag = with_timeout ? NLM_F_EXCL : 0;

    char set_name_nul[64];
    int set_name_len = snprintf(set_name_nul, sizeof(set_name_nul), "%s", set_name) + 1;

    *new_count = 0;

    for (int start = 0; start < count; start += IPSET_CHUNK_SIZE) {
        int end = start + IPSET_CHUNK_SIZE;
        if (end > count) end = count;

        uint8_t msg_bufs[IPSET_CHUNK_SIZE][256];
        int msg_lens[IPSET_CHUNK_SIZE];
        int valid_indices[IPSET_CHUNK_SIZE];
        int msg_count = 0;

        for (int i = start; i < end; i++) {
            if (is_service_ip(entries[i].ip, entries[i].family)) {
                LOG_FILTERED("Service IP rejected (family=%d)", entries[i].family);
                continue;
            }

            msg_lens[msg_count] = build_ipset_add_msg(
                msg_bufs[msg_count], sizeof(msg_bufs[msg_count]),
                mgr, set_name_nul, set_name_len,
                &entries[i],
                has_timeout, timeout_bytes,
                excl_flag);
            valid_indices[msg_count] = i;
            msg_count++;
        }

        if (msg_count == 0) continue;

        for (int i = 0; i < msg_count; i++) {
            if (send(mgr->fd, msg_bufs[i], msg_lens[i], 0) < 0) {
                LOG_ERROR("netlink send batch: %s", strerror(errno));
                return -1;
            }
        }

        for (int i = 0; i < msg_count; i++) {
            uint8_t resp[256];
            int n = recv(mgr->fd, resp, sizeof(resp), 0);
            if (n < 0) continue;

            struct nlmsghdr *nh = (struct nlmsghdr *)resp;
            if (nh->nlmsg_type == NLMSG_ERROR) {
                struct nlmsgerr *err = (struct nlmsgerr *)((uint8_t *)nh + NLMSG_HDRLEN);
                if (err->error == 0) {
                    if (with_timeout && new_indices) {
                        new_indices[*new_count] = valid_indices[i];
                        (*new_count)++;
                    }
                } else {
                    int errcode = -err->error;
                    if (errcode == IPSET_ERR_HASH_FULL) {
                        LOG_WARN("ipset '%s' full (maxelem exceeded): set IpsetMaxElem in config", set_name);
                    } else if (errcode != IPSET_ERR_EXIST) {
                        LOG_DEBUG("Netlink ADD error: errno=%d", errcode);
                    }
                }
            }
        }
    }

    return 0;
}

int ipset_refresh_set_list(ipset_manager_t *mgr) {
    char output[32768];
    char *argv[] = {"ipset", "list", "-n", NULL};
    int ret = run_command_output("ipset", argv, output, sizeof(output));
    if (ret != 0) {
        LOG_WARN("ipset list -n failed (exit %d)", ret);
        return -1;
    }

    mgr->set_count = 0;
    char *saveptr;
    char *line = strtok_r(output, "\n", &saveptr);
    while (line && mgr->set_count < IPSET_MAX_SETS) {
        char *trimmed = trim_whitespace(line);
        if (trimmed[0] != '\0') {
            strncpy(mgr->set_names[mgr->set_count], trimmed, 63);
            mgr->set_names[mgr->set_count][63] = '\0';
            mgr->set_count++;
        }
        line = strtok_r(NULL, "\n", &saveptr);
    }

    LOG_DEBUG("RefreshSetList: loaded %d sets into cache", mgr->set_count);
    return 0;
}

int ipset_set_exists(ipset_manager_t *mgr, const char *name) {
    for (int i = 0; i < mgr->set_count; i++) {
        if (strcmp(mgr->set_names[i], name) == 0) return 1;
    }
    return 0;
}

static void ipset_add_to_cache(ipset_manager_t *mgr, const char *name) {
    if (mgr->set_count < IPSET_MAX_SETS) {
        strncpy(mgr->set_names[mgr->set_count], name, 63);
        mgr->set_names[mgr->set_count][63] = '\0';
        mgr->set_count++;
        LOG_DEBUG("AddToCache: added %s to cache", name);
    }
}

void ipset_cache_timeout_for_set(ipset_manager_t *mgr, const char *name,
                                 int has_timeout, uint32_t timeout_val) {
    uint32_t idx = fnv1a_hash(name, strlen(name)) & 0xFF;
    mgr->set_has_timeout[idx] = has_timeout;
    mgr->timeout_value[idx] = timeout_val;
}