1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
use crate::expressions::Expression;
use crate::identifier::Identifier;
use crate::package::{Composition, OracleDef, Package, PackageInstance};
use crate::statement::{CodeBlock, Statement};

pub struct Transformation<'a>(pub &'a Composition);

impl<'a> super::Transformation for Transformation<'a> {
    type Err = ();
    type Aux = ();

    fn transform(&self) -> Result<(Composition, ()), ()> {
        Ok((
            Composition {
                pkgs: self
                    .0
                    .pkgs
                    .iter()
                    .map(|inst| var_specify(inst, &self.0.name))
                    .collect(),
                ..self.0.clone()
            },
            (),
        ))
    }
}

fn var_specify_helper(inst: &PackageInstance, block: CodeBlock, comp_name: &str) -> CodeBlock {
    let PackageInstance {
        name,
        pkg: Package { state, params, .. },
        ..
    } = inst;

    let fixup = |expr| match expr {
        Expression::Identifier(Identifier::Scalar(id)) => {
            if state.clone().iter().any(|(id_, _)| id == *id_) {
                Expression::Identifier(Identifier::State {
                    name: id,
                    pkgname: name.clone(),
                    compname: comp_name.into(),
                })
            } else if params.clone().iter().any(|(id_, _)| id == *id_) {
                Expression::Identifier(Identifier::Params {
                    name: id,
                    pkgname: name.clone(),
                    compname: comp_name.into(),
                })
            } else {
                Expression::Identifier(Identifier::Local(id))
            }
        }
        _ => expr,
    };
    CodeBlock(
        block
            .0
            .iter()
            .map(|stmt| match stmt {
                Statement::Abort => Statement::Abort,
                Statement::Return(None) => Statement::Return(None),
                Statement::Return(Some(expr)) => Statement::Return(Some(expr.map(fixup))),
                Statement::Assign(id, expr) => {
                    if let Expression::Identifier(id) = fixup(id.to_expression()) {
                        Statement::Assign(id, expr.map(fixup))
                    } else {
                        unreachable!()
                    }
                }
                Statement::IfThenElse(expr, ifcode, elsecode) => Statement::IfThenElse(
                    expr.map(fixup),
                    var_specify_helper(inst, ifcode.clone(), comp_name),
                    var_specify_helper(inst, elsecode.clone(), comp_name),
                ),
                Statement::TableAssign(table, index, expr) => {
                    if let Expression::Identifier(table) = fixup(table.to_expression()) {
                        Statement::TableAssign(table, index.map(fixup), expr.map(fixup))
                    } else {
                        unreachable!()
                    }
                }
            })
            .collect(),
    )
}

fn var_specify(inst: &PackageInstance, comp_name: &str) -> PackageInstance {
    PackageInstance {
        name: inst.name.clone(),
        params: inst.params.clone(),
        pkg: Package {
            params: inst.pkg.params.clone(),
            state: inst.pkg.state.clone(),
            oracles: inst
                .pkg
                .oracles
                .iter()
                .map(|def| OracleDef {
                    sig: def.sig.clone(),
                    code: var_specify_helper(inst, def.code.clone(), comp_name),
                })
                .collect(),
        },
    }
}