import sys
sys.path.append("../../")
import libcst as cst
def is_call_with_booleans(node: cst.Call) -> bool:
for arg in node.args:
if not isinstance(arg.value, cst.Name):
# This can't be the literal True/False, so bail early.
return False
if cst.ensure_type(arg.value, cst.Name).value not in ("True", "False"):
# This is a Name node, but not the literal True/False, so bail.
return False
# We got here, so all arguments are literal boolean values.
return True
call_1 = cst.Call(
func=cst.Name("foo"),
args=(
cst.Arg(cst.Name("True")),
),
)
is_call_with_booleans(call_1)
call_2 = cst.Call(
func=cst.Name("foo"),
args=(
cst.Arg(cst.Name("None")),
),
)
is_call_with_booleans(call_2)
import libcst.matchers as m
def better_is_call_with_booleans(node: cst.Call) -> bool:
for arg in node.args:
if not m.matches(arg.value, m.Name("True") | m.Name("False")):
# Oops, this isn't a True/False literal!
return False
# We got here, so all arguments are literal boolean values.
return True
better_is_call_with_booleans(call_1)
better_is_call_with_booleans(call_2)
def best_is_call_with_booleans(node: cst.Call) -> bool:
return m.matches(
node,
m.Call(
args=(
m.ZeroOrMore(m.Arg(m.Name("True") | m.Name("False"))),
),
),
)
best_is_call_with_booleans(call_1)
best_is_call_with_booleans(call_2)
class BoolInverter(cst.CSTTransformer):
def __init__(self) -> None:
self.in_call: int = 0
def visit_Call(self, node: cst.Call) -> None:
if m.matches(node, m.Call(args=(
m.ZeroOrMore(m.Arg(m.Name("True") | m.Name("False"))),
))):
self.in_call += 1
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call:
if m.matches(original_node, m.Call(args=(
m.ZeroOrMore(m.Arg(m.Name("True") | m.Name("False"))),
))):
self.in_call -= 1
return updated_node
def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name:
if self.in_call > 0:
if updated_node.value == "True":
return updated_node.with_changes(value="False")
if updated_node.value == "False":
return updated_node.with_changes(value="True")
return updated_node
source = "def some_func(*params: object) -> None:\n pass\n\nsome_func(True, False)\nsome_func(1, 2, 3)\nsome_func()\n"
module = cst.parse_module(source)
print(source)
new_module = module.visit(BoolInverter())
print(new_module.code)
class BetterBoolInverter(m.MatcherDecoratableTransformer):
@m.call_if_inside(m.Call(args=(
m.ZeroOrMore(m.Arg(m.Name("True") | m.Name("False"))),
)))
def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name:
if updated_node.value == "True":
return updated_node.with_changes(value="False")
if updated_node.value == "False":
return updated_node.with_changes(value="True")
return updated_node
new_module = module.visit(BetterBoolInverter())
print(new_module.code)
class BestBoolInverter(m.MatcherDecoratableTransformer):
@m.call_if_inside(m.Call(args=(
m.ZeroOrMore(m.Arg(m.Name("True") | m.Name("False"))),
)))
@m.leave(m.Name("True") | m.Name("False"))
def invert_bool_literal(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name:
return updated_node.with_changes(value="False" if updated_node.value == "True" else "True")
new_module = module.visit(BestBoolInverter())
print(new_module.code)