Skip to content

Commit

Permalink
Refactor PRG to single impl
Browse files Browse the repository at this point in the history
  • Loading branch information
myl7 committed Jun 27, 2024
1 parent 1a8849e commit ca62f99
Show file tree
Hide file tree
Showing 15 changed files with 249 additions and 355 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ use rand::prelude::*;
// Matyas-Meyer-Oseas (via AES128) provides 128-bit security and should be enough.
// Hirose (via AES256) still only provides 128-bit security because the output is not chained.
// But Hirose can be helpful is you are forced to choose AES256.
use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dcf::{Dcf, DcfImpl};

let keys: [[u8; 32]; 2] = thread_rng().gen();
let prg = Aes128MatyasMeyerOseasPrg::<16, 2>::new(std::array::from_fn(|i| &keys[i]));
let prg = Aes128MatyasMeyerOseasPrg::<16, 2, 2>::new(std::array::from_fn(|i| &keys[i]));
// DCF for example
let dcf = DcfImpl::<16, 16, _>::new(prg);
```
Expand Down
4 changes: 2 additions & 2 deletions benches/dcf_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::prelude::*;

use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dcf::{BoundState, CmpFn, Dcf, DcfImpl};
use fss_rs::group::byte::ByteGroup;
use fss_rs::group::Group;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;

fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
c: &mut Criterion,
Expand All @@ -16,7 +16,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 2, CIPHER_N>::new(keys_iter);
let dcf = DcfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);

let mut s0s = [[0; OUT_BLEN]; 2];
Expand Down
4 changes: 2 additions & 2 deletions benches/dcf_eval_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::prelude::*;

use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dcf::{BoundState, CmpFn, Dcf, DcfImpl};
use fss_rs::group::byte::ByteGroup;
use fss_rs::group::Group;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;

const POINT_NUM: usize = 10000;

Expand All @@ -18,7 +18,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 2, CIPHER_N>::new(keys_iter);
let dcf = DcfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);

let mut s0s = [[0; OUT_BLEN]; 2];
Expand Down
4 changes: 2 additions & 2 deletions benches/dcf_full_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::prelude::*;

use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dcf::{BoundState, CmpFn, Dcf, DcfImpl};
use fss_rs::group::byte::ByteGroup;
use fss_rs::group::Group;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;

fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
c: &mut Criterion,
Expand All @@ -17,7 +17,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 2, CIPHER_N>::new(keys_iter);
let dcf = DcfImpl::<IN_BLEN, OUT_BLEN, _>::new_with_filter(prg, filter_bitn);

let mut s0s = [[0; OUT_BLEN]; 2];
Expand Down
4 changes: 2 additions & 2 deletions benches/dcf_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::prelude::*;

use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dcf::{BoundState, CmpFn, Dcf, DcfImpl};
use fss_rs::group::byte::ByteGroup;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;

fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
c: &mut Criterion,
Expand All @@ -15,7 +15,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 2, CIPHER_N>::new(keys_iter);
let dcf = DcfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);

let mut s0s = [[0; OUT_BLEN]; 2];
Expand Down
4 changes: 2 additions & 2 deletions benches/dpf_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::prelude::*;

use fss_rs::dpf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dpf::{Dpf, DpfImpl, PointFn};
use fss_rs::group::byte::ByteGroup;
use fss_rs::group::Group;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;

fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
c: &mut Criterion,
Expand All @@ -16,7 +16,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 1, CIPHER_N>::new(keys_iter);
let dpf = DpfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);

let mut s0s = [[0; OUT_BLEN]; 2];
Expand Down
4 changes: 2 additions & 2 deletions benches/dpf_eval_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::prelude::*;

use fss_rs::dpf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dpf::{Dpf, DpfImpl, PointFn};
use fss_rs::group::byte::ByteGroup;
use fss_rs::group::Group;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;

const POINT_NUM: usize = 10000;

Expand All @@ -18,7 +18,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 1, CIPHER_N>::new(keys_iter);
let dpf = DpfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);

let mut s0s = [[0; OUT_BLEN]; 2];
Expand Down
4 changes: 2 additions & 2 deletions benches/dpf_full_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::prelude::*;

use fss_rs::dpf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dpf::{Dpf, DpfImpl, PointFn};
use fss_rs::group::byte::ByteGroup;
use fss_rs::group::Group;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;

fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
c: &mut Criterion,
Expand All @@ -17,7 +17,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 1, CIPHER_N>::new(keys_iter);
let dpf = DpfImpl::<IN_BLEN, OUT_BLEN, _>::new_with_filter(prg, filter_bitn);

let mut s0s = [[0; OUT_BLEN]; 2];
Expand Down
4 changes: 2 additions & 2 deletions benches/dpf_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use rand::prelude::*;

use fss_rs::dpf::prg::Aes128MatyasMeyerOseasPrg;
use fss_rs::dpf::{Dpf, DpfImpl, PointFn};
use fss_rs::group::byte::ByteGroup;
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;

fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
c: &mut Criterion,
Expand All @@ -15,7 +15,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
let keys_iter = std::array::from_fn(|i| &keys[i]);

let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 1, CIPHER_N>::new(keys_iter);
let dpf = DpfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);

let mut s0s = [[0; OUT_BLEN]; 2];
Expand Down
37 changes: 16 additions & 21 deletions src/dcf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ use rayon::prelude::*;

use crate::group::Group;
use crate::utils::{xor, xor_inplace};
use crate::{decl_prg_trait, Cw, PointFn, Share};

#[cfg(feature = "prg")]
pub mod prg;
use crate::{Cw, PointFn, Prg, Share};

/// Distributed comparison function API.
///
Expand Down Expand Up @@ -64,22 +61,20 @@ where
}
}

decl_prg_trait!(([u8; OUT_BLEN], [u8; OUT_BLEN], bool));

/// [`Dcf`] impl.
///
/// `$\alpha$` itself is not included (or say exclusive endpoint), which means `$f(\alpha)$ = 0`.
pub struct DcfImpl<const IN_BLEN: usize, const OUT_BLEN: usize, P>
where
P: Prg<OUT_BLEN>,
P: Prg<OUT_BLEN, 2>,
{
prg: P,
filter_bitn: usize,
}

impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DcfImpl<IN_BLEN, OUT_BLEN, P>
where
P: Prg<OUT_BLEN>,
P: Prg<OUT_BLEN, 2>,
{
pub fn new(prg: P) -> Self {
Self {
Expand All @@ -100,7 +95,7 @@ const IDX_R: usize = 1;
impl<const IN_BLEN: usize, const OUT_BLEN: usize, P, G> Dcf<IN_BLEN, OUT_BLEN, G>
for DcfImpl<IN_BLEN, OUT_BLEN, P>
where
P: Prg<OUT_BLEN>,
P: Prg<OUT_BLEN, 2>,
G: Group<OUT_BLEN>,
{
fn gen(
Expand All @@ -119,8 +114,8 @@ where
for i in 0..n {
// MSB is required since we index from high to low in arrays.
let alpha_i = f.alpha.view_bits::<Msb0>()[i];
let [(s0l, v0l, t0l), (s0r, v0r, t0r)] = self.prg.gen(&ss_prev[0]);
let [(s1l, v1l, t1l), (s1r, v1r, t1r)] = self.prg.gen(&ss_prev[1]);
let [([s0l, v0l], t0l), ([s0r, v0r], t0r)] = self.prg.gen(&ss_prev[0]);
let [([s1l, v1l], t1l), ([s1r, v1r], t1r)] = self.prg.gen(&ss_prev[1]);
// MSB is required since we index from high to low in arrays.
let (keep, lose) = if alpha_i {
(IDX_R, IDX_L)
Expand Down Expand Up @@ -201,7 +196,7 @@ where

impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DcfImpl<IN_BLEN, OUT_BLEN, P>
where
P: Prg<OUT_BLEN>,
P: Prg<OUT_BLEN, 2>,
{
/// Eval with single-threading.
/// See [`Dcf::eval`].
Expand Down Expand Up @@ -255,7 +250,7 @@ where

let cw = &k.cws[layer_i];
// `*_hat` before in-place XOR.
let [(mut sl, vl_hat, mut tl), (mut sr, vr_hat, mut tr)] = self.prg.gen(&s);
let [([mut sl, vl_hat], mut tl), ([mut sr, vr_hat], mut tr)] = self.prg.gen(&s);
xor_inplace(&mut sl, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
xor_inplace(&mut sr, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
tl ^= t & cw.tl;
Expand Down Expand Up @@ -291,7 +286,7 @@ where
for i in 0..n {
let cw = &k.cws[i];
// `*_hat` before in-place XOR.
let [(mut sl, vl_hat, mut tl), (mut sr, vr_hat, mut tr)] = self.prg.gen(&s_prev);
let [([mut sl, vl_hat], mut tl), ([mut sr, vr_hat], mut tr)] = self.prg.gen(&s_prev);
xor_inplace(&mut sl, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
xor_inplace(&mut sr, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
tl ^= t_prev & cw.tl;
Expand Down Expand Up @@ -326,9 +321,9 @@ pub enum BoundState {
mod tests {
use rand::prelude::*;

use super::prg::Aes256HirosePrg;
use super::*;
use crate::group::byte::ByteGroup;
use crate::prg::Aes256HirosePrg;

const KEYS: &[&[u8; 32]] = &[
b"j9\x1b_\xb3X\xf33\xacW\x15\x1b\x0812K\xb3I\xb9\x90r\x1cN\xb5\xee9W\xd3\xbb@\xc6d",
Expand All @@ -345,7 +340,7 @@ mod tests {

#[test]
fn test_dcf_gen_then_eval() {
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<16, 16, _>::new(prg);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = CmpFn {
Expand Down Expand Up @@ -377,7 +372,7 @@ mod tests {

#[test]
fn test_dcf_gen_gt_beta_then_eval() {
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<16, 16, _>::new(prg);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = CmpFn {
Expand Down Expand Up @@ -409,7 +404,7 @@ mod tests {

#[test]
fn test_dcf_gen_then_eval_with_filter() {
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<16, 16, _>::new_with_filter(prg, 127);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = CmpFn {
Expand Down Expand Up @@ -441,7 +436,7 @@ mod tests {

#[test]
fn test_dcf_gen_then_eval_not_zeros() {
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<16, 16, _>::new(prg);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = CmpFn {
Expand All @@ -465,7 +460,7 @@ mod tests {
#[test]
fn test_dcf_full_eval() {
let x: [u8; 2] = ALPHAS[2][..2].try_into().unwrap();
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<2, 16, _>::new(prg);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = CmpFn {
Expand All @@ -491,7 +486,7 @@ mod tests {
#[test]
fn test_dcf_full_eval_with_filter() {
let x: [u8; 2] = ALPHAS[2][..2].try_into().unwrap();
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
let dcf = DcfImpl::<2, 16, _>::new_with_filter(prg, 15);
let s0s: [[u8; 16]; 2] = thread_rng().gen();
let f = CmpFn {
Expand Down
Loading

0 comments on commit ca62f99

Please sign in to comment.