159 lines
4.1 KiB
Rust
159 lines
4.1 KiB
Rust
pub mod fb128;
|
|
|
|
use crypto_bigint::{Limb, RandomMod, NonZero, Encoding};
|
|
use base64::prelude::{BASE64_STANDARD, Engine};
|
|
use crate::{FBError, Packing, FBKey, Encode};
|
|
use rand::{Rng, prelude::IteratorRandom};
|
|
use std::cmp::PartialOrd;
|
|
|
|
pub struct FBObj<T> {
|
|
pub(crate) c: Vec<T>,
|
|
pub(crate) r: Vec<T>,
|
|
}
|
|
|
|
pub trait FalseBottom<T>: FBBlockOperations<T>
|
|
where
|
|
T: RandomMod + ModPrime<T> + Encoding + PartialOrd + Packing<T>
|
|
{
|
|
const MODULUS: NonZero<T>;
|
|
|
|
fn init(n: usize, k: usize) -> Result<FBObj<T>, FBError>
|
|
where
|
|
T: RandomMod
|
|
{
|
|
if n < k || k < 2 {
|
|
return Err(FBError::InvalidParams);
|
|
}
|
|
let mut rng = rand::thread_rng();
|
|
let r = (0..k)
|
|
.map(|_| T::random_mod(&mut rng, &Self::MODULUS))
|
|
.collect();
|
|
let c = (0..n)
|
|
.map(|_| T::random_mod(&mut rng, &Self::MODULUS))
|
|
.collect();
|
|
|
|
Ok(FBObj { c, r })
|
|
}
|
|
|
|
fn add(&mut self, msg: &[u8]) -> FBKey {
|
|
let indices = T::pack(msg).into_iter()
|
|
.map(|msg_uint| self.add_u128(&msg_uint))
|
|
.collect();
|
|
|
|
FBKey { indices }
|
|
}
|
|
|
|
fn decrypt(&self, key: &FBKey) -> Result<Vec<u8>, FBError> {
|
|
let decr = key.indices.iter()
|
|
.map(|index_row| self.decrypt_u128(&index_row))
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
let msg = T::unpack(&decr)?;
|
|
|
|
Ok(msg)
|
|
}
|
|
}
|
|
|
|
pub trait ModPrime<T> {
|
|
const PRIME_POS: Limb;
|
|
const PRIME: T;
|
|
fn mul_mod_prime(&self, rhs: &Self) -> Self;
|
|
fn add_mod_prime(&self, rhs: &Self) -> Self;
|
|
fn sub_mod_prime(&self, rhs: &Self) -> Self;
|
|
fn inv_mod_prime(&self) -> Self;
|
|
}
|
|
|
|
|
|
pub trait FBBlockOperations<T>
|
|
where
|
|
T: RandomMod + ModPrime<T>
|
|
{
|
|
const P: T;
|
|
const P_POS: Limb;
|
|
|
|
fn cipher(&self) -> &Vec<T>;
|
|
fn cipher_mut(&mut self) -> &mut Vec<T>;
|
|
fn keybase(&self) -> &Vec<T>;
|
|
|
|
fn add_u128(&mut self, msg_uint: &T) -> Vec<(usize, usize)> {
|
|
let c = self.cipher();
|
|
let r = self.keybase();
|
|
let mut rng = rand::thread_rng();
|
|
let n = rng.gen_range(2..=r.len());
|
|
let mut c_i = (0..c.len()).choose_multiple(&mut rng, n - 1);
|
|
let r_i = (0..r.len()).choose_multiple(&mut rng, n);
|
|
let mut sum = T::ZERO;
|
|
for (&ci, &ri) in c_i.iter().zip( r_i.iter() ) {
|
|
sum = sum.add_mod_prime( &c[ci].mul_mod_prime(&r[ri]) );
|
|
}
|
|
let ri_last = *r_i.last()
|
|
.expect("r_i will contain at least 2 elements");
|
|
let mod_inv = r[ri_last].inv_mod_prime();
|
|
let c_new_el = msg_uint.sub_mod_prime(&sum)
|
|
.mul_mod_prime(&mod_inv);
|
|
let c = self.cipher_mut();
|
|
c.push(c_new_el);
|
|
c_i.push(c.len() - 1);
|
|
let indices = c_i.into_iter()
|
|
.zip(r_i.into_iter())
|
|
.collect();
|
|
|
|
indices
|
|
}
|
|
|
|
fn decrypt_u128(&self, indices: &[(usize, usize)]) -> Result<T, FBError> {
|
|
let (r, c) = (self.keybase(), self.cipher());
|
|
if indices.len() > r.len() {
|
|
return Err(FBError::InvalidKey);
|
|
}
|
|
let mut msg = T::ZERO;
|
|
for &(ci, ri) in indices {
|
|
let c_el = c.get(ci).ok_or(FBError::InvalidKey)?;
|
|
let r_el = r.get(ri).ok_or(FBError::InvalidKey)?;
|
|
msg = msg.add_mod_prime(&c_el.mul_mod_prime(&r_el));
|
|
}
|
|
|
|
Ok(msg)
|
|
}
|
|
}
|
|
|
|
impl<T> Encode<T> for FBObj<T>
|
|
where
|
|
T: Encoding + crypto_bigint::Bounded,
|
|
<T as Encoding>::Repr: Iterator + for <'a> From<&'a [u8]>,
|
|
Vec<u8>: FromIterator<<<T as Encoding>::Repr as Iterator>::Item>
|
|
{
|
|
const BYTES: usize = T::BYTES;
|
|
|
|
fn export(&self) -> (String, String) {
|
|
let c_bytes: Vec<u8> = self.c.iter()
|
|
.flat_map(|bigint| bigint.to_le_bytes())
|
|
.collect();
|
|
let r_bytes: Vec<u8> = self.r.iter()
|
|
.flat_map(|bigint| bigint.to_le_bytes())
|
|
.collect();
|
|
|
|
(BASE64_STANDARD.encode(c_bytes), BASE64_STANDARD.encode(r_bytes))
|
|
}
|
|
|
|
fn import(cipher: &str, keybase: &str) -> Result<Self, FBError>
|
|
where
|
|
Self: Sized
|
|
{
|
|
let c_bytes = BASE64_STANDARD.decode(cipher)
|
|
.map_err(|_| FBError::DecodeError)?;
|
|
let c: Vec<T> = c_bytes.chunks_exact(Self::BYTES)
|
|
.map(|chunk| T::from_le_bytes(chunk.try_into().unwrap()))
|
|
.collect();
|
|
let r_bytes = BASE64_STANDARD.decode(keybase)
|
|
.map_err(|_| FBError::DecodeError)?;
|
|
let r: Vec<T> = r_bytes.chunks_exact(Self::BYTES)
|
|
.map(|chunk| T::from_le_bytes(chunk.try_into().unwrap()))
|
|
.collect();
|
|
if c.len() < r.len() || r.len() < 2 {
|
|
return Err(FBError::InvalidParams);
|
|
}
|
|
|
|
Ok(Self {c, r})
|
|
}
|
|
}
|