fix(console): 🐛 Fix console I/O on the host side

This commit is contained in:
Berkus Decker 2023-07-10 02:28:33 +03:00 committed by Berkus Decker
parent fa725c51cb
commit 2c91e685bd
2 changed files with 172 additions and 73 deletions

View File

@ -17,9 +17,11 @@ seahash = "4.1"
anyhow = "1.0"
fehler = "1.0"
crossterm = { version = "0.25", features = ["event-stream"] }
tokio-serial = "5.4"
tokio = { version = "1.21", features = ["full"] }
futures = "0.3"
futures-util = { version = "0.3", features = ["io"] }
tokio = { version = "1.21", features = ["full"] }
tokio-util = { version = "0.7", features = ["io", "codec", "io"] }
tokio-stream = { version = "0.1" }
tokio-serial = "5.4"
defer = "0.1"
tokio-util = { version = "0.7", features = ["codec"] }
bytes = "1.2"

View File

@ -1,4 +1,7 @@
#![feature(trait_alias)]
#![allow(stable_features)]
#![feature(let_else)] // stabilised in 1.65.0
#![feature(slice_take)]
use {
anyhow::{anyhow, Result},
@ -11,9 +14,10 @@ use {
tty::IsTty,
},
defer::defer,
futures::{future::FutureExt, StreamExt},
futures::{future::FutureExt, Stream},
seahash::SeaHasher,
std::{
fmt::Formatter,
fs::File,
hash::Hasher,
io::{BufRead, BufReader},
@ -22,42 +26,55 @@ use {
},
tokio::{io::AsyncReadExt, sync::mpsc},
tokio_serial::{SerialPortBuilderExt, SerialStream},
tokio_stream::StreamExt,
};
// mod utf8_codec;
trait Writable = std::io::Write + Send;
trait ThePath = AsRef<Path> + std::fmt::Display + Clone + Sync + Send + 'static;
async fn expect(
to_console2: &mpsc::Sender<Vec<u8>>,
from_serial: &mut mpsc::Receiver<Vec<u8>>,
m: &str,
) -> Result<()> {
if let Some(buf) = from_serial.recv().await {
if buf.len() >= m.len() && String::from_utf8_lossy(&buf[..m.len()]) == m {
if buf.len() > m.len() {
to_console2.send(buf[m.len()..].to_vec()).await?;
trait FramedStream = Stream<Item = Result<Message, anyhow::Error>> + Unpin;
type Sender = mpsc::Sender<Result<Message>>;
type Receiver = mpsc::Receiver<Result<Message>>;
async fn expect(to_console2: &Sender, from_serial: &mut Receiver, m: &str) -> Result<()> {
let mut s = String::new();
for _x in m.chars() {
let next_char = from_serial.recv().await;
let Some(Ok(c)) = next_char else {
return Err(anyhow!(
"Failed to receive expected value {:?}: got empty buf",
m,
));
};
match c {
Message::Text(payload) => {
s.push_str(&payload);
to_console2.send(Ok(Message::Text(payload))).await?;
}
return Ok(());
_ => unreachable!(),
}
to_console2.send(buf.clone()).await?;
}
if s != m {
return Err(anyhow!(
"Failed to receive expected value {:?}: got {:?}",
m,
buf
s
));
}
Err(anyhow!(
"Failed to receive expected value {:?}: got empty buf",
m,
))
Ok(())
}
async fn load_kernel<P>(to_console2: &mpsc::Sender<Vec<u8>>, kernel: P) -> Result<(File, u64)>
async fn load_kernel<P>(to_console2: &Sender, kernel: P) -> Result<(File, u64)>
where
P: ThePath,
{
to_console2
.send("[>>] Loading kernel image\n".into())
.send(Ok(Message::Text(" Loading kernel image\n".into())))
.await?;
let kernel_file = match std::fs::File::open(kernel.clone()) {
@ -67,38 +84,46 @@ where
let kernel_size: u64 = kernel_file.metadata()?.len();
to_console2
.send(format!("[>>] .. {} ({} bytes)\n", kernel, kernel_size).into())
.send(Ok(Message::Text(
format!("⏩ .. {} ({} bytes)\n", kernel, kernel_size).into(),
)))
.await?;
Ok((kernel_file, kernel_size))
}
async fn send_kernel<P>(
to_console2: &mpsc::Sender<Vec<u8>>,
to_serial: &mpsc::Sender<Vec<u8>>,
from_serial: &mut mpsc::Receiver<Vec<u8>>,
async fn send_kernel<P: ThePath>(
to_console2: &Sender,
to_serial: &Sender,
from_serial: &mut Receiver,
kernel: P,
) -> Result<()>
where
P: ThePath,
{
) -> Result<()> {
let (kernel_file, kernel_size) = load_kernel(to_console2, kernel).await?;
to_console2.send("⏩ Sending image size\n".into()).await?;
to_serial.send(kernel_size.to_le_bytes().into()).await?;
to_console2
.send(Ok(Message::Text("⏩ Sending image size\n".into())))
.await?;
to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(
&kernel_size.to_le_bytes(),
))))
.await?;
// Wait for OK response
expect(to_console2, from_serial, "OK").await?;
to_console2.send("⏩ Sending kernel image\n".into()).await?;
to_console2
.send(Ok(Message::Text("⏩ Sending kernel image\n".into())))
.await?;
let mut hasher = SeaHasher::new();
let mut reader = BufReader::with_capacity(1, kernel_file);
loop {
let length = {
let buf = reader.fill_buf()?;
to_serial.send(buf.into()).await?;
to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(buf))))
.await?;
hasher.write(buf);
buf.len()
};
@ -110,10 +135,16 @@ where
let hashed_value: u64 = hasher.finish();
to_console2
.send(format!("⏩ Sending image checksum {:x}\n", hashed_value).into())
.send(Ok(Message::Text(
format!("⏩ Sending image checksum {:x}\n", hashed_value).into(),
)))
.await?;
to_serial.send(hashed_value.to_le_bytes().into()).await?;
to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(
&hashed_value.to_le_bytes(),
))))
.await?;
expect(to_console2, from_serial, "OK").await?;
@ -124,8 +155,8 @@ where
async fn serial_loop(
mut port: tokio_serial::SerialStream,
to_console: mpsc::Sender<Vec<u8>>,
mut from_console: mpsc::Receiver<Vec<u8>>,
to_console: Sender,
mut from_console: Receiver,
) -> Result<()> {
let mut buf = [0; 256];
loop {
@ -134,8 +165,13 @@ async fn serial_loop(
Some(msg) = from_console.recv() => {
// debug!("serial write {} bytes", msg.len());
tokio::io::AsyncWriteExt::write_all(&mut port, msg.as_ref()).await?;
}
match msg.unwrap() {
Message::Text(s) => {
tokio::io::AsyncWriteExt::write_all(&mut port, s.as_bytes()).await?;
},
Message::Binary(b) => tokio::io::AsyncWriteExt::write_all(&mut port, b.as_ref()).await?,
}
}
res = port.read(&mut buf) => {
match res {
@ -145,7 +181,9 @@ async fn serial_loop(
}
Ok(n) => {
// debug!("Serial read {n} bytes.");
to_console.send(buf[0..n].to_owned()).await?;
// let codec = Utf8Codec::new(buf);
let s = String::from_utf8_lossy(&buf[0..n]);
to_console.send(Ok(Message::Text(s.to_string()))).await?;
}
Err(e) => {
// if e.kind() == ErrorKind::TimedOut {
@ -162,11 +200,57 @@ async fn serial_loop(
}
}
// Always send Binary() to serial
// Convert Text() to bytes and send in serial_loop
// Receive and convert bytes to Text() in serial_loop
#[derive(Clone, Debug)]
enum Message {
Binary(Bytes),
Text(String),
}
// impl Message {
// pub fn len(&self) -> usize {
// match self {
// Message::Binary(b) => b.len(),
// Message::Text(s) => s.len(),
// }
// }
// }
impl std::fmt::Display for Message {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Message::Binary(b) => {
for c in b {
write!(f, "{})", c)?;
}
Ok(())
}
Message::Text(s) => write!(f, "{}", s),
}
}
}
// impl Buf for Message {
// fn remaining(&self) -> usize {
// todo!()
// }
//
// fn chunk(&self) -> &[u8] {
// todo!()
// }
//
// fn advance(&mut self, cnt: usize) {
// todo!()
// }
// }
async fn console_loop<P>(
to_console2: mpsc::Sender<Vec<u8>>,
mut from_internal: mpsc::Receiver<Vec<u8>>,
to_serial: mpsc::Sender<Vec<u8>>,
mut from_serial: mpsc::Receiver<Vec<u8>>,
to_console2: Sender,
mut from_internal: Receiver,
to_serial: Sender,
mut from_serial: Receiver,
kernel: P,
) -> Result<()>
where
@ -183,33 +267,39 @@ where
biased;
Some(received) = from_internal.recv() => {
for &x in &received[..] {
execute!(w, style::Print(format!("{}", x as char)))?;
if let Ok(message) = received {
execute!(w, style::Print(message))?;
w.flush()?;
}
w.flush()?;
}
Some(received) = from_serial.recv() => {
// execute!(w, cursor::MoveToNextLine(1), style::Print(format!("[>>] Received {} bytes from serial", from_serial.len())), cursor::MoveToNextLine(1))?;
Some(received) = from_serial.recv() => { // returns Vec<char>
if let Ok(received) = received {
let Message::Text(received) = received else {
unreachable!();
};
execute!(w, cursor::MoveToNextLine(1), style::Print(format!("[>>] Received {} bytes from serial", received.len())), cursor::MoveToNextLine(1))?;
for &x in &received[..] {
if x == 0x3 {
// execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received a BREAK"), cursor::MoveToNextLine(1))?;
breaks += 1;
// Await for 3 consecutive \3 to start downloading
if breaks == 3 {
// execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received 3 BREAKs"), cursor::MoveToNextLine(1))?;
breaks = 0;
send_kernel(&to_console2, &to_serial, &mut from_serial, kernel.clone()).await?;
to_console2.send("🦀 Send successful, pass-through\n".into()).await?;
for x in received.chars() {
if x == 0x3 as char {
// execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received a BREAK"), cursor::MoveToNextLine(1))?;
breaks += 1;
// Await for 3 consecutive \3 to start downloading
if breaks == 3 {
// execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received 3 BREAKs"), cursor::MoveToNextLine(1))?;
breaks = 0;
send_kernel(&to_console2, &to_serial, &mut from_serial, kernel.clone()).await?;
to_console2.send(Ok(Message::Text("🦀 Send successful, pass-through\n".into()))).await?;
}
} else {
while breaks > 0 {
execute!(w, style::Print(format!("{}", 3 as char)))?;
breaks -= 1;
}
// TODO decode buf with Utf8Codec here?
execute!(w, style::Print(format!("{}", x)))?;
w.flush()?;
}
} else {
while breaks > 0 {
execute!(w, style::Print(format!("{}", 3 as char)))?;
breaks -= 1;
}
execute!(w, style::Print(format!("{}", x as char)))?;
w.flush()?;
}
}
}
@ -221,7 +311,7 @@ where
return Ok(());
}
if let Some(key) = handle_key_event(key_event) {
to_serial.send(key.to_vec()).await?;
to_serial.send(Ok(Message::Binary(Bytes::copy_from_slice(&key)))).await?;
// Local echo
execute!(w, style::Print(format!("{:?}", key)))?;
w.flush()?;
@ -244,8 +334,15 @@ where
P: ThePath,
{
// read from serial -> to_console==>from_serial -> output to console
let (to_console, from_serial) = mpsc::channel(256);
let (to_console2, from_internal) = mpsc::channel(256);
let (to_console, from_serial) = mpsc::channel::<Result<Message>>(256);
let (to_console2, from_internal) = mpsc::channel::<Result<Message>>(256);
// Make a Stream from Receiver
// let stream = ReceiverStream::new(from_serial);
// // Make AsyncRead from Stream
// let async_stream = StreamReader::new(stream);
// // Make FramedRead (Stream+Sink) from AsyncRead
// let from_serial = FramedRead::new(async_stream, Utf8Codec::new());
// read from console -> to_serial==>from_console -> output to serial
let (to_serial, from_console) = mpsc::channel(256);
@ -364,7 +461,7 @@ async fn main() -> Result<()> {
execute!(
stdout,
cursor::RestorePosition,
style::Print("[>>] Opening serial port ")
style::Print(" Opening serial port ")
)?;
// tokio_serial::new() creates a builder with 8N1 setup without flow control by default.