from __future__ import annotations from typing import Union import libcst as cst from libcst._nodes.module import Module DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef] def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine: """Extracts the docstring from the body of a node. Args: body: The body of a node. Returns: The docstring statement if it exists, None otherwise. """ if isinstance(body, cst.Module): body = body.body else: body = body.body.body if not body: return statement = body[0] if not isinstance(statement, cst.SimpleStatementLine): return expr = statement while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)): if len(expr.body) == 0: return None expr = expr.body[0] if not isinstance(expr, cst.Expr): return None val = expr.value if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)): return None evaluated_value = val.evaluated_value if isinstance(evaluated_value, bytes): return None return statement def has_decorator(node: DocstringNode, name: str) -> bool: return hasattr(node, "decorators") and any( (hasattr(i.decorator, "value") and i.decorator.value == name) or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name) for i in node.decorators ) class DocstringCollector(cst.CSTVisitor): """A visitor class for collecting docstrings from a CST. Attributes: stack: A list to keep track of the current path in the CST. docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. """ def __init__(self): self.stack: list[str] = [] self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {} def visit_Module(self, node: cst.Module) -> bool | None: self.stack.append("") def leave_Module(self, node: cst.Module) -> None: return self._leave(node) def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: self.stack.append(node.name.value) def leave_ClassDef(self, node: cst.ClassDef) -> None: return self._leave(node) def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: self.stack.append(node.name.value) def leave_FunctionDef(self, node: cst.FunctionDef) -> None: return self._leave(node) def _leave(self, node: DocstringNode) -> None: key = tuple(self.stack) self.stack.pop() if has_decorator(node, "overload"): return statement = get_docstring_statement(node) if statement: self.docstrings[key] = statement class DocstringTransformer(cst.CSTTransformer): """A transformer class for replacing docstrings in a CST. Attributes: stack: A list to keep track of the current path in the CST. docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. """ def __init__( self, docstrings: dict[tuple[str, ...], cst.SimpleStatementLine], ): self.stack: list[str] = [] self.docstrings = docstrings def visit_Module(self, node: cst.Module) -> bool | None: self.stack.append("") def leave_Module(self, original_node: Module, updated_node: Module) -> Module: return self._leave(original_node, updated_node) def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: self.stack.append(node.name.value) def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode: return self._leave(original_node, updated_node) def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: self.stack.append(node.name.value) def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode: return self._leave(original_node, updated_node) def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode: key = tuple(self.stack) self.stack.pop() if has_decorator(updated_node, "overload"): return updated_node statement = self.docstrings.get(key) if not statement: return updated_node original_statement = get_docstring_statement(original_node) if isinstance(updated_node, cst.Module): body = updated_node.body if original_statement: return updated_node.with_changes(body=(statement, *body[1:])) else: updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body)) return updated_node body = updated_node.body.body[1:] if original_statement else updated_node.body.body return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body))) def merge_docstring(code: str, documented_code: str) -> str: """Merges the docstrings from the documented code into the original code. Args: code: The original code. documented_code: The documented code. Returns: The original code with the docstrings from the documented code. """ code_tree = cst.parse_module(code) documented_code_tree = cst.parse_module(documented_code) visitor = DocstringCollector() documented_code_tree.visit(visitor) transformer = DocstringTransformer(visitor.docstrings) modified_tree = code_tree.visit(transformer) return modified_tree.code