# Compiler Toolchain Dispatch & Analysis.
# Copyright 2011 Clint Banis.  All rights reserved.
#
# Todo: Casting.
# Todo: Make jython compatible
# Todo: Ant script?
# Todo: Generics?
# Todo: statement nesting?
#
# Anonymous subclasses?
#
import optparse
import pdb
import code
import readline
import ConfigParser
import StringIO
import re
import contextlib

# Language tools -- todo: decouple the actual parser implementation from process.
import compiler
from compiler import ast

# Other Tools.
def examine(**kwd):
    ic = code.InteractiveConsole(locals = kwd)
    ic.interact()
    return kwd.get('__result__')

# Transformer Compiler Context.
DEFAULT_PACKAGE = None

VISIBLE_PACKAGE = 'package'
VISIBLE_PUBLIC = 'public'
VISIBLE_PRIVATE = 'private'
VISIBLE_PROTECTED = 'protected'

# Top-level Containers.
class Module:
    def __str__(self):
        parts = []
        if self.package:
            parts.append('package %s;' % self.package)
        if self.imports:
            parts.append('\n'.join('import %s;' % i for i in self.imports))

        for cls in self.classes:
            parts.append(str(cls))

        return '\n\n'.join(parts)

    def __init__(self, root):
        self.package = DEFAULT_PACKAGE
        self.imports = []
        self.comments = []
        self.classes = []
        self.loadTree(root)

    def loadTree(self, root):
        self.loadPrologue(root.doc)
        self.loadToplevel(root.node)

    def loadToplevel(self, toplevel):
        for stmt in toplevel.nodes:
            if isinstance(stmt, ast.Import):
                self.loadImport(stmt)
                continue

            elif isinstance(stmt, ast.Discard):
                expr = stmt.expr
                if isinstance(expr, ast.CallFunc):
                    name = expr.node
                    if isinstance(name, ast.Name) and name.name == 'comment':
                        if self.loadComments(expr.args, expr.lineno):
                            continue

                elif self.loadComments([expr], expr.lineno):
                    continue

            elif isinstance(stmt, ast.Class):
                self.loadClass(stmt)
                continue

    def loadPrologue(self, prolog):
        for (directive, value) in ParsePrologue(prolog):
            if directive == 'package':
                assert not self.package
                self.package = value

    def loadImport(self, imp):
        for name in imp.names:
            name = name[0] # ignore 'as'
            if name not in self.imports:
                self.imports.append(name)

    def loadComments(self, args, lineno):
        comments = []
        for a in args:
            if isinstance(a, ast.Const) and isinstance(a.value, basestring):
                comments.append(a.value)
            else:
                return False

        self.comments.append((lineno, comments))
        return True

    def loadClass(self, node):
        # assert node.name not in self.classes
        self.classes.append(Class(self, node))

class Class:
    def __str__(self):
        qualifiers = []
        if not self.interface:
            if self.abstract:
                qualifiers.append('abstract')
            elif self.final:
                qualifiers.append('final')

        if self.visibility != VISIBLE_PACKAGE:
            qualifiers.append(self.visibility)

        struct_type = self.interface and 'interface' or 'class'
        qualifiers = ' '.join(qualifiers)

        parents = ', '.join(self.parents)
        interfaces = ', '.join(self.implements)

        methods = '\n'.join(m.__str__(indent = '    ') for m in self.methods)
        if methods:
            methods = '{\n%s\n}' % methods
        else:
            methods = '%s{ }' % '    '

        return '%s%s%s %s%s%s%s%s\n%s' % (qualifiers, qualifiers and ' ', struct_type, self.name,
                                          parents and ' extends ' or '', parents,
                                          interfaces and ' implements ' or '', interfaces,
                                          methods)

    def __init__(self, module, root):
        self.module = module
        self.interface = False
        self.abstract = False
        self.final = False
        self.visibility = None
        self.parents = []
        self.implements = []
        self.members = [] # fields
        self.methods = []
        self.visibility_context = []

        self.loadTree(root)

        if self.visibility is None:
            self.visibility = VISIBLE_PACKAGE

    def loadTree(self, root):
        self.name = root.name
        self.loadDecorators(root.decorators)
        self.loadBases(root.bases)
        self.loadBody(root.code)

    def loadDecorators(self, decs):
        if decs:
            for d in decs:
                if isinstance(d, ast.Name):
                    self.loadDecoration(d.name)
                    continue

                elif isinstance(d, ast.CallFunc):
                    n = d.node
                    if isinstance(n, ast.Name):
                        if n.name == 'implements':
                            a = d.args
                            assert len(a) == 1
                            a = a[0]
                            assert isinstance(a, ast.Name)

                            n = a.name
                            if n not in self.implements:
                                self.implements.append(n)

    def loadDecoration(self, name):
        if name == 'interface':
            assert not self.final
            self.interface = self.abstract = True
        elif name == 'abstract':
            assert not self.final
            self.abstract = True
        elif name == 'final':
            self.final = True
            assert not self.interface
            assert not self.abstract
        elif name == 'public':
            assert self.visibility is None
            self.visibility = VISIBLE_PUBLIC
        elif name == 'protected':
            assert self.visibility is None
            self.visibility = VISIBLE_PROTECTED
        elif name == 'private':
            assert self.visibility is None
            self.visibility = VISIBLE_PRIVATE

    def loadBases(self, bases):
        for b in bases:
            if isinstance(b, ast.Name):
                b = b.name
                assert b not in self.parents
                self.parents.append(b)

            elif isinstance(b, ast.Getattr):
                self.parents.append(getQualifiedName(b))

    @contextlib.contextmanager
    def Visibility(self, name):
        self.visibility_context.append(name)
        try: yield name
        finally: del self.visibility_context[-1]

    def currentVisibility(self):
        try: return self.visibility_context[-1]
        except IndexError: pass

    def loadBody(self, body):
        for node in body:
            if isinstance(node, ast.Function):
                self.loadMethod(node)
            elif isinstance(node, ast.Class):
                name = node.name
                if name in ['public', 'private', 'protected']:
                    with self.Visibility(name):
                        self.loadBody(node.code)

                elif name == 'static':
                    raise NotImplementedError # XXX class init

    def loadMethod(self, node):
        assert node.name not in self.methods
        self.methods.append(Method(self, node, visibility = self.currentVisibility()))

class Method:
    def __str__(self, indent = ''):
        qualifiers = []
        if self.visibility != VISIBLE_PACKAGE:
            qualifiers.append(self.visibility)
        if self.static:
            qualifiers.append('static')
        if self.synchronized:
            qualifiers.append('synchronized')

        qualifiers = ' '.join(qualifiers)
        params = []
        for (name, type, value) in self.parameters:
            if value:
                params.append('%s %s = %s' % (type, name, LiteralValue(value)))
            else:
                params.append('%s %s' % (type, name))

        params = ', '.join(params)
        if self.cls.interface:
            code = ';'
        elif self.code:
            tab = indent + '    '
            code = '%s%s;\n' % (tab, (';\n' + tab).join(self.code))
            code = '\n%s{\n%s%s}\n' % (indent, code, indent)
        else:
            code = '\n%s%s{ }\n' % (indent, '    ')

        return '%s%s%s%s %s(%s)%s' % (indent, qualifiers, qualifiers and ' ' or '',
                                      self.returnType or 'void', self.name,
                                      params, code)

    def __init__(self, cls, root, visibility = None):
        self.cls = cls
        self.parameters = []
        self.returnType = None
        self.visibility = visibility
        self.synchronized = False
        self.static = False
        self.code = []
        self.doc = None
        self.declarations = set()
        self.context = self.RecompileContext(self, Recompiler)
        self.loadTree(root)

    def loadTree(self, root):
        self.name = root.name
        self.constructor = (self.name == self.cls.name)
        self.loadSignature(root.decorators, root.argnames, root.defaults)
        self.loadDocumentation(root.doc)
        self.loadStatements(root.code)

    def loadSignature(self, decorators, argnames, defaults):
        visibility = self.visibility
        returnType = None

        if decorators:
            for d in decorators.nodes:
                if isinstance(d, ast.CallFunc):
                    n = d.node
                    if isinstance(n, ast.Name):
                        if n.name == 'returns':
                            assert not self.constructor
                            assert returnType is None
                            a = d.args
                            assert len(a) == 1
                            a = a[0]
                            assert isinstance(a, ast.Name)
                            returnType = a.name

                elif isinstance(d, ast.Name):
                    name = d.name
                    if name == 'synchronized':
                        self.synchronized = True
                    elif name == 'public':
                        assert visibility is None
                        visibility = VISIBLE_PUBLIC
                    elif name == 'private':
                        assert visibility is None
                        visibility = VISIBLE_PRIVATE
                    elif name == 'protected':
                        assert visibility is None
                        visibility = VISIBLE_PROTECTED
                    elif name == 'static':
                        assert not self.static
                        self.static = True

        if returnType is None and not self.constructor:
            # If omitted from decorators, the first argument is the return type.
            # Otherwise, it will default to 'void'
            if argnames:
                assert len(argnames) > len(defaults)
                returnType = argnames[0]
                del argnames[0]

        params = []
        for i in xrange(len(argnames)):
            # XXX defaults / argnames not the same with new return-type syntax.
            (type, value) = getVariableTypeAndValue(defaults[i])
            params.append((argnames[i], type, value))

        self.parameters.extend(params)
        self.visibility = visibility or VISIBLE_PACKAGE
        self.returnType = returnType

    def loadDocumentation(self, doc):
        self.doc = doc

    def loadStatements(self, code):
        with self.context.indentation(1):
            for stmt in code:
                self.loadSingleStatement(stmt)

    def loadSingleStatement(self, stmt):
        self.emitStatement(recompile(stmt, self.context))

    class RecompileContext:
        def __init__(self, method, parent):
            self.method = method
            self.parent = parent
            self.indent = 0

        def defaultOperation(self, *args, **kwd):
            return self.parent.defaultOperation(*args, **kwd)
        def recompile(self, node, context = None):
            return recompile(node, context or self)

        @property
        def declarations(self):
            return self.method.declarations

        @contextlib.contextmanager
        def indentation(self, amount):
            self.indent += amount
            try: yield self.indent
            finally: self.indent -= amount

    # Statement Building.
    def emitStatement(self, stmt):
        if stmt is not None:
            self.code.append(stmt)

# Node Evaluation.
def getQualifiedName(node):
    return recompile(node)

    if isinstance(node, ast.Getattr):
        return '%s.%s' % (getQualifiedName(node.expr), node.attrname)
    elif isinstance(node, ast.Name):
        return node.name

def JavaStringRepr(value):
    # Force double-quoting enclosure.
    return '"' + repr("'" + value)[2:]

    # Poor man's:
    ##    value = repr(value)
    ##    if value == '"':
    ##        return value
    ##
    ##    value = value[1:-1]
    ##    return '"%s"' % value

def getConstantNodeValue(node, context):
    # This needs to convert strings to double-quoted form.
    # Also, recognize single-character forms.
    value = node.value
    if isinstance(value, basestring):
        return JavaStringRepr(value)

    return repr(value)

def EvaluatePassStatement(node, context):
    pass # return 'do { } while (0)'

def LiteralValue(value):
    if isinstance(value, basestring):
        return JavaStringRepr(value)

    return repr(value)

def IsPrimitive(type):
    return type in ['int', 'char', 'byte', 'double', 'float']

def NeedsDeclaration(context, name):
    try: declarations = context.declarations
    except AttributeError: pass
    else:
        if name in declarations:
            return False

        declarations.add(name)
        return True

def getVariableTypeAndValue(expr):
    if isinstance(expr, ast.Name):
        return (expr.name, None)

    if isinstance(expr, ast.CallFunc):
        type = getQualifiedName(expr.node)

        values = []
        for a in expr.args:
            if isinstance(a, ast.Const):
                values.append(a.value)

        return (type, values)

    if isinstance(expr, ast.Getattr):
        return (getQualifiedName(expr), None)

    if isinstance(expr, ast.List):
        nr = 0
        while isinstance(expr, ast.List):
            expr = expr.nodes[0]
            nr += 1

        type = '%s%s' % (getQualifiedName(expr), '[]' * nr)
        return (type, None)

# Recompiling Operations Table.
def RecompileAssignment(node, context):
    left = node.nodes[0]
    right = node.expr

    # Handle variable declaration.
    if isinstance(left, ast.AssName):
        name = left.name
        if NeedsDeclaration(context, name):
            # todo: allow more complex initializations (like auto-new)
            (type, value) = getVariableTypeAndValue(right)
            if value is None:
                return '%s %s' % (type, name)

            if IsPrimitive(type):
                # Scalar initialization value.
                return '%s %s = %s' % (type, name, value[0])

            values = ', '.join(map(LiteralValue, value))
            return '%s %s = new %s(%s)' % (type, name, type, values)

        return '%s = %s' % (', '.join(map(context.recompile, node.nodes)),
                            recompile(node.expr))

    return ''

def RecompileWith(node, context):
    # with context.indentation(1):
    return 'With(%s)' % context.recompile(node.expr) # node.body
    return context.defaultOperation(node, context)

def RecompileGenExpr(node, context):
    return 'GenExpr(%s)' % context.recompile(node.code)
    return context.defaultOperation(node, context)

def RecompileGenExprInner(node, context):
    return 'GenExprInner(%s, %s)' % (context.recompile(node.expr),
                                     context.recompile(node.quals))
    return context.defaultOperation(node, context)

def RecompileGenExprFor(node, context):
    return 'GenExprFor(%s, %s, %s)' % (context.recompile(node.assign),
                                       context.recompile(node.iter),
                                       context.recompile(node.ifs))
    return context.defaultOperation(node, context)

def RecompileCallFunc(node, context):
    # Todo: factor this out for variable initialization of new object instances.
    func = node.node
    if isinstance(func, ast.Name):
        if func.name == 'new':
            assert len(node.args) > 0
            name = node.args[0]
            if isinstance(name, ast.Name):
                name = name.name
            elif isinstance(name, ast.Getattr):
                name = getQualifiedName(name)
            else:
                raise TypeError(type(name).__name__)

            return 'new %s(%s)' % (name, ', '.join(map(context.recompile, node.args[1:])))

    return '%s(%s)' % (context.recompile(node.node), ', '.join(map(context.recompile, node.args)))

def RecompileSubscript(node, context):
    # For python reserved keywords that aren't reserved in java:
    if len(node.subs) == 1:
        index = node.subs[0]
        if isinstance(index, ast.Const):
            index = index.value
            if isinstance(index, basestring):
                return '%s.%s' % (context.recompile(node.expr), index)

    return '%s[%s]' % (context.recompile(node.expr), ':'.join(map(context.recompile, node.subs)))

def RecompileJoinedOp(context, op, *segments):
    return (' %s ' % op).join(map(context.recompile, segments))
def RecompileBinaryOp(context, op, left, right):
    return RecompileJoinedOp(context, op, left, right)

class Recompiler:
    # Todo: recursive statements.
    _recompiling_table = {}
    register = _recompiling_table.__setitem__

    @classmethod
    def recompiler(self, ntype):
        def makeRecompiler(function):
            self.register(ntype, staticmethod(function))
            return function
        return makeRecompiler

    @classmethod
    def defaultOperation(self, node, context):
        if isinstance(node, (basestring, int, float)):
            return repr(node)

        # Generate a call to method 'Unknown'
        return 'Unknown(%s)' % repr(repr(node))

    @classmethod
    def recompile(self, node, context = None):
        return self._recompiling_table.get(node.__class__, self.defaultOperation)(node, context or self)

    # Table.
    register(ast.Expression, lambda node, context:context.recompile(node.node))
    register(ast.Add, lambda node, context:' + '.join(map(context.recompile, (node.left, node.right))))
    register(ast.And, lambda node, context:' && '.join(map(context.recompile, (node.nodes))))
    register(ast.AssAttr, lambda node, context:'%s.%s' % (context.recompile(node.expr), node.attrname))
    register(ast.AssList, lambda node, context:'[%s]' % (', '.join(node.nodes)))
    register(ast.AssName, lambda node, context:'%s' % node.name)
    register(ast.AssTuple, lambda node, context:'(%s)' % (', '.join(node.nodes)))
    register(ast.Assert, lambda node, context:'if (!%s) { throw new AssertionError(%s); }' % \
             (context.recompile(node.test), context.recompile(node.fail)))

    register(ast.Assign, RecompileAssignment)
    register(ast.AugAssign, lambda node, context:'%s %s= %s' % (context.recompile(node.node), node.op, context.recompile(node.expr)))

    register(ast.CallFunc, RecompileCallFunc)
    register(ast.Const, getConstantNodeValue)
    register(ast.Discard, lambda node, context:context.recompile(node.expr))
    register(ast.Getattr, lambda node, context:'%s.%s' % (context.recompile(node.expr), node.attrname))
    register(ast.GenExpr, RecompileGenExpr)
    register(ast.GenExprInner, RecompileGenExprInner)
    register(ast.GenExprFor, RecompileGenExprFor)
    register(ast.Name, lambda node, context:node.name)
    register(ast.Pass, EvaluatePassStatement)
    register(ast.Return, lambda node, context:'return %s' % context.recompile(node.value))
    register(ast.Stmt, lambda node, context:';\n'.join(map(context.recompile, node.nodes)) + ';\n')
    register(ast.Sub, lambda node, context:' - '.join(map(context.recompile, (node.left, node.right))))
    register(ast.Subscript, RecompileSubscript)
    register(ast.With, RecompileWith)

# Singleton.
recompile = Recompiler.recompile

# Prologue.
HEAD_PATTERN = r'^\s*\[(.*?)\]'
HEAD_PATTERN = re.compile(HEAD_PATTERN)

DEFAULT_HEAD_SECTION = 'CottonHead'

def ReformatPrologue(prolog):
    # Parse the section header from it, or make a new one.
    head = HEAD_PATTERN.match(prolog)
    if head is not None:
        return (head.groups()[0], prolog)

    return (DEFAULT_HEAD_SECTION, '[%s]\n%s' % (DEFAULT_HEAD_SECTION, prolog))

def ParsePrologue(prolog):
    # Implementation: INI Config
    cfg = ConfigParser.ConfigParser()
    head, prolog = ReformatPrologue(prolog)
    cfg.readfp(StringIO.StringIO(prolog), filename = '%s Prologue' % head)

    for opt in cfg.options(head):
        yield (opt, cfg.get(head, opt))

# Front-End.
def getCmdlnParser():
    parser = optparse.OptionParser()
    parser.add_option('-g', '--debug', action = 'count', default = 0)
    parser.add_option('-e', '--examine', action = 'store_true')
    parser.add_option('-c', '--compile', action = 'store_true')
    return parser

def parseCmdln(argv = None):
    return getCmdlnParser().parse_args(argv)

def main(argv = None):
    (options, args) = parseCmdln(argv)
    if options.debug:
        pdb.set_trace()

    for filename in args:
        m = Module(compiler.parseFile(filename))
        if options.compile:
            # todo: stream-oriented output recompiling methodology
            print m

        if options.examine:
            examine(m = m)

if __name__ == '__main__':
    main()