Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an Arc to share stub resolver #393

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 34 additions & 31 deletions src/resolv/stub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ pub mod conf;
/// [`query`]: #method.query
/// [`run`]: #method.run
/// [`run_with_conf`]: #method.run_with_conf
#[derive(Debug, Clone)]
pub struct StubResolver(Arc<StubResolverInner>);

#[derive(Debug)]
pub struct StubResolver {
struct StubResolverInner {
transport: Mutex<Option<redundant::Connection<RequestMessage<Vec<u8>>>>>,

/// Resolver options.
Expand All @@ -91,16 +94,15 @@ impl StubResolver {

/// Creates a new resolver using the given configuraiton.
pub fn from_conf(conf: ResolvConf) -> Self {
StubResolver {
StubResolver(Arc::new(StubResolverInner {
transport: None.into(),
options: conf.options,

servers: conf.servers,
}
}))
}

pub fn options(&self) -> &ResolvOptions {
&self.options
&self.0.options
}

/// Adds a new connection to the running resolver.
Expand All @@ -119,7 +121,7 @@ impl StubResolver {
}

pub async fn query<N: ToName, Q: Into<Question<N>>>(
&self,
self,
question: Q,
) -> Result<Answer, io::Error> {
Query::new(self)?
Expand All @@ -128,7 +130,7 @@ impl StubResolver {
}

async fn query_message(
&self,
self,
message: QueryMessage,
) -> Result<Answer, io::Error> {
Query::new(self)?.run(message).await
Expand Down Expand Up @@ -162,10 +164,11 @@ impl StubResolver {
// We have 3 modes of operation: use_vc: only use TCP, ign_tc: only
// UDP no fallback to TCP, and normal with is UDP falling back to TCP.

for s in &self.servers {
for s in &self.0.servers {
// This assumes that Transport only has UdpTcp and Tcp. Sadly, a
// match doesn’t work here because of the use_cv flag.
if self.options.use_vc || matches!(s.transport, Transport::Tcp) {
if self.0.options.use_vc || matches!(s.transport, Transport::Tcp)
{
let (conn, tran) =
multi_stream::Connection::new(TcpConnect::new(s.addr));
// Start the run function on a separate task.
Expand All @@ -192,7 +195,7 @@ impl StubResolver {
async fn get_transport(
&self,
) -> Result<redundant::Connection<RequestMessage<Vec<u8>>>, Error> {
let mut opt_transport = self.transport.lock().await;
let mut opt_transport = self.0.transport.lock().await;

match &*opt_transport {
Some(transport) => Ok(transport.clone()),
Expand Down Expand Up @@ -230,22 +233,22 @@ impl StubResolver {
pub async fn lookup_addr(
&self,
addr: IpAddr,
) -> Result<FoundAddrs<&Self>, io::Error> {
lookup_addr(&self, addr).await
) -> Result<FoundAddrs<Self>, io::Error> {
lookup_addr(self, addr).await
}

pub async fn lookup_host(
&self,
qname: impl ToName,
) -> Result<FoundHosts<&Self>, io::Error> {
lookup_host(&self, qname).await
) -> Result<FoundHosts<Self>, io::Error> {
lookup_host(self, qname).await
}

pub async fn search_host(
&self,
qname: impl ToRelativeName,
) -> Result<FoundHosts<&Self>, io::Error> {
search_host(&self, qname).await
) -> Result<FoundHosts<Self>, io::Error> {
search_host(self, qname).await
}

/// Performs an SRV lookup using this resolver.
Expand All @@ -257,7 +260,7 @@ impl StubResolver {
name: impl ToName,
fallback_port: u16,
) -> Result<Option<FoundSrvs>, SrvError> {
lookup_srv(&self, service, name, fallback_port).await
lookup_srv(self, service, name, fallback_port).await
}
}

Expand Down Expand Up @@ -309,39 +312,39 @@ impl Default for StubResolver {
}
}

impl<'a> Resolver for &'a StubResolver {
impl Resolver for StubResolver {
type Octets = Bytes;
type Answer = Answer;
type Query =
Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + Send + 'a>>;
Pin<Box<dyn Future<Output = Result<Answer, io::Error>> + Send>>;

fn query<N, Q>(&self, question: Q) -> Self::Query
where
N: ToName,
Q: Into<Question<N>>,
{
let message = Query::create_message(question.into());
Box::pin(self.query_message(message))
Box::pin(self.clone().query_message(message))
}
}

impl<'a> SearchNames for &'a StubResolver {
impl SearchNames for StubResolver {
type Name = SearchSuffix;
type Iter = SearchIter<'a>;
type Iter = SearchIter;

fn search_iter(&self) -> Self::Iter {
SearchIter {
resolver: self,
resolver: self.clone(),
pos: 0,
}
}
}

//------------ Query ---------------------------------------------------------

pub struct Query<'a> {
pub struct Query {
/// The resolver whose configuration we are using.
resolver: &'a StubResolver,
resolver: StubResolver,

edns: Arc<AtomicBool>,

Expand All @@ -355,8 +358,8 @@ pub struct Query<'a> {
error: Result<Answer, io::Error>,
}

impl<'a> Query<'a> {
pub fn new(resolver: &'a StubResolver) -> Result<Self, io::Error> {
impl Query {
pub fn new(resolver: StubResolver) -> Result<Self, io::Error> {
Ok(Query {
resolver,
edns: Arc::new(AtomicBool::new(true)),
Expand Down Expand Up @@ -419,7 +422,7 @@ impl<'a> Query<'a> {
})?;
let mut gr_fut = transport.send_request(request_msg);
let reply =
timeout(self.resolver.options.timeout, gr_fut.get_response())
timeout(self.resolver.0.options.timeout, gr_fut.get_response())
.await?
.map_err(|e| {
io::Error::new(io::ErrorKind::Other, e.to_string())
Expand Down Expand Up @@ -506,12 +509,12 @@ impl AsRef<Message<Bytes>> for Answer {
//------------ SearchIter ----------------------------------------------------

#[derive(Clone, Debug)]
pub struct SearchIter<'a> {
resolver: &'a StubResolver,
pub struct SearchIter {
resolver: StubResolver,
pos: usize,
}

impl Iterator for SearchIter<'_> {
impl Iterator for SearchIter {
type Item = SearchSuffix;

fn next(&mut self) -> Option<Self::Item> {
Expand Down
Loading