Added Threading

This commit is contained in:
K Shiva Kiran 2024-04-18 19:21:23 +05:30
parent 1eca9a6aa0
commit 1726d45647
4 changed files with 59 additions and 53 deletions

View File

@ -1,16 +1,19 @@
// SPDX-License-Identifier: GPL-3.0-or-later // SPDX-License-Identifier: GPL-3.0-or-later
use crate::{FBError, FBKey, FBObj, FBObjTrait, FieldOps, Packing}; use crate::{FBError, FBKey, FBObj, FBObjTrait, FieldOps, Packing};
use crypto_bigint::{NonZero, RandomMod}; use crypto_bigint::{NonZero, RandomMod};
use rand::{Rng, seq::IteratorRandom}; use rand::{rngs::ThreadRng, seq::index, Rng};
use rayon::iter::*;
use std::marker::Send;
use std::sync::RwLock;
pub trait FBAlgo<T> pub trait FBAlgo<T>
where where
Self: BlockOps<T>, Self: BlockOps<T> + Sync + Send,
T: FieldOps + Packing + RandomMod, T: FieldOps + Packing + RandomMod + Send + Sync,
{ {
const MODULUS: NonZero<T>; const MODULUS: NonZero<T>;
/// Creates a new [`FBObj`]. /// Creates a new [`FBObj`].
/// The keybase and ciphertext are initialized from random values. /// The keybase and ciphertext are initialized from random values.
/// Bounds: `2 <= keybase_len <= cipher_len` /// Bounds: `2 <= keybase_len <= cipher_len`
/// # Errors /// # Errors
@ -23,9 +26,10 @@ where
let r = (0..keybase_len) let r = (0..keybase_len)
.map(|_| T::random_mod(&mut rng, &Self::MODULUS)) .map(|_| T::random_mod(&mut rng, &Self::MODULUS))
.collect(); .collect();
let c = (0..cipher_len) let c_vec = (0..cipher_len)
.map(|_| T::random_mod(&mut rng, &Self::MODULUS)) .map(|_| T::random_mod(&mut rng, &Self::MODULUS))
.collect(); .collect();
let c = RwLock::new(c_vec);
Ok(FBObj { c, r }) Ok(FBObj { c, r })
} }
@ -33,8 +37,11 @@ where
/// Adds the provided message to the ciphertext. /// Adds the provided message to the ciphertext.
fn add(&mut self, msg: &[u8]) -> FBKey { fn add(&mut self, msg: &[u8]) -> FBKey {
let indices = T::pack(msg) let indices = T::pack(msg)
.into_iter() .into_par_iter()
.map(|msg_uint| self.add_block(&msg_uint)) .map_init(
|| rand::thread_rng(),
|rng, index_row| self.add_block(rng, &index_row),
)
.collect(); .collect();
FBKey { indices } FBKey { indices }
@ -44,8 +51,7 @@ where
/// # Errors /// # Errors
/// [InvalidKey](FBError::InvalidKey) /// [InvalidKey](FBError::InvalidKey)
fn decrypt(&self, key: &FBKey) -> Result<Vec<u8>, FBError> { fn decrypt(&self, key: &FBKey) -> Result<Vec<u8>, FBError> {
let decr = key.indices let decr = key.indices.iter()
.iter()
.map(|index_row| self.decrypt_block(&index_row)) .map(|index_row| self.decrypt_block(&index_row))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let mut msg = T::unpack(decr)?; let mut msg = T::unpack(decr)?;
@ -55,40 +61,42 @@ where
} }
} }
pub trait BlockOps<T> pub trait BlockOps<T>
where where
Self: FBObjTrait<T>, Self: FBObjTrait<T>,
T: FieldOps + RandomMod, T: FieldOps + RandomMod + Send + Sync,
{ {
fn add_block(&mut self, msg_uint: &T) -> Vec<(usize, usize)> { fn add_block(&self, rng: &mut ThreadRng, msg_uint: &T) -> Vec<(usize, usize)> {
let c = self.cipher();
let r = self.keybase(); let r = self.keybase();
let mut rng = rand::thread_rng();
let n = rng.gen_range(2..=r.len()); let n = rng.gen_range(2..=r.len());
let mut c_i = (0..c.len()).choose_multiple(&mut rng, n-1); let r_i = index::sample(rng, r.len(), n);
let r_i = (0..r.len()).choose_multiple(&mut rng, n); let ri_last = r_i.iter().last()
let mut sum = T::ZERO;
for (&ci, &ri) in c_i.iter().zip( r_i.iter() ) {
sum = sum.field_add( &c[ci].field_mul(&r[ri]) );
}
let ri_last = *r_i.last()
.expect("r_i will contain at least 2 elements"); .expect("r_i will contain at least 2 elements");
let mod_inv = r[ri_last].field_inv(); let ri_last_inv = r[ri_last].field_inv();
let c_new_el = msg_uint.field_sub(&sum) let c_i;
.field_mul(&mod_inv); let c_len;
let c = self.cipher_mut(); {
c.push(c_new_el); let mut c = self.cipher().write().unwrap();
c_i.push(c.len() - 1); c_i = index::sample(rng, c.len(), n - 1);
let indices = c_i.into_iter() let sum = c_i.iter()
.zip(r_i.into_iter()) .zip(r_i.iter())
.map(|(ci, ri)| c[ci].field_mul(&r[ri]))
.reduce(|acc, i| acc.field_add(&i))
.unwrap();
let c_new_el = msg_uint.field_sub(&sum).field_mul(&ri_last_inv);
c.push(c_new_el);
c_len = c.len();
}
let indices = c_i.iter()
.chain([c_len - 1].into_iter())
.zip(r_i.iter())
.collect(); .collect();
indices indices
} }
fn decrypt_block(&self, indices: &[(usize, usize)]) -> Result<T, FBError> { fn decrypt_block(&self, indices: &[(usize, usize)]) -> Result<T, FBError> {
let (c, r) = (self.cipher(), self.keybase()); let (c, r) = (self.cipher().read().unwrap(), self.keybase());
if indices.len() > r.len() { if indices.len() > r.len() {
return Err(FBError::InvalidKey); return Err(FBError::InvalidKey);
} }
@ -96,20 +104,20 @@ where
for &(ci, ri) in indices { for &(ci, ri) in indices {
let c_el = c.get(ci).ok_or(FBError::InvalidKey)?; let c_el = c.get(ci).ok_or(FBError::InvalidKey)?;
let r_el = r.get(ri).ok_or(FBError::InvalidKey)?; let r_el = r.get(ri).ok_or(FBError::InvalidKey)?;
msg = msg.field_add( &c_el.field_mul(&r_el) ); msg = msg.field_add(&c_el.field_mul(&r_el));
} }
Ok(msg) Ok(msg)
} }
} }
#[test] #[test]
fn encrypt_u128() { fn encrypt_u128() {
use crypto_bigint::U128; use crypto_bigint::U128;
let msg = U128::from_u32(100); let msg = U128::from_u32(100);
let mut fb = FBObj::<U128>::init(18, 12).unwrap(); let fb = FBObj::<U128>::init(18, 12).unwrap();
let key = fb.add_block(&msg); let rng = &mut rand::thread_rng();
let key = fb.add_block(rng, &msg);
let decrypted = fb.decrypt_block(&key).unwrap(); let decrypted = fb.decrypt_block(&key).unwrap();
assert_eq!(msg, decrypted); assert_eq!(msg, decrypted);
} }

View File

@ -1,7 +1,8 @@
// SPDX-License-Identifier: GPL-3.0-or-later // SPDX-License-Identifier: GPL-3.0-or-later
use crate::{FBError, FBObj, FBObjTrait}; use crate::{FBError, FBObj, FBObjTrait};
use crypto_bigint::{ArrayEncoding, Bounded, generic_array::GenericArray};
use base64::{prelude::BASE64_STANDARD, Engine}; use base64::{prelude::BASE64_STANDARD, Engine};
use crypto_bigint::{generic_array::GenericArray, ArrayEncoding, Bounded};
use std::sync::RwLock;
pub trait Encode<T> pub trait Encode<T>
where where
@ -10,7 +11,8 @@ where
{ {
/// Returns the byte representation of the ciphertext and keybase. /// Returns the byte representation of the ciphertext and keybase.
fn to_bytes(&self) -> (Vec<u8>, Vec<u8>) { fn to_bytes(&self) -> (Vec<u8>, Vec<u8>) {
let c = self.cipher().iter() let c = self.cipher().read().unwrap()
.iter()
.flat_map(|bigint| bigint.to_le_byte_array()) .flat_map(|bigint| bigint.to_le_byte_array())
.collect(); .collect();
let r = self.keybase().iter() let r = self.keybase().iter()
@ -32,19 +34,17 @@ where
/// # Errors /// # Errors
/// - [InvalidParams](FBError::InvalidParams) - Are the parameters in the wrong order? /// - [InvalidParams](FBError::InvalidParams) - Are the parameters in the wrong order?
fn from_bytes(cipher: &[u8], keybase: &[u8]) -> Result<FBObj<T>, FBError> { fn from_bytes(cipher: &[u8], keybase: &[u8]) -> Result<FBObj<T>, FBError> {
let chunk_to_uint = |chunk| { let chunk_to_uint = |chunk| T::from_le_byte_array(GenericArray::clone_from_slice(chunk));
T::from_le_byte_array( let c_vec: Vec<T> = cipher.chunks_exact(T::BYTES)
GenericArray::clone_from_slice(chunk)
)};
let c: Vec<T> = cipher.chunks_exact(T::BYTES)
.map(chunk_to_uint) .map(chunk_to_uint)
.collect(); .collect();
let r: Vec<T> = keybase.chunks_exact(T::BYTES) let r: Vec<T> = keybase.chunks_exact(T::BYTES)
.map(chunk_to_uint) .map(chunk_to_uint)
.collect(); .collect();
if r.len() > c.len() || r.len() < 2 { if r.len() > c_vec.len() || r.len() < 2 {
return Err(FBError::InvalidParams); return Err(FBError::InvalidParams);
} }
let c = RwLock::new(c_vec);
Ok(FBObj {c, r}) Ok(FBObj {c, r})
} }
@ -54,7 +54,7 @@ where
/// # Errors /// # Errors
/// - [DecodeError](FBError::DecodeError) /// - [DecodeError](FBError::DecodeError)
/// - [InvalidParams](FBError::InvalidParams) - Are the parameters in the wrong order? /// - [InvalidParams](FBError::InvalidParams) - Are the parameters in the wrong order?
fn import(cipher: &str, keybase: &str) -> Result<FBObj<T>, FBError> { fn import(cipher: &str, keybase: &str) -> Result<FBObj<T>, FBError> {
let c_bytes = BASE64_STANDARD.decode(cipher) let c_bytes = BASE64_STANDARD.decode(cipher)
.map_err(|_| FBError::DecodeError)?; .map_err(|_| FBError::DecodeError)?;
let r_bytes = BASE64_STANDARD.decode(keybase) let r_bytes = BASE64_STANDARD.decode(keybase)

View File

@ -3,26 +3,24 @@ pub mod fb128;
use crate::{BlockOps, Encode, FieldOps}; use crate::{BlockOps, Encode, FieldOps};
use crypto_bigint::{ArrayEncoding, Bounded, RandomMod}; use crypto_bigint::{ArrayEncoding, Bounded, RandomMod};
use std::sync::RwLock;
use std::marker::Sync;
/// The False Bottom Object holds the ciphertext and the keybase. The provided type aliases can be used to pick a block size. /// The False Bottom Object holds the ciphertext and the keybase. The provided type aliases can be used to pick a block size.
pub struct FBObj<T> { pub struct FBObj<T> {
pub(crate) c: Vec<T>, pub(crate) c: RwLock<Vec<T>>,
pub(crate) r: Vec<T>, pub(crate) r: Vec<T>,
} }
pub trait FBObjTrait<T> { pub trait FBObjTrait<T> {
fn cipher(&self) -> &Vec<T>; fn cipher(&self) -> &RwLock<Vec<T>>;
fn cipher_mut(&mut self) -> &mut Vec<T>;
fn keybase(&self) -> &Vec<T>; fn keybase(&self) -> &Vec<T>;
} }
impl<T> FBObjTrait<T> for FBObj<T> { impl<T> FBObjTrait<T> for FBObj<T> {
fn cipher(&self) -> &Vec<T> { fn cipher(&self) -> &RwLock<Vec<T>> {
&self.c &self.c
} }
fn cipher_mut(&mut self) -> &mut Vec<T> {
&mut self.c
}
fn keybase(&self) -> &Vec<T> { fn keybase(&self) -> &Vec<T> {
&self.r &self.r
} }
@ -30,10 +28,10 @@ impl<T> FBObjTrait<T> for FBObj<T> {
impl<T> BlockOps<T> for FBObj<T> impl<T> BlockOps<T> for FBObj<T>
where where
T: FieldOps + RandomMod T: FieldOps + RandomMod + Send + Sync
{} {}
impl<T> Encode<T> for FBObj<T> impl<T> Encode<T> for FBObj<T>
where where
T: ArrayEncoding + Bounded, T: ArrayEncoding + Bounded
{} {}

View File

@ -24,7 +24,7 @@ impl FieldOps for U128 {
self.mul_mod_special(rhs, PRIME_POS) self.mul_mod_special(rhs, PRIME_POS)
} }
fn field_inv(&self) -> Self { fn field_inv(&self) -> Self {
self.inv_mod(&PRIME).0 self.inv_odd_mod(&PRIME).0
} }
} }