From 0f9ac2ac879bb5c36355c7fee266e09be0c3c9a4 Mon Sep 17 00:00:00 2001 From: lbl8603 <49143209+lbl8603@users.noreply.github.com> Date: Mon, 5 Aug 2024 20:44:54 +0800 Subject: [PATCH] =?UTF-8?q?NAT1=E4=B8=8B=E7=9A=84tcp=E6=89=93=E6=B4=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- vnt/proto/message.proto | 12 ++++ vnt/src/channel/context.rs | 100 ++++++++++++--------------- vnt/src/channel/mod.rs | 2 +- vnt/src/channel/punch.rs | 93 +++++++++++++++++++++---- vnt/src/channel/sender.rs | 12 +++- vnt/src/channel/socket/mod.rs | 22 +++++- vnt/src/channel/tcp_channel.rs | 51 +++++++++----- vnt/src/channel/ws_channel.rs | 15 ++-- vnt/src/core/conn.rs | 1 + vnt/src/handle/maintain/heartbeat.rs | 12 +++- vnt/src/handle/maintain/punch.rs | 4 +- vnt/src/handle/mod.rs | 6 +- vnt/src/handle/recv_data/client.rs | 26 +++---- vnt/src/handle/recv_data/server.rs | 13 +++- vnt/src/nat/mod.rs | 28 +++----- vnt/tun/src/windows/tun/mod.rs | 1 + 16 files changed, 260 insertions(+), 138 deletions(-) diff --git a/vnt/proto/message.proto b/vnt/proto/message.proto index 197c8848..948a1238 100644 --- a/vnt/proto/message.proto +++ b/vnt/proto/message.proto @@ -64,11 +64,23 @@ message PunchInfo { uint32 tcp_port = 11; repeated uint32 udp_ports = 12; repeated uint32 public_ports = 13; + uint32 public_tcp_port = 14; + PunchNatModel punch_model = 15; } enum PunchNatType { Symmetric = 0; Cone = 1; } +enum PunchNatModel { + All = 0; + IPv4 = 1; + IPv6 = 2; + IPv4Tcp = 3; + IPv4Udp = 4; + IPv6Tcp = 5; + IPv6Udp = 6; +} + /// 向服务器上报客户端状态信息 message ClientStatusInfo { fixed32 source = 1; diff --git a/vnt/src/channel/context.rs b/vnt/src/channel/context.rs index 505df6f9..889c106c 100644 --- a/vnt/src/channel/context.rs +++ b/vnt/src/channel/context.rs @@ -60,6 +60,7 @@ impl ChannelContext { up_traffic_meter, down_traffic_meter, default_interface, + default_route_key: AtomicCell::default(), }; Self { inner: Arc::new(inner), @@ -86,7 +87,7 @@ pub struct ContextInner { // 对称网络增加的udp socket sub_udp_socket: RwLock>, // tcp数据发送器 - pub(crate) packet_map: RwLock>, + pub(crate) packet_map: RwLock>, // 路由信息 pub route_table: RouteTable, // 使用什么协议连接服务器 @@ -98,6 +99,7 @@ pub struct ContextInner { pub(crate) up_traffic_meter: Option, pub(crate) down_traffic_meter: Option, default_interface: LocalInterface, + default_route_key: AtomicCell>, } impl ContextInner { @@ -107,6 +109,9 @@ impl ContextInner { pub fn default_interface(&self) -> &LocalInterface { &self.default_interface } + pub fn set_default_route_key(&self, route_key: RouteKey) { + self.default_route_key.store(Some(route_key)); + } /// 通过sub_udp_socket是否为空来判断是否为锥形网络 pub fn is_cone(&self) -> bool { self.sub_udp_socket.read().is_empty() @@ -175,11 +180,14 @@ impl ContextInner { } Ok(ports) } - pub fn send_tcp(&self, buf: &[u8], addr: SocketAddr) -> io::Result<()> { - if let Some(tcp) = self.packet_map.read().get(&addr) { + pub fn send_tcp(&self, buf: &[u8], route_key: &RouteKey) -> io::Result<()> { + if let Some(tcp) = self.packet_map.read().get(route_key) { tcp.try_send(buf) } else { - Err(io::Error::from(io::ErrorKind::NotFound)) + Err(io::Error::new( + io::ErrorKind::NotFound, + format!("dest={:?}", route_key), + )) } } pub fn send_main_udp(&self, index: usize, buf: &[u8], addr: SocketAddr) -> io::Result<()> { @@ -203,7 +211,14 @@ impl ContextInner { self.send_main_udp(self.v4_len, buf.buffer(), addr)? } } else { - self.send_tcp(buf.buffer(), addr)? + if let Some(key) = self.default_route_key.load() { + self.send_tcp(buf.buffer(), &key)? + } else { + return Err(io::Error::new( + io::ErrorKind::NotFound, + format!("dest={:?}", addr), + )); + } } if let Some(up_traffic_meter) = &self.up_traffic_meter { up_traffic_meter.add_traffic(buf.destination(), buf.data_len()); @@ -300,7 +315,7 @@ impl ContextInner { } } ConnectProtocol::TCP | ConnectProtocol::WS | ConnectProtocol::WSS => { - self.send_tcp(buf.buffer(), route_key.addr)? + self.send_tcp(buf.buffer(), &route_key)? } } if let Some(up_traffic_meter) = &self.up_traffic_meter { @@ -376,19 +391,11 @@ impl RouteTable { let key = route.route_key(); if only_if_absent { if let Some((_, list)) = self.route_table.read().get(&id) { - let mut p2p_num = 0; for (x, _) in list { - if x.is_p2p() { - p2p_num += 1; - } if x.route_key() == key { return true; } } - if !self.first_latency && p2p_num >= self.channel_num { - // 非优先延迟的情况下,通道满了则不用再添加 - return false; - } } } let mut route_table = self.route_table.write(); @@ -413,61 +420,42 @@ impl RouteTable { } } if exist { - // 这个排序还有待优化,因为后加入的大概率排最后,被直接淘汰的概率也大,可能导致更好的通道被移除了 list.sort_by_key(|(k, _)| k.rt); - //如果延迟都稳定了,则去除多余通道 - for (route, _) in list.iter() { - if route.rt == DEFAULT_RT { - return true; - } - } - //延迟优先模式需要更多的通道探测延迟最低的路线 - let limit_len = if self.first_latency { - self.channel_num + 2 - } else { - self.channel_num - }; - self.truncate_(list, limit_len); } else { if !self.first_latency { if route.is_p2p() { //非优先延迟的情况下 添加了直连的则排除非直连的 list.retain(|(k, _)| k.is_p2p()); } - if self.channel_num <= list.len() { - return false; - } }; - //增加路由表容量,避免波动 - let limit_len = self.channel_num * 2; list.sort_by_key(|(k, _)| k.rt); - self.truncate_(list, limit_len); list.push((route, AtomicCell::new(Instant::now()))); } return true; } - fn truncate_(&self, list: &mut Vec<(Route, AtomicCell)>, len: usize) { - if list.len() <= len { - return; - } - if self.first_latency { - //找到第一个p2p通道 - if let Some(index) = - list.iter() - .enumerate() - .find_map(|(index, (route, _))| if route.is_p2p() { Some(index) } else { None }) - { - if index >= len { - //保留第一个p2p通道 - let route = list.remove(index); - list.truncate(len - 1); - list.push(route); - return; - } - } - } - list.truncate(len); - } + // 直接移除会导致通道不稳定,所以废弃这个方法,后面改用多余通道不发心跳包,从而让通道自动过期 + // fn truncate_(&self, list: &mut Vec<(Route, AtomicCell)>, len: usize) { + // if list.len() <= len { + // return; + // } + // if self.first_latency { + // //找到第一个p2p通道 + // if let Some(index) = + // list.iter() + // .enumerate() + // .find_map(|(index, (route, _))| if route.is_p2p() { Some(index) } else { None }) + // { + // if index >= len { + // //保留第一个p2p通道 + // let route = list.remove(index); + // list.truncate(len - 1); + // list.push(route); + // return; + // } + // } + // } + // list.truncate(len); + // } pub fn route(&self, id: &Ipv4Addr) -> Option> { if let Some((_, v)) = self.route_table.read().get(id) { Some(v.iter().map(|(i, _)| *i).collect()) diff --git a/vnt/src/channel/mod.rs b/vnt/src/channel/mod.rs index 8cc1f3b9..19b15257 100644 --- a/vnt/src/channel/mod.rs +++ b/vnt/src/channel/mod.rs @@ -179,7 +179,7 @@ pub struct RouteKey { } impl RouteKey { - pub(crate) fn new(protocol: ConnectProtocol, index: usize, addr: SocketAddr) -> Self { + pub(crate) const fn new(protocol: ConnectProtocol, index: usize, addr: SocketAddr) -> Self { Self { protocol, index, diff --git a/vnt/src/channel/punch.rs b/vnt/src/channel/punch.rs index 76e476cd..3775134a 100644 --- a/vnt/src/channel/punch.rs +++ b/vnt/src/channel/punch.rs @@ -14,16 +14,17 @@ use crate::channel::context::ChannelContext; use crate::channel::sender::ConnectUtil; use crate::handle::CurrentDeviceInfo; use crate::nat::{is_ipv4_global, NatTest}; +use crate::proto::message::{PunchNatModel, PunchNatType}; #[derive(Copy, Clone, Eq, PartialEq, Debug)] pub enum PunchModel { + All, IPv4, IPv6, IPv4Tcp, IPv4Udp, IPv6Tcp, IPv6Udp, - All, } impl PunchModel { @@ -72,6 +73,33 @@ impl Default for PunchModel { PunchModel::All } } +impl From for PunchNatModel { + fn from(value: PunchModel) -> Self { + match value { + PunchModel::All => PunchNatModel::All, + PunchModel::IPv4 => PunchNatModel::IPv4, + PunchModel::IPv6 => PunchNatModel::IPv6, + PunchModel::IPv4Tcp => PunchNatModel::IPv4Tcp, + PunchModel::IPv4Udp => PunchNatModel::IPv4Udp, + PunchModel::IPv6Tcp => PunchNatModel::IPv6Tcp, + PunchModel::IPv6Udp => PunchNatModel::IPv6Udp, + } + } +} + +impl Into for PunchNatModel { + fn into(self) -> PunchModel { + match self { + PunchNatModel::All => PunchModel::All, + PunchNatModel::IPv4 => PunchModel::IPv4, + PunchNatModel::IPv6 => PunchModel::IPv6, + PunchNatModel::IPv4Tcp => PunchModel::IPv4Tcp, + PunchNatModel::IPv4Udp => PunchModel::IPv4Udp, + PunchNatModel::IPv6Tcp => PunchModel::IPv6Tcp, + PunchNatModel::IPv6Udp => PunchModel::IPv6Udp, + } + } +} #[derive(Clone, Debug)] pub struct NatInfo { @@ -83,6 +111,8 @@ pub struct NatInfo { pub(crate) ipv6: Option, pub udp_ports: Vec, pub tcp_port: u16, + pub public_tcp_port: u16, + pub punch_model: PunchModel, } #[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] @@ -91,6 +121,29 @@ pub enum NatType { Cone, } +impl NatType { + pub fn is_cone(&self) -> bool { + self == &NatType::Cone + } +} +impl From for PunchNatType { + fn from(value: NatType) -> Self { + match value { + NatType::Symmetric => PunchNatType::Symmetric, + NatType::Cone => PunchNatType::Cone, + } + } +} + +impl Into for PunchNatType { + fn into(self) -> NatType { + match self { + PunchNatType::Symmetric => NatType::Symmetric, + PunchNatType::Cone => NatType::Cone, + } + } +} + impl NatInfo { pub fn new( mut public_ips: Vec, @@ -100,7 +153,9 @@ impl NatInfo { mut ipv6: Option, udp_ports: Vec, tcp_port: u16, + public_tcp_port: u16, mut nat_type: NatType, + punch_model: PunchModel, ) -> Self { public_ips.retain(|ip| { !ip.is_multicast() @@ -130,7 +185,9 @@ impl NatInfo { ipv6, udp_ports, tcp_port, + public_tcp_port, nat_type, + punch_model, } } pub fn update_addr(&mut self, index: usize, ip: Ipv4Addr, port: u16) -> bool { @@ -144,7 +201,7 @@ impl NatInfo { *public_port = port; } } - if crate::nat::is_ipv4_global(&ip) { + if is_ipv4_global(&ip) { if !self.public_ips.contains(&ip) { self.public_ips.push(ip); updated = true; @@ -153,6 +210,9 @@ impl NatInfo { } updated } + pub fn update_tcp_port(&mut self, port: u16) { + self.public_tcp_port = port; + } pub fn local_ipv4(&self) -> Option { self.local_ipv4 } @@ -252,7 +312,10 @@ impl Punch { if self.nat_test.is_local_address(true, addr) { return; } - self.connect_util.try_connect_tcp(buf.to_vec(), addr); + if addr.ip().is_unspecified() || addr.port() == 0 { + return; + } + self.connect_util.try_connect_tcp_punch(buf.to_vec(), addr); } pub fn punch( &mut self, @@ -277,44 +340,46 @@ impl Punch { nat_info.local_ipv4 = nat_info .local_ipv4 .filter(|ip| device_info.not_in_network(*ip)); - - if punch_tcp && self.punch_model.use_tcp() && nat_info.tcp_port != 0 { + if punch_tcp && self.punch_model.use_tcp() && nat_info.punch_model.use_tcp() { //向tcp发起连接 - if self.punch_model.use_ipv6() { + if self.punch_model.use_ipv6() && nat_info.punch_model.use_ipv6() { if let Some(ipv6_addr) = nat_info.local_tcp_ipv6addr() { self.connect_tcp(buf, ipv6_addr) } } - if self.punch_model.use_ipv4() { + if self.punch_model.use_ipv4() && nat_info.punch_model.use_ipv4() { if let Some(ipv4_addr) = nat_info.local_tcp_ipv4addr() { self.connect_tcp(buf, ipv4_addr) } for ip in &nat_info.public_ips { let addr = SocketAddr::V4(SocketAddrV4::new(*ip, nat_info.tcp_port)); - self.connect_tcp(buf, addr) + self.connect_tcp(buf, addr); + } + if nat_info.nat_type.is_cone() && nat_info.public_tcp_port != 0 { + for ip in &nat_info.public_ips { + let addr = SocketAddr::V4(SocketAddrV4::new(*ip, nat_info.public_tcp_port)); + self.connect_tcp(buf, addr); + } } } } - if !self.punch_model.use_udp() { + if !self.punch_model.use_udp() || !nat_info.punch_model.use_udp() { return Ok(()); } let channel_num = self.context.channel_num(); let main_len = self.context.main_len(); - if self.punch_model.use_ipv6() { + if self.punch_model.use_ipv6() && nat_info.punch_model.use_ipv6() { for index in channel_num..main_len { if let Some(ipv6_addr) = nat_info.local_udp_ipv6addr(index) { if !self.nat_test.is_local_address(false, ipv6_addr) { let rs = self.context.send_main_udp(index, buf, ipv6_addr); log::info!("发送到ipv6地址:{:?},rs={:?} {}", ipv6_addr, rs, id); - if rs.is_ok() && self.punch_model == PunchModel::IPv6 { - return Ok(()); - } } } } } - if !self.punch_model.use_ipv4() { + if !self.punch_model.use_ipv4() || !nat_info.punch_model.use_ipv4() { return Ok(()); } for index in 0..channel_num { diff --git a/vnt/src/channel/sender.rs b/vnt/src/channel/sender.rs index 5cfa27cb..00255c80 100644 --- a/vnt/src/channel/sender.rs +++ b/vnt/src/channel/sender.rs @@ -239,13 +239,13 @@ impl PacketSender { #[derive(Clone)] pub struct ConnectUtil { - connect_tcp: Sender<(Vec, SocketAddr)>, + connect_tcp: Sender<(Vec, Option, SocketAddr)>, connect_ws: Sender<(Vec, String)>, } impl ConnectUtil { pub fn new( - connect_tcp: Sender<(Vec, SocketAddr)>, + connect_tcp: Sender<(Vec, Option, SocketAddr)>, connect_ws: Sender<(Vec, String)>, ) -> Self { Self { @@ -254,7 +254,13 @@ impl ConnectUtil { } } pub fn try_connect_tcp(&self, buf: Vec, addr: SocketAddr) { - if self.connect_tcp.try_send((buf, addr)).is_err() { + if self.connect_tcp.try_send((buf, None, addr)).is_err() { + log::warn!("try_connect_tcp failed {}", addr); + } + } + pub fn try_connect_tcp_punch(&self, buf: Vec, addr: SocketAddr) { + // 打洞的连接可以绑定随机端口 + if self.connect_tcp.try_send((buf, Some(0), addr)).is_err() { log::warn!("try_connect_tcp failed {}", addr); } } diff --git a/vnt/src/channel/socket/mod.rs b/vnt/src/channel/socket/mod.rs index 14052c23..338702dc 100644 --- a/vnt/src/channel/socket/mod.rs +++ b/vnt/src/channel/socket/mod.rs @@ -27,15 +27,22 @@ pub struct LocalInterface { pub async fn connect_tcp( addr: SocketAddr, + bind_port: u16, default_interface: &LocalInterface, ) -> anyhow::Result { - let socket = create_tcp(addr.is_ipv4(), default_interface)?; + let socket = create_tcp0(addr.is_ipv4(), bind_port, default_interface)?; Ok(socket.connect(addr).await?) } - pub fn create_tcp( v4: bool, default_interface: &LocalInterface, +) -> anyhow::Result { + create_tcp0(v4, 0, default_interface) +} +pub fn create_tcp0( + v4: bool, + bind_port: u16, + default_interface: &LocalInterface, ) -> anyhow::Result { let socket = if v4 { socket2::Socket::new( @@ -53,6 +60,17 @@ pub fn create_tcp( if v4 { socket.set_ip_unicast_if(default_interface)?; } + if bind_port != 0 { + socket.set_reuse_address(true)?; + if v4 { + let addr: SocketAddr = format!("0.0.0.0:{}", bind_port).parse().unwrap(); + socket.bind(&addr.into())?; + } else { + socket.set_only_v6(true)?; + let addr: SocketAddr = format!("[::]:{}", bind_port).parse().unwrap(); + socket.bind(&addr.into())?; + } + } socket.set_nonblocking(true)?; socket.set_nodelay(true)?; Ok(tokio::net::TcpSocket::from_std_stream(socket.into())) diff --git a/vnt/src/channel/tcp_channel.rs b/vnt/src/channel/tcp_channel.rs index 71aadd6b..679616c0 100644 --- a/vnt/src/channel/tcp_channel.rs +++ b/vnt/src/channel/tcp_channel.rs @@ -1,5 +1,9 @@ use anyhow::{anyhow, Context}; use std::net::SocketAddr; +#[cfg(unix)] +use std::os::unix::io::AsRawFd; +#[cfg(windows)] +use std::os::windows::io::AsRawSocket; use std::thread; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -16,7 +20,7 @@ use crate::util::StopManager; /// 监听tcp端口,等待客户端连接 pub fn tcp_listen( tcp_server: std::net::TcpListener, - receiver: Receiver<(Vec, SocketAddr)>, + receiver: Receiver<(Vec, Option, SocketAddr)>, recv_handler: H, context: ChannelContext, stop_manager: StopManager, @@ -28,6 +32,7 @@ where let worker = stop_manager.add_listener("tcpChannel".into(), move || { let _ = stop_sender.send(()); })?; + let bind_port = tcp_server.local_addr()?.port(); let runtime = tokio::runtime::Builder::new_multi_thread() .worker_threads(2) .enable_all() @@ -46,9 +51,9 @@ where } }); } - tokio::spawn( - async move { connect_tcp_handle(receiver, recv_handler, context).await }, - ); + tokio::spawn(async move { + connect_tcp_handle(receiver, recv_handler, context, bind_port).await + }); }); runtime.block_on(async { let _ = stop_receiver.await; @@ -61,17 +66,23 @@ where } async fn connect_tcp_handle( - mut receiver: Receiver<(Vec, SocketAddr)>, + mut receiver: Receiver<(Vec, Option, SocketAddr)>, recv_handler: H, context: ChannelContext, + listener_bind_port: u16, ) where H: RecvChannelHandler, { - while let Some((data, addr)) = receiver.recv().await { + while let Some((data, bind_port, addr)) = receiver.recv().await { let recv_handler = recv_handler.clone(); let context = context.clone(); + let bind_port = if let Some(bind_port) = bind_port { + bind_port + } else { + listener_bind_port + }; tokio::spawn(async move { - if let Err(e) = connect_tcp0(data, addr, recv_handler, context).await { + if let Err(e) = connect_tcp0(data, addr, recv_handler, context, bind_port).await { log::warn!("发送失败,链接终止:{:?},{:?}", addr, e); } }); @@ -83,13 +94,14 @@ async fn connect_tcp0( addr: SocketAddr, recv_handler: H, context: ChannelContext, + bind_port: u16, ) -> anyhow::Result<()> where H: RecvChannelHandler, { let mut stream = tokio::time::timeout( Duration::from_secs(3), - crate::channel::socket::connect_tcp(addr, context.default_interface()), + crate::channel::socket::connect_tcp(addr, bind_port, context.default_interface()), ) .await??; tcp_write(&mut stream, &data).await?; @@ -110,6 +122,7 @@ where loop { let (stream, addr) = tcp_server.accept().await?; + tcp_stream_handle(stream, addr, recv_handler.clone(), context.clone()).await; } } @@ -123,12 +136,18 @@ pub async fn tcp_stream_handle( H: RecvChannelHandler, { let _ = stream.set_nodelay(true); + let local = stream.local_addr(); + #[cfg(windows)] + let index = stream.as_raw_socket() as usize; + #[cfg(unix)] + let index = stream.as_raw_fd() as usize; + let route_key = RouteKey::new(ConnectProtocol::TCP, index, addr); let (r, mut w) = stream.into_split(); let (sender, mut receiver) = channel::>(100); context .packet_map .write() - .insert(addr, PacketSender::new(sender)); + .insert(route_key, PacketSender::new(sender)); tokio::spawn(async move { while let Some(data) = receiver.recv().await { if let Err(e) = tcp_write(&mut w, &data).await { @@ -139,10 +158,10 @@ pub async fn tcp_stream_handle( let _ = w.shutdown().await; }); tokio::spawn(async move { - if let Err(e) = tcp_read(r, addr, &context, recv_handler).await { - log::warn!("tcp_read {:?}", e) + if let Err(e) = tcp_read(r, addr, &context, recv_handler, route_key).await { + log::warn!("tcp_read {:?} {local:?}-{addr}", e) } - context.packet_map.write().remove(&addr); + context.packet_map.write().remove(&route_key); }); } @@ -162,6 +181,7 @@ async fn tcp_read( addr: SocketAddr, context: &ChannelContext, recv_handler: H, + route_key: RouteKey, ) -> anyhow::Result<()> where H: RecvChannelHandler, @@ -179,11 +199,6 @@ where return Err(anyhow!("tcp数据长度无效 {}", addr)); } read.read_exact(&mut buf[..len]).await?; - recv_handler.handle( - &mut buf[..len], - &mut extend, - RouteKey::new(ConnectProtocol::TCP, 0, addr), - context, - ); + recv_handler.handle(&mut buf[..len], &mut extend, route_key, context); } } diff --git a/vnt/src/channel/ws_channel.rs b/vnt/src/channel/ws_channel.rs index f7862628..bb7409aa 100644 --- a/vnt/src/channel/ws_channel.rs +++ b/vnt/src/channel/ws_channel.rs @@ -58,14 +58,16 @@ async fn connect_ws_handle( ) where H: RecvChannelHandler, { + let mut index = 0; while let Some((data, url)) = receiver.recv().await { let recv_handler = recv_handler.clone(); let context = context.clone(); tokio::spawn(async move { - if let Err(e) = connect_ws(data, url, recv_handler, context).await { + if let Err(e) = connect_ws(data, url, recv_handler, context, index).await { log::warn!("发送失败,ws链接终止:{:?}", e); } }); + index += 1; } } const WS_ADDR: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); @@ -75,6 +77,7 @@ async fn connect_ws( mut url: String, recv_handler: H, context: ChannelContext, + index: usize, ) -> anyhow::Result<()> where H: RecvChannelHandler, @@ -114,10 +117,12 @@ where ws.send(Message::Binary(data)).await?; let (mut ws_write, ws_read) = ws.split(); let (sender, mut receiver) = channel::>(100); + let route_key = RouteKey::new(ConnectProtocol::WS, index, WS_ADDR); + context .packet_map .write() - .insert(WS_ADDR, PacketSender::new(sender)); + .insert(route_key, PacketSender::new(sender)); tokio::spawn(async move { while let Some(data) = receiver.recv().await { if let Err(e) = ws_write.send(Message::Binary(data)).await { @@ -127,22 +132,22 @@ where } let _ = ws_write.close().await; }); - if let Err(e) = ws_read_handle(ws_read, recv_handler, &context).await { + if let Err(e) = ws_read_handle(ws_read, recv_handler, &context, route_key).await { log::warn!("{:?}", e); } - context.packet_map.write().remove(&WS_ADDR); + context.packet_map.write().remove(&route_key); Ok(()) } async fn ws_read_handle( mut ws_read: SplitStream>>, recv_handler: H, context: &ChannelContext, + route_key: RouteKey, ) -> anyhow::Result<()> where H: RecvChannelHandler, { let mut extend = [0; BUFFER_SIZE]; - let route_key = RouteKey::new(ConnectProtocol::WS, 0, WS_ADDR); while let Some(msg) = ws_read.next().await { let msg = msg.context("Error during WebSocket ")?; match msg { diff --git a/vnt/src/core/conn.rs b/vnt/src/core/conn.rs index 248f8a86..a2e749d9 100644 --- a/vnt/src/core/conn.rs +++ b/vnt/src/core/conn.rs @@ -215,6 +215,7 @@ impl VntInner { udp_ports, tcp_port, config.local_ipv4.is_none(), + config.punch_model, ); // 定时器 let scheduler = Scheduler::new(stop_manager.clone())?; diff --git a/vnt/src/handle/maintain/heartbeat.rs b/vnt/src/handle/maintain/heartbeat.rs index d5e21eed..d58fe888 100644 --- a/vnt/src/handle/maintain/heartbeat.rs +++ b/vnt/src/handle/maintain/heartbeat.rs @@ -56,6 +56,7 @@ fn heartbeat0( ) { let gateway_ip = current_device.virtual_gateway; let src_ip = current_device.virtual_ip; + let channel_num = context.channel_num(); // 可能服务器ip发生变化,导致发送失败 let mut is_send_gateway = false; match heartbeat_packet_server(device_map, server_cipher, src_ip, gateway_ip) { @@ -78,6 +79,10 @@ fn heartbeat0( } heartbeat_packet_server(device_map, server_cipher, src_ip, gateway_ip) } else { + if dest_ip < src_ip { + // 只向比自己大的发 + continue; + } heartbeat_packet_client(client_cipher, src_ip, dest_ip) }; let net_packet = match net_packet { @@ -87,7 +92,12 @@ fn heartbeat0( continue; } }; - for route in routes { + for (index, route) in routes.iter().enumerate() { + if index >= channel_num + 1 { + // 多余的通道不再发送心跳包,让它自动过期 + // 这里多留一个增加稳定性 + break; + } if let Err(e) = context.send_by_key(&net_packet, route.route_key()) { log::warn!("heartbeat err={:?}", e) } diff --git a/vnt/src/handle/maintain/punch.rs b/vnt/src/handle/maintain/punch.rs index 98ded7a4..cef4039f 100644 --- a/vnt/src/handle/maintain/punch.rs +++ b/vnt/src/handle/maintain/punch.rs @@ -149,7 +149,7 @@ fn punch_start( *v += 1; *v } else { - guard.insert(peer_ip, 1); + guard.insert(peer_ip, 0); 0 } }; @@ -326,6 +326,7 @@ fn punch_packet( punch_reply.public_port = nat_info.public_ports.get(0).map_or(0, |v| *v as u32); punch_reply.public_ports = nat_info.public_ports.iter().map(|e| *e as u32).collect(); punch_reply.public_port_range = nat_info.public_port_range as u32; + punch_reply.public_tcp_port = nat_info.public_tcp_port as u32; punch_reply.local_ip = u32::from(nat_info.local_ipv4().unwrap_or(Ipv4Addr::UNSPECIFIED)); punch_reply.local_port = nat_info.udp_ports[0] as u32; punch_reply.tcp_port = nat_info.tcp_port as u32; @@ -335,6 +336,7 @@ fn punch_packet( punch_reply.ipv6 = ipv6.octets().to_vec(); } punch_reply.nat_type = protobuf::EnumOrUnknown::new(PunchNatType::from(nat_info.nat_type)); + punch_reply.punch_model = protobuf::EnumOrUnknown::new(nat_info.punch_model.into()); log::info!("请求打洞={:?}", punch_reply); let bytes = punch_reply .write_to_bytes() diff --git a/vnt/src/handle/mod.rs b/vnt/src/handle/mod.rs index ea517328..2fec003f 100644 --- a/vnt/src/handle/mod.rs +++ b/vnt/src/handle/mod.rs @@ -223,10 +223,10 @@ impl CurrentDeviceInfo { virtual_gateway: Ipv4Addr, ) { let broadcast_ip = (!u32::from_be_bytes(virtual_netmask.octets())) - | u32::from_be_bytes(virtual_gateway.octets()); + | u32::from_be_bytes(virtual_ip.octets()); let broadcast_ip = Ipv4Addr::from(broadcast_ip); - let virtual_network = u32::from_be_bytes(virtual_netmask.octets()) - & u32::from_be_bytes(virtual_gateway.octets()); + let virtual_network = + u32::from_be_bytes(virtual_netmask.octets()) & u32::from_be_bytes(virtual_ip.octets()); let virtual_network = Ipv4Addr::from(virtual_network); self.virtual_ip = virtual_ip; self.virtual_netmask = virtual_netmask; diff --git a/vnt/src/handle/recv_data/client.rs b/vnt/src/handle/recv_data/client.rs index fc617216..e658ad7f 100644 --- a/vnt/src/handle/recv_data/client.rs +++ b/vnt/src/handle/recv_data/client.rs @@ -216,17 +216,13 @@ impl ClientPacketHandler { match ControlPacket::new(net_packet.transport_protocol(), net_packet.payload())? { ControlPacket::PingPacket(_) => { let route = Route::from_default_rt(route_key, metric); - if context.route_table.add_route_if_absent(source, route) - || net_packet.source() < current_device.virtual_ip - { - //在路由表中,或者来源比自己小,就需要回复,注意不能调换顺序 - net_packet.set_transport_protocol(control_packet::Protocol::Pong.into()); - net_packet.set_source(current_device.virtual_ip); - net_packet.set_destination(source); - net_packet.first_set_ttl(MAX_TTL); - self.client_cipher.encrypt_ipv4(&mut net_packet)?; - context.send_by_key(&net_packet, route_key)?; - } + context.route_table.add_route_if_absent(source, route); + net_packet.set_transport_protocol(control_packet::Protocol::Pong.into()); + net_packet.set_source(current_device.virtual_ip); + net_packet.set_destination(source); + net_packet.first_set_ttl(MAX_TTL); + self.client_cipher.encrypt_ipv4(&mut net_packet)?; + context.send_by_key(&net_packet, route_key)?; } ControlPacket::PongPacket(pong_packet) => { let current_time = crate::handle::now_time() as u16; @@ -272,7 +268,7 @@ impl ClientPacketHandler { { return Ok(()); } - let route = Route::from_default_rt(route_key, 1); + let route = Route::from_default_rt(route_key, metric); context.route_table.add_route_if_absent(source, route); } ControlPacket::AddrRequest => match route_key.addr.ip() { @@ -318,6 +314,7 @@ impl ClientPacketHandler { .collect(); let local_ipv4 = Some(Ipv4Addr::from(punch_info.local_ip.to_be_bytes())); let tcp_port = punch_info.tcp_port as u16; + let public_tcp_port = punch_info.public_tcp_port as u16; let ipv6 = if punch_info.ipv6.len() == 16 { let ipv6: [u8; 16] = punch_info.ipv6.try_into().unwrap(); Some(Ipv6Addr::from(ipv6)) @@ -340,7 +337,9 @@ impl ClientPacketHandler { ipv6, punch_info.udp_ports.iter().map(|e| *e as u16).collect(), tcp_port, + public_tcp_port, punch_info.nat_type.enum_value_or_default().into(), + punch_info.punch_model.enum_value_or_default().into(), ); { let peer_nat_info = peer_nat_info.clone(); @@ -360,8 +359,11 @@ impl ClientPacketHandler { nat_info.public_ports.iter().map(|e| *e as u32).collect(); punch_reply.public_port_range = nat_info.public_port_range as u32; punch_reply.tcp_port = nat_info.tcp_port as u32; + punch_reply.public_tcp_port = nat_info.public_tcp_port as u32; punch_reply.nat_type = protobuf::EnumOrUnknown::new(PunchNatType::from(nat_info.nat_type)); + punch_reply.punch_model = + protobuf::EnumOrUnknown::new(nat_info.punch_model.into()); punch_reply.local_ip = u32::from(nat_info.local_ipv4().unwrap_or(Ipv4Addr::UNSPECIFIED)); punch_reply.local_port = nat_info.udp_ports[0] as u32; diff --git a/vnt/src/handle/recv_data/server.rs b/vnt/src/handle/recv_data/server.rs index a9a43806..5b9a7329 100644 --- a/vnt/src/handle/recv_data/server.rs +++ b/vnt/src/handle/recv_data/server.rs @@ -151,6 +151,8 @@ impl PacketHandler for ServerPacketHandl let response = HandshakeResponse::parse_from_bytes(net_packet.payload()) .map_err(|e| anyhow!("HandshakeResponse {:?}", e))?; log::info!("握手响应:{:?},{}", route_key, response); + //设置为默认通道 + context.set_default_route_key(route_key); //如果开启了加密,则发送加密握手请求 #[cfg(feature = "server_encrypt")] if let Some(key) = self.server_cipher.key() { @@ -211,7 +213,7 @@ impl PacketHandler for ServerPacketHandl let handshake_info = HandshakeInfo::new_no_secret(response.version); if self.callback.handshake(handshake_info) { //没有加密,则发送注册请求 - self.register(current_device, context)?; + self.register(current_device, context, route_key)?; } return Ok(()); @@ -295,6 +297,10 @@ impl ServerPacketHandler { let public_port = response.public_port as u16; self.nat_test .update_addr(route_key.index(), public_ip, public_port); + if route_key.protocol().is_tcp() { + log::info!("更新公网tcp端口 {public_port}"); + self.nat_test.update_tcp_port(public_port); + } let old = current_device; let mut cur = *current_device; loop { @@ -421,7 +427,7 @@ impl ServerPacketHandler { service_packet::Protocol::SecretHandshakeResponse => { log::info!("SecretHandshakeResponse"); //加密握手结束,发送注册数据 - self.register(current_device, context)?; + self.register(current_device, context, route_key)?; } _ => { log::warn!( @@ -466,11 +472,14 @@ impl ServerPacketHandler { &self, current_device: &CurrentDeviceInfo, context: &ChannelContext, + route_key: RouteKey, ) -> anyhow::Result<()> { if current_device.status.online() { log::info!("已连接的不需要注册,{:?}", self.config_info); return Ok(()); } + //设置为默认通道 + context.set_default_route_key(route_key); let token = self.config_info.token.clone(); let device_id = self.config_info.device_id.clone(); let name = self.config_info.name.clone(); diff --git a/vnt/src/nat/mod.rs b/vnt/src/nat/mod.rs index 1ef7b2b7..628eddfb 100644 --- a/vnt/src/nat/mod.rs +++ b/vnt/src/nat/mod.rs @@ -10,9 +10,8 @@ use parking_lot::Mutex; use rand::prelude::SliceRandom; use rand::Rng; -use crate::channel::punch::{NatInfo, NatType}; +use crate::channel::punch::{NatInfo, NatType, PunchModel}; use crate::channel::socket::LocalInterface; -use crate::proto::message::PunchNatType; #[cfg(feature = "upnp")] use crate::util::UPnP; @@ -120,24 +119,6 @@ pub struct NatTest { pub(crate) update_local_ipv4: bool, } -impl From for PunchNatType { - fn from(value: NatType) -> Self { - match value { - NatType::Symmetric => PunchNatType::Symmetric, - NatType::Cone => PunchNatType::Cone, - } - } -} - -impl Into for PunchNatType { - fn into(self) -> NatType { - match self { - PunchNatType::Symmetric => NatType::Symmetric, - PunchNatType::Cone => NatType::Cone, - } - } -} - impl NatTest { pub fn new( _channel_num: usize, @@ -147,6 +128,7 @@ impl NatTest { udp_ports: Vec, tcp_port: u16, update_local_ipv4: bool, + punch_model: PunchModel, ) -> NatTest { let ports = vec![0; udp_ports.len()]; let nat_info = NatInfo::new( @@ -157,7 +139,9 @@ impl NatTest { ipv6, udp_ports.clone(), tcp_port, + 0, NatType::Cone, + punch_model, ); let info = Arc::new(Mutex::new(nat_info)); #[cfg(feature = "upnp")] @@ -257,6 +241,10 @@ impl NatTest { let mut guard = self.info.lock(); guard.update_addr(index, ip, port) } + pub fn update_tcp_port(&self, port: u16) { + let mut guard = self.info.lock(); + guard.update_tcp_port(port) + } pub fn re_test( &self, local_ipv4: Option, diff --git a/vnt/tun/src/windows/tun/mod.rs b/vnt/tun/src/windows/tun/mod.rs index 8c826ebe..27438bde 100644 --- a/vnt/tun/src/windows/tun/mod.rs +++ b/vnt/tun/src/windows/tun/mod.rs @@ -154,6 +154,7 @@ impl Device { } fn hash_guid(input: &str) -> [u8; 16] { let mut hasher = sha2::Sha256::new(); + hasher.update(input.as_bytes()); hasher.update(b"VNT"); hasher.update(input.as_bytes()); hasher.update(b"2024");