diff --git a/bin/chainofcommand/Cargo.toml b/bin/chainofcommand/Cargo.toml index 4cc0ec9..c40f77e 100644 --- a/bin/chainofcommand/Cargo.toml +++ b/bin/chainofcommand/Cargo.toml @@ -17,9 +17,11 @@ seahash = "4.1" anyhow = "1.0" fehler = "1.0" crossterm = { version = "0.25", features = ["event-stream"] } -tokio-serial = "5.4" -tokio = { version = "1.21", features = ["full"] } futures = "0.3" +futures-util = { version = "0.3", features = ["io"] } +tokio = { version = "1.21", features = ["full"] } +tokio-util = { version = "0.7", features = ["io", "codec", "io"] } +tokio-stream = { version = "0.1" } +tokio-serial = "5.4" defer = "0.1" -tokio-util = { version = "0.7", features = ["codec"] } bytes = "1.2" diff --git a/bin/chainofcommand/src/main.rs b/bin/chainofcommand/src/main.rs index 0f2e054..b0318e1 100644 --- a/bin/chainofcommand/src/main.rs +++ b/bin/chainofcommand/src/main.rs @@ -1,4 +1,7 @@ #![feature(trait_alias)] +#![allow(stable_features)] +#![feature(let_else)] // stabilised in 1.65.0 +#![feature(slice_take)] use { anyhow::{anyhow, Result}, @@ -11,9 +14,10 @@ use { tty::IsTty, }, defer::defer, - futures::{future::FutureExt, StreamExt}, + futures::{future::FutureExt, Stream}, seahash::SeaHasher, std::{ + fmt::Formatter, fs::File, hash::Hasher, io::{BufRead, BufReader}, @@ -22,42 +26,55 @@ use { }, tokio::{io::AsyncReadExt, sync::mpsc}, tokio_serial::{SerialPortBuilderExt, SerialStream}, + tokio_stream::StreamExt, }; +// mod utf8_codec; + trait Writable = std::io::Write + Send; trait ThePath = AsRef + std::fmt::Display + Clone + Sync + Send + 'static; -async fn expect( - to_console2: &mpsc::Sender>, - from_serial: &mut mpsc::Receiver>, - m: &str, -) -> Result<()> { - if let Some(buf) = from_serial.recv().await { - if buf.len() >= m.len() && String::from_utf8_lossy(&buf[..m.len()]) == m { - if buf.len() > m.len() { - to_console2.send(buf[m.len()..].to_vec()).await?; +trait FramedStream = Stream> + Unpin; + +type Sender = mpsc::Sender>; +type Receiver = mpsc::Receiver>; + +async fn expect(to_console2: &Sender, from_serial: &mut Receiver, m: &str) -> Result<()> { + let mut s = String::new(); + for _x in m.chars() { + let next_char = from_serial.recv().await; + + let Some(Ok(c)) = next_char else { + return Err(anyhow!( + "Failed to receive expected value {:?}: got empty buf", + m, + )); + }; + + match c { + Message::Text(payload) => { + s.push_str(&payload); + to_console2.send(Ok(Message::Text(payload))).await?; } - return Ok(()); + _ => unreachable!(), } - to_console2.send(buf.clone()).await?; + } + if s != m { return Err(anyhow!( "Failed to receive expected value {:?}: got {:?}", m, - buf + s )); } - Err(anyhow!( - "Failed to receive expected value {:?}: got empty buf", - m, - )) + Ok(()) } -async fn load_kernel

(to_console2: &mpsc::Sender>, kernel: P) -> Result<(File, u64)> +async fn load_kernel

(to_console2: &Sender, kernel: P) -> Result<(File, u64)> where P: ThePath, { to_console2 - .send("[>>] Loading kernel image\n".into()) + .send(Ok(Message::Text("⏩ Loading kernel image\n".into()))) .await?; let kernel_file = match std::fs::File::open(kernel.clone()) { @@ -67,38 +84,46 @@ where let kernel_size: u64 = kernel_file.metadata()?.len(); to_console2 - .send(format!("[>>] .. {} ({} bytes)\n", kernel, kernel_size).into()) + .send(Ok(Message::Text( + format!("⏩ .. {} ({} bytes)\n", kernel, kernel_size).into(), + ))) .await?; Ok((kernel_file, kernel_size)) } -async fn send_kernel

( - to_console2: &mpsc::Sender>, - to_serial: &mpsc::Sender>, - from_serial: &mut mpsc::Receiver>, +async fn send_kernel( + to_console2: &Sender, + to_serial: &Sender, + from_serial: &mut Receiver, kernel: P, -) -> Result<()> -where - P: ThePath, -{ +) -> Result<()> { let (kernel_file, kernel_size) = load_kernel(to_console2, kernel).await?; - to_console2.send("⏩ Sending image size\n".into()).await?; - - to_serial.send(kernel_size.to_le_bytes().into()).await?; + to_console2 + .send(Ok(Message::Text("⏩ Sending image size\n".into()))) + .await?; + to_serial + .send(Ok(Message::Binary(Bytes::copy_from_slice( + &kernel_size.to_le_bytes(), + )))) + .await?; // Wait for OK response expect(to_console2, from_serial, "OK").await?; - to_console2.send("⏩ Sending kernel image\n".into()).await?; + to_console2 + .send(Ok(Message::Text("⏩ Sending kernel image\n".into()))) + .await?; let mut hasher = SeaHasher::new(); let mut reader = BufReader::with_capacity(1, kernel_file); loop { let length = { let buf = reader.fill_buf()?; - to_serial.send(buf.into()).await?; + to_serial + .send(Ok(Message::Binary(Bytes::copy_from_slice(buf)))) + .await?; hasher.write(buf); buf.len() }; @@ -110,10 +135,16 @@ where let hashed_value: u64 = hasher.finish(); to_console2 - .send(format!("⏩ Sending image checksum {:x}\n", hashed_value).into()) + .send(Ok(Message::Text( + format!("⏩ Sending image checksum {:x}\n", hashed_value).into(), + ))) .await?; - to_serial.send(hashed_value.to_le_bytes().into()).await?; + to_serial + .send(Ok(Message::Binary(Bytes::copy_from_slice( + &hashed_value.to_le_bytes(), + )))) + .await?; expect(to_console2, from_serial, "OK").await?; @@ -124,8 +155,8 @@ where async fn serial_loop( mut port: tokio_serial::SerialStream, - to_console: mpsc::Sender>, - mut from_console: mpsc::Receiver>, + to_console: Sender, + mut from_console: Receiver, ) -> Result<()> { let mut buf = [0; 256]; loop { @@ -134,8 +165,13 @@ async fn serial_loop( Some(msg) = from_console.recv() => { // debug!("serial write {} bytes", msg.len()); - tokio::io::AsyncWriteExt::write_all(&mut port, msg.as_ref()).await?; - } + match msg.unwrap() { + Message::Text(s) => { + tokio::io::AsyncWriteExt::write_all(&mut port, s.as_bytes()).await?; + }, + Message::Binary(b) => tokio::io::AsyncWriteExt::write_all(&mut port, b.as_ref()).await?, + } + } res = port.read(&mut buf) => { match res { @@ -145,7 +181,9 @@ async fn serial_loop( } Ok(n) => { // debug!("Serial read {n} bytes."); - to_console.send(buf[0..n].to_owned()).await?; + // let codec = Utf8Codec::new(buf); + let s = String::from_utf8_lossy(&buf[0..n]); + to_console.send(Ok(Message::Text(s.to_string()))).await?; } Err(e) => { // if e.kind() == ErrorKind::TimedOut { @@ -162,11 +200,57 @@ async fn serial_loop( } } +// Always send Binary() to serial +// Convert Text() to bytes and send in serial_loop +// Receive and convert bytes to Text() in serial_loop +#[derive(Clone, Debug)] +enum Message { + Binary(Bytes), + Text(String), +} + +// impl Message { +// pub fn len(&self) -> usize { +// match self { +// Message::Binary(b) => b.len(), +// Message::Text(s) => s.len(), +// } +// } +// } + +impl std::fmt::Display for Message { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Message::Binary(b) => { + for c in b { + write!(f, "{})", c)?; + } + Ok(()) + } + Message::Text(s) => write!(f, "{}", s), + } + } +} + +// impl Buf for Message { +// fn remaining(&self) -> usize { +// todo!() +// } +// +// fn chunk(&self) -> &[u8] { +// todo!() +// } +// +// fn advance(&mut self, cnt: usize) { +// todo!() +// } +// } + async fn console_loop

( - to_console2: mpsc::Sender>, - mut from_internal: mpsc::Receiver>, - to_serial: mpsc::Sender>, - mut from_serial: mpsc::Receiver>, + to_console2: Sender, + mut from_internal: Receiver, + to_serial: Sender, + mut from_serial: Receiver, kernel: P, ) -> Result<()> where @@ -183,33 +267,39 @@ where biased; Some(received) = from_internal.recv() => { - for &x in &received[..] { - execute!(w, style::Print(format!("{}", x as char)))?; + if let Ok(message) = received { + execute!(w, style::Print(message))?; + w.flush()?; } - w.flush()?; } - Some(received) = from_serial.recv() => { - // execute!(w, cursor::MoveToNextLine(1), style::Print(format!("[>>] Received {} bytes from serial", from_serial.len())), cursor::MoveToNextLine(1))?; + Some(received) = from_serial.recv() => { // returns Vec + if let Ok(received) = received { + let Message::Text(received) = received else { + unreachable!(); + }; + execute!(w, cursor::MoveToNextLine(1), style::Print(format!("[>>] Received {} bytes from serial", received.len())), cursor::MoveToNextLine(1))?; - for &x in &received[..] { - if x == 0x3 { - // execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received a BREAK"), cursor::MoveToNextLine(1))?; - breaks += 1; - // Await for 3 consecutive \3 to start downloading - if breaks == 3 { - // execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received 3 BREAKs"), cursor::MoveToNextLine(1))?; - breaks = 0; - send_kernel(&to_console2, &to_serial, &mut from_serial, kernel.clone()).await?; - to_console2.send("🦀 Send successful, pass-through\n".into()).await?; + for x in received.chars() { + if x == 0x3 as char { + // execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received a BREAK"), cursor::MoveToNextLine(1))?; + breaks += 1; + // Await for 3 consecutive \3 to start downloading + if breaks == 3 { + // execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received 3 BREAKs"), cursor::MoveToNextLine(1))?; + breaks = 0; + send_kernel(&to_console2, &to_serial, &mut from_serial, kernel.clone()).await?; + to_console2.send(Ok(Message::Text("🦀 Send successful, pass-through\n".into()))).await?; + } + } else { + while breaks > 0 { + execute!(w, style::Print(format!("{}", 3 as char)))?; + breaks -= 1; + } + // TODO decode buf with Utf8Codec here? + execute!(w, style::Print(format!("{}", x)))?; + w.flush()?; } - } else { - while breaks > 0 { - execute!(w, style::Print(format!("{}", 3 as char)))?; - breaks -= 1; - } - execute!(w, style::Print(format!("{}", x as char)))?; - w.flush()?; } } } @@ -221,7 +311,7 @@ where return Ok(()); } if let Some(key) = handle_key_event(key_event) { - to_serial.send(key.to_vec()).await?; + to_serial.send(Ok(Message::Binary(Bytes::copy_from_slice(&key)))).await?; // Local echo execute!(w, style::Print(format!("{:?}", key)))?; w.flush()?; @@ -244,8 +334,15 @@ where P: ThePath, { // read from serial -> to_console==>from_serial -> output to console - let (to_console, from_serial) = mpsc::channel(256); - let (to_console2, from_internal) = mpsc::channel(256); + let (to_console, from_serial) = mpsc::channel::>(256); + let (to_console2, from_internal) = mpsc::channel::>(256); + + // Make a Stream from Receiver + // let stream = ReceiverStream::new(from_serial); + // // Make AsyncRead from Stream + // let async_stream = StreamReader::new(stream); + // // Make FramedRead (Stream+Sink) from AsyncRead + // let from_serial = FramedRead::new(async_stream, Utf8Codec::new()); // read from console -> to_serial==>from_console -> output to serial let (to_serial, from_console) = mpsc::channel(256); @@ -364,7 +461,7 @@ async fn main() -> Result<()> { execute!( stdout, cursor::RestorePosition, - style::Print("[>>] Opening serial port ") + style::Print("⏩ Opening serial port ") )?; // tokio_serial::new() creates a builder with 8N1 setup without flow control by default.