File size: 3,604 Bytes
8c0b7ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
""" This script extracts the function that contains a specified line number in a C++ file.
"""
from argparse import ArgumentParser
import os
from typing import List, Tuple
import clang.cindex


def get_functions_at_lines(fpath: os.PathLike, lines: List[int], clang_path: os.PathLike = None) -> List[Tuple[str, Tuple[int, int]]]:
    """ Find all the functions that contain the specified lines in a file.
    """
    functions = []
    for line in lines:

        # check if we've already found this line
        if any(start <= line <= end for _, (start, end) in functions):
            continue

        function_body, range = get_function_at_line(fpath, line, clang_path=clang_path)
        if function_body:
            functions.append((function_body, range))
    return functions


def remove_macros(filename: str, line_numbers: List[int]) -> List[int]:
    """ Remove all macros from a file. Lines is a list of numbers that you would like to have mapped
        to their new line numbers after the macros are removed.
    """
    with open(filename, 'r') as f:
        lines = f.readlines()

    new_line_numbers = []
    new_lines = []
    num_removed = 0
    for i, line in enumerate(lines):
        if line.startswith('#'):
            num_removed += 1
        else:
            new_lines.append(line)

        if i in line_numbers:
            new_line_numbers.append(i - num_removed)
    
    with open(filename, 'w') as f:
        f.write(''.join(new_lines))

    return new_line_numbers


def get_function_at_line(filename, line_number, clang_path=None):
    if clang_path and not clang.cindex.Config.loaded:
        clang.cindex.Config.set_library_file(clang_path)
    index = clang.cindex.Index.create()

    try:
        translation_unit = index.parse(filename)
    except clang.cindex.TranslationUnitLoadError:
        return None, None

    def find_function(node, line_number):
        # Check if node is function-like and contains the line number
        if node.kind == clang.cindex.CursorKind.FUNCTION_DECL or node.kind == clang.cindex.CursorKind.CXX_METHOD:
            start_line = node.extent.start.line
            end_line = node.extent.end.line

            #print(f"Checking function {node.spelling} at lines {start_line} - {end_line}")

            if start_line <= line_number <= end_line:
                return node

        for child in node.get_children():
            result = find_function(child, line_number)
            if result:
                return result
        return None

    # Start from the root node (translation unit) and find the function
    function_node = find_function(translation_unit.cursor, line_number)

    if function_node:
        start_line = function_node.extent.start.line
        end_line = function_node.extent.end.line
        with open(filename, 'r') as f:
            lines = f.readlines()
        return ''.join(lines[start_line - 1:end_line]), (start_line, end_line)
    else:
        return None, None


if __name__ == "__main__":
    parser = ArgumentParser(description="Extract the function that contains a specified line number in a C++ file.")
    parser.add_argument("filename", help="The C++ file to analyze")
    parser.add_argument("line_number", type=int, help="The line number to search for")
    parser.add_argument("--clang_path", help="Path to libclang.so if necessary")
    args = parser.parse_args()

    result, rnge = get_function_at_line(args.filename, args.line_number, clang_path=args.clang_path)
    if result is None:
        result = f"No function found at line {args.line_number}"
    print(result, rnge)