vesper/bin/chainofcommand/src/main.rs

547 lines
18 KiB
Rust

#![feature(trait_alias)]
#![allow(stable_features)]
#![feature(let_else)] // stabilised in 1.65.0
#![feature(slice_take)]
use {
anyhow::{anyhow, Result},
bytes::Bytes,
clap::{value_parser, Arg, ArgAction, Command},
crossterm::{
cursor,
event::{Event, EventStream, KeyCode, KeyEvent, KeyModifiers},
execute, style, terminal,
tty::IsTty,
},
defer::defer,
futures::{future::FutureExt, Stream},
seahash::SeaHasher,
std::{
fmt::Formatter,
fs::File,
hash::Hasher,
io::{BufRead, BufReader},
path::Path,
time::Duration,
},
tokio::{io::AsyncReadExt, sync::mpsc},
tokio_serial::{SerialPortBuilderExt, SerialStream},
tokio_stream::StreamExt,
};
// mod utf8_codec;
trait Writable = std::io::Write + Send;
trait ThePath = AsRef<Path> + std::fmt::Display + Clone + Sync + Send + 'static;
trait FramedStream = Stream<Item = Result<Message, anyhow::Error>> + Unpin;
type Sender = mpsc::Sender<Result<Message>>;
type Receiver = mpsc::Receiver<Result<Message>>;
async fn expect(to_console2: &Sender, from_serial: &mut Receiver, m: &str) -> Result<()> {
let mut s = String::new();
for _x in m.chars() {
let next_char = from_serial.recv().await;
let Some(Ok(c)) = next_char else {
return Err(anyhow!(
"Failed to receive expected value {:?}: got empty buf",
m,
));
};
match c {
Message::Text(payload) => {
s.push_str(&payload);
to_console2.send(Ok(Message::Text(payload))).await?;
}
_ => unreachable!(),
}
}
if s != m {
return Err(anyhow!(
"Failed to receive expected value {:?}: got {:?}",
m,
s
));
}
Ok(())
}
async fn load_kernel<P>(to_console2: &Sender, kernel: P) -> Result<(File, u64)>
where
P: ThePath,
{
to_console2
.send(Ok(Message::Text("⏩ Loading kernel image\n".into())))
.await?;
let kernel_file = match std::fs::File::open(kernel.clone()) {
Ok(file) => file,
Err(_) => return Err(anyhow!("Couldn't open kernel file {}", kernel)),
};
let kernel_size: u64 = kernel_file.metadata()?.len();
to_console2
.send(Ok(Message::Text(format!(
"⏩ .. {} ({} bytes)\n",
kernel, kernel_size
))))
.await?;
Ok((kernel_file, kernel_size))
}
async fn send_kernel<P: ThePath>(
to_console2: &Sender,
to_serial: &Sender,
from_serial: &mut Receiver,
kernel: P,
) -> Result<()> {
let (kernel_file, kernel_size) = load_kernel(to_console2, kernel).await?;
to_console2
.send(Ok(Message::Text("⏩ Sending image size\n".into())))
.await?;
to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(
&kernel_size.to_le_bytes(),
))))
.await?;
// Wait for OK response
expect(to_console2, from_serial, "OK").await?;
to_console2
.send(Ok(Message::Text("⏩ Sending kernel image\n".into())))
.await?;
let mut hasher = SeaHasher::new();
let mut reader = BufReader::with_capacity(1, kernel_file);
loop {
let length = {
let buf = reader.fill_buf()?;
to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(buf))))
.await?;
hasher.write(buf);
buf.len()
};
if length == 0 {
break;
}
reader.consume(length);
}
let hashed_value: u64 = hasher.finish();
to_console2
.send(Ok(Message::Text(format!(
"⏩ Sending image checksum {:x}\n",
hashed_value
))))
.await?;
to_serial
.send(Ok(Message::Binary(Bytes::copy_from_slice(
&hashed_value.to_le_bytes(),
))))
.await?;
expect(to_console2, from_serial, "OK").await?;
Ok(())
}
// Async reading using Tokio: https://fasterthanli.me/articles/a-terminal-case-of-linux
async fn serial_loop(
mut port: tokio_serial::SerialStream,
to_console: Sender,
mut from_console: Receiver,
) -> Result<()> {
let mut buf = [0; 256];
loop {
tokio::select! {
// _ = poll_send => {},
Some(msg) = from_console.recv() => {
// debug!("serial write {} bytes", msg.len());
match msg.unwrap() {
Message::Text(s) => {
tokio::io::AsyncWriteExt::write_all(&mut port, s.as_bytes()).await?;
},
Message::Binary(b) => tokio::io::AsyncWriteExt::write_all(&mut port, b.as_ref()).await?,
}
}
res = port.read(&mut buf) => {
match res {
Ok(0) => {
// info!("Serial <EOF>");
return Ok(());
}
Ok(n) => {
// debug!("Serial read {n} bytes.");
// let codec = Utf8Codec::new(buf);
let s = String::from_utf8_lossy(&buf[0..n]);
to_console.send(Ok(Message::Text(s.to_string()))).await?;
}
Err(e) => {
// if e.kind() == ErrorKind::TimedOut {
// execute!(w, style::Print("\r\nTimeout: the serial device has been unplugged!"))?;
// } else {
// execute!(w, style::Print(format!("\r\nSerial Error: {:?}\r", e)))?;
// }
// break;
return Err(anyhow!(e));
}
}
}
}
}
}
// Always send Binary() to serial
// Convert Text() to bytes and send in serial_loop
// Receive and convert bytes to Text() in serial_loop
#[derive(Clone, Debug)]
enum Message {
Binary(Bytes),
Text(String),
}
// impl Message {
// pub fn len(&self) -> usize {
// match self {
// Message::Binary(b) => b.len(),
// Message::Text(s) => s.len(),
// }
// }
// }
impl std::fmt::Display for Message {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Message::Binary(b) => {
for c in b {
write!(f, "{})", c)?;
}
Ok(())
}
Message::Text(s) => write!(f, "{}", s),
}
}
}
// impl Buf for Message {
// fn remaining(&self) -> usize {
// todo!()
// }
//
// fn chunk(&self) -> &[u8] {
// todo!()
// }
//
// fn advance(&mut self, cnt: usize) {
// todo!()
// }
// }
async fn console_loop<P>(
to_console2: Sender,
mut from_internal: Receiver,
to_serial: Sender,
mut from_serial: Receiver,
kernel: P,
) -> Result<()>
where
P: ThePath,
{
let mut w = std::io::stdout();
let mut breaks = 0;
let mut event_reader = EventStream::new();
loop {
tokio::select! {
biased;
Some(received) = from_internal.recv() => {
if let Ok(message) = received {
execute!(w, style::Print(message))?;
w.flush()?;
}
}
Some(received) = from_serial.recv() => { // returns Vec<char>
if let Ok(received) = received {
let Message::Text(received) = received else {
unreachable!();
};
execute!(w, cursor::MoveToNextLine(1), style::Print(format!("[>>] Received {} bytes from serial", received.len())), cursor::MoveToNextLine(1))?;
for x in received.chars() {
if x == 0x3 as char {
// execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received a BREAK"), cursor::MoveToNextLine(1))?;
breaks += 1;
// Await for 3 consecutive \3 to start downloading
if breaks == 3 {
// execute!(w, cursor::MoveToNextLine(1), style::Print("[>>] Received 3 BREAKs"), cursor::MoveToNextLine(1))?;
breaks = 0;
send_kernel(&to_console2, &to_serial, &mut from_serial, kernel.clone()).await?;
to_console2.send(Ok(Message::Text("🦀 Send successful, pass-through\n".into()))).await?;
}
} else {
while breaks > 0 {
execute!(w, style::Print(format!("{}", 3 as char)))?;
breaks -= 1;
}
// TODO decode buf with Utf8Codec here?
execute!(w, style::Print(format!("{}", x)))?;
w.flush()?;
}
}
}
}
maybe_event = event_reader.next().fuse() => {
match maybe_event {
Some(Ok(Event::Key(key_event))) => {
if key_event.code == KeyCode::Char('c') && key_event.modifiers == KeyModifiers::CONTROL {
return Ok(());
}
if let Some(key) = handle_key_event(key_event) {
to_serial.send(Ok(Message::Binary(Bytes::copy_from_slice(&key)))).await?;
// Local echo
execute!(w, style::Print(format!("{:?}", key)))?;
w.flush()?;
}
}
Some(Ok(_)) => {},
Some(Err(e)) => {
execute!(w, style::Print(format!("Console read error: {:?}\r", e)))?;
w.flush()?;
},
None => return Err(anyhow!("woops")),
}
}
}
}
}
async fn main_loop<P>(port: SerialStream, kernel: P) -> Result<()>
where
P: ThePath,
{
// read from serial -> to_console==>from_serial -> output to console
let (to_console, from_serial) = mpsc::channel::<Result<Message>>(256);
let (to_console2, from_internal) = mpsc::channel::<Result<Message>>(256);
// Make a Stream from Receiver
// let stream = ReceiverStream::new(from_serial);
// // Make AsyncRead from Stream
// let async_stream = StreamReader::new(stream);
// // Make FramedRead (Stream+Sink) from AsyncRead
// let from_serial = FramedRead::new(async_stream, Utf8Codec::new());
// read from console -> to_serial==>from_console -> output to serial
let (to_serial, from_console) = mpsc::channel(256);
tokio::spawn(serial_loop(port, to_console.clone(), from_console));
console_loop(to_console2, from_internal, to_serial, from_serial, kernel).await
// TODO: framed
// rx_device -> serial_reader -> app
// app -> serial_writer -> serial_consumer -> (poll_send to drive) -> serial_sink -> tx_device
// let (rx_device, tx_device) = split(port);
// let mut serial_reader = FramedRead::new(rx_device, BytesCodec::new());
// let serial_sink = FramedWrite::new(tx_device, BytesCodec::new());
//
// let (serial_writer, serial_consumer) = mpsc::unbounded::<Bytes>();
// let mut poll_send = serial_consumer.map(Ok).forward(serial_sink);
}
// From remote_serial -- https://github.com/zhp-rs/remote_serial/ (Licensed under MIT License)
fn handle_key_event(key_event: KeyEvent) -> Option<Bytes> {
let mut buf = [0; 4];
let key_str: Option<&[u8]> = match key_event.code {
KeyCode::Backspace => Some(b"\x08"),
KeyCode::Enter => Some(b"\x0D"),
KeyCode::Left => Some(b"\x1b[D"),
KeyCode::Right => Some(b"\x1b[C"),
KeyCode::Home => Some(b"\x1b[H"),
KeyCode::End => Some(b"\x1b[F"),
KeyCode::Up => Some(b"\x1b[A"),
KeyCode::Down => Some(b"\x1b[B"),
KeyCode::Tab => Some(b"\x09"),
KeyCode::Delete => Some(b"\x1b[3~"),
KeyCode::Insert => Some(b"\x1b[2~"),
KeyCode::Esc => Some(b"\x1b"),
KeyCode::Char(ch) => {
if key_event.modifiers & KeyModifiers::CONTROL == KeyModifiers::CONTROL {
buf[0] = ch as u8;
if ch.is_ascii_lowercase() || (ch == ' ') {
buf[0] &= 0x1f;
Some(&buf[0..1])
} else if ('4'..='7').contains(&ch) {
// crossterm returns Control-4 thru 7 for \x1c thru \x1f
buf[0] = (buf[0] + 8) & 0x1f;
Some(&buf[0..1])
} else {
Some(ch.encode_utf8(&mut buf).as_bytes())
}
} else {
Some(ch.encode_utf8(&mut buf).as_bytes())
}
}
_ => None,
};
key_str.map(Bytes::copy_from_slice)
}
// 1. connect to given serial port, e.g. /dev/ttyUSB23234
// 2. Await for \3\3\3 start signal, meanwhile pass-through all traffic to console
// 3. send selected kernel binary with checksum to the target
// 4. go to 2
#[tokio::main]
async fn main() -> Result<()> {
let matches = Command::new("ChainOfCommand - command chainboot protocol")
.about("Use to send freshly built kernel to chainboot-compatible boot loader")
.disable_version_flag(true)
.arg(
Arg::new("port")
.help("The device path to a serial port, e.g. /dev/ttyUSB0")
.required(true),
)
.arg(
Arg::new("baud")
.help("The baud rate to connect at")
.use_value_delimiter(false)
.action(ArgAction::Set)
.value_parser(value_parser!(u32))
.required(true), // .validator(valid_baud),
)
.arg(
Arg::new("kernel")
.long("kernel")
.help("Path of the binary kernel image to send")
.default_value("kernel8.img"),
)
.get_matches();
let port_name = matches
.get_one::<String>("port")
.expect("port must be specified");
let baud_rate = matches
.get_one("baud")
.copied()
.expect("baud rate must be an integer");
let kernel = matches
.get_one::<String>("kernel")
.expect("kernel file must be specified");
// Check that STDIN is a proper tty
if !std::io::stdin().is_tty() {
panic!("Must have a TTY for stdin");
}
// Disable line buffering, local echo, etc.
terminal::enable_raw_mode()?;
defer(|| terminal::disable_raw_mode().unwrap_or(()));
let mut serial_toggle = false;
let mut stdout = std::io::stdout();
execute!(stdout, cursor::SavePosition)?;
loop {
execute!(
stdout,
cursor::RestorePosition,
style::Print("⏩ Opening serial port ")
)?;
// tokio_serial::new() creates a builder with 8N1 setup without flow control by default.
let port = tokio_serial::new(port_name, baud_rate).open_native_async();
if let Err(e) = port {
let cont = match e.kind {
tokio_serial::ErrorKind::NoDevice => true,
tokio_serial::ErrorKind::Io(e)
if e == std::io::ErrorKind::NotFound
|| e == std::io::ErrorKind::PermissionDenied =>
{
true
}
_ => false,
};
if cont {
execute!(
stdout,
cursor::RestorePosition,
style::Print(format!(
"⏳ Waiting for serial port {}\r",
if serial_toggle { "# " } else { " #" }
))
)?;
stdout.flush()?;
serial_toggle = !serial_toggle;
if crossterm::event::poll(Duration::from_millis(1000))? {
if let Event::Key(KeyEvent {
code, modifiers, ..
}) = crossterm::event::read()?
{
if code == KeyCode::Char('c') && modifiers == KeyModifiers::CONTROL {
return Ok(());
}
}
}
continue;
}
return Err(e.into());
}
execute!(
stdout,
style::Print("\n✅ Waiting for handshake, pass-through. 🔌 Power the target now."),
)?;
stdout.flush()?;
// Run in pass-through mode by default.
// Once we receive BREAK (0x3) three times, switch to kernel send mode and upload kernel,
// then switch back to pass-through mode.
// Input from STDIN should pass through to serial
// Input from serial should pass through to STDOUT
let port = port?;
if let Err(e) = main_loop(port, kernel.clone()).await {
execute!(stdout, style::Print(format!("\nError: {:?}\n", e)))?;
stdout.flush()?;
let cont = matches!(e.downcast_ref::<std::io::Error>(),
Some(e) if e.kind() == std::io::ErrorKind::NotFound || e.kind() == std::io::ErrorKind::PermissionDenied)
|| matches!(e.downcast_ref::<tokio_serial::Error>(), Some(e) if e.kind == tokio_serial::ErrorKind::NoDevice)
|| e.downcast_ref::<tokio::sync::mpsc::error::SendError<Vec<u8>>>()
.is_some();
if !cont {
break;
}
} else {
// main_loop() returned Ok() we're good to finish
break;
}
execute!(stdout, cursor::SavePosition)?;
}
Ok(())
}