 #
 # Copyright (C) 2011 The Reconfigurable Multi-resolutions Profiling Project
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #      http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.

import sys
import os
from operator import attrgetter

frame_num = 0

class Tags:
    def __init__(self, tag, type, line, file, keyword):
        self.tag = tag
        self.type = type
        self.line = line
        self.file = file
        self.keyword = keyword
    def __repr__(self):
        return repr((self.tag, self.type, self.line, self.file, self.keyword))

class ClassTags:
    def __init__(self, tag, line, file, keyword):
        self.tag = tag
        self.line = line
        self.file = file
        if keyword.find("extends Thread") != -1 or keyword.find("implements Runnable") != -1:   #thread
            self.thread = True
        else:
            self.thread = False  
        full_name = self.tag + ".java"
        if full_name != self.file:  #inner class
            self.inter = True
        else:
            self.inter = False
        
        if keyword.find("extends Activity") != -1:  #Android components
            self.type = "activity"
        elif keyword.find("extends Service") != -1:
            self.type = "service"
        elif keyword.find("extends BroadcastReceiver") != -1:
            self.type = "receiver"
        elif keyword.find("extends ContentProvider") != -1:
            self.type = "content"
        else:
            self.type = "normal"
    def isThread(self):
        return self.thread
    def isInterClass(self):
        return self.inter
    def getClassType(self):
        return self.type

class PackageTags:
    def __init__(self, tag, file):
        self.tag = tag
        self.file = file

class scanFramework:    #scan out the framework method which be called
    def __init__(self):
        self.framework_list = []
        self.number = 0
    def scanLine(self, line):
        name = ""
        count = 0
        for index in range(0, len(line)):
            if line[index] == "." or line[index] == " ":
                name = ""
            elif line[index] == "(":
                if name != "\n"  and name != "" and name.startswith("\"") == False and name.startswith("'") == False and line.find("<") == -1:
                    if not name in self.framework_list:
                        self.framework_list.append(name)
                count += 1
                name = ""
            elif line[index] == ")":
                count -= 1
            else:
                name += line[index]
        if count != 0:  #error case
            print("!!!count = %d" % (count))

    def output(self, fd):
        self.number += len(self.framework_list)
        for index in range(0, len(self.framework_list)):
            fd.write(self.framework_list[index] + "\n")
    def getnum(self):
        return self.number

class insSystem:
    def __init__(self):
        self.mod = "method"
        self.ins_name = None
        self.ins_block_name = None
        self.block_name_stack = []
    def setMod(self, new_mod):
        if self.mod == new_mod:
            return
        else:
            self.mod = new_mod
    def setInsName(self, package_name, class_name, method_name):
        self.ins_name = package_name + "." + class_name + "." + method_name
    def setBlockInsName(self, line_num):    #for loop
        self.block_name_stack.append(self.ins_name + "_b_" + str(line_num))
        self.ins_block_name = self.ins_name + "_b_" + str(line_num)
    def getInsName(self):
        return self.ins_name
    def getBlockInsName(self):
        return self.ins_block_name
    def getStackIns(self):
        return "String caller_name = Thread.currentThread().getCaller();\n"
    def getIns(self, reg):
        if self.mod != "method":
            return "probe(\"%s\", false, \"null\", %s);" % (self.ins_name, str(reg).lower())
        else:
            return  "probe(\"%s\", true, caller_name, %s);" % (self.ins_name, str(reg).lower())
    def getBlockIns(self):
        if self.mod != "method":
            return "probe(\"%s\", false, \"null\", true);" % (self.ins_block_name)
        else:
            return "probe(\"%s\", true, caller_name, true);" % (self.ins_block_name)
    def getBlockEndIns(self):
        if len(self.block_name_stack) == 0:
            self.block_name_stack.append("null")
        if self.mod != "method":
            return "probe(\"%s\", false, \"null\", false);" % (self.block_name_stack.pop())
        else:
            return "probe(\"%s\", true, caller_name, false);" % (self.block_name_stack.pop())
    def getBlockEndInsNoRemove(self):   #not remove from stack
        if len(self.block_name_stack) == 0:
            self.block_name_stack.append("null")
        name = self.block_name_stack.pop()
        self.block_name_stack.append(name)
        if self.mod != "method":
            return "probe(\"%s\", false, \"null\", false);" % (name)
        else:
            return "probe(\"%s\", true, caller_name, false);" % (name)
    def insJNI(self, fd):
        print("JNI")
        fd.write(r'static { System.loadLibrary("appProbes");}' + "\n")
        fd.write(r'static public native void probe(String name, boolean f_type, String cname, boolean reg);' + "\n")

def isProcessTag(tag):
    if tag == "onCreate" or tag == "onRestart" or tag == "onStart" or tag == "onResume" or tag == "onPause" or tag == "onStop" or tag == "onDestroy":
        return True
    else:
        return False

def isBlockTag(line):
    #loop
    if line.find("while (") != -1 or line.find("for (") != -1 or line.find("do {") != -1:
        return True
    
    return False

def getNextTag(each_file_tags):
    target_tag = None
    try:
        target_tag = each_file_tags.pop(0)
    except:
        target_tag = Tags("---", "---", "---", "---", "---")
    return target_tag

def insThisFile(class_sort_tags, package_sort_tags, each_file_tags, app_probes_fd, frame_list_fd):
    ins_system = insSystem()
    framework_system = scanFramework()
    line_num = 0
    tail_count = 0
    
    is_scan = False
    
    return_mod = False
    is_return_end = False
    is_return_head = True
    
    field_num = 0
    is_field_mod = False
    
    finally_brace_num = 0
    is_finally = False
    is_finally_end = False
    
    block_num = 0
    block_skip_num = 0
    is_block_mod = False
    block_break_mod = False
    block_return_mod = False
    block_return_head = True
    after_block_return_mod = False 
    
    #get tags
    target_tag = getNextTag(each_file_tags)
    class_tag = None
    package_tag = None
    #for tag in class_sort_tags:
    for index in range(0, len(class_sort_tags)):
        if target_tag.file == class_sort_tags[index].file:
            class_tag = class_sort_tags.pop(index)
            break;
    for index in range(0, len(package_sort_tags)):
        if target_tag.file == package_sort_tags[index].file:
            package_tag = package_sort_tags.pop(index)
            break;
    if class_tag == None or package_tag == None:
        print("Error, no match class or package\n")
        return -1
    #open new file
    path = sys.argv[2] + '/' + target_tag.file
    os.rename(path, path + '~')
    file_src_fd = open(path + '~', 'r')
    file_new_fd = open(path, 'w')
    
    #framework part
    src_import_fd = open(path + '~', 'r')
    for line in src_import_fd:
        if line.startswith("import"):
            line = line.strip("\n")
            line = line[len("import "):len(line) - 1]
            line = line.replace(".", "\\")
            frame_list_fd.write(line + "\n")
    frame_list_fd.write("~~~\n")
    src_import_fd.close()
    
    for line in file_src_fd:
        line_num += 1
        if line.find("context.getContentResolver().delete(uri, null, null);") != -1:
            print("T")
        #if target_tag.file == "BrowserPluginList.java":
        #    print("%s\n" % (line))
        #insert class code(definition)
        if line_num == int(class_tag.line):
            while(True):
                if line.find("{") == -1:
                    file_new_fd.write(line)
                    line = file_src_fd.readline()
                    line_num += 1
                else:
                    break;
            file_new_fd.write(line)
            ins_system.insJNI(file_new_fd)
        elif is_field_mod == True:  #skip field
            if line.find("}") != -1 and line.rstrip("\n").endswith(";") and line.find("{") == -1:
                if field_num == 1:
                    is_field_mod = False
                    field_num = 0
                    print("Field end %s" % (line))
                else:
                    field_num -= 1    
            elif line.find("new ") != -1 and line.rstrip("\n").endswith("{"):
                if line.find("})") == -1:
                    field_num += 1
            file_new_fd.write(line)
        elif is_block_mod == True:  #in loop
            if line.find("return ") != -1 or line.find("throw ") != -1:
                if block_return_head == False:
                    file_new_fd.write(ins_system.getBlockEndInsNoRemove() + "\n")
                file_new_fd.write(ins_system.getIns(False) + "\n")
                block_return_head = False
                    
                re_tail_count = 0;
                while(True):
                    re_tail_count += line.count("{");
                    re_tail_count -= line.count("}");
                    if line.find(";") == -1:
                        file_new_fd.write(line)
                        if line.find("(") != -1 and line.find(")") != -1:
                            framework_system.scanLine(line)
                        line = file_src_fd.readline()
                        line_num += 1
                    else:
                        if re_tail_count == 0:
                            break
                        else:
                            file_new_fd.write(line)
                            if line.find("(") != -1 and line.find(")") != -1:
                                framework_system.scanLine(line)
                            line = file_src_fd.readline()
                            line_num += 1
                block_return_mod = True
            elif line.find("continue;") != -1 or line.find("break;") != -1:
                if block_return_head == False:
                    file_new_fd.write(ins_system.getBlockEndIns() + "\n")
                block_return_head = False
                block_break_mod = True
            elif line.find("}") != -1 and line.find("{") == -1:
                if block_skip_num != 0:
                    block_skip_num -= 1
                else:
                    block_num -= 1
                    if block_return_mod == False:
                        if block_break_mod == False:
                            file_new_fd.write(ins_system.getBlockEndIns() + "\n")
                        else:
                            block_break_mod = False
                    else:   #skip this block
                        ins_system.getBlockEndIns()
                    
                if block_num == 0:
                    is_block_mod = False
                    block_return_head = True
                    if block_return_mod == True:
                        after_block_return_mod = True
                    else:
                        after_block_return_mod = False
                else:
                    block_return_head = False
                block_return_mod = False
            elif line.rstrip("\n").endswith("{"):
                if block_return_head == True:
                    file_new_fd.write(ins_system.getBlockIns() + "\n")
                    app_probes_fd.write("%s %s 1\n" % (ins_system.getBlockInsName(), "block"))
                    block_return_head = False
                    block_return_mod = False
                if line.find("new ") != -1:
                    print("BLOCK Field %s %s" % (line, target_tag.file))
                    field_num += 1
                    is_field_mod = True
                    block_return_head = False
                    block_return_mod = False
                else:
                    if isBlockTag(line):
                        print("Block %s %s" % (line, target_tag.file))
                        block_num += 1
                        block_return_head = True
                        block_return_mod = False
                        ins_system.setBlockInsName(line_num)
                    else:
                        block_skip_num += 1
            else:
                if block_return_head == True:
                    file_new_fd.write(ins_system.getBlockIns() + "\n")
                    app_probes_fd.write("%s %s 1\n" % (ins_system.getBlockInsName(), "block"))
                    block_return_head = False
            
            file_new_fd.write(line)
            if line.find("(") != -1 and line.find(")") != -1:
                framework_system.scanLine(line)
            
        #scan each method's content
        elif is_scan == True:
            tail_count += line.count("{");
            tail_count -= line.count("}");
            if tail_count == -1:    #block's tail
                if is_return_head == False:
                    if is_return_end == False and return_mod == False and after_block_return_mod == False and is_finally_end == False:
                        file_new_fd.write(ins_system.getIns(False) + " }\n")
                    else:
                        file_new_fd.write(line)
                        print("Tail skip %s %s" %(target_tag.tag, target_tag.file))
                else:
                    file_new_fd.write(line)
                
                is_return_head = True
                is_return_end = False
                return_mod = False
                after_block_return_mod = False
                is_finally_end = False
                
                is_scan = False
                tail_count = 0
                target_tag = getNextTag(each_file_tags)    
            elif line.find("return ") != -1 or line.find("return;") != -1 or line.find("throw ") != -1:
                #scan call who
                if is_return_head == True:
                    print("Head skip %s %s" % (target_tag.tag, target_tag.file))
                else:
                    file_new_fd.write(ins_system.getIns(False) + "\n")
                re_tail_count = 0;
                
                if line.find("return new ") != -1 and line.endswith("{\n"):
                    print("Special case for \"return new\"")
                    tail_count -= 1
                
                while(True):
                    re_tail_count += line.count("{");
                    re_tail_count -= line.count("}");
                    if line.find(";") == -1:
                        file_new_fd.write(line)
                        if line.find("(") != -1 and line.find(")") != -1:
                            framework_system.scanLine(line)
                        line = file_src_fd.readline()
                        line_num += 1
                    else:
                        if re_tail_count == 0:
                            break
                        else:
                            file_new_fd.write(line)
                            if line.find("(") != -1 and line.find(")") != -1:
                                framework_system.scanLine(line)
                            line = file_src_fd.readline()
                            line_num += 1
                file_new_fd.write(line)
                return_mod = True
                if line.find("(") != -1 and line.find(")") != -1:
                    framework_system.scanLine(line)
            else:
                if is_return_head == True:
                    if line.find("super(") == -1 and line.find("this(") == -1: #not the special case
                        app_probes_fd.write("%s %s 1\n" % (ins_system.getInsName(), ins_system.mod))
                        if ins_system.mod == "method":
                            file_new_fd.write(ins_system.getStackIns())
                        file_new_fd.write(ins_system.getIns(True) + "\n")
                        is_return_head = False
                            
                        file_new_fd.write(line)
                        if line.find("new ") != -1 and line.rstrip("\n").endswith("{"):
                            print("Field %s %s" % (line, target_tag.file))
                            tail_count -= 1
                            field_num += 1
                            is_field_mod = True
                        elif line.rstrip("\n").endswith("{") and isBlockTag(line):
                            print("Block %s %s" % (line, target_tag.file))
                            tail_count -= 1
                            block_num += 1
                            is_block_mod = True
                            block_return_head = True
                            ins_system.setBlockInsName(line_num)
                        elif line.find("(") != -1 and line.find(")") != -1:
                            framework_system.scanLine(line)

                    else: #special case super or this
                        file_new_fd.write(line)
                        app_probes_fd.write("%s %s 1\n" % (ins_system.getInsName(), ins_system.mod))
                        if ins_system.mod == "method":
                            file_new_fd.write(ins_system.getStackIns())
                        file_new_fd.write(ins_system.getIns(True) + "\n")
                        is_return_head = False
                else:           
                    file_new_fd.write(line)
                    if line.find("new ") != -1 and line.rstrip("\n").endswith("{"):
                        print("Field %s %s" % (line, target_tag.file))
                        tail_count -= 1
                        field_num += 1
                        is_field_mod = True
                    elif line.rstrip("\n").endswith("{") and isBlockTag(line):
                        print("Block %s %s" % (line, target_tag.file))
                        tail_count -= 1
                        block_num += 1
                        is_block_mod = True
                        block_return_head = True
                        ins_system.setBlockInsName(line_num)
                    elif line.find("finally {") != -1:
                        finally_brace_num = tail_count
                        is_finally = True
                    elif line.find("(") != -1 and line.find(")") != -1:
                        framework_system.scanLine(line)

                if return_mod == True and tail_count == 0:
                    is_return_end = True
                else:
                    is_return_end = False
                return_mod = False
                
                if is_finally == True and tail_count < finally_brace_num:
                    is_finally = False
                    is_finally_end = True
                else:
                    is_finally_end = False
        elif is_scan == False and target_tag.tag != "---" and int(target_tag.line) == line_num:
            if target_tag.tag == "run" and class_tag.isThread() == True:
                ins_system.setMod("thread")
            elif isProcessTag(target_tag.tag) == True and class_tag.getClassType() != "normal":
                ins_system.setMod("process")
            else:
                ins_system.setMod("method")
            ins_system.setInsName(package_tag.tag, class_tag.tag, target_tag.tag)
            
            if line.find("{") != -1 and line.find("}") != -1:
                file_new_fd.write(line)
                print("A space line skip %s %s" % (target_tag.tag, target_tag.file))
                target_tag = getNextTag(each_file_tags)
                continue
            while(True):
                if line.find("{") == -1:
                    file_new_fd.write(line)
                    line = file_src_fd.readline()
                    line_num += 1
                else:
                    break;
            file_new_fd.write(line)           
            is_scan = True
        else:
            file_new_fd.write(line)
    
    framework_system.output(frame_list_fd)
    frame_list_fd.write("---\n")
    return framework_system.getnum()
    file_src_fd.close()      
    file_new_fd.close()

def readTags():
    class_tags = []
    package_tags = []
    normal_tags = []
    
    for line in open(sys.argv[1], 'r'):
        #[0]tag [1]type [2]line [3]file [4]keyword
        data = line.split()
        #re-construct keyword        
        keyword_str = ""
        for index in range(4, len(data)):
            keyword_str = keyword_str + data[index] + " "
        data[4] = keyword_str.rstrip()
        #store tags
        if data[1] == "class":
            new_class_tag = ClassTags(data[0], data[2], data[3], data[4])
            #skip inter class
            if new_class_tag.isInterClass() == False:
                class_tags.append(new_class_tag)
        elif data[1] == "package":
            package_tags.append(PackageTags(data[0], data[3]))
        #elif data[1] == "field":
        #    if data[4].find("new") != -1 and data[4].rstrip("\n").endswith("{"):
        #        normal_tags.append(Tags(data[0], data[1], int(data[2]), data[3], data[4]))
        else:
            if data[4].endswith(";") == True:
                continue
            normal_tags.append(Tags(data[0], data[1], int(data[2]), data[3], data[4]))

    #Sort tags by file and line number
    class_sort_tags = sorted(class_tags, key=attrgetter('file'))
    package_sort_tags = sorted(package_tags, key=attrgetter('file'))
    normal_sort_tags = sorted(normal_tags, key=attrgetter('file', 'line'))

    #separate by each file
    file_name = "null"
    for tag in normal_sort_tags:  
        if file_name != str(tag.file) and file_name != "null":
            normal_sort_tags.insert(normal_sort_tags.index(tag), Tags("---", "---", "---", "---", "---"))
        file_name = tag.file
    return (class_sort_tags, package_sort_tags, normal_sort_tags)

if __name__ == "__main__":
    if len(sys.argv) < 5:
        print("Error format: autoInsApp.py <app tags>, <app src>, <app probes>, <frame list>")
        sys.exit()
    else:
        class_sort_tags, package_sort_tags, normal_sort_tags = readTags()
         
        '''for tag in class_sort_tags:
            print(tag.tag + " " + str(tag.getClassType()) + " " + str(tag.isInterClass()) + " " + str(tag.isThread()))
        for tag in package_sort_tags:
            print(tag.tag)
        for tag in normal_sort_tags:
            print(tag.tag + " "  + tag.file + " " + str(tag.line))'''
        
        each_file_tags = []
        app_probes_fd = open(sys.argv[3], 'w')
        frame_list_fd = open(sys.argv[4], 'w')
        for tag in normal_sort_tags:
            if tag.tag == "---":
                frame_num += insThisFile(class_sort_tags, package_sort_tags, each_file_tags, app_probes_fd, frame_list_fd)
                each_file_tags = []
                app_probes_fd.write("---\n")
            else:
                each_file_tags.append(tag)
        frame_num += insThisFile(class_sort_tags, package_sort_tags, each_file_tags, app_probes_fd, frame_list_fd)
        app_probes_fd.close()
        frame_list_fd.close()
        print("NUM: %s" % (frame_num))
        print("Finish")