fix(console): 🐛 Fix console I/O on the host side

This commit is contained in:
Berkus Decker 2023-07-10 02:28:33 +03:00 committed by Berkus Decker
parent fa725c51cb
commit 2c91e685bd
2 changed files with 172 additions and 73 deletions

View File

@ -17,9 +17,11 @@ seahash = "4.1"
anyhow = "1.0" anyhow = "1.0"
fehler = "1.0" fehler = "1.0"
crossterm = { version = "0.25", features = ["event-stream"] } crossterm = { version = "0.25", features = ["event-stream"] }
tokio-serial = "5.4"
tokio = { version = "1.21", features = ["full"] }
futures = "0.3" futures = "0.3"
futures-util = { version = "0.3", features = ["io"] }
tokio = { version = "1.21", features = ["full"] }
tokio-util = { version = "0.7", features = ["io", "codec", "io"] }
tokio-stream = { version = "0.1" }
tokio-serial = "5.4"
defer = "0.1" defer = "0.1"
tokio-util = { version = "0.7", features = ["codec"] }
bytes = "1.2" bytes = "1.2"

View File

@ -1,4 +1,7 @@
#![feature(trait_alias)] #![feature(trait_alias)]
#![allow(stable_features)]
#![feature(let_else)] // stabilised in 1.65.0
#![feature(slice_take)]
use { use {
anyhow::{anyhow, Result}, anyhow::{anyhow, Result},
@ -11,9 +14,10 @@ use {
tty::IsTty, tty::IsTty,
}, },
defer::defer, defer::defer,
futures::{future::FutureExt, StreamExt}, futures::{future::FutureExt, Stream},
seahash::SeaHasher, seahash::SeaHasher,
std::{ std::{
fmt::Formatter,
fs::File, fs::File,
hash::Hasher, hash::Hasher,
io::{BufRead, BufReader}, io::{BufRead, BufReader},
@ -22,42 +26,55 @@ use {
}, },
tokio::{io::AsyncReadExt, sync::mpsc}, tokio::{io::AsyncReadExt, sync::mpsc},
tokio_serial::{SerialPortBuilderExt, SerialStream}, tokio_serial::{SerialPortBuilderExt, SerialStream},
tokio_stream::StreamExt,
}; };
// mod utf8_codec;
trait Writable = std::io::Write + Send; trait Writable = std::io::Write + Send;
trait ThePath = AsRef<Path> + std::fmt::Display + Clone + Sync + Send + 'static; trait ThePath = AsRef<Path> + std::fmt::Display + Clone + Sync + Send + 'static;
async fn expect( trait FramedStream = Stream<Item = Result<Message, anyhow::Error>> + Unpin;
to_console2: &mpsc::Sender<Vec<u8>>,
from_serial: &mut mpsc::Receiver<Vec<u8>>, type Sender = mpsc::Sender<Result<Message>>;
m: &str, type Receiver = mpsc::Receiver<Result<Message>>;
) -> Result<()> {
if let Some(buf) = from_serial.recv().await { async fn expect(to_console2: &Sender, from_serial: &mut Receiver, m: &str) -> Result<()> {
if buf.len() >= m.len() && String::from_utf8_lossy(&buf[..m.len()]) == m { let mut s = String::new();
if buf.len() > m.len() { for _x in m.chars() {
to_console2.send(buf[m.len()..].to_vec()).await?; let next_char = from_serial.recv().await;
let Some(Ok(c)) = next_char else {
return Err(anyhow!(
"Failed to receive expected value {:?}: got empty buf",
m,
));
};
match c {
Message::Text(payload) => {
s.push_str(&payload);
to_console2.send(Ok(Message::Text(payload))).await?;
} }
return Ok(()); _ => unreachable!(),
} }
to_console2.send(buf.clone()).await?; }
if s != m {
return Err(anyhow!( return Err(anyhow!(
"Failed to receive expected value {:?}: got {:?}", "Failed to receive expected value {:?}: got {:?}",
m, m,
buf s
)); ));
} }
Err(anyhow!( Ok(())
"Failed to receive expected value {:?}: got empty buf",
m,
))
} }
async fn load_kernel<P>(to_console2: &mpsc::Sender<Vec<u8>>, kernel: P) -> Result<(File, u64)> async fn load_kernel<P>(to_console2: &Sender, kernel: P) -> Result<(File, u64)>
where where
P: ThePath, P: ThePath,
{ {
to_console2 to_console2
.send("[>>] Loading kernel image\n".into()) .send(Ok(Message::Text(" Loading kernel image\n".into())))
.await?; .await?;
let kernel_file = match std::fs::File::open(kernel.clone()) { let kernel_file = match std::fs::File::open(kernel.clone()) {
@ -67,38 +84,46 @@ where
let kernel_size: u64 = kernel_file.metadata()?.len(); let kernel_size: u64 = kernel_file.metadata()?.len();
to_console2 to_console2
.send(format!("[>>] .. {} ({} bytes)\n", kernel, kernel_size).into()) .send(Ok(Message::Text(
format!("⏩ .. {} ({} bytes)\n", kernel, kernel_size).into(),
)))
.await?; .await?;
Ok((kernel_file, kernel_size)) Ok((kernel_file, kernel_size))
} }
async fn send_kernel<P>( async fn send_kernel<P: ThePath>(
to_console2: &mpsc::Sender<Vec<u8>>, to_console2: &Sender,
to_serial: &mpsc::Sender<Vec<u8>>, to_serial: &Sender,
from_serial: &mut mpsc::Receiver<Vec<u8>>, from_serial: &mut Receiver,
kernel: P, kernel: P,
) -> Result<()> ) -> Result<()> {
where
P: ThePath,
{
let (kernel_file, kernel_size) = load_kernel(to_console2, kernel).await?; let (kernel_file, kernel_size) = load_kernel(to_console2, kernel).await?;
to_console2.send("⏩ Sending image size\n".into()).await?; to_console2
.send(Ok(Message::Text("⏩ Sending image size\n".into())))
to_serial.send(kernel_size.to_le_bytes().into()).await?; .await?;
to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(
&kernel_size.to_le_bytes(),
))))
.await?;
// Wait for OK response // Wait for OK response
expect(to_console2, from_serial, "OK").await?; expect(to_console2, from_serial, "OK").await?;
to_console2.send("⏩ Sending kernel image\n".into()).await?; to_console2
.send(Ok(Message::Text("⏩ Sending kernel image\n".into())))
.await?;
let mut hasher = SeaHasher::new(); let mut hasher = SeaHasher::new();
let mut reader = BufReader::with_capacity(1, kernel_file); let mut reader = BufReader::with_capacity(1, kernel_file);
loop { loop {
let length = { let length = {
let buf = reader.fill_buf()?; let buf = reader.fill_buf()?;
to_serial.send(buf.into()).await?; to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(buf))))
.await?;
hasher.write(buf); hasher.write(buf);
buf.len() buf.len()
}; };
@ -110,10 +135,16 @@ where
let hashed_value: u64 = hasher.finish(); let hashed_value: u64 = hasher.finish();
to_console2 to_console2
.send(format!("⏩ Sending image checksum {:x}\n", hashed_value).into()) .send(Ok(Message::Text(
format!("⏩ Sending image checksum {:x}\n", hashed_value).into(),
)))
.await?; .await?;
to_serial.send(hashed_value.to_le_bytes().into()).await?; to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(
&hashed_value.to_le_bytes(),
))))
.await?;
expect(to_console2, from_serial, "OK").await?; expect(to_console2, from_serial, "OK").await?;
@ -124,8 +155,8 @@ where
async fn serial_loop( async fn serial_loop(
mut port: tokio_serial::SerialStream, mut port: tokio_serial::SerialStream,
to_console: mpsc::Sender<Vec<u8>>, to_console: Sender,
mut from_console: mpsc::Receiver<Vec<u8>>, mut from_console: Receiver,
) -> Result<()> { ) -> Result<()> {
let mut buf = [0; 256]; let mut buf = [0; 256];
loop { loop {
@ -134,8 +165,13 @@ async fn serial_loop(
Some(msg) = from_console.recv() => { Some(msg) = from_console.recv() => {
// debug!("serial write {} bytes", msg.len()); // debug!("serial write {} bytes", msg.len());
tokio::io::AsyncWriteExt::write_all(&mut port, msg.as_ref()).await?; match msg.unwrap() {
} Message::Text(s) => {
tokio::io::AsyncWriteExt::write_all(&mut port, s.as_bytes()).await?;
},
Message::Binary(b) => tokio::io::AsyncWriteExt::write_all(&mut port, b.as_ref()).await?,
}
}
res = port.read(&mut buf) => { res = port.read(&mut buf) => {
match res { match res {
@ -145,7 +181,9 @@ async fn serial_loop(
} }
Ok(n) => { Ok(n) => {
// debug!("Serial read {n} bytes."); // debug!("Serial read {n} bytes.");
to_console.send(buf[0..n].to_owned()).await?; // let codec = Utf8Codec::new(buf);
let s = String::from_utf8_lossy(&buf[0..n]);
to_console.send(Ok(Message::Text(s.to_string()))).await?;
} }
Err(e) => { Err(e) => {
// if e.kind() == ErrorKind::TimedOut { // if e.kind() == ErrorKind::TimedOut {
@ -162,11 +200,57 @@ async fn serial_loop(
} }
} }
// Always send Binary() to serial
// Convert Text() to bytes and send in serial_loop
// Receive and convert bytes to Text() in serial_loop
#[derive(Clone, Debug)]
enum Message {
Binary(Bytes),
Text(String),
}
// impl Message {
// pub fn len(&self) -> usize {
// match self {
// Message::Binary(b) => b.len(),
// Message::Text(s) => s.len(),
// }
// }
// }
impl std::fmt::Display for Message {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Message::Binary(b) => {
for c in b {
write!(f, "{})", c)?;
}
Ok(())
}
Message::Text(s) => write!(f, "{}", s),
}
}
}
// impl Buf for Message {
// fn remaining(&self) -> usize {
// todo!()
// }
//
// fn chunk(&self) -> &[u8] {
// todo!()
// }
//
// fn advance(&mut self, cnt: usize) {
// todo!()
// }
// }
async fn console_loop<P>( async fn console_loop<P>(
to_console2: mpsc::Sender<Vec<u8>>, to_console2: Sender,
mut from_internal: mpsc::Receiver<Vec<u8>>, mut from_internal: Receiver,
to_serial: mpsc::Sender<Vec<u8>>, to_serial: Sender,
mut from_serial: mpsc::Receiver<Vec<u8>>, mut from_serial: Receiver,
kernel: P, kernel: P,
) -> Result<()> ) -> Result<()>
where where
@ -183,33 +267,39 @@ where
biased; biased;
Some(received) = from_internal.recv() => { Some(received) = from_internal.recv() => {
for &x in &received[..] { if let Ok(message) = received {
execute!(w, style::Print(format!("{}", x as char)))?; execute!(w, style::Print(message))?;
w.flush()?;
} }
w.flush()?;
} }
Some(received) = from_serial.recv() => { Some(received) = from_serial.recv() => { // returns Vec<char>
// execute!(w, cursor::MoveToNextLine(1), style::Print(format!("[>>] Received {} bytes from serial", from_serial.len())), cursor::MoveToNextLine(1))?; if let Ok(received) = received {
let Message::Text(received) = received else {
unreachable!();
};
execute!(w, cursor::MoveToNextLine(1), style::Print(format!("[>>] Received {} bytes from serial", received.len())), cursor::MoveToNextLine(1))?;
for &x in &received[..] { for x in received.chars() {
if x == 0x3 { if x == 0x3 as char {
// execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received a BREAK"), cursor::MoveToNextLine(1))?; // execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received a BREAK"), cursor::MoveToNextLine(1))?;
breaks += 1; breaks += 1;
// Await for 3 consecutive \3 to start downloading // Await for 3 consecutive \3 to start downloading
if breaks == 3 { if breaks == 3 {
// execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received 3 BREAKs"), cursor::MoveToNextLine(1))?; // execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received 3 BREAKs"), cursor::MoveToNextLine(1))?;
breaks = 0; breaks = 0;
send_kernel(&to_console2, &to_serial, &mut from_serial, kernel.clone()).await?; send_kernel(&to_console2, &to_serial, &mut from_serial, kernel.clone()).await?;
to_console2.send("🦀 Send successful, pass-through\n".into()).await?; to_console2.send(Ok(Message::Text("🦀 Send successful, pass-through\n".into()))).await?;
}
} else {
while breaks > 0 {
execute!(w, style::Print(format!("{}", 3 as char)))?;
breaks -= 1;
}
// TODO decode buf with Utf8Codec here?
execute!(w, style::Print(format!("{}", x)))?;
w.flush()?;
} }
} else {
while breaks > 0 {
execute!(w, style::Print(format!("{}", 3 as char)))?;
breaks -= 1;
}
execute!(w, style::Print(format!("{}", x as char)))?;
w.flush()?;
} }
} }
} }
@ -221,7 +311,7 @@ where
return Ok(()); return Ok(());
} }
if let Some(key) = handle_key_event(key_event) { if let Some(key) = handle_key_event(key_event) {
to_serial.send(key.to_vec()).await?; to_serial.send(Ok(Message::Binary(Bytes::copy_from_slice(&key)))).await?;
// Local echo // Local echo
execute!(w, style::Print(format!("{:?}", key)))?; execute!(w, style::Print(format!("{:?}", key)))?;
w.flush()?; w.flush()?;
@ -244,8 +334,15 @@ where
P: ThePath, P: ThePath,
{ {
// read from serial -> to_console==>from_serial -> output to console // read from serial -> to_console==>from_serial -> output to console
let (to_console, from_serial) = mpsc::channel(256); let (to_console, from_serial) = mpsc::channel::<Result<Message>>(256);
let (to_console2, from_internal) = mpsc::channel(256); let (to_console2, from_internal) = mpsc::channel::<Result<Message>>(256);
// Make a Stream from Receiver
// let stream = ReceiverStream::new(from_serial);
// // Make AsyncRead from Stream
// let async_stream = StreamReader::new(stream);
// // Make FramedRead (Stream+Sink) from AsyncRead
// let from_serial = FramedRead::new(async_stream, Utf8Codec::new());
// read from console -> to_serial==>from_console -> output to serial // read from console -> to_serial==>from_console -> output to serial
let (to_serial, from_console) = mpsc::channel(256); let (to_serial, from_console) = mpsc::channel(256);
@ -364,7 +461,7 @@ async fn main() -> Result<()> {
execute!( execute!(
stdout, stdout,
cursor::RestorePosition, cursor::RestorePosition,
style::Print("[>>] Opening serial port ") style::Print(" Opening serial port ")
)?; )?;
// tokio_serial::new() creates a builder with 8N1 setup without flow control by default. // tokio_serial::new() creates a builder with 8N1 setup without flow control by default.