diff --git a/wireguard-linux/drivers/net/wireguard/device.c b/wireguard-linux/drivers/net/wireguard/device.c
index deb9636b0ecf..af3a001fe772 100644
--- a/wireguard-linux/drivers/net/wireguard/device.c
+++ b/wireguard-linux/drivers/net/wireguard/device.c
@@ -19,6 +19,8 @@
#include <linux/if_arp.h>
#include <linux/icmp.h>
#include <linux/suspend.h>
+#include <linux/spinlock.h>
+#include <linux/wireguard.h>
#include <net/dst_metadata.h>
#include <net/gso.h>
#include <net/icmp.h>
@@ -27,7 +29,7 @@
#include <net/addrconf.h>
static LIST_HEAD(device_list);
-
+
static int wg_open(struct net_device *dev)
{
struct in_device *dev_v4 = __in_dev_get_rtnl(dev);
@@ -48,7 +50,11 @@ static int wg_open(struct net_device *dev)
dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
mutex_lock(&wg->device_update_lock);
- ret = wg_socket_init(wg, wg->incoming_port);
+ wg->transport = WG_TRANSPORT_TCP;
+ if (wg->transport == WG_TRANSPORT_TCP)
+ ret = wg_tcp_socket_init(wg, wg->incoming_port);
+ else
+ ret = wg_socket_init(wg, wg->incoming_port);
if (ret < 0)
goto out;
list_for_each_entry(peer, &wg->peer_list, peer_list) {
@@ -265,6 +271,7 @@ static void wg_destruct(struct net_device *dev)
free_percpu(dev->tstats);
kvfree(wg->index_hashtable);
kvfree(wg->peer_hashtable);
+ wg_destruct_tcp_connection_list(wg);
mutex_unlock(&wg->device_update_lock);
pr_debug("%s: Interface destroyed\n", dev->name);
@@ -286,7 +293,7 @@ static void wg_setup(struct net_device *dev)
dev->header_ops = &ip_tunnel_header_ops;
dev->hard_header_len = 0;
dev->addr_len = 0;
- dev->needed_headroom = DATA_PACKET_HEAD_ROOM;
+ dev->needed_headroom = DATA_PACKET_HEAD_ROOM + (wg->transport? WG_TCP_ENCAP_HDR_LEN : 0);
dev->needed_tailroom = noise_encrypted_len(MESSAGE_PADDING_MULTIPLE);
dev->type = ARPHRD_NONE;
dev->flags = IFF_POINTOPOINT | IFF_NOARP;
@@ -296,7 +303,8 @@ static void wg_setup(struct net_device *dev)
dev->hw_features |= WG_NETDEV_FEATURES;
dev->hw_enc_features |= WG_NETDEV_FEATURES;
dev->mtu = ETH_DATA_LEN - overhead;
- dev->max_mtu = round_down(INT_MAX, MESSAGE_PADDING_MULTIPLE) - overhead;
+ dev->max_mtu = round_down(INT_MAX, MESSAGE_PADDING_MULTIPLE) - overhead -
+ (wg->transport == WG_TRANSPORT_TCP ? WG_TCP_ENCAP_HDR_LEN : 0);
SET_NETDEV_DEVTYPE(dev, &device_type);
@@ -305,6 +313,10 @@ static void wg_setup(struct net_device *dev)
memset(wg, 0, sizeof(*wg));
wg->dev = dev;
+ INIT_LIST_HEAD(&wg->tcp_connection_list);
+ spin_lock_init(&wg->tcp_connection_list_lock);
+ wg->tcp_socket4_ready = false;
+ wg->tcp_socket6_ready = false;
}
static int wg_newlink(struct net *src_net, struct net_device *dev,
diff --git a/wireguard-linux/drivers/net/wireguard/device.h b/wireguard-linux/drivers/net/wireguard/device.h
index 43c7cebbf50b..648dd36941d6 100644
--- a/wireguard-linux/drivers/net/wireguard/device.h
+++ b/wireguard-linux/drivers/net/wireguard/device.h
@@ -14,9 +14,11 @@
#include <linux/types.h>
#include <linux/netdevice.h>
#include <linux/workqueue.h>
+#include <linux/ktime.h>
#include <linux/mutex.h>
#include <linux/net.h>
#include <linux/ptr_ring.h>
+#include <linux/spinlock.h>
struct wg_device;
@@ -40,7 +42,8 @@ struct prev_queue {
struct wg_device {
struct net_device *dev;
struct crypt_queue encrypt_queue, decrypt_queue, handshake_queue;
- struct sock __rcu *sock4, *sock6;
+ struct sock __rcu *sock4, *sock6; // UDP listening sockets
+ struct socket __rcu *tcp_listen_socket4, *tcp_listen_socket6; // TCP listening sockets
struct net __rcu *creating_net;
struct noise_static_identity static_identity;
struct workqueue_struct *packet_crypt_wq,*handshake_receive_wq, *handshake_send_wq;
@@ -49,11 +52,17 @@ struct wg_device {
struct index_hashtable *index_hashtable;
struct allowedips peer_allowedips;
struct mutex device_update_lock, socket_update_lock;
- struct list_head device_list, peer_list;
+ struct list_head device_list, peer_list, tcp_connection_list;
+ struct task_struct *tcp_listener4_thread, *tcp_listener6_thread;
+ struct delayed_work tcp_cleanup_work;
+ bool tcp_socket4_ready;
+ bool tcp_socket6_ready;
+ spinlock_t tcp_connection_list_lock;
atomic_t handshake_queue_len;
unsigned int num_peers, device_update_gen;
u32 fwmark;
u16 incoming_port;
+ u8 transport;
};
int wg_device_init(void);
diff --git a/wireguard-linux/drivers/net/wireguard/netlink.c b/wireguard-linux/drivers/net/wireguard/netlink.c
index e220d761b1f2..f8c0f6a9b1e2 100644
--- a/wireguard-linux/drivers/net/wireguard/netlink.c
+++ b/wireguard-linux/drivers/net/wireguard/netlink.c
@@ -15,6 +15,7 @@
#include <linux/if.h>
#include <net/genetlink.h>
#include <net/sock.h>
+#include <crypto/algapi.h>
#include <crypto/utils.h>
static struct genl_family genl_family;
@@ -27,7 +28,8 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
[WGDEVICE_A_FLAGS] = { .type = NLA_U32 },
[WGDEVICE_A_LISTEN_PORT] = { .type = NLA_U16 },
[WGDEVICE_A_FWMARK] = { .type = NLA_U32 },
- [WGDEVICE_A_PEERS] = { .type = NLA_NESTED }
+ [WGDEVICE_A_PEERS] = { .type = NLA_NESTED },
+ [WGDEVICE_A_TRANSPORT] = { .type = NLA_U8 }
};
static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
@@ -40,7 +42,8 @@ static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
[WGPEER_A_RX_BYTES] = { .type = NLA_U64 },
[WGPEER_A_TX_BYTES] = { .type = NLA_U64 },
[WGPEER_A_ALLOWEDIPS] = { .type = NLA_NESTED },
- [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32 }
+ [WGPEER_A_PROTOCOL_VERSION] = { .type = NLA_U32 },
+ [WGPEER_A_TRANSPORT] = { .type = NLA_U8 }
};
static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
@@ -200,7 +203,10 @@ static int wg_get_device_start(struct netlink_callback *cb)
{
struct wg_device *wg;
+ // Different kernels need one of these two functions. Need a macro to
+ // handle this
wg = lookup_interface(genl_info_dump(cb)->attrs, cb->skb);
+ //wg = lookup_interface(genl_dumpit_info(cb)->attrs, cb->skb);
if (IS_ERR(wg))
return PTR_ERR(wg);
DUMP_CTX(cb)->wg = wg;
@@ -233,7 +239,8 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
wg->incoming_port) ||
nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) ||
nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) ||
- nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
+ nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name) ||
+ nla_put_u8(skb, WGDEVICE_A_TRANSPORT, wg->transport))
goto out;
down_read(&wg->static_identity.lock);
@@ -323,7 +330,11 @@ static int set_port(struct wg_device *wg, u16 port)
wg->incoming_port = port;
return 0;
}
- return wg_socket_init(wg, port);
+ wg->transport = WG_TRANSPORT_TCP;
+ if (wg->transport == WG_TRANSPORT_TCP)
+ return wg_tcp_socket_init(wg, port);
+ else
+ return wg_socket_init(wg, port);
}
static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
@@ -479,6 +490,12 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs)
wg_packet_send_keepalive(peer);
}
+ if (attrs[WGPEER_A_TRANSPORT]) {
+ u8 transport = nla_get_u8(attrs[WGPEER_A_TRANSPORT]);
+ peer->transport = transport;
+ pr_debug("WireGuard: Setting peer %p transport mode to %u\n", peer, transport);
+ }
+
if (netif_running(wg->dev))
wg_packet_send_staged_packets(peer);
@@ -578,8 +595,14 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
}
up_write(&wg->static_identity.lock);
}
-skip_set_private_key:
+ if (info->attrs[WGDEVICE_A_TRANSPORT]) {
+ u8 transport = nla_get_u8(info->attrs[WGDEVICE_A_TRANSPORT]);
+ wg->transport = transport;
+ pr_debug("WireGuard: Setting device %p transport mode to %u\n", wg, transport);
+ }
+
+skip_set_private_key:
if (info->attrs[WGDEVICE_A_PEERS]) {
struct nlattr *attr, *peer[WGPEER_A_MAX + 1];
int rem;
diff --git a/wireguard-linux/drivers/net/wireguard/peer.c b/wireguard-linux/drivers/net/wireguard/peer.c
index 1cb502a932e0..3d3f7806fa58 100644
--- a/wireguard-linux/drivers/net/wireguard/peer.c
+++ b/wireguard-linux/drivers/net/wireguard/peer.c
@@ -9,11 +9,13 @@
#include "timers.h"
#include "peerlookup.h"
#include "noise.h"
+#include "socket.h"
#include <linux/kref.h>
#include <linux/lockdep.h>
#include <linux/rcupdate.h>
#include <linux/list.h>
+#include <linux/wireguard.h>
static struct kmem_cache *peer_cache;
static atomic64_t peer_counter = ATOMIC64_INIT(0);
@@ -37,6 +39,18 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
goto err;
peer->device = wg;
+ peer->transport = WG_TRANSPORT_TCP;
+ if (peer->transport == WG_TRANSPORT_TCP) {
+ spin_lock_init(&peer->tcp_lock);
+ skb_queue_head_init(&peer->tcp_packet_queue);
+ skb_queue_head_init(&peer->send_queue);
+ // Set up conneciton retry timer
+ timer_setup(&peer->tcp_connect_retry_timer, wg_tcp_connection_retry_timer, 0);
+
+ // Initially mark the connection as not established
+ peer->tcp_established = false;
+ wg_tcp_connect(peer);
+ }
wg_noise_handshake_init(&peer->handshake, &wg->static_identity,
public_key, preshared_key, peer);
peer->internal_id = atomic64_inc_return(&peer_counter);
@@ -59,6 +73,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
list_add_tail(&peer->peer_list, &wg->peer_list);
INIT_LIST_HEAD(&peer->allowedips_list);
wg_pubkey_hashtable_add(wg->peer_hashtable, peer);
+
++wg->num_peers;
pr_debug("%s: Peer %llu created\n", wg->dev->name, peer->internal_id);
return peer;
diff --git a/wireguard-linux/drivers/net/wireguard/peer.h b/wireguard-linux/drivers/net/wireguard/peer.h
index 76e4d3128ad4..5b7956964c25 100644
--- a/wireguard-linux/drivers/net/wireguard/peer.h
+++ b/wireguard-linux/drivers/net/wireguard/peer.h
@@ -64,8 +64,30 @@ struct wg_peer {
struct list_head allowedips_list;
struct napi_struct napi;
u64 internal_id;
+ // TCP-related members
+ u8 transport;
+ struct socket *tcp_receive_socket;
+ void (*original_state_change)(struct sock *sk);
+ void (*original_write_space)(struct sock *sk);
+ void (*original_data_ready)(struct sock *sk);
+ void (*original_error_report)(struct sock *sk);
+ void (*original_destruct)(struct sock *sk);
+ struct sk_buff *partial_skb;
+ size_t expected_len;
+ size_t received_len;
+ struct sk_buff_head tcp_packet_queue; // For queuing TCP packets
+ struct timer_list tcp_connect_retry_timer; // Timer for retrying TCP connection
+ struct delayed_work tcp_retry_work; // Work for retrying TCP connection
+ struct delayed_work tcp_connect_retry_work; // Work for retrying TCP connection
+ bool tcp_established; // Flag to track TCP connection status
+ bool tcp_pending; // Flag to track pending TCP connection status
+ spinlock_t tcp_lock; // Protects TCP-related state
+ struct sk_buff_head send_queue; // TX queue
+ spinlock_t send_queue_lock; // TX lock
+ struct list_head pending_connection_list; //peers pending connection handshake
};
+
struct wg_peer *wg_peer_create(struct wg_device *wg,
const u8 public_key[NOISE_PUBLIC_KEY_LEN],
const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]);
@@ -83,4 +105,8 @@ void wg_peer_remove_all(struct wg_device *wg);
int wg_peer_init(void);
void wg_peer_uninit(void);
+void wg_peer_tcp_connect(struct work_struct *work);
+void wg_peer_tcp_send(struct work_struct *work);
+void wg_peer_tcp_receive(struct work_struct *work);
+
#endif /* _WG_PEER_H */
diff --git a/wireguard-linux/drivers/net/wireguard/socket.c b/wireguard-linux/drivers/net/wireguard/socket.c
index 0414d7a6ce74..36280c2672a9 100644
--- a/wireguard-linux/drivers/net/wireguard/socket.c
+++ b/wireguard-linux/drivers/net/wireguard/socket.c
@@ -3,6 +3,8 @@
* Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
+/* TCP Support by Jeff Nathan and Dragos Ruiu 2024-03-16 */
+
#include "device.h"
#include "peer.h"
#include "socket.h"
@@ -10,16 +12,222 @@
#include "messages.h"
#include <linux/ctype.h>
-#include <linux/net.h>
#include <linux/if_vlan.h>
#include <linux/if_ether.h>
#include <linux/inetdevice.h>
+#include <linux/wireguard.h>
#include <net/udp_tunnel.h>
#include <net/ipv6.h>
+#include <net/sock.h>
+#include <linux/inet.h>
+#include <linux/kthread.h>
+#include <net/inet_sock.h>
+#include <net/ipv6.h>
+
+#include <linux/kernel.h>
+#include <linux/skbuff.h>
+#include <linux/net.h>
+#include <linux/tcp.h>
+#include <linux/time.h>
+#include <linux/in.h>
+#include <asm/byteorder.h> // For ntohl
+#include <linux/workqueue.h>
+#include <linux/spinlock.h>
+#include <linux/socket.h>
+#include <linux/in6.h>
+#include <net/inet_common.h>
+
+struct wg_tcp_socket_list_entry {
+ struct wg_peer *tcp_peer;
+ struct list_head tcp_connection;
+ ktime_t timestamp; // Timestamp when the connection was added
+};
+
+static int wg_tcp_queuepkt(struct wg_peer *, const void *, size_t);
+static void wg_setup_tcp_socket_callbacks(struct wg_peer *);
+static void wg_reset_tcp_socket_callbacks(struct wg_peer *);
+
+// ******** DIAGNOSTIC CODE ********
+
+#include <linux/module.h>
+#include <linux/list.h>
+#include <linux/timer.h>
+#include <linux/workqueue.h>
+#include <linux/rcupdate.h>
+#include <linux/kref.h>
+
+// Function to print details of sk_buff_head for diagnostic purposes
+void print_skbuff_head_info(const char *label, struct sk_buff_head *queue) {
+ const struct sk_buff *skb;
+ unsigned long flags;
+
+ printk(KERN_INFO "%s:\n", label);
+ if (!queue) {
+ printk(KERN_INFO "Queue is NULL\n");
+ return;
+ }
+
+ spin_lock_irqsave(&queue->lock, flags);
+ skb_queue_walk(queue, skb) {
+ printk(KERN_INFO "Packet: len=%u, data_len=%u, users=%d\n",
+ skb->len, skb->data_len, refcount_read(&skb->users));
+ }
+ spin_unlock_irqrestore(&queue->lock, flags);
+}
+
+void print_wg_peer(struct wg_peer *peer) {
+ if (!peer) {
+ printk(KERN_ERR "NULL wg_peer provided\n");
+ return;
+ }
+
+ printk(KERN_INFO "WG Peer Complete Diagnostic Info:\n");
+ printk(KERN_INFO "Device Pointer: %p, Serial Work CPU: %d, "
+ "Is Dead: %d, Transport Mode: %u\n",
+ peer->device, peer->serial_work_cpu, peer->is_dead,
+ peer->transport);
+ printk(KERN_INFO "RX Bytes: %llu, TX Bytes: %llu, Internal ID: %llu\n",
+ peer->rx_bytes, peer->tx_bytes, peer->internal_id);
+ printk(KERN_INFO "Last Sent Handshake: %llu\n",
+ atomic64_read(&peer->last_sent_handshake));
+
+ // Endpoint info
+ printk(KERN_INFO "Endpoint Address Family: %u\n",
+ peer->endpoint.addr.sa_family);
+ if (peer->endpoint.addr.sa_family == AF_INET) {
+ printk(KERN_INFO "IPv4 Address: %pI4, IPv4 Source: %pI4, "
+ "Interface: %d\n",
+ &peer->endpoint.addr4.sin_addr, &peer->endpoint.src4,
+ peer->endpoint.src_if4);
+ } else if (peer->endpoint.addr.sa_family == AF_INET6) {
+ printk(KERN_INFO "IPv6 Address: %pI6c, IPv6 Source: %pI6c\n",
+ &peer->endpoint.addr6.sin6_addr, &peer->endpoint.src6);
+ }
+
+ // Correctly accessing sk_buff_head queues
+ if (!skb_queue_empty(&peer->staged_packet_queue)) {
+ print_skbuff_head_info("Staged Packet Queue",
+ &peer->staged_packet_queue);
+ } else {
+ printk(KERN_INFO "Staged Packet Queue: NULL\n");
+ }
+
+ // Additional diagnostics and corrections for TCP
+ if (peer->tcp_receive_socket) {
+ printk(KERN_INFO "TCP Socket: %p, Established: %d\n",
+ peer->tcp_receive_socket, peer->tcp_established);
+ if (!skb_queue_empty(&peer->tcp_packet_queue)) {
+ print_skbuff_head_info("TCP Packet Queue",
+ &peer->tcp_packet_queue);
+ } else {
+ printk(KERN_INFO "TCP Packet Queue: NULL\n");
+ }
+ } else {
+ printk(KERN_INFO "TCP Socket: NULL\n");
+ }
+
+ // Timer diagnostics
+ printk(KERN_INFO "Timer for Retransmit Handshake Expires: %ld\n",
+ peer->timer_retransmit_handshake.expires);
+ printk(KERN_INFO "Timer for Sending Keepalive Expires: %ld\n",
+ peer->timer_send_keepalive.expires);
+ printk(KERN_INFO "Timer for New Handshake Expires: %ld\n",
+ peer->timer_new_handshake.expires);
+ printk(KERN_INFO "Timer for Zero Key Material Expires: %ld\n",
+ peer->timer_zero_key_material.expires);
+ printk(KERN_INFO "Timer for Persistent Keepalive Expires: %ld\n",
+ peer->timer_persistent_keepalive.expires);
+ printk(KERN_INFO "Timer for TCP Connect Retry Expires: %ld\n",
+ peer->tcp_connect_retry_timer.expires);
+
+ // RCU and reference count
+ printk(KERN_INFO "RCU Head Address: %p, Reference Count: %d\n",
+ &peer->rcu, kref_read(&peer->refcount));
+}
+
+// Function to print information about crypt_queue
+void print_crypt_queue(const char *label, struct crypt_queue *queue) {
+ if (!queue) {
+ printk(KERN_INFO "%s: NULL\n", label);
+ return;
+ }
+
+ printk(KERN_INFO "%s:\n", label);
+ printk(KERN_INFO " Last CPU used: %d\n", queue->last_cpu);
+ // Assuming you have a way to inspect ptr_ring structure:
+ // printk(KERN_INFO " Ring capacity: %d\n", queue->ring.size);
+ if (queue->worker)
+ printk(KERN_INFO " Worker pointer: %p\n", queue->worker);
+ else
+ printk(KERN_INFO " Worker: NULL\n");
+}
+
+// Diagnostic function for wg_device
+void print_wg_device(struct wg_device *device) {
+ if (!device) {
+ printk(KERN_ERR "NULL wg_device provided\n");
+ return;
+ }
+
+ printk(KERN_INFO "WG Device Diagnostic Info:\n");
+
+ if (device->dev)
+ printk(KERN_INFO "Net device: %s\n", device->dev->name);
+ else
+ printk(KERN_INFO "Net device: NULL\n");
+
+ print_crypt_queue("Encrypt Queue", &(device->encrypt_queue));
+ print_crypt_queue("Decrypt Queue", &(device->decrypt_queue));
+ print_crypt_queue("Handshake Queue", &(device->handshake_queue));
+
+ if (rcu_access_pointer(device->tcp_listen_socket4))
+ printk(KERN_INFO "IPv4 Socket: %p\n", device->tcp_listen_socket4);
+ else
+ printk(KERN_INFO "IPv4 Socket: NULL\n");
+
+ if (rcu_access_pointer(device->tcp_listen_socket6))
+ printk(KERN_INFO "IPv6 Socket: %p\n", device->tcp_listen_socket6);
+ else
+ printk(KERN_INFO "IPv6 Socket: NULL\n");
+
+ if (rcu_access_pointer(device->tcp_listen_socket4))
+ printk(KERN_INFO "TCP Listener IPv4 Socket: %p\n",
+ device->tcp_listen_socket4);
+ else
+ printk(KERN_INFO "TCP Listener IPv4 Socket: NULL\n");
+
+ if (rcu_access_pointer(device->tcp_listen_socket6))
+ printk(KERN_INFO "TCP Listener IPv6 Socket: %p\n",
+ device->tcp_listen_socket6);
+ else
+ printk(KERN_INFO "TCP Listener IPv6 Socket: NULL\n");
+
+ if (device->creating_net)
+ printk(KERN_INFO "Creating net namespace: %p\n",
+ device->creating_net);
+ else
+ printk(KERN_INFO "Creating net namespace: NULL\n");
+
+ // Assuming noise_static_identity and other structures have similar diagnostic print functions
+ printk(KERN_INFO "Static Identity: (printing details not implemented)\n");
+ printk(KERN_INFO "Workqueues and other components would similarly have their details printed based on available data.\n");
+
+ printk(KERN_INFO "FW Mark: %u, Incoming Port: %u, Transport: %u\n", device->fwmark, device->incoming_port, device->transport);
+ printk(KERN_INFO "Handshake queue length: %d\n", atomic_read(&device->handshake_queue_len));
+ printk(KERN_INFO "Number of Peers: %u, Device Update Generation: %u\n", device->num_peers, device->device_update_gen);
+}
+// ******** END OF DIAGNOSTIC CODE ********
+
+
static int send4(struct wg_device *wg, struct sk_buff *skb,
- struct endpoint *endpoint, u8 ds, struct dst_cache *cache)
-{
+ struct endpoint *endpoint, u8 ds, struct dst_cache *cache) {
+ printk(KERN_INFO "Entering function send4\n");
+ if (wg->transport == WG_TRANSPORT_TCP) {
+ return wg_tcp_queuepkt(wg->tcp_listen_socket4->sk->sk_user_data,
+ (const void *)skb, (size_t)skb->len);
+ }
+
struct flowi4 fl = {
.saddr = endpoint->src4.s_addr,
.daddr = endpoint->addr4.sin_addr.s_addr,
@@ -92,12 +300,18 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
out:
rcu_read_unlock_bh();
return ret;
+ printk(KERN_INFO "Exiting function send4\n");
}
static int send6(struct wg_device *wg, struct sk_buff *skb,
- struct endpoint *endpoint, u8 ds, struct dst_cache *cache)
-{
+ struct endpoint *endpoint, u8 ds, struct dst_cache *cache) {
+ printk(KERN_INFO "Entering function send6\n");
#if IS_ENABLED(CONFIG_IPV6)
+ if (wg->transport == WG_TRANSPORT_TCP) {
+ return wg_tcp_queuepkt(wg->tcp_listen_socket6->sk->sk_user_data, (const void *)skb,
+ (size_t)skb->len);
+ }
+
struct flowi6 fl = {
.saddr = endpoint->src6,
.daddr = endpoint->addr6.sin6_addr,
@@ -158,15 +372,18 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
kfree_skb(skb);
out:
rcu_read_unlock_bh();
+ printk(KERN_INFO "Exiting function send6\n");
return ret;
#else
kfree_skb(skb);
+ printk(KERN_INFO "Exiting function send6\n");
return -EAFNOSUPPORT;
#endif
+ printk(KERN_INFO "Exiting function send6\n");
}
-int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds)
-{
+int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds) {
+ printk(KERN_INFO "Entering function wg_socket_send_skb_to_peer\n");
size_t skb_len = skb->len;
int ret = -EAFNOSUPPORT;
@@ -184,11 +401,12 @@ int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds)
read_unlock_bh(&peer->endpoint_lock);
return ret;
+ printk(KERN_INFO "Exiting function wg_socket_send_skb_to_peer\n");
}
int wg_socket_send_buffer_to_peer(struct wg_peer *peer, void *buffer,
- size_t len, u8 ds)
-{
+ size_t len, u8 ds) {
+ printk(KERN_INFO "Entering function wg_socket_send_buffer_to_peer\n");
struct sk_buff *skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC);
if (unlikely(!skb))
@@ -198,12 +416,13 @@ int wg_socket_send_buffer_to_peer(struct wg_peer *peer, void *buffer,
skb_set_inner_network_header(skb, 0);
skb_put_data(skb, buffer, len);
return wg_socket_send_skb_to_peer(peer, skb, ds);
+ printk(KERN_INFO "Exiting function wg_socket_send_buffer_to_peer\n");
}
int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg,
struct sk_buff *in_skb, void *buffer,
- size_t len)
-{
+ size_t len) {
+ printk(KERN_INFO "Entering function wg_socket_send_buffer_as_reply_to_skb\n");
int ret = 0;
struct sk_buff *skb;
struct endpoint endpoint;
@@ -229,12 +448,46 @@ int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg,
* as we checked above.
*/
+ printk(KERN_INFO "Exiting function wg_socket_send_buffer_as_reply_to_skb\n");
return ret;
}
-int wg_socket_endpoint_from_skb(struct endpoint *endpoint,
- const struct sk_buff *skb)
-{
+void hexdump_skb(const struct sk_buff *skb, int bytes) {
+ int i, j;
+ int lines = min(bytes / 32, 4); // Calculate number of lines to print, max 4 lines
+ const unsigned char *data = skb->data;
+
+ for (i = 0; i < lines; ++i) {
+ char line[80];
+ char *pos = line;
+
+ pos += sprintf(pos, "%06x: ", i * 32); // Print offset
+
+ for (j = 0; j < 32; ++j) {
+ if (i * 32 + j < skb->len) {
+ pos += sprintf(pos, "%02x ", data[i * 32 + j]); // Print hex values
+ } else {
+ pos += sprintf(pos, " "); // Print spaces for padding
+ }
+ }
+
+ pos += sprintf(pos, " "); // Space between hex and ASCII
+
+ for (j = 0; j < 32; ++j) {
+ if (i * 32 + j < skb->len) {
+ unsigned char c = data[i * 32 + j];
+ pos += sprintf(pos, "%c", (c >= 32 && c <= 126) ? c : '.'); // Print ASCII characters
+ } else {
+ pos += sprintf(pos, " "); // Print spaces for padding
+ }
+ }
+
+ printk(KERN_INFO "%s\n", line); // Print the line to the kernel log
+ }
+}
+
+int wg_socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb) {
+ printk(KERN_INFO "Entering function wg_socket_endpoint_from_skb\n");
memset(endpoint, 0, sizeof(*endpoint));
if (skb->protocol == htons(ETH_P_IP)) {
endpoint->addr4.sin_family = AF_INET;
@@ -242,21 +495,29 @@ int wg_socket_endpoint_from_skb(struct endpoint *endpoint,
endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr;
endpoint->src4.s_addr = ip_hdr(skb)->daddr;
endpoint->src_if4 = skb->skb_iif;
+ printk(KERN_INFO "wg_socket_endpoint_from_skb: Extracted IPv4 address %pI4:%d\n",
+ &endpoint->addr4.sin_addr, ntohs(endpoint->addr4.sin_port));
} else if (IS_ENABLED(CONFIG_IPV6) && skb->protocol == htons(ETH_P_IPV6)) {
endpoint->addr6.sin6_family = AF_INET6;
endpoint->addr6.sin6_port = udp_hdr(skb)->source;
endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr;
- endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id(
- &ipv6_hdr(skb)->saddr, skb->skb_iif);
+ endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr, skb->skb_iif);
endpoint->src6 = ipv6_hdr(skb)->daddr;
+ printk(KERN_INFO "wg_socket_endpoint_from_skb: Extracted IPv6 address %pI6c:%d\n",
+ &endpoint->addr6.sin6_addr, ntohs(endpoint->addr6.sin6_port));
} else {
return -EINVAL;
}
+
+ // Hexdump the first 128 bytes or the entire packet, whichever is smaller
+ hexdump_skb(skb, min(skb->len, (unsigned int)128)); // Hexdump the first 128 bytes or the entire packet, whichever is smaller
+ printk(KERN_INFO "Exiting function wg_socket_endpoint_from_skb\n");
return 0;
}
-static bool endpoint_eq(const struct endpoint *a, const struct endpoint *b)
-{
+static bool endpoint_eq(const struct endpoint *a, const struct endpoint *b) {
+ printk(KERN_INFO "Entering function endpoint_eq\n");
+ printk(KERN_INFO "Exiting function endpoint_eq\n");
return (a->addr.sa_family == AF_INET && b->addr.sa_family == AF_INET &&
a->addr4.sin_port == b->addr4.sin_port &&
a->addr4.sin_addr.s_addr == b->addr4.sin_addr.s_addr &&
@@ -270,9 +531,8 @@ static bool endpoint_eq(const struct endpoint *a, const struct endpoint *b)
unlikely(!a->addr.sa_family && !b->addr.sa_family);
}
-void wg_socket_set_peer_endpoint(struct wg_peer *peer,
- const struct endpoint *endpoint)
-{
+void wg_socket_set_peer_endpoint(struct wg_peer *peer, const struct endpoint *endpoint) {
+ printk(KERN_INFO "Entering function wg_socket_set_peer_endpoint\n");
/* First we check unlocked, in order to optimize, since it's pretty rare
* that an endpoint will change. If we happen to be mid-write, and two
* CPUs wind up writing the same thing or something slightly different,
@@ -294,27 +554,31 @@ void wg_socket_set_peer_endpoint(struct wg_peer *peer,
dst_cache_reset(&peer->endpoint_cache);
out:
write_unlock_bh(&peer->endpoint_lock);
+ printk(KERN_INFO "Exiting function wg_socket_set_peer_endpoint\n");
}
void wg_socket_set_peer_endpoint_from_skb(struct wg_peer *peer,
- const struct sk_buff *skb)
-{
+ const struct sk_buff *skb) {
+ printk(KERN_INFO "Entering function wg_socket_set_peer_endpoint_from_skb\n");
struct endpoint endpoint;
+
if (!wg_socket_endpoint_from_skb(&endpoint, skb))
wg_socket_set_peer_endpoint(peer, &endpoint);
+ printk(KERN_INFO "Exiting function wg_socket_set_peer_endpoint_from_skb\n");
}
-void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer)
-{
+void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer) {
+ printk(KERN_INFO "Entering function wg_socket_clear_peer_endpoint_src\n");
write_lock_bh(&peer->endpoint_lock);
memset(&peer->endpoint.src6, 0, sizeof(peer->endpoint.src6));
dst_cache_reset_now(&peer->endpoint_cache);
write_unlock_bh(&peer->endpoint_lock);
+ printk(KERN_INFO "Exiting function wg_socket_clear_peer_endpoint_src\n");
}
-static int wg_receive(struct sock *sk, struct sk_buff *skb)
-{
+static int wg_receive(struct sock *sk, struct sk_buff *skb) {
+ printk(KERN_INFO "Entering function wg_receive\n");
struct wg_device *wg;
if (unlikely(!sk))
@@ -324,30 +588,34 @@ static int wg_receive(struct sock *sk, struct sk_buff *skb)
goto err;
skb_mark_not_on_list(skb);
wg_packet_receive(wg, skb);
+ printk(KERN_INFO "Exiting function wg_receive\n");
return 0;
err:
kfree_skb(skb);
+ printk(KERN_INFO "Exiting function wg_receive\n");
return 0;
}
-static void sock_free(struct sock *sock)
-{
+static void sock_free(struct sock *sock) {
+ printk(KERN_INFO "Entering function sock_free\n");
if (unlikely(!sock))
return;
sk_clear_memalloc(sock);
udp_tunnel_sock_release(sock->sk_socket);
+ printk(KERN_INFO "Exiting function sock_free\n");
}
-static void set_sock_opts(struct socket *sock)
-{
+static void set_sock_opts(struct socket *sock) {
+ printk(KERN_INFO "Entering function set_sock_opts\n");
sock->sk->sk_allocation = GFP_ATOMIC;
sock->sk->sk_sndbuf = INT_MAX;
sk_set_memalloc(sock->sk);
+ printk(KERN_INFO "Exiting function set_sock_opts\n");
}
-int wg_socket_init(struct wg_device *wg, u16 port)
-{
+int wg_socket_init(struct wg_device *wg, u16 port) {
+ printk(KERN_INFO "Entering function wg_socket_init\n");
struct net *net;
int ret;
struct udp_tunnel_sock_cfg cfg = {
@@ -413,12 +681,13 @@ int wg_socket_init(struct wg_device *wg, u16 port)
ret = 0;
out:
put_net(net);
+ printk(KERN_INFO "Exiting function wg_socket_init\n");
return ret;
}
void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
- struct sock *new6)
-{
+ struct sock *new6) {
+ printk(KERN_INFO "Entering function wg_socket_reinit\n");
struct sock *old4, *old6;
mutex_lock(&wg->socket_update_lock);
@@ -434,4 +703,1147 @@ void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
synchronize_net();
sock_free(old4);
sock_free(old6);
+ printk(KERN_INFO "Exiting function wg_socket_reinit\n");
+}
+
+static int wg_set_socket_timeouts(struct socket *sock, unsigned long snd_timeout, unsigned long rcv_timeout) {
+ printk(KERN_INFO "Entering function wg_set_socket_timeouts\n");
+ if (!sock || !sock->sk) {
+ pr_err("Invalid socket or sock is NULL\n");
+ return -EINVAL;
+ }
+
+ struct sock *sk = sock->sk;
+
+ sk->sk_sndtimeo = snd_timeout;
+ sk->sk_rcvtimeo = rcv_timeout;
+
+ printk(KERN_INFO "Exiting function wg_set_socket_timeouts\n");
+ return 0;
+}
+
+/* Attempt to establish a TCP connection */
+int wg_tcp_connect(struct wg_peer *peer) {
+ printk(KERN_INFO "Entering function wg_tcp_connect\n");
+ struct sockaddr_storage addr_storage;
+ struct sockaddr *addr = (struct sockaddr *)&addr_storage;
+ struct socket **socket_ptr = NULL;
+ unsigned long timeout = 5 * HZ; // 5 seconds in jiffies
+ int ret;
+
+ if (peer->transport != WG_TRANSPORT_TCP || peer->tcp_established) {
+ pr_err("Invalid state for TCP connection attempt.\n");
+ printk(KERN_INFO "Exiting function wg_tcp_connect\n");
+ return -EINVAL;
+ }
+
+ memset(&addr_storage, 0, sizeof(addr_storage));
+
+ if (peer->endpoint.addr.sa_family == AF_INET) {
+ struct sockaddr_in *addr4 = (struct sockaddr_in *)&addr_storage;
+ addr4->sin_family = AF_INET;
+ addr4->sin_port = htons(peer->device->incoming_port);
+ addr4->sin_addr.s_addr = peer->endpoint.addr4.sin_addr.s_addr;
+ addr = (struct sockaddr *)addr4;
+ socket_ptr = &peer->device->tcp_listen_socket4;
+ }
+#ifdef CONFIG_IPV6
+ else if (peer->endpoint.addr.sa_family == AF_INET6) {
+ struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)&addr_storage;
+ addr6->sin6_family = AF_INET6;
+ addr6->sin6_port = htons(peer->device->incoming_port);
+ memcpy(&addr6->sin6_addr, &peer->endpoint.addr6.sin6_addr,
+ sizeof(peer->endpoint.addr6.sin6_addr));
+ addr = (struct sockaddr *)addr6;
+ socket_ptr = &peer->device->tcp_listen_socket6;
+ }
+#endif
+ else {
+ pr_err("Unsupported address family\n");
+ printk(KERN_INFO "Exiting function wg_tcp_connect\n");
+ return -EAFNOSUPPORT;
+ }
+
+ // Ensure the socket pointer is not already set
+ if (*socket_ptr != NULL) {
+ pr_err("Socket already exists.\n");
+ printk(KERN_INFO "Exiting function wg_tcp_connect\n");
+ return -EEXIST;
+ }
+
+ // Create the socket
+ ret = sock_create_kern(&init_net, peer->endpoint.addr.sa_family, SOCK_STREAM,
+ IPPROTO_TCP, socket_ptr);
+ if (ret) {
+ pr_err("Failed to create TCP socket\n");
+ printk(KERN_INFO "Exiting function wg_tcp_connect\n");
+ return ret;
+ }
+
+ // Set socket timeouts for send and receive operations
+ ret = wg_set_socket_timeouts(*socket_ptr, timeout, timeout);
+ if (ret) {
+ sock_release(*socket_ptr);
+ *socket_ptr = NULL;
+ printk(KERN_INFO "Exiting function wg_tcp_connect\n");
+ return ret;
+ }
+
+ // Initiate the non-blocking connect
+ ret = kernel_connect(*socket_ptr, addr, sizeof(addr_storage), O_NONBLOCK);
+ if (ret != -EINPROGRESS && ret != 0) {
+ pr_err("TCP connection attempt failed: %d\n", ret);
+ sock_release(*socket_ptr);
+ *socket_ptr = NULL;
+ printk(KERN_INFO "Exiting function wg_tcp_connect\n");
+ return ret;
+ }
+
+
+ pr_info("TCP connection attempt initiated\n");
+ // Setup socket callbacks here after a successful connection initiation
+ wg_setup_tcp_socket_callbacks(peer);
+ (*socket_ptr)->sk->sk_user_data = peer;
+ spin_lock_bh(&peer->tcp_lock);
+ peer->tcp_pending = true;
+ spin_unlock_bh(&peer->tcp_lock);
+
+ printk(KERN_INFO "Exiting function wg_tcp_connect\n");
+ return 0; // Indicate success
+}
+
+void wg_tcp_state_change(struct sock *sk) {
+ printk(KERN_INFO "Entering function wg_tcp_state_change\n");
+ struct wg_peer *peer = sk->sk_user_data;
+ if (!peer) return;
+
+ spin_lock_bh(&peer->tcp_lock);
+ switch (sk->sk_state) {
+ case TCP_ESTABLISHED:
+ if (!peer->tcp_established) {
+ peer->tcp_pending = false;
+ peer->tcp_established = true;
+ spin_unlock_bh(&peer->tcp_lock);
+ pr_info("TCP connection established.\n");
+ // If the socket is writable, send immediately
+ if (sk_stream_is_writeable(sk)) {
+ wg_tcp_write_space(sk);
+ }
+ } else {
+ spin_unlock_bh(&peer->tcp_lock);
+ }
+ break;
+ case TCP_CLOSE:
+ case TCP_CLOSE_WAIT:
+ case TCP_CLOSING:
+ peer->tcp_pending = false;
+ // Handle closure/failure
+ if (peer->tcp_established) {
+ // Connection failed or closed unexpectedly
+ pr_info("TCP connection failed or closed, scheduling retry.\n");
+ // Reset the tcp_established flag because it tracks connection state
+ peer->tcp_established = false;
+ spin_unlock_bh(&peer->tcp_lock);
+ // Schedule a retry. Note: This uses a delayed work for simplicity
+ schedule_delayed_work(&peer->tcp_retry_work, msecs_to_jiffies(5000));
+ break;
+ }
+ case TCP_FIN_WAIT1:
+ case TCP_FIN_WAIT2:
+ case TCP_LAST_ACK:
+ pr_info("TCP connection is in the process of closing.\n");
+ peer->tcp_pending = false;
+ if (peer->tcp_established) {
+ peer->tcp_established = false;
+ }
+ spin_unlock_bh(&peer->tcp_lock);
+ break;
+ default:
+ spin_unlock_bh(&peer->tcp_lock);
+ break;
+ }
+ if (peer->original_state_change) {
+ peer->original_state_change(sk);
+ }
+ printk(KERN_INFO "Exiting function wg_tcp_state_change\n");
+}
+
+void wg_extract_endpoint_from_sock(struct sock *sk,
+ struct endpoint *endpoint) {
+ printk(KERN_INFO "Entering function wg_extract_endpoint_from_sock\n");
+ if (!sk || !endpoint) {
+ pr_warn("Socket or endpoint is NULL.\n");
+ return;
+ }
+ memset(endpoint, 0, sizeof(*endpoint)); // Clear the endpoint structure
+
+ if (sk->sk_family == AF_INET) {
+ // IPv4
+ struct inet_sock *inet = inet_sk(sk);
+
+ endpoint->addr4.sin_family = AF_INET;
+ endpoint->addr4.sin_port = inet->inet_dport; // Destination port
+ endpoint->addr4.sin_addr.s_addr = inet->inet_daddr; // Destination IP address
+ } else if (sk->sk_family == AF_INET6) {
+#if IS_ENABLED(CONFIG_IPV6)
+ // IPv6
+ endpoint->addr6.sin6_family = AF_INET6;
+ endpoint->addr6.sin6_port = sk->sk_dport; // Destination port
+ endpoint->addr6.sin6_addr = sk->sk_v6_daddr; // Destination IP address
+
+ if (ipv6_addr_type((struct in6_addr *)&sk->sk_v6_daddr) & IPV6_ADDR_LINKLOCAL) {
+ // The destination address is link-local; use the socket's bound device for the scope ID
+ endpoint->addr6.sin6_scope_id = sk->sk_bound_dev_if;
+ } else {
+ // Not a link-local address; no scope ID required
+ endpoint->addr6.sin6_scope_id = 0;
+ }
+ } else {
+#endif
+ pr_warn("Unsupported socket family: %d.\n", sk->sk_family);
+ }
+ printk(KERN_INFO "Exiting function wg_extract_endpoint_from_sock\n");
+}
+
+static bool wg_endpoints_match(const struct endpoint *a,
+ const struct endpoint *b) {
+ printk(KERN_INFO "Entering function wg_endpoints_match\n");
+ // Compare endpoints
+ if (a->addr.sa_family != b->addr.sa_family) {
+ printk(KERN_INFO "Exiting function wg_endpoints_match\n");
+ return false;
+ }
+
+ if (a->addr.sa_family == AF_INET) {
+ return a->addr4.sin_port == b->addr4.sin_port &&
+ a->addr4.sin_addr.s_addr == b->addr4.sin_addr.s_addr;
+ } else if (a->addr.sa_family == AF_INET6) {
+ // For IPv6, also compare the scope ID if the address is link-local
+ bool is_link_local_a = ipv6_addr_type(&a->addr6.sin6_addr) & IPV6_ADDR_LINKLOCAL;
+ bool is_link_local_b = ipv6_addr_type(&b->addr6.sin6_addr) & IPV6_ADDR_LINKLOCAL;
+
+ return a->addr6.sin6_port == b->addr6.sin6_port &&
+ ipv6_addr_equal(&a->addr6.sin6_addr, &b->addr6.sin6_addr) &&
+ (!is_link_local_a || !is_link_local_b || a->addr6.sin6_scope_id == b->addr6.sin6_scope_id);
+ }
+ printk(KERN_INFO "Exiting function wg_endpoints_match\n");
+ return false;
+}
+
+static int wg_tcp_queuepkt(struct wg_peer *peer, const void *data,
+ size_t len) {
+ printk(KERN_INFO "Entering function wg_tcp_queuepkt\n");
+ struct endpoint current_endpoint;
+ struct wg_tcp_socket_list_entry *socket_iter;
+ struct sock *sk;
+ bool found = false;
+
+ if (!peer || !data || len == 0) {
+ printk(KERN_INFO "Exiting function wg_tcp_queuepkt\n");
+ return -EINVAL;
+ }
+
+ struct sk_buff *skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC);
+ if (!skb) {
+ printk(KERN_INFO "Exiting function wg_tcp_queuepkt\n");
+ return -ENOMEM;
+ }
+
+ skb_reserve(skb, SKB_HEADER_LEN);
+ skb_put_data(skb, data, len);
+
+ // Extract the destination address from skb
+ if (!wg_socket_endpoint_from_skb(¤t_endpoint, skb)) {
+ kfree_skb(skb);
+ printk(KERN_INFO "Exiting function wg_tcp_queuepkt\n");
+ return -EINVAL; // Failed to extract endpoint
+ }
+
+ read_lock_bh(&peer->endpoint_lock);
+
+ // Check if the current destination matches the peer's destination address
+ if (!wg_endpoints_match(¤t_endpoint, &peer->endpoint)) {
+ list_for_each_entry(socket_iter, &peer->device->tcp_connection_list, tcp_connection) {
+ sk = socket_iter->tcp_peer->tcp_receive_socket->sk;
+ struct endpoint socket_endpoint;
+ wg_extract_endpoint_from_sock(sk, &socket_endpoint);
+
+ if (endpoint_eq(¤t_endpoint, &socket_endpoint)) {
+ found = true;
+ break;
+ }
+ }
+ }
+
+ if (found) {
+ // Before switching sockets, reset callbacks on the old socket and clear references
+ mutex_lock(&peer->device->socket_update_lock);
+ wg_reset_tcp_socket_callbacks(peer);
+ kernel_sock_shutdown(peer->tcp_receive_socket, SHUT_RDWR);
+ sock_release(peer->tcp_receive_socket);
+ wg_remove_from_tcp_connection_list(peer->device, peer->tcp_receive_socket);
+ peer->tcp_receive_socket = NULL;
+ if (peer->tcp_receive_socket->sk->sk_family == AF_INET) {
+ peer->device->tcp_listen_socket4 = socket_iter->tcp_peer->tcp_receive_socket == peer->tcp_receive_socket ? NULL : socket_iter->tcp_peer->tcp_receive_socket;
+ } else if (peer->tcp_receive_socket->sk->sk_family == AF_INET6) {
+#if IS_ENABLED(CONFIG_IPV6)
+ peer->device->tcp_listen_socket6 = socket_iter->tcp_peer->tcp_receive_socket == peer->tcp_receive_socket ? NULL : socket_iter->tcp_peer->tcp_receive_socket;
+#else
+ read_unlock_bh(&peer->endpoint_lock);
+ kfree_skb(skb);
+ printk(KERN_INFO "Exiting function wg_tcp_queuepkt\n");
+ return -EAFNOSUPPORT;
+#endif
+ }
+ mutex_unlock(&peer->device->socket_update_lock);
+
+ peer->tcp_receive_socket = socket_iter->tcp_peer->tcp_receive_socket; // Switch to the new socket
+ peer->endpoint = current_endpoint; // Update the peer's endpoint
+
+ // Setup callbacks on the new socket
+ wg_setup_tcp_socket_callbacks(peer);
+
+ } else {
+ // No matching socket found; initiate a new connection
+ wg_tcp_connect(peer);
+ }
+
+ read_unlock_bh(&peer->endpoint_lock);
+
+ // Queue the packet for sending
+ spin_lock_bh(&peer->send_queue_lock);
+ skb_queue_tail(&peer->send_queue, skb);
+ spin_unlock_bh(&peer->send_queue_lock);
+
+ // Trigger sending if possible
+ if (peer->tcp_established && sk_stream_is_writeable(peer->tcp_receive_socket->sk)) {
+ wg_tcp_write_space(peer->tcp_receive_socket->sk);
+ }
+
+ printk(KERN_INFO "Exiting function wg_tcp_queuepkt\n");
+ return 0;
+}
+
+// Simple checksum function for TCP encapsulation header
+static __be16 wg_header_checksum(const struct wg_tcp_encap_header *hdr) {
+ printk(KERN_INFO "Entering function wg_header_checksum\n");
+ uint16_t checksum = 0;
+ uint32_t length = ntohl(hdr->length); // Ensure network byte order is converted to host byte order for calculation
+
+ // Break the length into two 16-bit halves and XOR them with the flags and type
+ checksum ^= (length >> 16) & 0xFFFF;
+ checksum ^= length & 0xFFFF;
+ checksum ^= (hdr->flags << 8) | hdr->type;
+
+ // Simple rotate to mix bits a bit more
+ checksum = (checksum << 5) | (checksum >> (16 - 5));
+
+ return htons(checksum); // Convert back to network byte order
+ printk(KERN_INFO "Exiting function wg_header_checksum\n");
+}
+
+// Function to validate the header checksum
+static bool wg_validate_header_checksum(const struct wg_tcp_encap_header *hdr) {
+ printk(KERN_INFO "Entering function wg_validate_header_checksum\n");
+ printk(KERN_INFO "Exiting function wg_validate_header_checksum\n");
+ return wg_header_checksum(hdr) == hdr->checksum;
+}
+
+// Function to check if the given data pointer has a valid WireGuard TCP encapsulation header
+bool wg_check_potential_header_validity(__u8 *data, size_t remaining_len) {
+ printk(KERN_INFO "Entering function wg_check_potential_header_validity\n");
+ if (remaining_len < WG_TCP_ENCAP_HDR_LEN) return false; // Not enough data for a header
+
+ struct wg_tcp_encap_header *potential_hdr = (struct wg_tcp_encap_header *)data;
+ // Adjust checksum validation as necessary. This is just an example
+ return wg_validate_header_checksum(potential_hdr);
+ printk(KERN_INFO "Exiting function wg_check_potential_header_validity\n");
+}
+
+static int wg_tcp_send(struct socket *sock, const void *buff, size_t len,
+ __u8 type, __u8 flags) {
+ printk(KERN_INFO "Entering function wg_tcp_send\n");
+ struct wg_tcp_encap_header header;
+ struct msghdr msg = { .msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL };
+ struct kvec vec[2];
+ int sent;
+
+ // Prepare the header
+ header.length = htonl(len) + WG_TCP_ENCAP_HDR_LEN; // Include the payload length and header length
+ header.type = type;
+ header.flags = flags;
+ header.checksum = wg_header_checksum(&header); // Compute checksum for the header
+
+ // Set up the vector for the header and the payload
+ vec[0].iov_base = &header;
+ vec[0].iov_len = WG_TCP_ENCAP_HDR_LEN;
+ vec[1].iov_base = (void *)buff; // Cast away const
+ vec[1].iov_len = len;
+
+ // Send the message including the header and the payload
+ sent = kernel_sendmsg(sock, &msg, vec, 2, WG_TCP_ENCAP_HDR_LEN + len);
+ if (sent >= 0) {
+ // Successfully sent some or all data
+ printk(KERN_INFO "Exiting function wg_tcp_send\n");
+ return sent;
+ } else if (sent == -EAGAIN) {
+ // Socket buffer is full, operation would block
+ printk(KERN_INFO "Exiting function wg_tcp_send\n");
+ return 0;
+ } else {
+ // An error occurred; return the error code
+ printk(KERN_INFO "Exiting function wg_tcp_send\n");
+ return sent;
+ }
+}
+
+void wg_tcp_write_space(struct sock *sk) {
+ printk(KERN_INFO "Entering function wg_tcp_write_space\n");
+ struct wg_peer *peer = sk->sk_user_data;
+ struct sk_buff *skb;
+ int sent;
+
+ if (!peer || !peer->tcp_receive_socket) {
+ printk(KERN_INFO "Exiting function wg_tcp_write_space\n");
+ return;
+ }
+
+ sk = peer->tcp_receive_socket->sk;
+ if (!sk_stream_is_writeable(sk)) {
+ // Socket is not ready for writing, exit and wait for sk_write_space callback
+ printk(KERN_INFO "Exiting function wg_tcp_write_space\n");
+ return;
+ }
+
+ spin_lock_bh(&peer->tcp_lock);
+ while ((skb = skb_peek(&peer->send_queue)) != NULL && sk_stream_is_writeable(sk)) {
+ sent = wg_tcp_send(peer->tcp_receive_socket, skb->data, skb->len, 0, 0); // no type or flags for now
+ if (sent > 0) {
+ if (sent < skb->len) {
+ // Handle partial send by trimming the skb and leaving it in the queue
+ skb_pull(skb, sent);
+ } else {
+ // Full send successful, dequeue and free the skb
+ __skb_unlink(skb, &peer->tcp_packet_queue);
+ kfree_skb(skb);
+ }
+ } else if (sent == 0) {
+ // Socket buffer is full, stop sending and wait for sk_write_space
+ break;
+ } else {
+ // An error occurred, dequeue and free the skb
+ __skb_unlink(skb, &peer->tcp_packet_queue);
+ kfree_skb(skb);
+ break;
+ }
+ }
+ spin_unlock_bh(&peer->tcp_lock);
+ if (peer->original_write_space) {
+ peer->original_write_space(sk);
+ }
+ printk(KERN_INFO "Exiting function wg_tcp_write_space\n");
+}
+
+void wg_tcp_data_ready(struct sock *sk) {
+ printk(KERN_INFO "Entering function wg_tcp_data_ready\n");
+ struct wg_peer *peer = sk->sk_user_data;
+ struct msghdr msg = { .msg_flags = MSG_DONTWAIT | MSG_TRUNC };
+ struct kvec vec;
+ ssize_t read_bytes;
+ struct sk_buff *new_skb = NULL;
+ struct sk_buff *read_skb = NULL;
+
+ if (!peer || !peer->tcp_receive_socket) {
+ printk(KERN_INFO "Exiting function wg_tcp_data_ready\n");
+ return;
+ }
+
+ lock_sock(sk); // Lock the socket for reading
+
+ while (true) {
+ if (!peer->partial_skb || peer->received_len < WG_TCP_ENCAP_HDR_LEN) {
+ // Allocate buffer for new read (header initially)
+ size_t alloc_size = max(WG_TCP_ENCAP_HDR_LEN, peer->received_len);
+ new_skb = alloc_skb(alloc_size + NET_IP_ALIGN, GFP_ATOMIC);
+ if (!new_skb) {
+ pr_err("WireGuard: Failed to allocate skb\n");
+ break;
+ }
+ skb_reserve(new_skb, NET_IP_ALIGN);
+ peer->expected_len = 0;
+ if (peer->partial_skb && peer->received_len > 0) {
+ // If there's already some data, copy it to the new buffer
+ skb_put_data(new_skb, skb_tail_pointer(peer->partial_skb), peer->received_len);
+ } else {
+ if (peer->partial_skb)
+ kfree_skb(peer->partial_skb);
+ peer->partial_skb = new_skb;
+ }
+ }
+ // Read the encapsulation header or the remainder of it
+ if (peer->received_len < WG_TCP_ENCAP_HDR_LEN) {
+ vec.iov_base = skb_tail_pointer(peer->partial_skb);
+ vec.iov_len = WG_TCP_ENCAP_HDR_LEN - peer->received_len;
+ read_bytes = kernel_recvmsg(peer->tcp_receive_socket, &msg, &vec, 1, vec.iov_len, msg.msg_flags);
+ if (read_bytes <= 0) {
+ if (read_bytes == -EAGAIN) break; // No more data available
+ pr_err("WireGuard: Error receiving TCP data\n");
+ if (peer->partial_skb)
+ kfree_skb(peer->partial_skb);
+ peer->partial_skb = NULL;
+ peer->expected_len = 0;
+ peer->received_len = 0;
+ break;
+ }
+ skb_put(peer->partial_skb, read_bytes);
+ peer->received_len += read_bytes;
+ }
+
+ if (peer->received_len >= WG_TCP_ENCAP_HDR_LEN) {
+ // Complete header received, validate and prepare for packet data
+
+ struct wg_tcp_encap_header *hdr = (struct wg_tcp_encap_header *)peer->partial_skb->data;
+ peer->expected_len = ntohl(hdr->length);
+
+ // Check header validity
+ if (peer->received_len >= WG_TCP_ENCAP_HDR_LEN) {
+ // No valid header found yet
+ bool valid_header_found = false;
+
+ // Use wg_validate_header_checksum as the criteria for checking header validity
+ if (peer->expected_len > WG_MAX_PACKET_SIZE || !wg_validate_header_checksum(hdr)) {
+ pr_err("WireGuard: Invalid packet header detected, attempting to resynchronize\n");
+ // No valid header found in the existing data
+ // Attempt to read more data from the socket
+ // Clear existing partial read, if any, and prepare for bulk read
+
+ if (peer->partial_skb) {
+ kfree_skb(peer->partial_skb);
+ peer->partial_skb = NULL;
+ }
+ peer->received_len = 0;
+ peer->expected_len = 0;
+
+ // Attempt to read as much data as available from the socket
+ read_skb = alloc_skb(WG_MAX_PACKET_SIZE + NET_IP_ALIGN,
+ GFP_ATOMIC); // Allocate buffer for bulk read
+ if (!read_skb) {
+ pr_err("WireGuard: Failed to allocate skb for bulk data read\n");
+ break;
+ }
+ skb_reserve(read_skb, NET_IP_ALIGN);
+
+ // Perform the read operation
+ vec.iov_base = skb_put(read_skb, WG_MAX_PACKET_SIZE); // Prepare space
+ vec.iov_len = WG_MAX_PACKET_SIZE;
+ read_bytes = kernel_recvmsg(peer->tcp_receive_socket, &msg, &vec, 1, vec.iov_len,
+ MSG_DONTWAIT | MSG_TRUNC);
+ if (read_bytes <= 0) {
+ if (read_bytes == -EAGAIN) {
+ // No more data available, exit
+ kfree_skb(read_skb);
+ break;
+ }
+ pr_err("WireGuard: Error receiving bulk data from socket\n");
+ kfree_skb(read_skb);
+ break;
+ }
+ skb_trim(read_skb, read_bytes); // Trim skb to actual size of received data
+
+ // Now attempt to find the next valid header within the newly read data
+
+ for (size_t i = 0; i <= read_skb->len - WG_TCP_ENCAP_HDR_LEN; ++i) {
+ // Attempt to validate the header starting from the current byte
+ struct wg_tcp_encap_header *potential_hdr = (struct wg_tcp_encap_header *)(read_skb->data + i);
+ if (wg_validate_header_checksum(potential_hdr)) {
+ valid_header_found = true;
+ // Adjust the skb to start from the found valid header
+ skb_pull(read_skb, i);
+ if (peer->partial_skb)
+ kfree_skb(peer->partial_skb); // free discarded data buffer
+ peer->partial_skb = read_skb; // Transfer ownership of the buffer to partial_skb for further processing
+ peer->received_len = read_bytes - i; // Update received_len to remaining data length
+ peer->expected_len = ntohl(potential_hdr->length); // Set expected length from valid header
+ break; // Exit the loop as we've found a starting point
+ }
+ }
+ }
+
+ // If a valid header was found, continue processing based on the adjusted partial_skb
+ // If not, the data is discarded, and the socket read loop continues for new data
+ if (!valid_header_found) {
+ pr_err("WireGuard: Failed to find valid header in bulk read data\n");
+ if (peer->partial_skb)
+ kfree_skb(peer->partial_skb);
+ peer->partial_skb = NULL;
+ peer->expected_len = 0;
+ peer->received_len = 0;
+ kfree_skb(read_skb); // Clean up as no valid header was found
+ break;
+ }
+ }
+ }
+
+ // If received_len is greater than expected_len (which includes WG_TCP_ENCAP_HDR_LEN),
+ // it implies there's more data potentially for another packet or part of the current
+ //packet beyond what was expected.
+ if (peer->received_len < peer->expected_len) {
+ if ((skb_tailroom(peer->partial_skb) < peer->expected_len) &&
+ (peer->received_len < peer->expected_len)) {
+ // check if need a bigger buffer
+ struct sk_buff *resized_skb = skb_copy_expand(peer->partial_skb, 0,
+ peer->expected_len - skb_tailroom(peer->partial_skb),
+ GFP_ATOMIC);
+ if (!resized_skb) {
+ pr_err("WireGuard: Failed to resize skb\n");
+ if (peer->partial_skb)
+ kfree_skb(peer->partial_skb);
+ peer->partial_skb = NULL;
+ peer->expected_len = 0;
+ peer->received_len = 0;
+
+ break;
+ }
+ if (peer->partial_skb)
+ kfree_skb(peer->partial_skb);
+ peer->partial_skb = resized_skb;
+ }
+
+ // Read packet data, ensuring we don't read beyond the buffer size
+ vec.iov_base = skb_tail_pointer(peer->partial_skb);
+ vec.iov_len = peer->expected_len - peer->received_len;
+ read_bytes = kernel_recvmsg(peer->tcp_receive_socket, &msg, &vec, 1, vec.iov_len, msg.msg_flags);
+ if (read_bytes <= 0) {
+ if (read_bytes == -EAGAIN) break; // No more data available
+ pr_err("WireGuard: Error or incomplete packet data received\n");
+ if (peer->partial_skb)
+ kfree_skb(peer->partial_skb);
+ peer->partial_skb = NULL;
+ peer->expected_len = 0;
+ peer->received_len = 0;
+ break;
+ }
+ skb_put(peer->partial_skb, read_bytes);
+ peer->received_len += read_bytes;
+ }
+ // Check if we've received the complete packet now
+ if (peer->received_len >= peer->expected_len) {
+ // Process the complete packet
+ skb_pull(peer->partial_skb, WG_TCP_ENCAP_HDR_LEN); // Remove the header from the skb
+ wg_receive(sk, peer->partial_skb); // Process the complete packet
+
+ // Handle any leftover data for the next packet
+ size_t leftover_len = peer->received_len - peer->expected_len;
+ if (leftover_len > 0) {
+ // Adjust the buffer to only include leftover data
+ skb_pull(peer->partial_skb, peer->expected_len - WG_TCP_ENCAP_HDR_LEN);
+
+ peer->received_len = leftover_len;
+ peer->expected_len = 0; // Reset for the next packet
+ // Continue processing in case there is enough data for another packet
+ } else {
+ // No leftover data, reset for the next packet
+ if (peer->partial_skb) {
+ kfree_skb(peer->partial_skb);
+ peer->partial_skb = NULL;
+ }
+
+ peer->received_len = 0;
+ peer->expected_len = 0;
+ // Continue loop to process more data if available
+ }
+ }
+ }
+ release_sock(sk); // Unlock the socket
+
+ // Call the original data_ready callback if it exists
+ if (peer->original_data_ready) {
+ peer->original_data_ready(sk);
+ }
+ printk(KERN_INFO "Exiting function wg_tcp_data_ready\n");
+}
+
+void wg_setup_tcp_socket_callbacks(struct wg_peer *peer) {
+ printk(KERN_INFO "Entering function wg_setup_tcp_socket_callbacks\n");
+ if (!peer || !peer->tcp_receive_socket) return;
+ struct sock *sk = peer->tcp_receive_socket->sk;
+
+ write_lock_bh(&sk->sk_callback_lock);
+
+ // Save the original callbacks
+ peer->original_state_change = sk->sk_state_change;
+ peer->original_write_space = sk->sk_write_space;
+ peer->original_data_ready = sk->sk_data_ready;
+
+ // Assign new callbacks and pass `peer` as user data for callback functions
+ sk->sk_user_data = peer;
+ sk->sk_state_change = wg_tcp_state_change; // Callback when socket state changes
+ sk->sk_write_space = wg_tcp_write_space; // Callback when socket becomes writable
+ sk->sk_data_ready = wg_tcp_data_ready; // Callback when data is ready to be read
+
+ write_unlock_bh(&sk->sk_callback_lock);
+ printk(KERN_INFO "Exiting function wg_setup_tcp_socket_callbacks\n");
+}
+
+void wg_reset_tcp_socket_callbacks(struct wg_peer *peer) {
+ printk(KERN_INFO "Entering function wg_reset_tcp_socket_callbacks\n");
+ struct sock *sk;
+
+ if (!peer || !peer->tcp_receive_socket) {
+ printk(KERN_INFO "Exiting function wg_reset_tcp_socket_callbacks\n");
+ return;
+ }
+
+ sk = peer->tcp_receive_socket->sk;
+
+ // Lock the socket to safely update callback pointers
+ write_lock_bh(&sk->sk_callback_lock);
+
+ // Check if we previously saved original callbacks and restore them
+ if (peer->original_state_change) {
+ sk->sk_state_change = peer->original_state_change;
+ peer->original_state_change = NULL; // Clear the reference in the peer structure
+ }
+ if (peer->original_write_space) {
+ sk->sk_write_space = peer->original_write_space;
+ peer->original_write_space = NULL; // Clear the reference in the peer structure
+ }
+ if (peer->original_data_ready) {
+ sk->sk_data_ready = peer->original_data_ready;
+ peer->original_data_ready = NULL; // Clear the reference in the peer structure
+ }
+
+ // Clear the user data to avoid any dangling references
+ sk->sk_user_data = NULL;
+
+ write_unlock_bh(&sk->sk_callback_lock);
+ printk(KERN_INFO "Exiting function wg_reset_tcp_socket_callbacks\n");
+}
+
+void wg_tcp_connection_retry_timer(struct timer_list *t) {
+ printk(KERN_INFO "Entering function wg_tcp_connection_retry_timer\n");
+ struct wg_peer *peer = from_timer(peer, t, tcp_connect_retry_timer);
+ if (peer->tcp_established == false)
+ wg_tcp_connect(peer);
+ printk(KERN_INFO "Exiting function wg_tcp_connection_retry_timer\n");
+}
+
+void peer_remove_after_dead(struct wg_peer *peer) {
+ printk(KERN_INFO "Entering function peer_remove_after_dead\n");
+ // Existing cleanup logic...
+ pr_err("WireGuard: Invalid packet header detected, attempting to resynchronize\n");
+
+ if (peer->tcp_receive_socket) {
+ kernel_sock_shutdown(peer->tcp_receive_socket, SHUT_RDWR);
+ sock_release(peer->tcp_receive_socket);
+ peer->tcp_receive_socket = NULL;
+ peer->tcp_established = false;
+ pr_info("TCP resources cleaned up for peer.\n");
+ }
+
+ del_timer_sync(&peer->tcp_connect_retry_timer);
+ printk(KERN_INFO "Exiting function peer_remove_after_dead\n");
+}
+
+void wg_add_tcp_socket_to_list(struct wg_device *wg, struct socket *receive_socket) {
+ printk(KERN_INFO "Entering function wg_add_tcp_socket_to_list\n");
+ struct wg_tcp_socket_list_entry *entry;
+
+ entry = kmalloc(sizeof(*entry), GFP_KERNEL);
+ if (!entry) {
+ pr_err("Failed to allocate wg_tcp_socket_list_entry\n");
+ printk(KERN_INFO "Exiting function wg_add_tcp_socket_to_list\n");
+ return;
+ }
+
+ entry->tcp_peer->tcp_receive_socket = receive_socket;
+ entry->timestamp = ktime_get(); // Capture the current time
+
+ spin_lock_bh(&wg->tcp_connection_list_lock);
+ list_add_tail(&entry->tcp_connection, &wg->tcp_connection_list);
+ spin_unlock_bh(&wg->tcp_connection_list_lock);
+ printk(KERN_INFO "Exiting function wg_add_tcp_socket_to_list\n");
+}
+
+void wg_remove_from_tcp_connection_list(struct wg_device *wg,
+ struct socket *sock) {
+ printk(KERN_INFO "Entering function wg_remove_from_tcp_connection_list\n");
+ struct wg_tcp_socket_list_entry *entry, *tmp;
+
+ spin_lock_bh(&wg->tcp_connection_list_lock);
+
+ list_for_each_entry_safe(entry, tmp, &wg->tcp_connection_list, tcp_connection) {
+ if (entry->tcp_peer->tcp_receive_socket == sock) {
+ list_del(&entry->tcp_connection); // Removes the entry from the list
+ // Check if there's associated user_data to free
+ if (entry->tcp_peer->tcp_receive_socket && entry->tcp_peer->tcp_receive_socket->sk && entry->tcp_peer->tcp_receive_socket->sk->sk_user_data) {
+ struct wg_peer *temp_peer = entry->tcp_peer->tcp_receive_socket->sk->sk_user_data;
+ entry->tcp_peer->tcp_receive_socket->sk->sk_user_data = NULL; // Clear the pointer to prevent use-after-free
+ kfree(temp_peer);
+ }
+ kfree(entry); // Frees the memory allocated for the entry
+ break; // Exit the loop after removing the entry
+ }
+ }
+ spin_unlock_bh(&wg->tcp_connection_list_lock);
+ printk(KERN_INFO "Exiting function wg_remove_from_tcp_connection_list\n");
+}
+
+int wg_tcp_listener_worker(struct wg_device *wg, struct socket *tcp_socket) {
+ printk(KERN_INFO "Entering function wg_tcp_listener_worker\n");
+ struct socket *new_peer_listen_socket = NULL;
+
+ if (!tcp_socket) {
+ pr_err("tcp_socket is NULL\n");
+ return -EINVAL;
+ }
+
+ while (!kthread_should_stop()) {
+ int err;
+
+ err = kernel_accept(tcp_socket, &new_peer_listen_socket, 0);
+ if (err < 0) {
+ if (err == -EAGAIN || err == -ERESTARTSYS)
+ continue;
+ pr_err("Error accepting new connection: %d\n", err);
+ continue;
+ }
+
+ if (!new_peer_listen_socket) {
+ pr_err("new_peer_listen_socket is NULL after kernel_accept\n");
+ continue;
+ }
+
+ // Assume new_peer_listen_socket has been accepted successfully.
+ struct wg_peer *matched_peer = NULL;
+ struct list_head *pos;
+
+ // Set socket options for timeouts using jiffies (5 seconds)
+ wg_set_socket_timeouts(new_peer_listen_socket, 5 * HZ, 5 * HZ); // snd_timeout and rcv_timeout in jiffies
+
+ // Lock the list to safely iterate and modify.
+ spin_lock_bh(&wg->tcp_connection_list_lock);
+
+ struct endpoint new_endpoint;
+ wg_extract_endpoint_from_sock(new_peer_listen_socket->sk, &new_endpoint);
+
+ // Iterate over wgdevice's peer list to find the matching peer.
+ list_for_each(pos, &wg->peer_list) {
+ struct wg_peer *peer = list_entry(pos, struct wg_peer, peer_list);
+ if (wg_endpoints_match(&peer->endpoint, &new_endpoint)) {
+ matched_peer = peer;
+ break;
+ }
+ }
+
+ /* XXX - This needs to be fixed to deal with active and pending sockets */
+ if (matched_peer) {
+ struct socket *wg_socket_iter;
+
+ // Before switching sockets, reset callbacks on the old socket and clear references
+ mutex_lock(&matched_peer->device->socket_update_lock);
+ wg_reset_tcp_socket_callbacks(matched_peer);
+
+ if (matched_peer->tcp_receive_socket) {
+ kernel_sock_shutdown(matched_peer->tcp_receive_socket, SHUT_RDWR);
+ sock_release(matched_peer->tcp_receive_socket);
+ matched_peer->tcp_receive_socket = NULL;
+ }
+
+
+ if (new_peer_listen_socket->sk->sk_family == AF_INET) {
+ matched_peer->device->tcp_listen_socket4 = new_peer_listen_socket;
+ } else if (new_peer_listen_socket->sk->sk_family == AF_INET6) {
+#if IS_ENABLED(CONFIG_IPV6)
+ matched_peer->device->tcp_listen_socket6 = new_peer_listen_socket;
+#else
+ spin_unlock_bh(&wg->tcp_connection_list_lock);
+ pr_err("IPv6 not supported\n");
+ kernel_sock_shutdown(new_peer_listen_socket, SHUT_RDWR);
+ sock_release(new_peer_listen_socket);
+ continue;
+#endif
+ }
+
+ mutex_unlock(&matched_peer->device->socket_update_lock);
+
+ matched_peer->tcp_receive_socket = new_peer_listen_socket; // Switch to the new socket
+
+ // Setup callbacks on the new socket
+ wg_setup_tcp_socket_callbacks(matched_peer);
+
+ // When switching to a new connection, free any partial skbs
+ if (matched_peer->partial_skb) {
+ kfree_skb(matched_peer->partial_skb);
+ matched_peer->partial_skb = NULL;
+ }
+ } else {
+ struct wg_peer *temp_peer = kzalloc(sizeof(struct wg_peer), GFP_KERNEL);
+
+ if (temp_peer) {
+ new_peer_listen_socket->sk->sk_user_data = temp_peer; // Associate socket with the temporary peer
+
+ // Associate the new socket with the peer
+ temp_peer->tcp_receive_socket = new_peer_listen_socket;
+ temp_peer->device = wg;
+
+ // Set up the new socket for non-blocking IO
+ new_peer_listen_socket->sk->sk_allocation = GFP_ATOMIC;
+
+ // Setup callbacks on the new socket
+ wg_setup_tcp_socket_callbacks(temp_peer);
+
+ wg_add_tcp_socket_to_list(wg, new_peer_listen_socket);
+ } else {
+ pr_err("Failed to allocate memory for temp_peer\n");
+ // Properly close and release new_peer_listen_socket as it won't be used.
+ kernel_sock_shutdown(new_peer_listen_socket, SHUT_RDWR);
+ sock_release(new_peer_listen_socket);
+ }
+ }
+
+ spin_unlock_bh(&wg->tcp_connection_list_lock);
+ }
+
+ if (new_peer_listen_socket) {
+ process_new_connection(wg, new_peer_listen_socket);
+ }
+
+ printk(KERN_INFO "Exiting function wg_tcp_listener_worker\n");
+ return 0;
+}
+
+int wg_tcp_listener4_thread(void *data) {
+ printk(KERN_INFO "Entering function wg_tcp_listener4_thread\n");
+ struct wg_device *wg = data;
+ struct socket *listen_socket;
+
+ // Check if tcp_socket4_ready is set
+ if (!wg->tcp_socket4_ready) {
+ printk(KERN_INFO "tcp_socket4 is not ready, exiting wg_tcp_listener4_thread\n");
+ return 0;
+ }
+ listen_socket = wg->tcp_listen_socket4;
+
+ printk(KERN_INFO "Exiting function wg_tcp_listener4_thread\n");
+ return wg_tcp_listener_worker(wg, listen_socket);
+}
+
+int wg_tcp_listener6_thread(void *data) {
+ printk(KERN_INFO "Entering function wg_tcp_listener6_thread\n");
+ struct wg_device *wg = data;
+ struct socket *listen_socket;
+
+ if (!wg->tcp_socket6_ready) {
+ printk(KERN_INFO "tcp_socket6 is not ready, exiting wg_tcp_listener6_thread\n");
+ return 0;
+ }
+
+ listen_socket = wg->tcp_listen_socket6;
+
+ printk(KERN_INFO "Exiting function wg_tcp_listener6_thread\n");
+ return wg_tcp_listener_worker(wg, listen_socket);
+}
+
+void process_new_connection(struct wg_device *wg, struct socket *new_sock) {
+ printk(KERN_INFO "Entering function process_new_connection\n");
+ struct wg_peer *temp_peer = kzalloc(sizeof(struct wg_peer), GFP_KERNEL);
+ if (!temp_peer) {
+ pr_err("Failed to allocate temp wg_peer\n");
+ sock_release(new_sock);
+ printk(KERN_INFO "Exiting function process_new_connection\n");
+ return;
+ }
+
+ // Initialize temporary wg_peer
+ temp_peer->device = wg;
+ temp_peer->tcp_receive_socket = new_sock;
+
+ // Set up socket callbacks with the temporary peer as context
+ wg_setup_tcp_socket_callbacks(temp_peer);
+
+ // Add the temporary peer to a list in wg_device for management
+ list_add_tail(&temp_peer->pending_connection_list, &wg->tcp_connection_list);
+ printk(KERN_INFO "Exiting function process_new_connection\n");
+}
+
+void wg_destruct_tcp_connection_list(struct wg_device *wg) {
+ printk(KERN_INFO "Entering function wg_destruct_tcp_connection_list\n");
+ struct wg_tcp_socket_list_entry *entry, *tmp;
+
+ spin_lock_bh(&wg->tcp_connection_list_lock);
+
+ // Iterate over the entire list and free each entry
+ list_for_each_entry_safe(entry, tmp, &wg->tcp_connection_list, tcp_connection) {
+ list_del(&entry->tcp_connection); // Removes the entry from the list
+
+ // Release the socket
+ if (entry->tcp_peer) {
+ // Check if there's associated user_data to free
+ if (entry->tcp_peer->tcp_receive_socket->sk && entry->tcp_peer->tcp_receive_socket->sk->sk_user_data) {
+ struct wg_peer *temp_peer = entry->tcp_peer->tcp_receive_socket->sk->sk_user_data;
+ entry->tcp_peer->tcp_receive_socket->sk->sk_user_data = NULL; // Clear the pointer to prevent use-after-free
+ kfree(temp_peer);
+ }
+ sock_release(entry->tcp_peer->tcp_receive_socket); // Release the socket
+ }
+ kfree(entry); // Free the memory allocated for the list entry
+ }
+ spin_unlock_bh(&wg->tcp_connection_list_lock);
+ printk(KERN_INFO "Exiting function wg_destruct_tcp_connection_list\n");
+}
+
+void wg_tcp_cleanup_worker(struct work_struct *work) {
+ printk(KERN_INFO "Entering function wg_tcp_cleanup_worker\n");
+ struct wg_device *wg = container_of(work, struct wg_device, tcp_cleanup_work.work);
+ struct wg_tcp_socket_list_entry *entry, *tmp;
+ ktime_t now = ktime_get();
+
+ spin_lock_bh(&wg->tcp_connection_list_lock);
+ list_for_each_entry_safe(entry, tmp, &wg->tcp_connection_list, tcp_connection) {
+ if (ktime_ms_delta(now, entry->timestamp) > 5000) { // Check if older than 5 seconds
+ wg_remove_from_tcp_connection_list(wg, entry->tcp_peer->tcp_receive_socket);
+ }
+ }
+ spin_unlock_bh(&wg->tcp_connection_list_lock);
+
+ // Reschedule the worker
+ schedule_delayed_work(&wg->tcp_cleanup_work, msecs_to_jiffies(5000));
+ printk(KERN_INFO "Exiting function wg_tcp_cleanup_worker\n");
+}
+
+int wg_tcp_socket_init(struct wg_device *wg, u16 port) {
+ printk(KERN_INFO "Entering function wg_tcp_socket_init\n");
+ struct net *net;
+ int ret;
+ struct socket *listen_socket4 = NULL, *listen_socket6 = NULL;
+ struct sockaddr_in addr4 = {
+ .sin_family = AF_INET,
+ .sin_port = htons(port),
+ .sin_addr = { htonl(INADDR_ANY) }
+ };
+#if IS_ENABLED(CONFIG_IPV6)
+ struct sockaddr_in6 addr6 = {
+ .sin6_family = AF_INET6,
+ .sin6_port = htons(port),
+ .sin6_addr = IN6ADDR_ANY_INIT,
+ };
+#endif
+
+ printk(KERN_INFO "Locking RCU and dereferencing wg->creating_net\n");
+ rcu_read_lock();
+ net = rcu_dereference(wg->creating_net);
+ net = net ? maybe_get_net(net) : NULL;
+ rcu_read_unlock();
+ printk(KERN_INFO "RCU lock released\n");
+
+ if (unlikely(!net)) {
+ printk(KERN_ERR "Error: net is NULL, exiting wg_tcp_socket_init\n");
+ return -ENONET;
+ }
+
+ printk(KERN_INFO "Creating IPv4 socket\n");
+ ret = sock_create_kern(net, AF_INET, SOCK_STREAM, IPPROTO_TCP, &listen_socket4);
+ if (ret < 0) {
+ pr_err("%s: Could not create IPv4 TCP socket, error: %d\n", wg->dev->name, ret);
+ goto out;
+ }
+ printk(KERN_INFO "IPv4 socket created successfully\n");
+
+ printk(KERN_INFO "Binding IPv4 socket\n");
+ ret = kernel_bind(listen_socket4, (struct sockaddr *)&addr4, sizeof(addr4));
+ if (ret < 0) {
+ pr_err("%s: Could not bind IPv4 TCP socket, error: %d\n", wg->dev->name, ret);
+ goto release_ipv4;
+ }
+ printk(KERN_INFO "IPv4 socket bound successfully\n");
+
+ printk(KERN_INFO "Starting to listen on IPv4 socket\n");
+ ret = kernel_listen(listen_socket4, SOMAXCONN);
+ if (ret < 0) {
+ pr_err("%s: Could not listen on IPv4 TCP socket, error: %d\n", wg->dev->name, ret);
+ goto release_ipv4;
+ }
+ printk(KERN_INFO "IPv4 socket is now listening\n");
+
+#if IS_ENABLED(CONFIG_IPV6)
+ printk(KERN_INFO "Creating IPv6 socket\n");
+ ret = sock_create_kern(net, AF_INET6, SOCK_STREAM, IPPROTO_TCP, &listen_socket6);
+ if (ret < 0) {
+ pr_err("%s: Could not create IPv6 TCP socket, error: %d\n", wg->dev->name, ret);
+ goto release_ipv4;
+ }
+ printk(KERN_INFO "IPv6 socket created successfully\n");
+
+ printk(KERN_INFO "Binding IPv6 socket\n");
+ ret = kernel_bind(listen_socket6, (struct sockaddr *)&addr6, sizeof(addr6));
+ if (ret < 0) {
+ pr_err("%s: Could not bind IPv6 TCP socket, error: %d\n", wg->dev->name, ret);
+ goto release_ipv6;
+ }
+ printk(KERN_INFO "IPv6 socket bound successfully\n");
+
+ printk(KERN_INFO "Starting to listen on IPv6 socket\n");
+ ret = kernel_listen(listen_socket6, SOMAXCONN);
+ if (ret < 0) {
+ pr_err("%s: Could not listen on IPv6 TCP socket, error: %d\n", wg->dev->name, ret);
+ goto release_ipv6;
+ }
+ printk(KERN_INFO "IPv6 socket is now listening\n");
+#endif
+
+ wg->tcp_listen_socket4 = listen_socket4;
+ wg->tcp_socket4_ready = true;
+#if IS_ENABLED(CONFIG_IPV6)
+ wg->tcp_listen_socket6 = listen_socket6;
+ wg->tcp_socket6_ready = true;
+#endif
+
+ // Once for IPv4
+ if (!wg->tcp_listener4_thread) {
+ printk(KERN_INFO "Starting IPv4 listener thread\n");
+ wg->tcp_listener4_thread = kthread_run(wg_tcp_listener4_thread, (void *)wg, "wg_listener");
+ if (IS_ERR(wg->tcp_listener4_thread)) {
+ pr_err("Failed to establish IPv4 TCP listener thread\n");
+ } else {
+ printk(KERN_INFO "IPv4 listener thread started successfully\n");
+ }
+ }
+
+ // And a second time for IPv6
+#if IS_ENABLED(CONFIG_IPV6)
+ if (!wg->tcp_listener6_thread) {
+ printk(KERN_INFO "Starting IPv6 listener thread\n");
+ wg->tcp_listener6_thread = kthread_run(wg_tcp_listener6_thread, (void *)wg, "wg_listener");
+ if (IS_ERR(wg->tcp_listener6_thread)) {
+ pr_err("Failed to establish IPv6 TCP listener thread\n");
+ } else {
+ printk(KERN_INFO "IPv6 listener thread started successfully\n");
+ }
+ }
+#endif
+
+ printk(KERN_INFO "Initializing delayed work for TCP cleanup\n");
+ INIT_DELAYED_WORK(&wg->tcp_cleanup_work, wg_tcp_cleanup_worker);
+ schedule_delayed_work(&wg->tcp_cleanup_work, msecs_to_jiffies(5000));
+ printk(KERN_INFO "Delayed work scheduled\n");
+
+release_ipv4:
+ if (ret < 0 && listen_socket4) {
+ printk(KERN_INFO "Releasing IPv4 socket\n");
+ sock_release(listen_socket4);
+ wg->tcp_listen_socket4 = NULL;
+ }
+release_ipv6:
+#if IS_ENABLED(CONFIG_IPV6)
+ if (ret < 0 && listen_socket6) {
+ printk(KERN_INFO "Releasing IPv6 socket\n");
+ sock_release(listen_socket6);
+ wg->tcp_listen_socket6 = NULL;
+ }
+#endif
+out:
+ put_net(net);
+ printk(KERN_INFO "Exiting function wg_tcp_socket_init with ret=%d\n", ret);
+ return ret;
}
diff --git a/wireguard-linux/drivers/net/wireguard/socket.h b/wireguard-linux/drivers/net/wireguard/socket.h
index bab5848efbcd..89867abf27f5 100644
--- a/wireguard-linux/drivers/net/wireguard/socket.h
+++ b/wireguard-linux/drivers/net/wireguard/socket.h
@@ -6,6 +6,7 @@
#ifndef _WG_SOCKET_H
#define _WG_SOCKET_H
+#include <linux/net.h>
#include <linux/netdevice.h>
#include <linux/udp.h>
#include <linux/if_vlan.h>
@@ -29,6 +30,18 @@ void wg_socket_set_peer_endpoint(struct wg_peer *peer,
void wg_socket_set_peer_endpoint_from_skb(struct wg_peer *peer,
const struct sk_buff *skb);
void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer);
+int wg_tcp_socket_init(struct wg_device *wg, u16 port);
+void wg_destruct_tcp_connection_list(struct wg_device *wg);
+
+struct wg_tcp_encap_header {
+ __be32 length;
+ __u8 type;
+ __u8 flags;
+ __be16 checksum;
+};
+
+#define WG_TCP_ENCAP_HDR_LEN sizeof(struct wg_tcp_encap_header)
+#define WG_MAX_PACKET_SIZE 65535+WG_TCP_ENCAP_HDR_LEN
#if defined(CONFIG_DYNAMIC_DEBUG) || defined(DEBUG)
#define net_dbg_skb_ratelimited(fmt, dev, skb, ...) do { \
@@ -41,4 +54,28 @@ void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer);
#define net_dbg_skb_ratelimited(fmt, skb, ...)
#endif
+/* Forward declarations of functions */
+int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds);
+int wg_socket_send_buffer_to_peer(struct wg_peer *peer, void *buffer, size_t len, u8 ds);
+int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg, struct sk_buff *in_skb, void *buffer, size_t len);
+int wg_socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb);
+void wg_socket_set_peer_endpoint(struct wg_peer *peer, const struct endpoint *endpoint);
+void wg_socket_set_peer_endpoint_from_skb(struct wg_peer *peer, const struct sk_buff *skb);
+void wg_socket_clear_peer_endpoint_src(struct wg_peer *peer);
+int wg_socket_init(struct wg_device *wg, u16 port);
+void wg_socket_reinit(struct wg_device *wg, struct sock *new4, struct sock *new6);
+void wg_tcp_state_change(struct sock *sk);
+void wg_extract_endpoint_from_sock(struct sock *sk, struct endpoint *endpoint);
+bool wg_check_potential_header_validity(__u8 *data, size_t remaining_len);
+void wg_tcp_write_space(struct sock *sk);
+void wg_tcp_data_ready(struct sock *sk);
+void wg_add_tcp_socket_to_list(struct wg_device *wg, struct socket *sock);
+void wg_remove_from_tcp_connection_list(struct wg_device *wg, struct socket *sock);
+void process_new_connection(struct wg_device *wg, struct socket *new_sock);
+void wg_destruct_tcp_connection_list(struct wg_device *wg);
+void wg_tcp_cleanup_worker(struct work_struct *work);
+int wg_tcp_socket_init(struct wg_device *wg, u16 port);
+void wg_tcp_connection_retry_timer(struct timer_list *);
+int wg_tcp_connect(struct wg_peer *);
+
#endif /* _WG_SOCKET_H */
diff --git a/wireguard-linux/include/uapi/linux/wireguard.h b/wireguard-linux/include/uapi/linux/wireguard.h
index ae88be14c947..7b9d67d4555d 100644
--- a/wireguard-linux/include/uapi/linux/wireguard.h
+++ b/wireguard-linux/include/uapi/linux/wireguard.h
@@ -136,6 +136,9 @@
#define WG_KEY_LEN 32
+#define WG_TRANSPORT_UDP 0
+#define WG_TRANSPORT_TCP 1
+
enum wg_cmd {
WG_CMD_GET_DEVICE,
WG_CMD_SET_DEVICE,
@@ -157,6 +160,7 @@ enum wgdevice_attribute {
WGDEVICE_A_LISTEN_PORT,
WGDEVICE_A_FWMARK,
WGDEVICE_A_PEERS,
+ WGDEVICE_A_TRANSPORT,
__WGDEVICE_A_LAST
};
#define WGDEVICE_A_MAX (__WGDEVICE_A_LAST - 1)
@@ -180,6 +184,7 @@ enum wgpeer_attribute {
WGPEER_A_TX_BYTES,
WGPEER_A_ALLOWEDIPS,
WGPEER_A_PROTOCOL_VERSION,
+ WGPEER_A_TRANSPORT,
__WGPEER_A_LAST
};
#define WGPEER_A_MAX (__WGPEER_A_LAST - 1)
diff --git a/wireguard-tools/src/config.c b/wireguard-tools/src/config.c
index 81ccb479c367..135b90e4bb20 100644
--- a/wireguard-tools/src/config.c
+++ b/wireguard-tools/src/config.c
@@ -410,6 +410,21 @@ err:
return false;
}
+static inline bool parse_transport(uint32_t *transport, const char *value)
+{
+ if (!strcasecmp(value, "tcp")) {
+ *transport = WG_TRANSPORT_TCP;
+ return true;
+ } else if (!strcasecmp(value, "udp")) {
+ *transport = WG_TRANSPORT_UDP;
+ return true;
+ } else
+ goto err;
+err:
+ fprintf(stderr, "Transport protocol is neither tcp nor udp: `%s'\n", value);
+ return false;
+}
+
static bool process_line(struct config_ctx *ctx, const char *line)
{
const char *value;
@@ -436,6 +451,7 @@ static bool process_line(struct config_ctx *ctx, const char *line)
ctx->is_peer_section = true;
ctx->is_device_section = false;
ctx->last_peer->flags |= WGPEER_REPLACE_ALLOWEDIPS;
+ ctx->last_peer->transport = WG_TRANSPORT_UDP;
return true;
}
@@ -450,6 +466,8 @@ static bool process_line(struct config_ctx *ctx, const char *line)
ret = parse_key(ctx->device->private_key, value);
if (ret)
ctx->device->flags |= WGDEVICE_HAS_PRIVATE_KEY;
+ } else if (key_match("TransportMode")) {
+ ret = parse_transport(&ctx->device->transport, value);
} else
goto error;
} else if (ctx->is_peer_section) {
@@ -467,6 +485,8 @@ static bool process_line(struct config_ctx *ctx, const char *line)
ret = parse_key(ctx->last_peer->preshared_key, value);
if (ret)
ctx->last_peer->flags |= WGPEER_HAS_PRESHARED_KEY;
+ } else if (key_match("PeerTransportMode")) {
+ ret = parse_transport(&ctx->last_peer->transport, value);
} else
goto error;
} else
@@ -588,6 +608,11 @@ struct wgdevice *config_read_cmd(const char *argv[], int argc)
device->flags |= WGDEVICE_HAS_PRIVATE_KEY;
argv += 2;
argc -= 2;
+ } else if (!strcmp(argv[0], "transport-mode") && argc >= 2 && !peer) {
+ if (!parse_transport(&device->transport, argv[1]))
+ goto error;
+ argv += 2;
+ argc -= 2;
} else if (!strcmp(argv[0], "peer") && argc >= 2) {
struct wgpeer *new_peer = calloc(1, sizeof(*new_peer));
@@ -604,6 +629,7 @@ struct wgdevice *config_read_cmd(const char *argv[], int argc)
if (!parse_key(peer->public_key, argv[1]))
goto error;
peer->flags |= WGPEER_HAS_PUBLIC_KEY;
+ peer->transport = WG_TRANSPORT_UDP;
argv += 2;
argc -= 2;
} else if (!strcmp(argv[0], "remove") && argc >= 1 && peer) {
@@ -638,6 +664,10 @@ struct wgdevice *config_read_cmd(const char *argv[], int argc)
peer->flags |= WGPEER_HAS_PRESHARED_KEY;
argv += 2;
argc -= 2;
+ } else if (!strcmp(argv[0], "peertransport-mode") && argc >= 2 && peer) {
+ if (!parse_transport(&peer->transport, argv[1]))
+ goto error;
+ argv += 2;
} else {
fprintf(stderr, "Invalid argument: %s\n", argv[0]);
goto error;
@@ -647,4 +677,4 @@ struct wgdevice *config_read_cmd(const char *argv[], int argc)
error:
free_wgdevice(device);
return false;
-}
+}
\ No newline at end of file
diff --git a/wireguard-tools/src/containers.h b/wireguard-tools/src/containers.h
index a82e8ddee46a..5c7f21121fb6 100644
--- a/wireguard-tools/src/containers.h
+++ b/wireguard-tools/src/containers.h
@@ -64,6 +64,7 @@ struct wgpeer {
struct wgallowedip *first_allowedip, *last_allowedip;
struct wgpeer *next_peer;
+ uint32_t transport;
};
enum {
@@ -87,6 +88,7 @@ struct wgdevice {
uint16_t listen_port;
struct wgpeer *first_peer, *last_peer;
+ uint32_t transport;
};
#define for_each_wgpeer(__dev, __peer) for ((__peer) = (__dev)->first_peer; (__peer); (__peer) = (__peer)->next_peer)
diff --git a/wireguard-tools/src/ipc-linux.h b/wireguard-tools/src/ipc-linux.h
index d29c0c5dbf9b..1123b1128719 100644
--- a/wireguard-tools/src/ipc-linux.h
+++ b/wireguard-tools/src/ipc-linux.h
@@ -155,6 +155,7 @@ static int kernel_set_device(struct wgdevice *dev)
again:
nlh = mnlg_msg_prepare(nlg, WG_CMD_SET_DEVICE, NLM_F_REQUEST | NLM_F_ACK);
mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, dev->name);
+ mnl_attr_put_u8(nlh, WGDEVICE_A_TRANSPORT, dev->transport);
if (!peer) {
uint32_t flags = 0;
@@ -184,6 +185,7 @@ again:
goto toobig_peers;
if (peer->flags & WGPEER_REMOVE_ME)
flags |= WGPEER_F_REMOVE_ME;
+ mnl_attr_put_u8(nlh, WGPEER_A_TRANSPORT, peer->transport);
if (!allowedip) {
if (peer->flags & WGPEER_REPLACE_ALLOWEDIPS)
flags |= WGPEER_F_REPLACE_ALLOWEDIPS;
@@ -371,6 +373,10 @@ static int parse_peer(const struct nlattr *attr, void *data)
if (!mnl_attr_validate(attr, MNL_TYPE_U64))
peer->tx_bytes = mnl_attr_get_u64(attr);
break;
+ case WGPEER_A_TRANSPORT:
+ if (!mnl_attr_validate(attr, MNL_TYPE_U8))
+ peer->transport = mnl_attr_get_u8(attr);
+ break;
case WGPEER_A_ALLOWEDIPS:
return mnl_attr_parse_nested(attr, parse_allowedips, peer);
}
@@ -439,6 +445,10 @@ static int parse_device(const struct nlattr *attr, void *data)
if (!mnl_attr_validate(attr, MNL_TYPE_U32))
device->fwmark = mnl_attr_get_u32(attr);
break;
+ case WGDEVICE_A_TRANSPORT:
+ if (!mnl_attr_validate(attr, MNL_TYPE_U8))
+ device->transport = mnl_attr_get_u8(attr);
+ break;
case WGDEVICE_A_PEERS:
return mnl_attr_parse_nested(attr, parse_peers, device);
}
diff --git a/wireguard-tools/src/netlink.h b/wireguard-tools/src/netlink.h
index f9729ee280f1..b7348a5740d0 100644
--- a/wireguard-tools/src/netlink.h
+++ b/wireguard-tools/src/netlink.h
@@ -318,6 +318,11 @@ static void mnl_attr_put(struct nlmsghdr *nlh, uint16_t type, size_t len,
memset(mnl_attr_get_payload(attr) + len, 0, pad);
}
+static void mnl_attr_put_u8(struct nlmsghdr *nlh, uint8_t type, uint8_t data)
+{
+ mnl_attr_put(nlh, type, sizeof(uint8_t), &data);
+}
+
static void mnl_attr_put_u16(struct nlmsghdr *nlh, uint16_t type, uint16_t data)
{
mnl_attr_put(nlh, type, sizeof(uint16_t), &data);
diff --git a/wireguard-tools/src/set.c b/wireguard-tools/src/set.c
index 75560fd8cf62..b65d69912ae7 100644
--- a/wireguard-tools/src/set.c
+++ b/wireguard-tools/src/set.c
@@ -18,7 +18,7 @@ int set_main(int argc, const char *argv[])
int ret = 1;
if (argc < 3) {
- fprintf(stderr, "Usage: %s %s <interface> [listen-port <port>] [fwmark <mark>] [private-key <file path>] [peer <base64 public key> [remove] [preshared-key <file path>] [endpoint <ip>:<port>] [persistent-keepalive <interval seconds>] [allowed-ips <ip1>/<cidr1>[,<ip2>/<cidr2>]...] ]...\n", PROG_NAME, argv[0]);
+ fprintf(stderr, "Usage: %s %s <interface> [listen-port <port>] [fwmark <mark>] [private-key <file path>] [peer <base64 public key> [remove] [preshared-key <file path>] [endpoint <ip>:<port>] [persistent-keepalive <interval seconds>] [allowed-ips <ip1>/<cidr1>[,<ip2>/<cidr2>]...] [peertransportmode tcp/udp] [transportmode tcp/udp]...\n", PROG_NAME, argv[0]);
return 1;
}
diff --git a/wireguard-tools/src/show.c b/wireguard-tools/src/show.c
index 13777cf04280..3f8c49afad7c 100644
--- a/wireguard-tools/src/show.c
+++ b/wireguard-tools/src/show.c
@@ -126,6 +126,15 @@ static char *endpoint(const struct sockaddr *addr)
return buf;
}
+static char *transport(const uint32_t transport_val)
+{
+ if (transport_val == WG_TRANSPORT_UDP)
+ return "udp";
+ else if (transport_val == WG_TRANSPORT_TCP)
+ return "tcp";
+ return "unknown";
+}
+
static size_t pretty_time(char *buf, const size_t len, unsigned long long left)
{
size_t offset = 0;
@@ -220,6 +229,8 @@ static void pretty_print(struct wgdevice *device)
terminal_printf(" " TERMINAL_BOLD "listening port" TERMINAL_RESET ": %u\n", device->listen_port);
if (device->fwmark)
terminal_printf(" " TERMINAL_BOLD "fwmark" TERMINAL_RESET ": 0x%x\n", device->fwmark);
+ if (device->transport)
+ terminal_printf(" " TERMINAL_BOLD "transport" TERMINAL_RESET ": %s\n", transport(device->transport));
if (device->first_peer) {
sort_peers(device);
terminal_printf("\n");
@@ -245,6 +256,8 @@ static void pretty_print(struct wgdevice *device)
}
if (peer->persistent_keepalive_interval)
terminal_printf(" " TERMINAL_BOLD "persistent keepalive" TERMINAL_RESET ": %s\n", every(peer->persistent_keepalive_interval));
+ if (peer->transport)
+ terminal_printf(" " TERMINAL_BOLD "transport" TERMINAL_RESET ": %s\n", transport(peer->transport));
if (peer->next_peer)
terminal_printf("\n");
}
diff --git a/wireguard-tools/src/showconf.c b/wireguard-tools/src/showconf.c
index 62070dc27af2..ea5d49807c84 100644
--- a/wireguard-tools/src/showconf.c
+++ b/wireguard-tools/src/showconf.c
@@ -46,6 +46,9 @@ int showconf_main(int argc, const char *argv[])
key_to_base64(base64, device->private_key);
printf("PrivateKey = %s\n", base64);
}
+ if (device->transport) {
+ printf("TransportMode = %s\n", (device->transport == WG_TRANSPORT_TCP ? "tcp" : "udp"));
+ }
printf("\n");
for_each_wgpeer(device, peer) {
key_to_base64(base64, peer->public_key);
@@ -91,6 +94,8 @@ int showconf_main(int argc, const char *argv[])
if (peer->persistent_keepalive_interval)
printf("PersistentKeepalive = %u\n", peer->persistent_keepalive_interval);
+ if (peer->transport)
+ printf("PeerTransportMode = %s\n", (peer->transport == WG_TRANSPORT_TCP ? "tcp" : "udp"));
if (peer->next_peer)
printf("\n");
diff --git a/wireguard-tools/src/uapi/linux/linux/wireguard.h b/wireguard-tools/src/uapi/linux/linux/wireguard.h
index 0efd52c3687d..674b069f8f58 100644
--- a/wireguard-tools/src/uapi/linux/linux/wireguard.h
+++ b/wireguard-tools/src/uapi/linux/linux/wireguard.h
@@ -157,6 +157,7 @@ enum wgdevice_attribute {
WGDEVICE_A_LISTEN_PORT,
WGDEVICE_A_FWMARK,
WGDEVICE_A_PEERS,
+ WGDEVICE_A_TRANSPORT_MODE,
__WGDEVICE_A_LAST
};
#define WGDEVICE_A_MAX (__WGDEVICE_A_LAST - 1)
@@ -180,6 +181,7 @@ enum wgpeer_attribute {
WGPEER_A_TX_BYTES,
WGPEER_A_ALLOWEDIPS,
WGPEER_A_PROTOCOL_VERSION,
+ WGPEER_A_TRANSPORT_MODE,
__WGPEER_A_LAST
};
#define WGPEER_A_MAX (__WGPEER_A_LAST - 1)