diff --git a/easytier/src/instance/dns_server/server_instance.rs b/easytier/src/instance/dns_server/server_instance.rs index 21288295..2d3129f1 100644 --- a/easytier/src/instance/dns_server/server_instance.rs +++ b/easytier/src/instance/dns_server/server_instance.rs @@ -140,6 +140,19 @@ impl MagicDnsServerInstanceData { } } + async fn keep_zone_authoritative(&self, zone: &str) { + if let Err(e) = self + .update_dns_records(std::iter::empty::<&Route>(), zone) + .await + { + tracing::error!( + "Failed to keep DNS zone {} authoritative after route prune: {:?}", + zone, + e + ); + } + } + fn do_system_config(&self, zone: &str) -> Result<(), anyhow::Error> { if let Some(c) = &self.system_config { c.set_dns(&OSConfig { @@ -183,10 +196,25 @@ impl MagicDnsServerRpc for MagicDnsServerInstanceData { return Err(anyhow::anyhow!("No remote addr").into()); }; let zone = input.zone.clone(); - self.route_infos - .entry(zone.clone()) - .or_default() - .insert_many(remote_addr.clone().into(), input.routes); + let remote_addr: url::Url = remote_addr.clone().into(); + let mut zone_removed = false; + + if let Some(mut routes_by_addr) = self.route_infos.get_mut(&zone) { + routes_by_addr.remove(&remote_addr); + if !input.routes.is_empty() { + routes_by_addr.insert_many(remote_addr, input.routes); + } + zone_removed = routes_by_addr.is_empty(); + } else if !input.routes.is_empty() { + let mut routes_by_addr = MultiMap::new(); + routes_by_addr.insert_many(remote_addr, input.routes); + self.route_infos.insert(zone.clone(), routes_by_addr); + } + + if zone_removed { + self.route_infos.remove(&zone); + self.keep_zone_authoritative(&zone).await; + } self.update().await; Ok(Default::default()) @@ -438,11 +466,19 @@ impl RpcServerHook for MagicDnsServerInstanceData { return; }; let remote_addr = remote_addr.into(); + let mut removed_zones = vec![]; for mut item in self.route_infos.iter_mut() { item.value_mut().remove(&remote_addr); + if item.value().is_empty() { + removed_zones.push(item.key().clone()); + } + } + for zone in &removed_zones { + self.route_infos.remove(zone); + } + for zone in removed_zones { + self.keep_zone_authoritative(&zone).await; } - self.route_infos.retain(|_, v| !v.is_empty()); - self.route_infos.shrink_to_fit(); self.update().await; } } diff --git a/easytier/src/instance/dns_server/tests.rs b/easytier/src/instance/dns_server/tests.rs index 4853362b..cad28c62 100644 --- a/easytier/src/instance/dns_server/tests.rs +++ b/easytier/src/instance/dns_server/tests.rs @@ -23,6 +23,8 @@ use crate::peers::peer_manager::{PeerManager, RouteAlgoType}; use crate::peers::create_packet_recv_chan; use crate::proto::api::instance::Route; use crate::proto::common::NatType; +use crate::proto::magic_dns::{MagicDnsServerRpc as _, UpdateDnsRecordRequest}; +use crate::proto::rpc_types::controller::{BaseController, Controller as _}; pub async fn prepare_env(dns_name: &str, tun_ip: Ipv4Inet) -> (Arc, NicCtx) { prepare_env_with_tld_dns_zone(dns_name, tun_ip, None).await @@ -78,8 +80,9 @@ pub async fn check_dns_record(fake_ip: &Ipv4Addr, domain: &str, expected_ip: &st rr::RecordType::A, ) .await - .unwrap(); - drop(background_task); + .unwrap_or_else(|e| panic!("DNS query failed unexpectedly for domain '{domain}': {e}")); + background_task.abort(); + let _ = background_task.await; println!("Response: {:?}", response); @@ -91,6 +94,29 @@ pub async fn check_dns_record(fake_ip: &Ipv4Addr, domain: &str, expected_ip: &st ); } +pub async fn check_dns_record_missing(fake_ip: &Ipv4Addr, domain: &str) { + let stream = UdpClientStream::builder( + SocketAddr::new((*fake_ip).into(), 53), + TokioRuntimeProvider::default(), + ) + .build(); + let (mut client, background) = Client::connect(stream).await.unwrap(); + let background_task = tokio::spawn(background); + let response = client + .query( + rr::Name::from_str(domain).unwrap(), + rr::DNSClass::IN, + rr::RecordType::A, + ) + .await + .unwrap_or_else(|e| { + panic!("DNS query for missing record failed unexpectedly for domain '{domain}': {e}") + }); + background_task.abort(); + let _ = background_task.await; + assert!(response.answers().is_empty(), "{:?}", response.answers()); +} + #[tokio::test] async fn test_magic_dns_server_instance() { let tun_ip = Ipv4Inet::from_str("10.144.144.10/24").unwrap(); @@ -181,3 +207,119 @@ async fn test_magic_dns_runner() { t.await.unwrap(); } } + +#[tokio::test] +async fn test_magic_dns_update_replaces_records_for_same_client() { + let tun_ip = Ipv4Inet::from_str("100.100.100.0/24").unwrap(); + let ctx = get_mock_global_ctx(); + ctx.set_hostname("test1".to_string()); + ctx.set_ipv4(Some(tun_ip)); + + let (s, _r) = create_packet_recv_chan(); + let peer_mgr = Arc::new(PeerManager::new(RouteAlgoType::Ospf, ctx, s)); + peer_mgr.run().await.unwrap(); + replace_stun_info_collector(peer_mgr.clone(), NatType::PortRestricted); + + let fake_ip = Ipv4Addr::from_str(MAGIC_DNS_FAKE_IP).unwrap(); + let dns_server_inst = MagicDnsServerInstance::new(peer_mgr.clone(), None, tun_ip, fake_ip) + .await + .unwrap(); + + let mut ctrl = BaseController::default(); + ctrl.set_tunnel_info(Some(crate::proto::common::TunnelInfo { + tunnel_type: "tcp".to_string(), + local_addr: None, + remote_addr: Some(crate::proto::common::Url { + url: "tcp://127.0.0.1:54321".to_string(), + }), + })); + + dns_server_inst + .data + .update_dns_record( + ctrl.clone(), + UpdateDnsRecordRequest { + zone: DEFAULT_ET_DNS_ZONE.to_string(), + routes: vec![Route { + hostname: "test1".to_string(), + ipv4_addr: Some(Ipv4Inet::from_str("8.8.8.8/32").unwrap().into()), + ..Default::default() + }], + }, + ) + .await + .unwrap(); + + dns_server_inst + .data + .update_dns_record( + ctrl, + UpdateDnsRecordRequest { + zone: DEFAULT_ET_DNS_ZONE.to_string(), + routes: vec![Route { + hostname: "test1".to_string(), + ipv4_addr: Some(Ipv4Inet::from_str("1.1.1.1/32").unwrap().into()), + ..Default::default() + }], + }, + ) + .await + .unwrap(); + + let dns_records = dns_server_inst + .data + .get_dns_record( + BaseController::default(), + crate::proto::common::Void::default(), + ) + .await + .unwrap(); + let zone_records = dns_records.records.get(DEFAULT_ET_DNS_ZONE).unwrap(); + let a_records = zone_records + .records + .iter() + .filter_map(|record| match record.record.as_ref() { + Some(crate::proto::magic_dns::dns_record::Record::A(a)) + if a.name == "test1.et.net." => + { + Some(a) + } + _ => None, + }) + .collect::>(); + + assert_eq!(a_records.len(), 1, "{a_records:?}"); + let resolved_ip = Ipv4Addr::from(a_records[0].value.unwrap_or_default()); + assert_eq!(resolved_ip, Ipv4Addr::new(1, 1, 1, 1)); + + let mut ctrl = BaseController::default(); + ctrl.set_tunnel_info(Some(crate::proto::common::TunnelInfo { + tunnel_type: "tcp".to_string(), + local_addr: None, + remote_addr: Some(crate::proto::common::Url { + url: "tcp://127.0.0.1:54321".to_string(), + }), + })); + + dns_server_inst + .data + .update_dns_record( + ctrl, + UpdateDnsRecordRequest { + zone: DEFAULT_ET_DNS_ZONE.to_string(), + routes: vec![], + }, + ) + .await + .unwrap(); + + let dns_records = dns_server_inst + .data + .get_dns_record( + BaseController::default(), + crate::proto::common::Void::default(), + ) + .await + .unwrap(); + assert!(!dns_records.records.contains_key(DEFAULT_ET_DNS_ZONE)); +}