/*
 * Copyright (c) 2016, 2019, Oracle and/or its affiliates. All rights reserved.
 * ORACLE PROPRIETARY/CONFIDENTIAL. Use is subject to license terms.
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 *
 */

package jdk.jshell;

import com.sun.source.tree.ClassTree;
import com.sun.source.tree.MethodTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.Tree.Kind;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCBlock;
import com.sun.tools.javac.tree.JCTree.JCClassDecl;
import com.sun.tools.javac.tree.JCTree.JCMethodDecl;
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
import com.sun.tools.javac.tree.JCTree.Visitor;
import com.sun.tools.javac.util.List;
import com.sun.tools.javac.util.ListBuffer;
import static com.sun.tools.javac.code.Flags.FINAL;
import static com.sun.tools.javac.code.Flags.PUBLIC;
import static com.sun.tools.javac.code.Flags.STATIC;
import static com.sun.tools.javac.code.Flags.INTERFACE;
import static com.sun.tools.javac.code.Flags.ENUM;
import jdk.jshell.Wrap.CompoundWrap;
import jdk.jshell.Wrap.Range;
import jdk.jshell.Wrap.RangeWrap;

/**
 * Produce a corralled version of the Wrap for a snippet.
 */
class Corraller extends Visitor {

    /** Visitor result field: a Wrap
     */
    protected Wrap result;

    private final TreeDissector dis;
    private final String resolutionExceptionBlock;
    private final String source;

    public Corraller(TreeDissector dis, int keyIndex, String source) {
        this.dis = dis;
        this.resolutionExceptionBlock = "\n      { throw new jdk.jshell.spi.SPIResolutionException(" + keyIndex + "); }";
        this.source = source;
    }

    public Wrap corralType(ClassTree tree) {
        return corralToWrap(tree);
    }

    public Wrap corralMethod(MethodTree tree) {
        return corralToWrap(tree);
    }

    private Wrap corralToWrap(Tree tree) {
        try {
            JCTree jct = (JCTree) tree;
            Wrap w = new CompoundWrap(
                    "    public static\n    ",
                    corral(jct));
            debugWrap("corralToWrap SUCCESS source: %s -- wrap:\n %s\n", tree, w.wrapped());
            return w;
        } catch (Exception ex) {
            debugWrap("corralToWrap FAIL: %s - %s\n", tree, ex);
            //ex.printStackTrace(System.err);
            return null;
        }
    }

    // Corral a single node.
//    @SuppressWarnings("unchecked")
    private <T extends JCTree> Wrap corral(T tree) {
        if (tree == null) {
            return null;
        } else {
            tree.accept(this);
            Wrap tmpResult = this.result;
            this.result = null;
            return tmpResult;
        }
    }

    private String defaultConstructor(JCClassDecl tree) {
        return "  public " + tree.name.toString() + "() " +
                resolutionExceptionBlock;
    }

    /* ***************************************************************************
     * Visitor methods
     ****************************************************************************/

    @Override
    public void visitClassDef(JCClassDecl tree) {
        boolean isEnum = (tree.mods.flags & ENUM) != 0;
        boolean isInterface = (tree.mods.flags & INTERFACE ) != 0;
        int classBegin = dis.getStartPosition(tree);
        int classEnd = dis.getEndPosition(tree);
        //debugWrap("visitClassDef: %d-%d = %s\n", classBegin, classEnd, source.substring(classBegin, classEnd));
        ListBuffer<Object> wrappedDefs = new ListBuffer<>();
        int bodyBegin = -1;
        if (tree.defs != null && !tree.defs.isEmpty()) {
            if (isEnum) {
                // copy the enum constants verbatim
                int enumBegin = dis.getStartPosition(tree.defs.head);
                JCTree t = null; // null to shut-up compiler, always set because non-empty
                List<? extends JCTree> l = tree.defs;
                for (; l.nonEmpty(); l = l.tail) {
                    t = l.head;
                    if (t.getKind() == Kind.VARIABLE) {
                        if ((((JCVariableDecl)t).mods.flags & (PUBLIC | STATIC | FINAL)) != (PUBLIC | STATIC | FINAL)) {
                            // non-enum constant, process normally
                            break;
                        }
                    } else {
                        // non-variable, process normally
                        break;
                    }
                }
                int constEnd = l.nonEmpty()                  // end of constants
                        ? dis.getStartPosition(l.head) - 1   // is one before next defs, if there is one
                        : dis.getEndPosition(t);             // and otherwise end of the last constant
                wrappedDefs.append(new RangeWrap(source, new Range(enumBegin, constEnd)));
                // handle any other defs
                for (; l.nonEmpty(); l = l.tail) {
                    wrappedDefs.append("\n");
                    t = l.head;
                    wrappedDefs.append(corral(t));
                }
            } else {
                // non-enum
                boolean constructorSeen = false;
                for (List<? extends JCTree> l = tree.defs; l.nonEmpty(); l = l.tail) {
                    wrappedDefs.append("\n   ");
                    JCTree t = l.head;
                    switch (t.getKind()) {
                        case METHOD:
                            constructorSeen = constructorSeen || ((MethodTree)t).getName() == tree.name.table.names.init;
                            break;
                        case BLOCK:
                            // throw exception in instance initializer too -- inline because String not Wrap
                            wrappedDefs.append((((JCBlock)t).flags & STATIC) != 0
                                    ? new RangeWrap(source, dis.treeToRange(t))
                                    : resolutionExceptionBlock);
                            continue; // already appended, skip append below
                    }
                    wrappedDefs.append(corral(t));
                }
                if (!constructorSeen && !isInterface && !isEnum) {
                    // Generate a default constructor, since
                    // this is a regular class and there are no constructors
                    if (wrappedDefs.length() > 0) {
                        wrappedDefs.append("\n ");
                    }
                    wrappedDefs.append(defaultConstructor(tree));
                }
            }
            bodyBegin = dis.getStartPosition(tree.defs.head);
        }
        Object defs = wrappedDefs.length() == 1
            ? wrappedDefs.first()
            : new CompoundWrap(wrappedDefs.toArray());
        if (bodyBegin < 0) {
            int brace = source.indexOf('{', classBegin);
            if (brace < 0 || brace >= classEnd) {
                throw new IllegalArgumentException("No brace found: " + source.substring(classBegin, classEnd));
            }
            bodyBegin = brace + 1;
        }
        // body includes openning brace
        result = new CompoundWrap(
                new RangeWrap(source, new Range(classBegin, bodyBegin)),
                defs,
                "\n}"
        );
    }

    // Corral the body
    @Override
    public void visitMethodDef(JCMethodDecl tree) {
        int methodBegin = dis.getStartPosition(tree);
        int methodEnd = dis.getEndPosition(tree);
        //debugWrap("+visitMethodDef: %d-%d = %s\n", methodBegin, methodEnd,
        //        source.substring(methodBegin, methodEnd));
        int bodyBegin = dis.getStartPosition(tree.getBody());
        if (bodyBegin < 0) {
            bodyBegin = source.indexOf('{', methodBegin);
            if (bodyBegin > methodEnd) {
                bodyBegin = -1;
            }
        }
        if (bodyBegin > 0) {
            //debugWrap("-visitMethodDef BEGIN: %d = '%s'\n", bodyBegin,
            //        source.substring(methodBegin, bodyBegin));
            Range noBodyRange = new Range(methodBegin, bodyBegin);
            result = new CompoundWrap(
                    new RangeWrap(source, noBodyRange),
                    resolutionExceptionBlock);
        } else {
            Range range = new Range(methodBegin, methodEnd);
            result = new RangeWrap(source, range);
        }
    }

    // Remove initializer, if present
    @Override
    public void visitVarDef(JCVariableDecl tree) {
        int begin = dis.getStartPosition(tree);
        int end = dis.getEndPosition(tree);
        if (tree.init == null) {
            result = new RangeWrap(source, new Range(begin, end));
        } else {
            int sinit = dis.getStartPosition(tree.init);
            int eq = source.lastIndexOf('=', sinit);
            if (eq < begin) {
                throw new IllegalArgumentException("Equals not found before init: " + source + " @" + sinit);
            }
            result = new CompoundWrap(new RangeWrap(source, new Range(begin, eq - 1)), ";");
        }
    }

    @Override
    public void visitTree(JCTree tree) {
        throw new IllegalArgumentException("Unexpected tree: " + tree);
    }

    void debugWrap(String format, Object... args) {
        //state.debug(this, InternalDebugControl.DBG_WRAP, format, args);
    }
}
