From 7706b95125e6f1fbfa09e63f94a6dd0fd409d45c Mon Sep 17 00:00:00 2001 From: Adrien Burgun Date: Sat, 30 Sep 2023 15:32:09 +0200 Subject: [PATCH] :sparkles: Add unary and binary mathematical functions, add Config struct --- examples/procedural-waves.basic | 173 +++++++++++++++++ src/common.rs | 127 ++++++++++++ src/compile/mod.rs | 74 +++---- src/compile/optimize.rs | 26 ++- src/lib.rs | 1 + src/main.rs | 6 +- src/parse/ast.rs | 333 ++++++++++++++++++++------------ src/parse/test.rs | 3 +- src/parse/tokenize.rs | 33 +--- tests/examples.rs | 35 ++++ 10 files changed, 616 insertions(+), 195 deletions(-) create mode 100644 examples/procedural-waves.basic create mode 100644 src/common.rs create mode 100644 tests/examples.rs diff --git a/examples/procedural-waves.basic b/examples/procedural-waves.basic new file mode 100644 index 0000000..8c05059 --- /dev/null +++ b/examples/procedural-waves.basic @@ -0,0 +1,173 @@ +REM This is an example of a more complex program, +REM which spawns procedural waves + +LET wave = 0 +LET timeout = @time + 120000 + +initial_wait: + LET remaining = timeout - @time + IF remaining <= 0 THEN + GOTO main + END IF + + PRINT "[red]Enemies[white] approaching: " + PRINT ceil(remaining / 1000), " s" + PRINT_MESSAGE_MISSION() + + wait(0.5) + + GOTO initial_wait +main: + wave = wave + 1 + REM TODO: difficult control wave multiplier + LET progression = POW(wave / 4, 0.75) + REM TODO: optimize duplicate operations + progression = MIN(progression / 2 + rand(progression / 2), 10) + + LET units = 2 + SQRT(progression) * 4 + RAND(progression * 2) + REM TODO: difficulty control unit amount + units = CEIL(units * 1) + LET tank_units = FLOOR(RAND(units)) + LET mech_units = FLOOR(RAND(units - tank_units)) + LET air_units = units - tank_units - mech_units + + LET spawnx = 30 + LET spawny = 50 + + GOTO spawn_tank + spawn_tank_end: + GOTO spawn_mech + spawn_mech_end: + LET spawnx = 20 + GOTO spawn_air + spawn_air_end: + + WRITE(wave, cell1, 1) + READ(timeout, cell1, 0) + timeout = @time + timeout * 1000 + + main_wait: + LET remaining = timeout - @time + IF remaining <= 0 THEN + GOTO main + END IF + + PRINT "[yellow]Wave ", wave, "[white] - " + PRINT "Next wave: ", ceil(remaining / 1000), " s" + PRINT_MESSAGE_MISSION() + + wait(0.5) + + GOTO main_wait + +spawn_tank: + LET spawned = 0 + spawn_tank_loop: + LET roll = rand(progression) + IF roll >= 3 THEN + IF roll >= 4 THEN + SPAWN(@conquer, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 5.75 + ELSE + SPAWN(@vanquish, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 3.5 + END IF + ELSE + IF roll >= 2 THEN + SPAWN(@precept, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 3.25 + ELSE + REM Small units can unclump easily + IF roll >= 1 THEN + SPAWN(@locus, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 1.0 + ELSE + SPAWN(@stell, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 1.0 + END IF + END IF + END IF + + IF spawnx < 10 THEN + spawnx = 10 + END IF + + spawned = spawned + 1 + IF spawned < tank_units THEN + GOTO spawn_tank_loop + END IF + GOTO spawn_tank_end + +spawn_mech: + LET spawned = 0 + spawn_mech_loop: + LET roll = rand(progression) + IF roll >= 3 THEN + IF roll >= 4 THEN + SPAWN(@collaris, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 5.5 + ELSE + SPAWN(@tecta, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 2.87 + END IF + ELSE + IF roll >= 2 THEN + SPAWN(@anthicus, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 2.62 + ELSE + IF roll >= 1 THEN + SPAWN(@cleroi, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 1.0 + ELSE + SPAWN(@merui, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 1.0 + END IF + END IF + END IF + + IF spawnx < 10 THEN + spawnx = 10 + END IF + + spawned = spawned + 1 + IF spawned < mech_units THEN + GOTO spawn_mech_loop + END IF + GOTO spawn_mech_end + +spawn_air: + LET spawned = 0 + spawn_air_loop: + LET roll = rand(progression) + IF roll >= 3 THEN + IF roll >= 4 THEN + SPAWN(@disrupt, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 5.75 + ELSE + SPAWN(@quell, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 4.5 + END IF + ELSE + IF roll >= 2 THEN + SPAWN(@obviate, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 3.12 + ELSE + IF roll >= 1 THEN + SPAWN(@avert, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 1 + ELSE + SPAWN(@elude, spawnx, spawny, 0, @crux, _) + spawnx = spawnx - 1 + END IF + END IF + END IF + + IF spawnx < 10 THEN + spawnx = 10 + END IF + + spawned = spawned + 1 + IF spawned < air_units THEN + GOTO spawn_air_loop + END IF + GOTO spawn_air_end diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..0c8b9ed --- /dev/null +++ b/src/common.rs @@ -0,0 +1,127 @@ +use std::collections::HashMap; + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum Operator { + Add, + Sub, + Mul, + Div, + Mod, + RShift, + LShift, + Gt, + Lt, + Gte, + Lte, + Eq, + Neq, + Max, + Min, + Pow, + // etc. +} + +impl Operator { + pub(crate) fn precedence(self) -> u8 { + use Operator as O; + match self { + O::Add | O::Sub => 3, + O::RShift | O::LShift => 4, + O::Mod => 5, + O::Mul | O::Div => 10, + O::Eq | O::Neq | O::Gt | O::Lt | O::Gte | O::Lte => 0, + _ => 128, + } + } + + pub(crate) fn from_fn_name(raw: &str) -> Option { + match raw { + "max" => Some(Self::Max), + "min" => Some(Self::Min), + "pow" => Some(Self::Pow), + _ => None, + } + } +} + +pub(crate) fn format_operator(operator: Operator) -> &'static str { + match operator { + Operator::Eq => "equal", + Operator::Neq => "notEqual", + Operator::Lt => "lessThan", + Operator::Lte => "lessThanEq", + Operator::Gt => "greaterThan", + Operator::Gte => "greaterThanEq", + Operator::Add => "add", + Operator::Sub => "sub", + Operator::Mul => "mul", + Operator::Div => "div", + Operator::Mod => "mod", + Operator::RShift => "shr", + Operator::LShift => "shl", + Operator::Max => "max", + Operator::Min => "min", + Operator::Pow => "pow", + } +} + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum UnaryOperator { + Floor, + Round, + Ceil, + Rand, + Sqrt, +} + +impl TryFrom<&str> for UnaryOperator { + type Error = (); + + fn try_from(value: &str) -> Result { + match value { + "floor" => Ok(Self::Floor), + "round" => Ok(Self::Round), + "ceil" => Ok(Self::Ceil), + "rand" => Ok(Self::Rand), + "sqrt" => Ok(Self::Sqrt), + _ => Err(()), + } + } +} + +pub(crate) fn format_unary_operator(operator: UnaryOperator) -> &'static str { + match operator { + UnaryOperator::Floor => "floor", + UnaryOperator::Round => "round", + UnaryOperator::Ceil => "ceil", + UnaryOperator::Rand => "rand", + UnaryOperator::Sqrt => "sqrt", + } +} + +#[derive(Clone, Debug)] +pub struct Config { + pub builtin_functions: HashMap, +} + +impl Default for Config { + fn default() -> Self { + macro_rules! builtin_function { + ( $name:expr, $target_name:expr, $n_args:expr ) => { + (String::from($name), (String::from($target_name), $n_args)) + }; + } + Self { + builtin_functions: HashMap::from([ + builtin_function!("print_flush", "printflush", 1), + builtin_function!("print_message_mission", "message mission", 0), + builtin_function!("read", "read", 3), + builtin_function!("write", "write", 3), + builtin_function!("wait", "wait", 1), + builtin_function!("set_flag", "setflag", 2), + builtin_function!("get_flag", "getflag", 2), + builtin_function!("spawn", "spawn", 6), + ]), + } + } +} diff --git a/src/compile/mod.rs b/src/compile/mod.rs index b19e99a..e1929db 100644 --- a/src/compile/mod.rs +++ b/src/compile/mod.rs @@ -1,6 +1,7 @@ use regex::Regex; -use crate::parse::{BasicAstBlock, BasicAstExpression, Operator}; +use crate::common::*; +use crate::parse::{BasicAstBlock, BasicAstExpression}; mod optimize; pub use optimize::*; @@ -33,6 +34,7 @@ pub enum MindustryOperation { Jump(String), JumpIf(String, Operator, Operand, Operand), Operator(String, Operator, Operand, Operand), + UnaryOperator(String, UnaryOperator, Operand), Set(String, Operand), Generic(String, Vec), } @@ -151,12 +153,27 @@ fn translate_expression( Operand::Variable(right_name), )); + res + } + BasicAstExpression::Unary(op, value) => { + let mut res = translate_expression(value.as_ref(), namer, target_name.clone()); + + res.push(MindustryOperation::UnaryOperator( + target_name.clone(), + *op, + Operand::Variable(target_name), + )); + res } } } -pub fn translate_ast(basic_ast: &BasicAstBlock, namer: &mut Namer) -> MindustryProgram { +pub fn translate_ast( + basic_ast: &BasicAstBlock, + namer: &mut Namer, + config: &Config, +) -> MindustryProgram { use crate::parse::BasicAstInstruction as Instr; let mut res = MindustryProgram::new(); @@ -190,12 +207,12 @@ pub fn translate_ast(basic_ast: &BasicAstBlock, namer: &mut Namer) -> MindustryP Operand::Variable(String::from("true")), )); - res.append(&mut translate_ast(true_branch, namer)); + res.append(&mut translate_ast(true_branch, namer, config)); res.push(MindustryOperation::Jump(end_label.clone())); res.push(MindustryOperation::JumpLabel(else_label)); - res.append(&mut translate_ast(false_branch, namer)); + res.append(&mut translate_ast(false_branch, namer, config)); res.push(MindustryOperation::JumpLabel(end_label)); } else { @@ -207,7 +224,7 @@ pub fn translate_ast(basic_ast: &BasicAstBlock, namer: &mut Namer) -> MindustryP Operand::Variable(String::from("true")), )); - res.append(&mut translate_ast(true_branch, namer)); + res.append(&mut translate_ast(true_branch, namer, config)); res.push(MindustryOperation::JumpLabel(end_label)); } @@ -238,18 +255,14 @@ pub fn translate_ast(basic_ast: &BasicAstBlock, namer: &mut Namer) -> MindustryP )); } - match name.as_str() { - "print_flush" => { - if arguments.len() == 1 { - res.push(MindustryOperation::Generic( - String::from("printflush"), - vec![Operand::Variable(argument_names[0].clone())], - )); - } else { - panic!("Invalid amount of arguments: {}", arguments.len()); - } - } - _ => unimplemented!(), + if let Some((target_name, _)) = config.builtin_functions.get(name) { + res.push(MindustryOperation::Generic( + target_name.clone(), + argument_names + .into_iter() + .map(|name| Operand::Variable(name)) + .collect(), + )); } } } @@ -322,6 +335,15 @@ impl std::fmt::Display for MindustryProgram { rhs )?; } + MindustryOperation::UnaryOperator(name, operator, lhs) => { + writeln!( + f, + "op {} {} {} 0", + format_unary_operator(*operator), + name, + lhs + )?; + } MindustryOperation::Set(name, value) => { writeln!(f, "set {} {}", name, value)?; } @@ -352,21 +374,3 @@ fn format_condition(operator: Operator) -> &'static str { } } } - -fn format_operator(operator: Operator) -> &'static str { - match operator { - Operator::Eq => "equal", - Operator::Neq => "notEqual", - Operator::Lt => "lessThan", - Operator::Lte => "lessThanEqual", - Operator::Gt => "greaterThan", - Operator::Gte => "greaterThanEqual", - Operator::Add => "add", - Operator::Sub => "sub", - Operator::Mul => "mul", - Operator::Div => "div", - Operator::Mod => "mod", - Operator::RShift => "shr", - Operator::LShift => "shl", - } -} diff --git a/src/compile/optimize.rs b/src/compile/optimize.rs index a9d8d79..03a2017 100644 --- a/src/compile/optimize.rs +++ b/src/compile/optimize.rs @@ -68,7 +68,8 @@ pub fn optimize_set_use(program: MindustryProgram) -> MindustryProgram { if matches!( assigned_var.as_str(), "@this" | "@thisx" | "@thisy" | "@links" - ) { + ) || is_unit_constant(assigned_var.as_str()) + { return true; } if assigned_var.starts_with('@') { @@ -119,7 +120,7 @@ pub fn optimize_set_use(program: MindustryProgram) -> MindustryProgram { for instruction in instructions.iter() { let MindustryOperation::Set(set_name, _) = instruction else { res.push(instruction.clone()); - continue + continue; }; if !tmp_regex.is_match(set_name) { @@ -152,6 +153,27 @@ pub fn optimize_set_use(program: MindustryProgram) -> MindustryProgram { res } +fn is_unit_constant(name: &str) -> bool { + matches!( + name, + "@stell" + | "@locus" + | "@precept" + | "@vanquish" + | "@conquer" + | "@merui" + | "@cleroi" + | "@anthicus" + | "@tecta" + | "@collaris" + | "@elude" + | "@avert" + | "@obviate" + | "@quell" + | "@disrupt" + ) +} + // TODO: // - optimize op-jumpif // - optimize jump(1)-label(2)-...instr-label(1) into ...instr-jump(2) diff --git a/src/lib.rs b/src/lib.rs index 347d1fe..626aa04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod common; pub mod compile; pub mod cursor; pub mod parse; diff --git a/src/main.rs b/src/main.rs index 7ffa621..fe917ac 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use basic_to_mindustry::{ + common::Config, compile::{optimize_set_use, translate_ast, Namer}, parse::{build_ast, tokenize}, }; @@ -6,10 +7,11 @@ use basic_to_mindustry::{ fn main() { let path = std::env::args().nth(1).expect("Expected 1 argument"); let source = std::fs::read_to_string(path).expect("Couldn't read input file"); + let config = Config::default(); let tokens = tokenize(&source).unwrap(); - let parsed = build_ast(&tokens).unwrap(); - let transformed = translate_ast(&parsed, &mut Namer::default()); + let parsed = build_ast(&tokens, &config).unwrap(); + let transformed = translate_ast(&parsed, &mut Namer::default(), &config); println!("{}", transformed); diff --git a/src/parse/ast.rs b/src/parse/ast.rs index 0072590..8d166ae 100644 --- a/src/parse/ast.rs +++ b/src/parse/ast.rs @@ -1,4 +1,5 @@ use super::*; +use crate::common::*; use crate::cursor::Cursor; #[derive(Clone, Debug, PartialEq)] @@ -8,26 +9,9 @@ pub enum BasicAstExpression { Variable(String), String(String), Binary(Operator, Box, Box), + Unary(UnaryOperator, Box), } -macro_rules! impl_op_basic_ast_expression { - ( $std_op:ty, $fn_name:ident, $self_op:expr ) => { - impl $std_op for BasicAstExpression { - type Output = BasicAstExpression; - - fn $fn_name(self, other: Self) -> Self { - Self::Binary($self_op, Box::new(self), Box::new(other)) - } - } - }; -} - -// These are primarily here for ease of use in testing -impl_op_basic_ast_expression!(std::ops::Add, add, Operator::Add); -impl_op_basic_ast_expression!(std::ops::Sub, sub, Operator::Sub); -impl_op_basic_ast_expression!(std::ops::Mul, mul, Operator::Mul); -impl_op_basic_ast_expression!(std::ops::Div, div, Operator::Div); - #[derive(Clone, Debug, PartialEq)] pub enum BasicAstInstruction { JumpLabel(String), @@ -51,100 +35,7 @@ impl BasicAstBlock { } } -/// Returns the index of the first token matching `needle` -fn find_token_index(tokens: &[BasicToken], needle: BasicToken) -> Result { - tokens - .iter() - .enumerate() - .find(|(_, t)| **t == needle) - .map(|(i, _)| i) - .ok_or(ParseError::MissingToken(needle)) -} - -pub(crate) fn parse_expression( - tokens: &mut Cursor<'_, BasicToken>, -) -> Result { - /// Returns the first non-newline token in `tokens` - fn peek<'a>(tokens: &'a [BasicToken]) -> Option<&'a BasicToken> { - tokens.iter().find(|t| !matches!(t, BasicToken::NewLine)) - } - - /// Parses a single expression item - fn parse_expression_item( - tokens: &mut Cursor<'_, BasicToken>, - ) -> Result { - match tokens.peek(2) { - [BasicToken::Integer(int), ..] => { - tokens.take(1); - Ok(BasicAstExpression::Integer(*int)) - } - [BasicToken::Float(float), ..] => { - tokens.take(1); - Ok(BasicAstExpression::Float(*float)) - } - [BasicToken::Name(_fn_name), BasicToken::OpenParen, ..] => { - unimplemented!("Function calls are not yet supported"); - } - [BasicToken::Name(name), ..] => { - tokens.take(1); - Ok(BasicAstExpression::Variable(name.clone())) - } - [BasicToken::String(string), ..] => { - tokens.take(1); - Ok(BasicAstExpression::String(string.clone())) - } - [BasicToken::OpenParen, ..] => { - tokens.take(1); - let res = parse_expression(tokens)?; - if let Some(BasicToken::CloseParen) = tokens.take(1).get(0) { - Ok(res) - } else { - Err(ParseError::MissingToken(BasicToken::CloseParen)) - } - } - [first, ..] => Err(ParseError::UnexpectedToken(first.clone())), - [] => Err(ParseError::ExpectedOperand), - } - } - - /// Given an lhs and a minimum precedence, eats as many binary operations as possible, - /// recursively calling itself when an operator with a higher precedence is encountered. - /// - /// See https://en.wikipedia.org/wiki/Operator-precedence_parser for more information - fn parse_expression_main( - tokens: &mut Cursor<'_, BasicToken>, - lhs: BasicAstExpression, - min_precedence: u8, - ) -> Result { - let mut ast = lhs; - while let Some(&BasicToken::Operator(operator)) = peek(tokens) { - if operator.precedence() < min_precedence { - break; - } - tokens.take(1); - let mut rhs = parse_expression_item(tokens)?; - while let Some(&BasicToken::Operator(sub_operator)) = peek(tokens) { - if sub_operator.precedence() > operator.precedence() { - rhs = parse_expression_main(tokens, rhs, operator.precedence() + 1)?; - } else { - break; - } - } - - ast = BasicAstExpression::Binary(operator, Box::new(ast), Box::new(rhs)); - } - - Ok(ast) - } - - // Remove starting newlines - let lhs = parse_expression_item(tokens)?; - let res = parse_expression_main(tokens, lhs, 0)?; - - Ok(res) -} - -pub fn build_ast(tokens: &[BasicToken]) -> Result { +pub fn build_ast(tokens: &[BasicToken], config: &Config) -> Result { enum Context { Main, If(BasicAstExpression), @@ -183,7 +74,6 @@ pub fn build_ast(tokens: &[BasicToken]) -> Result { [BasicToken::If, ..] => { tokens.take(1); let then_index = find_token_index(&tokens, BasicToken::Then)?; - // let end_index = find_token_index(&tokens, BasicToken::EndIf)?; let condition = parse_expression(&mut tokens.range(0..then_index))?; @@ -211,7 +101,8 @@ pub fn build_ast(tokens: &[BasicToken]) -> Result { match context_stack.pop().unwrap() { (instructions, Context::If(condition)) => { - let Some((ref mut parent_instructions, _)) = context_stack.last_mut() else { + let Some((ref mut parent_instructions, _)) = context_stack.last_mut() + else { unreachable!("Context::If not wrapped in another context"); }; @@ -222,7 +113,8 @@ pub fn build_ast(tokens: &[BasicToken]) -> Result { )); } (instructions, Context::IfElse(condition, true_branch)) => { - let Some((ref mut parent_instructions, _)) = context_stack.last_mut() else { + let Some((ref mut parent_instructions, _)) = context_stack.last_mut() + else { unreachable!("Context::IfElse not wrapped in another context"); }; @@ -256,13 +148,19 @@ pub fn build_ast(tokens: &[BasicToken]) -> Result { expressions.push((parse_expression(&mut tokens)?, false)); while let Some(BasicToken::Comma) = tokens.get(0) { + tokens.take(1); expressions.push((parse_expression(&mut tokens)?, false)); } instructions.push(BasicAstInstruction::Print(expressions)); } - // TODO: expect newline + match tokens.get(0) { + Some(BasicToken::NewLine) | None => {} + Some(other) => { + return Err(ParseError::UnexpectedToken(other.clone())); + } + } } [BasicToken::Name(fn_name), BasicToken::OpenParen, ..] => { tokens.take(2); @@ -271,25 +169,41 @@ pub fn build_ast(tokens: &[BasicToken]) -> Result { while tokens.get(0) != Some(&BasicToken::CloseParen) { arguments.push(parse_expression(&mut tokens)?); - match tokens.take(1) { - [BasicToken::Comma] => {} - [BasicToken::CloseParen] => break, + match tokens.get(0) { + Some(BasicToken::Comma) => { + tokens.take(1); + } + Some(BasicToken::CloseParen) => break, _ => return Err(ParseError::MissingToken(BasicToken::Comma)), } } - // TODO: expect closeparen - - tokens.take(1); + match tokens.take(1) { + [BasicToken::CloseParen] => {} + [other] => { + return Err(ParseError::UnexpectedToken(other.clone())); + } + _ => { + return Err(ParseError::MissingToken(BasicToken::CloseParen)); + } + } let lowercase_fn_name = fn_name.to_lowercase(); - if matches!(lowercase_fn_name.as_str(), "print_flush") { + if let Some((_, n_args)) = config.builtin_functions.get(&lowercase_fn_name) { + if arguments.len() != *n_args { + return Err(ParseError::InvalidArgumentCount( + lowercase_fn_name, + *n_args, + arguments.len(), + )); + } + instructions.push(BasicAstInstruction::CallBuiltin( lowercase_fn_name, arguments, )); } else { - unimplemented!(); + unimplemented!("User procedure calls are not yet supported!"); } } _ => { @@ -322,3 +236,174 @@ pub fn build_ast(tokens: &[BasicToken]) -> Result { Ok(BasicAstBlock { instructions }) } + +/// Returns the index of the first token matching `needle` +fn find_token_index(tokens: &[BasicToken], needle: BasicToken) -> Result { + tokens + .iter() + .enumerate() + .find(|(_, t)| **t == needle) + .map(|(i, _)| i) + .ok_or(ParseError::MissingToken(needle)) +} + +macro_rules! impl_op_basic_ast_expression { + ( $std_op:ty, $fn_name:ident, $self_op:expr ) => { + impl $std_op for BasicAstExpression { + type Output = BasicAstExpression; + + fn $fn_name(self, other: Self) -> Self { + Self::Binary($self_op, Box::new(self), Box::new(other)) + } + } + }; +} + +// These are primarily here for ease of use in testing +impl_op_basic_ast_expression!(std::ops::Add, add, Operator::Add); +impl_op_basic_ast_expression!(std::ops::Sub, sub, Operator::Sub); +impl_op_basic_ast_expression!(std::ops::Mul, mul, Operator::Mul); +impl_op_basic_ast_expression!(std::ops::Div, div, Operator::Div); + +pub(crate) fn parse_expression( + tokens: &mut Cursor<'_, BasicToken>, +) -> Result { + /// Returns the first non-newline token in `tokens` + fn peek<'a>(tokens: &'a [BasicToken]) -> Option<&'a BasicToken> { + tokens.iter().find(|t| !matches!(t, BasicToken::NewLine)) + } + + /// Parses a single expression item + fn parse_expression_item( + tokens: &mut Cursor<'_, BasicToken>, + ) -> Result { + match tokens.peek(2) { + [BasicToken::Integer(int), ..] => { + tokens.take(1); + Ok(BasicAstExpression::Integer(*int)) + } + [BasicToken::Float(float), ..] => { + tokens.take(1); + Ok(BasicAstExpression::Float(*float)) + } + [BasicToken::Name(fn_name), BasicToken::OpenParen, ..] => { + tokens.take(2); + let fn_name_lowercase = fn_name.to_ascii_lowercase(); + let mut arguments = Vec::new(); + while tokens.get(0) != Some(&BasicToken::CloseParen) { + arguments.push(parse_expression(tokens)?); + + match tokens.get(0) { + Some(BasicToken::Comma) => { + tokens.take(1); + } + Some(BasicToken::CloseParen) => break, + _ => return Err(ParseError::MissingToken(BasicToken::Comma)), + } + } + + match tokens.take(1) { + [BasicToken::CloseParen] => {} + [other] => { + return Err(ParseError::UnexpectedToken(other.clone())); + } + _ => { + return Err(ParseError::MissingToken(BasicToken::CloseParen)); + } + } + + if let Ok(unary_operator) = UnaryOperator::try_from(fn_name_lowercase.as_str()) { + if arguments.len() != 1 { + Err(ParseError::InvalidArgumentCount( + fn_name_lowercase, + 1, + arguments.len(), + )) + } else { + Ok(BasicAstExpression::Unary( + unary_operator, + Box::new(arguments.into_iter().next().unwrap()), + )) + } + } else if let Some(binary_operator) = + Operator::from_fn_name(fn_name_lowercase.as_str()) + { + if arguments.len() != 2 { + Err(ParseError::InvalidArgumentCount( + fn_name_lowercase, + 2, + arguments.len(), + )) + } else { + let mut iter = arguments.into_iter(); + Ok(BasicAstExpression::Binary( + binary_operator, + Box::new(iter.next().unwrap()), + Box::new(iter.next().unwrap()), + )) + } + } else { + unimplemented!( + "User function calls are not yet supported! Function: {:?}", + fn_name + ); + } + } + [BasicToken::Name(name), ..] => { + tokens.take(1); + Ok(BasicAstExpression::Variable(name.clone())) + } + [BasicToken::String(string), ..] => { + tokens.take(1); + Ok(BasicAstExpression::String(string.clone())) + } + [BasicToken::OpenParen, ..] => { + tokens.take(1); + let res = parse_expression(tokens)?; + if let Some(BasicToken::CloseParen) = tokens.take(1).get(0) { + Ok(res) + } else { + Err(ParseError::MissingToken(BasicToken::CloseParen)) + } + } + [first, ..] => Err(ParseError::UnexpectedToken(first.clone())), + [] => Err(ParseError::ExpectedOperand), + } + } + + /// Given an lhs and a minimum precedence, eats as many binary operations as possible, + /// recursively calling itself when an operator with a higher precedence is encountered. + /// + /// See https://en.wikipedia.org/wiki/Operator-precedence_parser for more information + fn parse_expression_main( + tokens: &mut Cursor<'_, BasicToken>, + lhs: BasicAstExpression, + min_precedence: u8, + ) -> Result { + let mut ast = lhs; + while let Some(&BasicToken::Operator(operator)) = peek(tokens) { + if operator.precedence() < min_precedence { + break; + } + tokens.take(1); + let mut rhs = parse_expression_item(tokens)?; + while let Some(&BasicToken::Operator(sub_operator)) = peek(tokens) { + if sub_operator.precedence() > operator.precedence() { + rhs = parse_expression_main(tokens, rhs, operator.precedence() + 1)?; + } else { + break; + } + } + + ast = BasicAstExpression::Binary(operator, Box::new(ast), Box::new(rhs)); + } + + Ok(ast) + } + + // Remove starting newlines + let lhs = parse_expression_item(tokens)?; + let res = parse_expression_main(tokens, lhs, 0)?; + + Ok(res) +} diff --git a/src/parse/test.rs b/src/parse/test.rs index af8328f..f07429e 100644 --- a/src/parse/test.rs +++ b/src/parse/test.rs @@ -1,3 +1,4 @@ +use crate::common::*; use crate::cursor::Cursor; use super::*; @@ -243,7 +244,7 @@ fn test_build_ast(raw: &str) -> BasicAstBlock { e, raw ); }); - let parsed = build_ast(&tokens).unwrap_or_else(|e| { + let parsed = build_ast(&tokens, &Default::default()).unwrap_or_else(|e| { panic!( "Error while parsing: {:?}\nProgram:\n```\n{}\n```\nTokens:\n{:#?}", e, raw, tokens diff --git a/src/parse/tokenize.rs b/src/parse/tokenize.rs index 07a5563..0d77529 100644 --- a/src/parse/tokenize.rs +++ b/src/parse/tokenize.rs @@ -1,36 +1,6 @@ +use crate::common::*; use regex::Regex; -#[derive(PartialEq, Eq, Clone, Copy, Debug)] -pub enum Operator { - Add, - Sub, - Mul, - Div, - Mod, - RShift, - LShift, - Gt, - Lt, - Gte, - Lte, - Eq, - Neq, - // etc. -} - -impl Operator { - pub(crate) fn precedence(self) -> u8 { - use Operator as O; - match self { - O::Add | O::Sub => 3, - O::RShift | O::LShift => 4, - O::Mod => 5, - O::Mul | O::Div => 10, - O::Eq | O::Neq | O::Gt | O::Lt | O::Gte | O::Lte => 0, - } - } -} - #[derive(PartialEq, Clone, Debug)] pub enum BasicToken { NewLine, @@ -57,6 +27,7 @@ pub enum ParseError { InvalidToken(String), UnexpectedToken(BasicToken), MissingToken(BasicToken), + InvalidArgumentCount(String, usize, usize), ExpectedOperand, } diff --git a/tests/examples.rs b/tests/examples.rs new file mode 100644 index 0000000..dc5b733 --- /dev/null +++ b/tests/examples.rs @@ -0,0 +1,35 @@ +use std::path::Path; + +use basic_to_mindustry::common::Config; +use basic_to_mindustry::compile::{optimize_set_use, translate_ast}; +use basic_to_mindustry::parse::{build_ast, tokenize}; + +#[test] +fn test_examples() { + let config = Config::default(); + for entry in Path::new("./examples/").read_dir().unwrap() { + let Ok(entry) = entry else { continue }; + if entry + .file_name() + .into_string() + .map(|name| name.ends_with(".basic")) + .unwrap_or(false) + { + let file_name = entry.file_name().into_string().unwrap(); + let file = std::fs::read_to_string(entry.path()).unwrap_or_else(|e| { + panic!("Error opening {:?}: {:?}", file_name, e); + }); + + let tokenized = tokenize(&file).unwrap_or_else(|e| { + panic!("Error tokenizing {:?}: {:?}", file_name, e); + }); + let parsed = build_ast(&tokenized, &config).unwrap_or_else(|e| { + panic!("Error parsing {:?}: {:?}", file_name, e); + }); + let translated = translate_ast(&parsed, &mut Default::default(), &config); + let optimized = optimize_set_use(translated); + + let _ = optimized; + } + } +}