Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/pose_enc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn bench_poseidon<const T: usize, const RATE: usize, const K: u32>(name: &str, c
let mut ref_pos_enc =
PoseidonCipher::<Fr, FULL_ROUND, PARTIAL_ROUND, T, RATE>::new([key.key0, key.key1]);

let ref_cipher = ref_pos_enc.encrypt(&inputs, &Fr::ONE);
let ref_cipher = ref_pos_enc.encrypt(&inputs, &Fr::ONE).unwrap();

print!("\nmsg length: {:?}\n", MESSAGE_CAPACITY);

Expand Down
2 changes: 1 addition & 1 deletion src/encryption/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ fn test_pos_enc() {

//== Poseidon Encryption ==//

let ref_cipher = ref_pos_enc.encrypt(&inputs, &F::ONE);
let ref_cipher = ref_pos_enc.encrypt(&inputs, &F::ONE).unwrap();

//== Circuit ==//

Expand Down
102 changes: 68 additions & 34 deletions src/encryption/poseidon_enc.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use ff::{Field, FromUniformBytes, PrimeField};
use halo2wrong::curves::bn256;
use poseidon::Poseidon;
use rand_core::Error;

use crate::poseidon::{
self,
chip::{FULL_ROUND, PARTIAL_ROUND},
};

pub const MESSAGE_CAPACITY: usize = 2; //max 31
pub const MESSAGE_CAPACITY: usize = 33;
pub const CIPHER_SIZE: usize = MESSAGE_CAPACITY + 1;

#[derive(Copy, Clone, Debug, Default)]
Expand All @@ -32,7 +33,7 @@ impl<F: PrimeField> PoseidonEncKey<F> {
}
}

#[derive(Debug, Clone, Copy, Default)]
#[derive(Debug, Clone, Copy)]
pub struct PoseidonCipher<
F: PrimeField + FromUniformBytes<64>,
const r_f: usize,
Expand All @@ -46,6 +47,20 @@ pub struct PoseidonCipher<
pub cipher: [F; CIPHER_SIZE],
}

impl<F, const r_f: usize, const r_p: usize, const T: usize, const RATE: usize> Default
for PoseidonCipher<F, r_f, r_p, T, RATE>
where
F: PrimeField + FromUniformBytes<64> + Default,
{
fn default() -> Self {
PoseidonCipher {
cipherKey: [F::default(); 2],
cipherByteSize: Default::default(),
cipher: [F::default(); CIPHER_SIZE], // CIPHER_SIZE에 따라
}
}
}

impl<F, const r_f: usize, const r_p: usize, const T: usize, const RATE: usize>
PoseidonCipher<F, r_f, r_p, T, RATE>
where
Expand Down Expand Up @@ -83,7 +98,7 @@ where
]
}

pub fn encrypt(&mut self, message: &[F], nonce: &F) -> [F; CIPHER_SIZE] {
pub fn encrypt(&mut self, message: &[F], nonce: &F) -> Result<[F; CIPHER_SIZE], Error> {
let mut encrypter = Poseidon::<F, T, RATE>::new_enc(
FULL_ROUND,
PARTIAL_ROUND,
Expand Down Expand Up @@ -111,67 +126,86 @@ where
.skip(1)
.zip(inputs.iter())
{
*word = word.add(input);
*word = word.add(input); // c = s + m, m = c - s
if i < MESSAGE_CAPACITY {
// c_n = p(s+m) + m_n
cipher[i] = word.clone();
i += 1;
}
}
if inputs.len() == RATE {
encrypter.update(&inputs);
} else {

encrypter.update(&inputs);
if inputs.len() < RATE {
encrypter.squeeze(0);
}
}
// encrypter.perm_with_input(&[]);

cipher[MESSAGE_CAPACITY] = encrypter.state.words()[1].clone();

// println!("enc_cipher[MESSAGE_CAPACITY]:{:?}", cipher[MESSAGE_CAPACITY]);

self.cipher = cipher;

cipher
Ok(cipher)
}

pub fn decrypt(&mut self, nonce: &F) -> Option<[F; MESSAGE_CAPACITY]> {
let mut encrypter = Poseidon::<F, T, RATE>::new_enc(
pub fn decrypt(
&mut self,
cipher: &[F; CIPHER_SIZE],
nonce: &F,
) -> Result<[F; MESSAGE_CAPACITY], Error> {
let mut decrypter = Poseidon::<F, T, RATE>::new_enc(
FULL_ROUND,
PARTIAL_ROUND,
self.cipherKey[0],
self.cipherKey[1],
);

let mut message = [F::ZERO; MESSAGE_CAPACITY];
decrypter.update(&vec![]);
decrypter.squeeze(0);

encrypter.update(&vec![]);
encrypter.squeeze(0);

let mut state_2 = encrypter.state.words().clone();

(0..MESSAGE_CAPACITY).for_each(|i| {
message[i] = self.cipher[i] - state_2[(i + 1) % T];
state_2[(i + 1) % T] = self.cipher[i];
});

encrypter.update(&mut message);
encrypter.squeeze(0);
let mut message = [F::ZERO; MESSAGE_CAPACITY];
let mut i = 0;

let state_3 = encrypter.state.words().clone();
let parity = cipher[MESSAGE_CAPACITY];

if self.cipher[MESSAGE_CAPACITY] != state_3[1] {
return None;
for chunk in cipher[..MESSAGE_CAPACITY].chunks(RATE) {
for (word, encrypted_word) in
decrypter.state.words().iter_mut().skip(1).zip(chunk.iter())
{
if i < MESSAGE_CAPACITY {
message[i] = encrypted_word.sub(word.clone());
i += 1;
}
}
// println!(">>1. state:{:?}", decrypter.state.words());
let offset = i % RATE;
if offset == 0 {
decrypter.update(&message[i - RATE..i]);
} else {
// if chunk.len() < RATE {
decrypter.update(&message[i - offset..i]);
decrypter.squeeze(0);
// }
}
}
if parity != decrypter.state.words()[1] {
return Err(Error::new("Invalid cipher text"));
}
Some(message)
// println!(">>2. state:{:?}", decrypter.state.words());
Ok(message)
}
}

#[test]
fn test_encryption() {
let mut cipher = PoseidonCipher::<bn256::Fr, 8, 57, 5, 4>::new([bn256::Fr::ZERO; 2]);
let message = [bn256::Fr::ZERO; MESSAGE_CAPACITY];
let message = [bn256::Fr::ONE; MESSAGE_CAPACITY];

println!("message: {:?}", message);

let cipherText = cipher.encrypt(&message, &bn256::Fr::ONE);
println!("encrypted: {:?}", cipherText);
println!("decrypted: {:?}", cipher.decrypt(&bn256::Fr::ONE));
let cipher_text = cipher.encrypt(&message, &bn256::Fr::ONE).unwrap();
println!("encrypted: {:?}", cipher_text);
println!(
"decrypted: {:?}",
cipher.decrypt(&cipher_text, &bn256::Fr::ONE).unwrap()
);
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ impl<F: PrimeField + FromUniformBytes<64>, const T: usize, const RATE: usize> Ci
// // == Encryption Scheme == //
let mut ref_enc =
PoseidonCipher::<F, FULL_ROUND, PARTIAL_ROUND, T, RATE>::new(pose_key);
let encryption_result = ref_enc.encrypt(&self.message, &F::ONE);
let encryption_result = ref_enc.encrypt(&self.message, &F::ONE).unwrap();
let mut expected_result = vec![];
// assign expected result
for result in &encryption_result {
Expand Down