from typing import List, Tuple, Dict, Optional
class TypingCollector(cst.CSTVisitor):
def __init__(self):
# stack for storing the canonical name of the current function
self.stack: List[Tuple[str, ...]] = []
# store the annotations
self.annotations: Dict[
Tuple[str, ...], # key: tuple of canonical class/function name
Tuple[cst.Parameters, Optional[cst.Annotation]], # value: (params, returns)
] = {}
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.stack.append(node.name.value)
def leave_ClassDef(self, node: cst.ClassDef) -> None:
self.stack.pop()
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.stack.append(node.name.value)
self.annotations[tuple(self.stack)] = (node.params, node.returns)
return (
False
) # pyi files don't support inner functions, return False to stop the traversal.
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
self.stack.pop()
class TypingTransformer(cst.CSTTransformer):
def __init__(self, annotations):
# stack for storing the canonical name of the current function
self.stack: List[Tuple[str, ...]] = []
# store the annotations
self.annotations: Dict[
Tuple[str, ...], # key: tuple of canonical class/function name
Tuple[cst.Parameters, Optional[cst.Annotation]], # value: (params, returns)
] = annotations
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.stack.append(node.name.value)
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
) -> cst.CSTNode:
self.stack.pop()
return updated_node
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.stack.append(node.name.value)
return (
False
) # pyi files don't support inner functions, return False to stop the traversal.
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.CSTNode:
key = tuple(self.stack)
self.stack.pop()
if key in self.annotations:
annotations = self.annotations[key]
return updated_node.with_changes(
params=annotations[0], returns=annotations[1]
)
return updated_node
visitor = TypingCollector()
stub_tree.visit(visitor)
transformer = TypingTransformer(visitor.annotations)
modified_tree = source_tree.visit(transformer)