How to close TCPStream connection when switch to another connection?

My TCP client use Reader and Writer structs to read/write packets, this is how they defined:

pub struct Reader {
    stream: OwnedReadHalf,
    // ...
}

pub struct Writer {
    stream: OwnedWriteHalf,
    // ...
}

by request client sometimes reconnect to another IP addr:

// initial connection
pub async fn connect(&mut self, host: &str, port: u16) -> Result<(), Error> {
	// ...

	return match Self::connect_inner(host, port).await {
		Ok(stream) => {
			Self::set_stream_halves(stream, &self._reader, &self._writer).await;
			// ...
		},
		Err(err) => {
			// ...
		},
	}
}

async fn connect_inner(host: &str, port: u16) -> Result<TcpStream, Error> {
	let addr = format!("{}:{}", host, port);
	match TcpStream::connect(&addr).await {
		Ok(stream) => Ok(stream),
		Err(err) => Err(err),
	}
}

async fn set_stream_halves(
	stream: TcpStream,
	reader: &Arc<Mutex<Option<Reader>>>,
	writer: &Arc<Mutex<Option<Writer>>>
) {
	let (rx, tx) = stream.into_split();

	let mut reader = reader.lock().await;
	*reader = Some(Reader::new(rx));
	let mut writer = writer.lock().await;
	*writer = Some(Writer::new(tx));
}

// RECONNECT handles below
fn handle_queue(&mut self) -> JoinHandle<()> {
    // reader and writer are Arc<Mutex<Option<Reader or Writer>>>
	let reader = Arc::clone(&self._reader);
	let writer = Arc::clone(&self._writer);

	tokio::spawn(async move {
		loop {
			let packets = input_queue.lock().unwrap().pop_front();
			if packets.is_some() {
				for packet in packets.unwrap() {
					// ...

					for mut handler in handler_list {
						let response = handler.handle(&mut input).await;
						match response {
							Ok(output) => {
								match output {
									HandlerOutput::Data((opcode, header, body)) => {},
									HandlerOutput::ConnectionRequest(host, port) => {
										match Self::connect_inner(&host, port).await {
											Ok(stream) => {
												let message = format!(
													"Connected to {}:{}", host, port
												);

												Self::set_stream_halves(
													stream, &reader, &writer
												).await;

												message_income.send_success_message(message);

											},
											Err(err) => {
												message_income.send_error_message(
													err.to_string()
												);
											}
										}
									},
									HandlerOutput::UpdateState(state) => {},
									HandlerOutput::Freeze => {},
									HandlerOutput::Void => {},
								};
							},
							Err(err) => {},
						};

						sleep(Duration::from_millis(WRITE_TIMEOUT)).await;
					}
				}
			}
		}
	})
}

the issue with code is that when I reconnect to another IP:

match Self::connect_inner(&host, port).await {
	Ok(stream) => {
		Self::set_stream_halves(stream, &reader, &writer).await;
	},
	Err(err) => {}
}

my handle_read task still use previous connection:

fn handle_read(&mut self) -> JoinHandle<()> {
	// ...
	let reader = Arc::clone(&self._reader);

	tokio::spawn(async move {
		loop {
			match &mut *reader.lock().await {
				Some(reader) => {
					if let Some(packets) = reader.read().await.ok() {
						input_queue.lock().unwrap().push_back(packets);
					}
				},
				None => {},
			};

			sleep(Duration::from_millis(READ_TIMEOUT)).await;
		}
	})
}

Could somebody advice what can I do in this case ? Probably drop connection in some way, or restart the reader task or smth else ?

It looks to me like you're updating the reader inside the Mutex, so the problem isn't that you don't have the right reader in your read task. It's that you're already waiting for data on the old reader, and can't see the new reader until the old reader wakes up.

You could use a channel and tokio::select (or something like it) to wait on both a "reconnected" channel and the reader. Just send a message on the channel once the reconnect is finished, your reader task will wake up, and you can skip to the next iteration of the loop which will see the new reader. Then your reader task wouldn't be stuck with the old connection.

Note: select can lose data sometimes if one of the futures isn't cancel safe. In this specific case I'm not sure that will be a problem, but you should think about it carefully to make sure.

1 Like

I tried to refactor this with using tokio::select! inside handle_read task, but got an error:

error: no rules expected the token `}`
   --> src\client\mod.rs:488:17
    |
488 | /                 tokio::select! {
489 | |                     _ = signal_receiver.lock().unwrap().recv().unwrap() => {},
490 | |                     _ => {
491 | |                         match &mut *reader.lock().await {
...   |
503 | |                     },
504 | |                 };
    | |_________________^ no rules expected this token in macro call

this is what I did:

fn handle_read(&mut self) -> JoinHandle<()> {
	// ...
	let reader = Arc::clone(&self._reader);
	let signal_receiver = Arc::clone(&self._signal_receiver);

	tokio::spawn(async move {
		loop {
			tokio::select! {
				_ = signal_receiver.lock().unwrap().recv().unwrap() => {},
				_ => {
					match &mut *reader.lock().await {
						Some(reader) => {
							if let Some(packets) = reader.read().await.ok() {
								input_queue.lock().unwrap().push_back(packets);
							}
						},
						None => {},
					};
				},
			};
		}
	})
}

Could you tell what am I doing wrong ?

I also tried to do smth like:

tokio::select! {
	s = signal_receiver.lock().unwrap().recv().unwrap(), if s => {},
	_ => {
		match &mut *reader.lock().await {
			Some(reader) => {
				if let Some(packets) = reader.read().await.ok() {
					input_queue.lock().unwrap().push_back(packets);
				}
			},
			None => {
				message_income.send_error_message(
					String::from("Not connected to TCP")
				);
			},
		};
	},
};

same result (an error).

select always needs to have a future in each arm (except the else arm). You need some extra pattern matching to handle the inner optional, so for simplicity I just broke that out into it's own async function. There may be a better way to do that though

use std::{fmt::Debug, sync::Arc};
use tokio::{net::tcp::OwnedReadHalf, sync::Mutex, task::JoinHandle};

fn main() {}

struct Reader(OwnedReadHalf);

impl Reader {
    async fn read(&mut self) -> Option<()> {
        todo!()
    }
}

fn handle_read(
    reader: Arc<tokio::sync::Mutex<Option<Reader>>>,
    mut signal_receiver: tokio::sync::mpsc::Receiver<()>,
) -> JoinHandle<()> {
    // ...
    let reader = Arc::clone(&reader);

    tokio::spawn(async move {
        loop {
            tokio::select! {
                _ = signal_receiver.recv() => {},
                Some(packets) = check_reader(&reader) => {
                    println!("Got packet! {packets:?}");
                    //input_queue.lock().unwrap().push_back(packets);
                },
                else => {}
            };
        }
    })
}

async fn check_reader(reader: &Arc<Mutex<Option<Reader>>>) -> Option<()> {
    match &mut *reader.lock().await {
        Some(reader) => reader.read().await,
        None => None,
    }
}
1 Like

I refactored the code, seems like it works now, but it's not clear why this not works:

// NOT WORKS
async fn read_packets(reader: &Arc<Mutex<Option<Reader>>>) -> Option<Vec<Vec<u8>>> {
	if let Some(reader) = &mut *reader.lock().await {
		if let Some(packets) = reader.read().await.ok() {
			return Some(packets);
		}
	}
	
	None
}

if this works:

// WORKS
async fn read_packets(reader: &Arc<Mutex<Option<Reader>>>) -> Option<Vec<Vec<u8>>> {
	match &mut *reader.lock().await {
		Some(reader) => {
			if let Some(packets) = reader.read().await.ok() {
				return Some(packets);
			}

			None
		},
		_ => None,
	}
}

both functions seems identical for me. Could you explain why first case not work and what the difference between two cases ?

Can you be more specific about what doesn't work?

oh, now it works as expected, not sure what was that :thinking:

well, the final working code is next:

fn handle_read(&mut self) -> JoinHandle<()> {
	let input_queue = Arc::clone(&self._input_queue);
	let reader = Arc::clone(&self._reader);
	let signal_receiver = Arc::clone(&self._signal_receiver);

	tokio::spawn(async move {
		loop {
			let receiver = &mut *signal_receiver.lock().await;

			tokio::select! {
				_ = receiver.recv() => {},
				Some(packets) = Self::read_packets(&reader) => {
					input_queue.lock().unwrap().push_back(packets);
				},
			};

			sleep(Duration::from_millis(READ_TIMEOUT)).await;
		}
	})
}

async fn read_packets(reader: &Arc<Mutex<Option<Reader>>>) -> Option<Vec<Vec<u8>>> {
	if let Some(reader) = &mut *reader.lock().await {
		if let Some(packets) = reader.read().await.ok() {
			return Some(packets);
		}
	}

	None
}

thanks @semicoleon very much for help !

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.