diff --git a/src/compile/mod.rs b/src/compile/mod.rs index e1929db..174bc3c 100644 --- a/src/compile/mod.rs +++ b/src/compile/mod.rs @@ -64,7 +64,7 @@ impl MindustryOperation { } #[derive(Debug, Clone)] -pub struct MindustryProgram(Vec); +pub struct MindustryProgram(pub Vec); impl MindustryProgram { pub fn new() -> Self { diff --git a/src/compile/optimize.rs b/src/compile/optimize.rs index 03a2017..87aebd4 100644 --- a/src/compile/optimize.rs +++ b/src/compile/optimize.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use super::*; /// Optimizes away unnecessary `sets` @@ -112,45 +114,7 @@ pub fn optimize_set_use(program: MindustryProgram) -> MindustryProgram { } } - let instructions = res.0; - let mut res = MindustryProgram::new(); - - // Remove unneeded `set`s - // PERF: could be split into a search for all variable operands, and a removal of all unneeded `set`s - for instruction in instructions.iter() { - let MindustryOperation::Set(set_name, _) = instruction else { - res.push(instruction.clone()); - continue; - }; - - if !tmp_regex.is_match(set_name) { - res.push(instruction.clone()); - continue; - } - - // Note: this will give false positives for temporary variable names that get re-used somewhere else - let mut needed = false; - for future_instruction in instructions.iter() { - if future_instruction.operands().iter().any(|operand| { - if let Operand::Variable(use_name) = operand { - if use_name == set_name { - return true; - } - } - false - }) { - needed = true; - break; - } - } - - if needed { - res.push(instruction.clone()); - } - // else don't push - } - - res + optimize_dead_code(res) } fn is_unit_constant(name: &str) -> bool { @@ -175,6 +139,269 @@ fn is_unit_constant(name: &str) -> bool { } // TODO: -// - optimize op-jumpif // - optimize jump(1)-label(2)-...instr-label(1) into ...instr-jump(2) // - shorten temporary variable names + +/// Tries to merge the condition in an `op` into the `jump` itself +pub fn optimize_jump_op(program: MindustryProgram) -> MindustryProgram { + let tmp_regex = Regex::new(r"__tmp_[0-9]+$").unwrap(); + + let mut res = MindustryProgram::new(); + let instructions = program.0; + + for (index, instruction) in instructions.iter().enumerate() { + match instruction { + MindustryOperation::JumpIf(label, operator, lhs, rhs) => { + let (truthiness, var_name) = match ( + operator, + replace_constants(lhs.clone()), + replace_constants(rhs.clone()), + ) { + (Operator::Neq, Operand::Variable(var_name), Operand::Integer(0)) + | (Operator::Eq, Operand::Variable(var_name), Operand::Integer(1)) + | (Operator::Neq, Operand::Integer(0), Operand::Variable(var_name)) + | (Operator::Eq, Operand::Integer(1), Operand::Variable(var_name)) => { + (true, var_name) + } + (Operator::Eq, Operand::Variable(var_name), Operand::Integer(0)) + | (Operator::Neq, Operand::Variable(var_name), Operand::Integer(1)) + | (Operator::Eq, Operand::Integer(0), Operand::Variable(var_name)) + | (Operator::Neq, Operand::Integer(1), Operand::Variable(var_name)) => { + (false, var_name) + } + _ => { + res.push(instruction.clone()); + continue; + } + }; + + if !tmp_regex.is_match(&var_name) { + res.push(instruction.clone()); + continue; + } + + let mut last_op = None; + for prev_instruction in instructions[0..=index].iter().rev().skip(1) { + match prev_instruction { + MindustryOperation::Operator(name, operator, lhs, rhs) + if *name == var_name && is_condition_op(*operator) => + { + last_op = Some((*operator, lhs.clone(), rhs.clone())); + } + MindustryOperation::JumpLabel(_) => break, + _ => {} + } + } + + let Some(last_op) = last_op else { + res.push(instruction.clone()); + continue + }; + + let (operator, lhs, rhs) = if truthiness { + last_op + } else { + ( + match last_op.0 { + Operator::Gt => Operator::Lte, + Operator::Lt => Operator::Gte, + Operator::Gte => Operator::Lt, + Operator::Lte => Operator::Gt, + Operator::Eq => Operator::Neq, + Operator::Neq => Operator::Eq, + _ => unreachable!(), + }, + last_op.1, + last_op.2, + ) + }; + + res.push(MindustryOperation::JumpIf( + label.clone(), + operator, + lhs, + rhs, + )); + } + _ => { + res.push(instruction.clone()); + } + } + } + + return optimize_dead_code(res); + + fn replace_constants(value: Operand) -> Operand { + if let Operand::Variable(var) = &value { + match var.as_str() { + "true" => Operand::Integer(1), + "false" | "null" => Operand::Integer(0), + _ => value, + } + } else { + value + } + } + + fn is_condition_op(op: Operator) -> bool { + matches!( + op, + Operator::Neq + | Operator::Eq + | Operator::Lt + | Operator::Lte + | Operator::Gt + | Operator::Gte + ) + } +} + +/// Tries to remove unnecessary `jump always` instructions +pub fn optimize_jump_always(mut program: MindustryProgram) -> MindustryProgram { + let instructions = &mut program.0; + + let mut substitutions = Vec::new(); + + // Detect `label`-`jump always` pairs + for (index, instruction) in instructions.iter().enumerate() { + let MindustryOperation::JumpLabel(label_from) = instruction else { + continue + }; + + for future_instruction in instructions[index..].iter() { + match future_instruction { + MindustryOperation::JumpLabel(_) => {} + MindustryOperation::Jump(label_to) => { + substitutions.push((label_from.clone(), label_to.clone())); + break; + } + _ => break, + } + } + } + + // Apply transitivity to the pairs + let substitutions = substitutions + .iter() + .map(|(from, to)| { + let mut new_to = to; + let mut history = vec![to]; + + loop { + let mut found = false; + + for (other_from, other_to) in substitutions.iter() { + if other_from == new_to { + // Leave cycles untouched + if history.contains(&other_to) { + return (from.clone(), to.clone()); + } + new_to = other_to; + history.push(other_to); + found = true; + break; + } + } + + if !found { + break; + } + } + + (from.clone(), to.clone()) + }) + .collect::>(); + + for instruction in instructions.iter_mut() { + match instruction { + MindustryOperation::Jump(label) => { + if let Some((_, new_label)) = substitutions.iter().find(|(from, _)| from == label) { + *label = new_label.clone(); + } + } + MindustryOperation::JumpIf(label, _, _, _) => { + if let Some((_, new_label)) = substitutions.iter().find(|(from, _)| from == label) { + *label = new_label.clone(); + } + } + _ => {} + } + } + + optimize_dead_code(program) +} + +fn optimize_dead_code(program: MindustryProgram) -> MindustryProgram { + let instructions = program.0; + let tmp_regex = Regex::new(r"__tmp_[0-9]+$").unwrap(); + let label_regex = Regex::new(r"__label_[0-9]+").unwrap(); + let mut res = MindustryProgram::new(); + + let mut needed_vars = HashSet::new(); + let mut needed_labels = HashSet::new(); + let mut push_var = |operand: &Operand| match operand { + Operand::Variable(name) => { + needed_vars.insert(name.clone()); + } + _ => {} + }; + + for instruction in instructions.iter() { + match instruction { + MindustryOperation::JumpLabel(_) => {} + MindustryOperation::Jump(label) => { + needed_labels.insert(label.clone()); + } + MindustryOperation::JumpIf(label, _, lhs, rhs) => { + needed_labels.insert(label.clone()); + push_var(lhs); + push_var(rhs); + } + MindustryOperation::Operator(_, _, lhs, rhs) => { + push_var(lhs); + push_var(rhs); + } + MindustryOperation::UnaryOperator(_, _, value) => { + push_var(value); + } + MindustryOperation::Set(_, value) => { + push_var(value); + } + MindustryOperation::Generic(_, values) => { + values.iter().for_each(&mut push_var); + } + } + } + + // Remove unneeded `set`s and `op`s + for instruction in instructions.iter() { + match instruction { + MindustryOperation::Set(name, _) | MindustryOperation::Operator(name, _, _, _) => { + if tmp_regex.is_match(name) { + if needed_vars.contains(name) { + res.push(instruction.clone()); + } + // else don't push + } else { + res.push(instruction.clone()); + } + } + MindustryOperation::JumpLabel(label) => { + if label_regex.is_match(label) { + if needed_labels.contains(label) { + res.push(instruction.clone()); + } + // else don't push + } else { + res.push(instruction.clone()); + } + } + _ => { + res.push(instruction.clone()); + continue; + } + }; + } + + res +} diff --git a/src/main.rs b/src/main.rs index fe917ac..0cbfdd4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use basic_to_mindustry::{ common::Config, - compile::{optimize_set_use, translate_ast, Namer}, + compile::{optimize_jump_always, optimize_jump_op, optimize_set_use, translate_ast, Namer}, parse::{build_ast, tokenize}, }; @@ -13,10 +13,12 @@ fn main() { let parsed = build_ast(&tokens, &config).unwrap(); let transformed = translate_ast(&parsed, &mut Namer::default(), &config); - println!("{}", transformed); + // println!("{}", transformed); let optimized = optimize_set_use(transformed); + let optimized = optimize_jump_op(optimized); + let optimized = optimize_jump_always(optimized); - println!("== OPT =="); + // println!("== OPT == ({} instructions)", optimized.0.len()); println!("{}", optimized); }