Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi argument functions #7

Merged
merged 28 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
/* Crate imports */
use crate::misc::{Function, HashMap, HashSet};
use crate::{
misc::{HashMap, HashSet},
token::Function,
};

#[derive(Debug, Default, PartialEq)]
#[non_exhaustive]
pub struct Context<'a> {
pub vars: HashMap<&'a str, f64>,
pub funcs: HashMap<&'a str, Function>,
pub funcs: HashMap<&'a str, Function<'a>>,
pub expected_vars: Option<HashSet<&'a str>>,
}
29 changes: 18 additions & 11 deletions src/element/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,38 @@
use core::fmt;
/* Crate imports */
use super::Element;
use crate::misc::Function;
use crate::token::Function;

#[derive(Debug, PartialEq, PartialOrd)]
pub struct FunctionCall<'a> {
pub(crate) name: &'a str,
pub(crate) func: Function,
pub(crate) arg: Element<'a>,
pub(crate) desc: &'a Function<'a>,
pub(crate) args: Vec<Element<'a>>,
}

impl<'a> FunctionCall<'a> {
pub const fn new(name: &'a str, func: Function, arg: Element<'a>) -> Self {
Self { name, func, arg }
pub const fn new(desc: &'a Function<'a>, args: Vec<Element<'a>>) -> Self {
Self { desc, args }
}

pub fn new_element(
name: &'a str,
func: Function,
arg: Element<'a>,
desc: &'a Function<'a>,
args: Vec<Element<'a>>,
) -> Element<'a> {
Element::Function(Box::new(Self::new(name, func, arg)))
Element::Function(Box::new(Self::new(desc, args)))
}

pub fn call(&self, args: &[f64]) -> f64 {
(self.desc.func)(args)
}
}

impl fmt::Display for FunctionCall<'_> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "{}({})", self.name, self.arg)
let args = self
.args
.iter()
.map(|arg| format!("{arg}"))
.collect::<Vec<_>>();
write!(fmt, "{}({})", self.desc.name, args.join(", "))
}
}
2 changes: 1 addition & 1 deletion src/element/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl<'a> Element<'a> {
unop.operand.find_variables(vars);
},
Self::Function(ref func) => {
func.arg.find_variables(vars);
func.args.iter().for_each(|arg| arg.find_variables(vars));
},
Self::Number(_) => (),
};
Expand Down
34 changes: 26 additions & 8 deletions src/element/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,37 @@ impl<'a> Simplify<'a> for UnOp<'a> {
impl<'a> Simplify<'a> for FunctionCall<'a> {
#[inline]
fn simplify_for(mut self, var: (&str, f64)) -> Element<'a> {
self.arg = self.arg.simplify_for(var);
self.args = self
.args
.into_iter()
.map(|arg| arg.simplify_for(var))
.collect();
self.simplify()
}

#[inline]
fn simplify(mut self) -> Element<'a> {
self.arg = self.arg.simplify();
match self.arg {
Element::Number(num) => Element::Number((self.func)(num)),
Element::BinOp(_)
| Element::UnOp(_)
| Element::Function(_)
| Element::Variable(_) => self.into(),
// TODO: Not a big fan of the second vector.
// We need to simplify the arguments in all cases, but
// if they are all numbers, we can call the function.
let mut args_values: Vec<f64> = Vec::with_capacity(self.args.len());

self.args = self
.args
.into_iter()
.map(|arg| {
let simplified = arg.simplify();
if let Element::Number(num) = simplified {
args_values.push(num);
}
simplified
})
.collect();

if args_values.len() == self.args.len() {
self.call(&args_values).into()
} else {
self.into()
}
}
}
1 change: 0 additions & 1 deletion src/misc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/* Built-in imports */
use std::collections;
/* Exports */
pub type Function = fn(f64) -> f64;
pub type HashMap<K, V> = collections::HashMap<K, V>;
pub type HashSet<T> = collections::HashSet<T>;
109 changes: 97 additions & 12 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,24 +175,25 @@ impl<'a> ParserImpl<'a> {
.vars
.get(name)
.map(|&value| Identifier::Constant(value))
.or_else(|| {
self.ctx
.funcs
.get(name)
.map(|&func| Identifier::Function(name, func))
})
.or_else(|| self.ctx.funcs.get(name).map(Identifier::Function))
.unwrap_or_else(|| name.into());

let el = match ident {
Identifier::Constant(val) => Element::Number(val),
Identifier::Variable(var) => Element::Variable(var),
Identifier::Function(fn_name, func)
if Some(&b'(') == self.next() =>
{
let el = self.element(precedence::FN_PRECEDENCE)?;
FunctionCall::new_element(fn_name, func, el)
Identifier::Function(func) if Some(&b'(') == self.next() => {
self.cursor += 1;
let args = match func.nb_args {
Some(nb) => self.parse_arguments(nb)?,
None => self.parse_variadic_arguments()?,
};
if self.next() != Some(&b')') {
yeet!(ParseError::new_expected_token(self, b')'));
}
self.cursor += 1;
FunctionCall::new_element(func, args)
},
Identifier::Function(_, _) => {
Identifier::Function(_) => {
yeet!(ParseError::new_expected_token(self, b'('))
},
};
Expand All @@ -211,6 +212,66 @@ impl<'a> ParserImpl<'a> {
Ok(Element::Number(num))
}

fn parse_arguments(
&mut self,
nb_args: u8,
) -> Result<Vec<Element<'a>>, ParseError> {
let args = (1..=nb_args)
.map(|idx| {
let arg = self.element(precedence::NO_PRECEDENCE)?;
// check for comma if not last argument
if idx == nb_args {
return Ok(arg);
}
// if not last argument, check for comma
if self.next() == Some(&b',') {
self.cursor += 1;
} else {
yeet!(ParseError::new_expected_token(self, b','));
}
// if a comma is followed by a closing parenthesis
// it means we have a missing argument
if self.next() == Some(&b')') {
yeet!(ParseError::new_missing_argument(self));
}

Ok(arg)
})
.collect::<Result<Vec<Element<'a>>, ParseError>>();
if self.next() == Some(&b',') {
yeet!(ParseError::new_too_many_arguments(
self,
nb_args,
nb_args + 1
));
}
args
}

fn parse_variadic_arguments(
&mut self,
) -> Result<Vec<Element<'a>>, ParseError> {
let mut args = Vec::new();

loop {
let arg = self.element(precedence::NO_PRECEDENCE)?;
args.push(arg);

// expect either a comma or a closing parenthesis
match self.next() {
Some(&b',') => self.cursor += 1,
Some(&b')') => break,
Some(&tok) => {
yeet!(ParseError::new_unexpected_token(self, tok))
},
None => {
yeet!(ParseError::new_unexpected_end_of_expression(self))
},
}
}
Ok(args)
}

#[inline]
fn take_while(&mut self, predicate: fn(&u8) -> bool) -> &'a str {
let start = self.cursor;
Expand Down Expand Up @@ -301,6 +362,10 @@ pub enum ErrorKind {
ExpectedToken(char),
#[error("Variable not previously declared: `{0}`")]
VariableNotDeclared(String),
#[error("Too many arguments for function call, expected {0} got {1}")]
TooManyArguments(u8, u8),
#[error("Missing argument for function call")]
MissingArgument,
}

impl ParseError {
Expand Down Expand Up @@ -354,4 +419,24 @@ impl ParseError {
src: trust_me!(str::from_utf8_unchecked(parser.input)).to_owned(),
}
}

fn new_too_many_arguments(
parser: &ParserImpl,
expected: u8,
got: u8,
) -> Self {
Self {
kind: ErrorKind::TooManyArguments(expected, got),
span: parser.cursor.into(),
src: trust_me!(str::from_utf8_unchecked(parser.input)).to_owned(),
}
}

fn new_missing_argument(parser: &ParserImpl) -> Self {
Self {
kind: ErrorKind::MissingArgument,
span: parser.cursor.into(),
src: trust_me!(str::from_utf8_unchecked(parser.input)).to_owned(),
}
}
}
50 changes: 36 additions & 14 deletions src/tests/parser/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@
use crate::{
context::Context,
element::{BinOp, Element, FunctionCall},
misc::Function,
token::Operator,
Parser,
token::{Function, Operator},
xprs_fn, Parser,
};

const DOUBLE: Function = |x| x * 2.0_f64;
fn triple(x: f64) -> f64 {
x * 3.0_f64
fn double(x: f64) -> f64 {
x * 2.0_f64
}
const DOUBLE: Function = xprs_fn!("DOUBLE", double, 1);
fn add(x: f64, y: f64) -> f64 {
x + y
}
const ADD: Function = xprs_fn!("ADD", add, 2);
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
fn mean(args: &[f64]) -> f64 {
args.iter().sum::<f64>() / args.len() as f64
}
const MEAN: Function = xprs_fn!("MEAN", mean);

fn get_parser_with_ctx<'a>() -> Parser<'a> {
let mut ctx = Context::default();
Expand All @@ -19,16 +27,17 @@ fn get_parser_with_ctx<'a>() -> Parser<'a> {
ctx.vars.insert("phi", 1.618_033_988_749_895_f64);

ctx.funcs.insert("double", DOUBLE);
ctx.funcs.insert("triple", triple);
ctx.funcs.insert("add", ADD);

let mut parser = Parser::new_with_ctx(ctx);

parser.ctx_mut().vars.insert("y", 1.0_f64);
parser.ctx_mut().funcs.insert("mean", MEAN);

parser
}

fn get_valid_test_cases<'a>() -> [(&'static str, Element<'a>); 5] {
fn get_valid_test_cases<'a>() -> [(&'static str, Element<'a>); 6] {
[
("y", Element::Number(1.0)),
(
Expand All @@ -54,22 +63,35 @@ fn get_valid_test_cases<'a>() -> [(&'static str, Element<'a>); 5] {
(
"double(2 + phi * x)",
FunctionCall::new_element(
"double",
DOUBLE,
BinOp::new_element(
&DOUBLE,
vec![BinOp::new_element(
Operator::Plus,
Element::Number(2.0),
BinOp::new_element(
Operator::Times,
Element::Number(1.618_033_988_749_895),
Element::Number(2.0),
),
),
)],
),
),
(
"triple(2)",
FunctionCall::new_element("triple", triple, Element::Number(2.0)),
"add(2, 3)",
FunctionCall::new_element(
&ADD,
vec![Element::Number(2.0), Element::Number(3.0)],
),
),
(
"mean(2, 3, 4)",
FunctionCall::new_element(
&MEAN,
vec![
Element::Number(2.0),
Element::Number(3.0),
Element::Number(4.0),
],
),
),
]
}
Expand Down
Loading