/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.ttl.threadpool.agent;

import com.alibaba.ttl.TtlCallable;
import com.alibaba.ttl.TtlRunnable;
import com.alibaba.ttl.javassist.CannotCompileException;
import com.alibaba.ttl.javassist.ClassPool;
import com.alibaba.ttl.javassist.CtClass;
import com.alibaba.ttl.javassist.CtField;
import com.alibaba.ttl.javassist.CtMethod;
import com.alibaba.ttl.javassist.CtNewMethod;
import com.alibaba.ttl.javassist.LoaderClassPath;
import com.alibaba.ttl.javassist.NotFoundException;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.lang.instrument.ClassFileTransformer;
import java.lang.reflect.Modifier;
import java.security.ProtectionDomain;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

public class TtlTransformer
implements ClassFileTransformer {
    private static final Logger logger = Logger.getLogger(TtlTransformer.class.getName());
    private static final String TTL_RUNNABLE_CLASS_NAME = TtlRunnable.class.getName();
    private static final String TTL_CALLABLE_CLASS_NAME = TtlCallable.class.getName();
    private static final String RUNNABLE_CLASS_NAME = "java.lang.Runnable";
    private static final String CALLABLE_CLASS_NAME = "java.util.concurrent.Callable";
    private static final String TIMER_TASK_CLASS_NAME = "java.util.TimerTask";
    private static Set<String> EXECUTOR_CLASS_NAMES = new HashSet<String>();
    private static final String FORK_JOIN_TASK_CLASS_NAME = "java.util.concurrent.ForkJoinTask";
    private static final String TTL_RECURSIVE_ACTION_CLASS_NAME = "com.alibaba.ttl.TtlRecursiveAction";
    private static final String TTL_RECURSIVE_TASK_CLASS_NAME = "com.alibaba.ttl.TtlRecursiveTask";
    private static final byte[] EMPTY_BYTE_ARRAY;

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public byte[] transform(ClassLoader loader, String classFile, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classFileBuffer) {
        try {
            String name;
            if (classFile == null) {
                return EMPTY_BYTE_ARRAY;
            }
            String className = TtlTransformer.toClassName(classFile);
            if (EXECUTOR_CLASS_NAMES.contains(className)) {
                logger.info("Transforming class " + className);
                CtClass clazz = TtlTransformer.getCtClass(classFileBuffer, loader);
                for (CtMethod method : clazz.getDeclaredMethods()) {
                    TtlTransformer.updateMethodOfExecutorClass(clazz, method);
                }
                return clazz.toBytecode();
            }
            if (FORK_JOIN_TASK_CLASS_NAME.equals(className)) {
                logger.info("Transforming class " + className);
                CtClass clazz = TtlTransformer.getCtClass(classFileBuffer, loader);
                TtlTransformer.updateForkJoinTaskClass(className, clazz);
                return clazz.toBytecode();
            }
            if (!TIMER_TASK_CLASS_NAME.equals(className)) return EMPTY_BYTE_ARRAY;
            CtClass clazz = TtlTransformer.getCtClass(classFileBuffer, loader);
            do {
                name = clazz.getSuperclass().getName();
                if (Object.class.getName().equals(name)) return EMPTY_BYTE_ARRAY;
            } while (!TIMER_TASK_CLASS_NAME.equals(name));
            logger.info("Transforming class " + className);
            return EMPTY_BYTE_ARRAY;
        }
        catch (Throwable t) {
            String msg = "Fail to transform class " + classFile + ", cause: " + t.toString();
            if (!logger.isLoggable(Level.SEVERE)) throw new IllegalStateException(msg, t);
            logger.log(Level.SEVERE, msg, t);
            throw new IllegalStateException(msg, t);
        }
    }

    private static String toClassName(String classFile) {
        return classFile.replace('/', '.');
    }

    private static CtClass getCtClass(byte[] classFileBuffer, ClassLoader classLoader) throws IOException {
        ClassPool classPool = new ClassPool(true);
        if (classLoader == null) {
            classPool.appendClassPath(new LoaderClassPath(ClassLoader.getSystemClassLoader()));
        } else {
            classPool.appendClassPath(new LoaderClassPath(classLoader));
        }
        CtClass clazz = classPool.makeClass(new ByteArrayInputStream(classFileBuffer), false);
        clazz.defrost();
        return clazz;
    }

    private static void updateMethodOfExecutorClass(CtClass clazz, CtMethod method) throws NotFoundException, CannotCompileException {
        if (method.getDeclaringClass() != clazz) {
            return;
        }
        int modifiers = method.getModifiers();
        if (!Modifier.isPublic(modifiers) || Modifier.isStatic(modifiers)) {
            return;
        }
        CtClass[] parameterTypes = method.getParameterTypes();
        StringBuilder insertCode = new StringBuilder();
        for (int i = 0; i < parameterTypes.length; ++i) {
            String code;
            CtClass paraType = parameterTypes[i];
            if (RUNNABLE_CLASS_NAME.equals(paraType.getName())) {
                code = String.format("$%d = %s.get($%d, false, true);", i + 1, TTL_RUNNABLE_CLASS_NAME, i + 1);
                logger.info("insert code before method " + method + " of class " + method.getDeclaringClass().getName() + ": " + code);
                insertCode.append(code);
                continue;
            }
            if (!CALLABLE_CLASS_NAME.equals(paraType.getName())) continue;
            code = String.format("$%d = %s.get($%d, false, true);", i + 1, TTL_CALLABLE_CLASS_NAME, i + 1);
            logger.info("insert code before method " + method + " of class " + method.getDeclaringClass().getName() + ": " + code);
            insertCode.append(code);
        }
        if (insertCode.length() > 0) {
            method.insertBefore(insertCode.toString());
        }
    }

    private static void updateForkJoinTaskClass(String className, CtClass clazz) throws CannotCompileException, NotFoundException {
        String capturedFieldName = "captured$field$add$by$ttl";
        CtField capturedField = CtField.make("private final java.lang.Object captured$field$add$by$ttl;", clazz);
        clazz.addField(capturedField, "com.alibaba.ttl.TransmittableThreadLocal.Transmitter.capture();");
        logger.info("add new field captured$field$add$by$ttl to class " + className);
        String doExec_methodName = "doExec";
        CtMethod doExecMethod = clazz.getDeclaredMethod("doExec");
        CtMethod new_doExecMethod = CtNewMethod.copy(doExecMethod, "doExec", clazz, null);
        String original_doExec_method_rename = "original$doExec$method$renamed$by$ttl";
        doExecMethod.setName("original$doExec$method$renamed$by$ttl");
        doExecMethod.setModifiers(doExecMethod.getModifiers() & 0xFFFFFFFE | 2);
        String code = "{\nif (this instanceof com.alibaba.ttl.TtlRecursiveAction || this instanceof com.alibaba.ttl.TtlRecursiveTask) {\n    return original$doExec$method$renamed$by$ttl($$);\n}\njava.lang.Object backup = com.alibaba.ttl.TransmittableThreadLocal.Transmitter.replay(captured$field$add$by$ttl);\ntry {\n    return original$doExec$method$renamed$by$ttl($$);\n} finally {\n    com.alibaba.ttl.TransmittableThreadLocal.Transmitter.restore(backup);\n}\n}";
        new_doExecMethod.setBody("{\nif (this instanceof com.alibaba.ttl.TtlRecursiveAction || this instanceof com.alibaba.ttl.TtlRecursiveTask) {\n    return original$doExec$method$renamed$by$ttl($$);\n}\njava.lang.Object backup = com.alibaba.ttl.TransmittableThreadLocal.Transmitter.replay(captured$field$add$by$ttl);\ntry {\n    return original$doExec$method$renamed$by$ttl($$);\n} finally {\n    com.alibaba.ttl.TransmittableThreadLocal.Transmitter.restore(backup);\n}\n}");
        clazz.addMethod(new_doExecMethod);
        logger.info("insert code around method " + doExecMethod + " of class " + className + ": " + "{\nif (this instanceof com.alibaba.ttl.TtlRecursiveAction || this instanceof com.alibaba.ttl.TtlRecursiveTask) {\n    return original$doExec$method$renamed$by$ttl($$);\n}\njava.lang.Object backup = com.alibaba.ttl.TransmittableThreadLocal.Transmitter.replay(captured$field$add$by$ttl);\ntry {\n    return original$doExec$method$renamed$by$ttl($$);\n} finally {\n    com.alibaba.ttl.TransmittableThreadLocal.Transmitter.restore(backup);\n}\n}");
    }

    static {
        EXECUTOR_CLASS_NAMES.add("java.util.concurrent.ThreadPoolExecutor");
        EXECUTOR_CLASS_NAMES.add("java.util.concurrent.ScheduledThreadPoolExecutor");
        EMPTY_BYTE_ARRAY = new byte[0];
    }
}

