Saturday, August 28, 2010

A Simple Python NodeVisitor Example

The Python ast module helps parsing python code into its abstract syntax tree and manipulating the tree. This is immensely useful, e.g., when you want to inspect code for safety or correctness. This page walks you through a very simple example to get you up and running quickly.

A Very Simple Parser

The first step to analyzing a program is parsing the textual source code into an in-memory walkable tree. Module ast takes away the hard task of implementing a generic lexer and python parser with the function ast.parse(). The module also takes care of walking the tree. All you have to do is implement analysis code for the type of statements you are want to analyze. To do this, you implement a class that derives from class ast.NodeVisitor and only override the member functions that are pertinent to your analysis.

NodeVisitor exposes callback member functions that the tree walker ("node visitor") calls when it encounters particular python statements, one for each type of statement in the language. For instance, NodeVisitor.visit_Import is called each time the waloker encounters an import statement.

Enough with the theory..

As code is often the best documentation, let's look at a parser that does one thing: print a message for each import statement

 
import ast

class FirstParser(ast.NodeVisitor):

    def __init__(self):
        pass

    def visit_Import(self, stmt_import):
      # retrieve the name from the returned object
      # normally, there is just a single alias
      for alias in stmt_import.names:
        print 'import name "%s"' % alias.name
        print 'import object %s % alias

      # allow parser to continue to parse the statement's children
      super(FirstParser, self).generic_visit(self, stmt_import)
For the code snippet "import foo", this produces
 
import name "foo"
import object <_ast.alias object at 0x7f05b871a690>

Implementing visit_.. Callbacks

You can define callbacks of the form visit_<type> for each of the left-hand symbols and right-hand constructors defined in the abstract grammar for python. Thus, visit_stmt and visit_Import are both valid callbacks. Left-hand symbols may be abstract classes that are never called directly, however. Instead, their concrete implementations listed on the right are: visit_stmt is never called, but visit_Import implements a concrete type of statement and will be called for all import statements.

In the common case, when a node has no associated visit_<type> member, the parser calls the member generic_visit, which ensures that the walk recurses to the children of that node -- for which you may have a member defined, even if you did not define one for the node. When you override a member, that function is called and generic_visit is no longer called automatically. You are responsible for ensuring that the children are called, by calling generic_visit explicitly (unless you expressly intended to stop recursion) in your member.

Using the returned objects

Each callback function visit_<type>(self, object) returns with an object of a class particular to the given type. All classes derive from the abstract class ast.AST. As a result, each has a member _fields, along with members specific to the class. The names member shown in the first example is specific to visit_Import, for instance. Note that this corresponds to the argument of the Import constructor in the syntax. In general, I believe that these arguments are the class-specific members, although I could not find any definite documentation on this.

The _fields Member

The following snippet gives an example of how iterating of _fields returns all children of a node. Given the input "a = b + 1", the member function
 
    def visit_BinOp(self, stmt_binop):
      for child in ast.iter_fields(stmt_expr):
        print 'child %s' % str(child)

generates the output

 
  child ('left', <_ast.Name object at 0x7f05b871a710>) 
  child ('op', <_ast.Add object at 0x7f05b8715610>) 
  child ('right', <_ast.Num object at 0x7f05b871a750>) 
For each child, the generator returns a tuple consisting of name and child object. A quick look at the abstract syntax grammar shows that indeed all child classes again correspond to symbols in the grammar: Name, Add and Num.

Calling the Parser

This brings us to the last step: how to actually pass input to the parser and generate output. Assuming you have a string containing Python code, this string is parsed into an in-memory tree and the tree walked with your callbacks using:


 
code = "a = b + 5"
tree = ast.parse(code)
parser = FirstParser()
parser.visit(tree)

A Warning on Modifying Code

Tree walking is not just useful for inspecting code, you can also use it to modify the parse tree. The reference documentation (see below) is very clear on the fact that you cannot use NodeVisitor for this purpose. Instead, derive from the NodeTransformer class, whose members are expected to return a replacement object for each object with which they are called.

Further Reading


Feedback

I wrote this mini tutorial, because I failed to find one when I first started using the ast module. That said, I'm no expert at it and not even a full-time Python programmer. If you spot errors or see room for improvement, don't hesitate to post a message.

Complete Example

The snippets above combine into the following example, which contains minor tweaks to avoid code duplication and improve readability:
 
import ast

class FirstParser(ast.NodeVisitor):

    def __init__(self):
        pass

    def continue(self, stmt):
        '''Helper: parse a node's children'''
        super(FirstParser, self).generic_visit(stmt)

    def parse(self, code):
        '''Parse text into a tree and walk the result'''  
        tree = ast.parse(code)
        self.visit(tree)

    def visit_Import(self, stmt_import):
        # retrieve the name from the returned object
        # normally, there is just a single alias
        for alias in stmt_import.names:
            print 'import name "%s"' % alias.name
            print 'import object %s' % alias

        self.continue(stmt_binop)

    def visit_BinOp(self, stmt_binop):
        print 'expression: '
        for child in ast.iter_fields(stmt_binop):
            print '  child %s ' % str(child)

        self.continue(stmt_binop)


parser = FirstParser()
parser.parse('import foo')
parser.parse('a = b + 5')

No comments:

Post a Comment