Skip to content

Commit

Permalink
RawSubscripts correspond to GEMM and DOT
Browse files Browse the repository at this point in the history
  • Loading branch information
termoshtt committed Mar 5, 2024
1 parent 4417cd3 commit 5cb284d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 66 deletions.
63 changes: 1 addition & 62 deletions einsum-codegen/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
//! and corresponding EBNF-like schema are written in each document page.
//!
use anyhow::{bail, Error, Result};
use crate::{RawSubscript, RawSubscripts};
use nom::{
bytes::complete::*, character::complete::*, combinator::*, multi::*, sequence::*, IResult,
Parser,
};
use std::fmt;

/// index = `a` | `b` | `c` | `d` | `e` | `f` | `g` | `h` | `i` | `j` | `k` | `l` |`m` | `n` | `o` | `p` | `q` | `r` | `s` | `t` | `u` | `v` | `w` | `x` |`y` | `z`;
pub fn index(input: &str) -> IResult<&str, char> {
Expand All @@ -34,66 +33,6 @@ pub fn subscript(input: &str) -> IResult<&str, RawSubscript> {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum RawSubscript {
/// Indices without ellipsis, e.g. `ijk`
Indices(Vec<char>),
/// Indices with ellipsis, e.g. `i...j`
Ellipsis { start: Vec<char>, end: Vec<char> },
}

impl<const N: usize> PartialEq<[char; N]> for RawSubscript {
fn eq(&self, other: &[char; N]) -> bool {
match self {
RawSubscript::Indices(indices) => indices.eq(other),
_ => false,
}
}
}

impl fmt::Display for RawSubscript {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RawSubscript::Indices(indices) => {
for i in indices {
write!(f, "{}", i)?;
}
}
RawSubscript::Ellipsis { start, end } => {
for i in start {
write!(f, "{}", i)?;
}
write!(f, "___")?;
for i in end {
write!(f, "{}", i)?;
}
}
}
Ok(())
}
}

/// Einsum subscripts, e.g. `ij,jk->ik`
#[derive(Debug, PartialEq, Eq)]
pub struct RawSubscripts {
/// Input subscript, `ij` and `jk`
pub inputs: Vec<RawSubscript>,
/// Output subscript. This may be empty for "implicit mode".
pub output: Option<RawSubscript>,
}

impl std::str::FromStr for RawSubscripts {
type Err = Error;
fn from_str(input: &str) -> Result<Self> {
use nom::Finish;
if let Ok((_, ss)) = subscripts(input).finish() {
Ok(ss)
} else {
bail!("Invalid subscripts: {}", input);
}
}
}

/// subscripts = [subscript] {`,` [subscript]} \[ `->` [subscript] \]
pub fn subscripts(input: &str) -> IResult<&str, RawSubscripts> {
let (input, _head) = multispace0(input)?;
Expand Down
78 changes: 74 additions & 4 deletions einsum-codegen/src/subscripts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Einsum subscripts, e.g. `ij,jk->ik`
use crate::{parser::*, *};
use anyhow::Result;
use anyhow::{bail, Error, Result};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
use std::{
Expand All @@ -9,6 +9,76 @@ use std::{
str::FromStr,
};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum RawSubscript {
/// Indices without ellipsis, e.g. `ijk`
Indices(Vec<char>),
/// Indices with ellipsis, e.g. `i...j`
Ellipsis { start: Vec<char>, end: Vec<char> },
}

impl<const N: usize> PartialEq<[char; N]> for RawSubscript {
fn eq(&self, other: &[char; N]) -> bool {
match self {
RawSubscript::Indices(indices) => indices.eq(other),
_ => false,
}
}
}

impl fmt::Display for RawSubscript {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RawSubscript::Indices(indices) => {
for i in indices {
write!(f, "{}", i)?;
}
}
RawSubscript::Ellipsis { start, end } => {
for i in start {
write!(f, "{}", i)?;
}
write!(f, "___")?;
for i in end {
write!(f, "{}", i)?;
}
}
}
Ok(())
}
}

/// Einsum subscripts, e.g. `ij,jk->ik`
#[derive(Debug, PartialEq, Eq)]
pub struct RawSubscripts {
/// Input subscript, `ij` and `jk`
pub inputs: Vec<RawSubscript>,
/// Output subscript. This may be empty for "implicit mode".
pub output: Option<RawSubscript>,
}

impl std::str::FromStr for RawSubscripts {
type Err = Error;
fn from_str(input: &str) -> Result<Self> {
use nom::Finish;
if let Ok((_, ss)) = subscripts(input).finish() {
Ok(ss)
} else {
bail!("Invalid subscripts: {}", input);
}
}
}

/// Subscripts corresponding to DOT in BLAS
pub fn dot() -> RawSubscripts {
RawSubscripts::from_str("a,a->").unwrap()
}

/// Subscripts corresponding to GEMM in BLAS
pub fn gemm() -> RawSubscripts {
RawSubscripts::from_str("ab,bc->ac").unwrap()
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Subscript {
raw: RawSubscript,
Expand Down Expand Up @@ -46,7 +116,7 @@ impl ToTokens for Subscript {
/// Indices are remapped as starting from `a` to distinguish same subscripts, e.g. `i,i->` and `j,j->`
///
/// ```
/// use einsum_codegen::{*, parser::RawSubscript};
/// use einsum_codegen::*;
///
/// let mut names = Namespace::init();
/// let mut ss1 = Subscripts::from_raw_indices(&mut names, "ij,jk,kl->il").unwrap();
Expand Down Expand Up @@ -133,7 +203,7 @@ impl Subscripts {
///
/// ```
/// use std::str::FromStr;
/// use einsum_codegen::{*, parser::*};
/// use einsum_codegen::*;
///
/// // Infer output subscripts for implicit mode
/// let mut names = Namespace::init();
Expand Down Expand Up @@ -241,7 +311,7 @@ impl Subscripts {
/// into `ab` in the second step.
///
/// ```
/// use einsum_codegen::{*, parser::RawSubscript};
/// use einsum_codegen::*;
/// use std::str::FromStr;
/// use maplit::btreeset;
///
Expand Down

0 comments on commit 5cb284d

Please sign in to comment.