JUnitUseExpectedRule.java
001 /**
002  * BSD-style license; for more info see http://pmd.sourceforge.net/license.html
003  */
004 package net.sourceforge.pmd.lang.java.rule.migrating;
005 
006 import java.util.ArrayList;
007 import java.util.List;
008 
009 import net.sourceforge.pmd.lang.ast.Node;
010 import net.sourceforge.pmd.lang.java.ast.ASTAnnotation;
011 import net.sourceforge.pmd.lang.java.ast.ASTBlock;
012 import net.sourceforge.pmd.lang.java.ast.ASTBlockStatement;
013 import net.sourceforge.pmd.lang.java.ast.ASTCatchStatement;
014 import net.sourceforge.pmd.lang.java.ast.ASTClassOrInterfaceBodyDeclaration;
015 import net.sourceforge.pmd.lang.java.ast.ASTMethodDeclaration;
016 import net.sourceforge.pmd.lang.java.ast.ASTName;
017 import net.sourceforge.pmd.lang.java.ast.ASTThrowStatement;
018 import net.sourceforge.pmd.lang.java.ast.ASTTryStatement;
019 import net.sourceforge.pmd.lang.java.rule.junit.AbstractJUnitRule;
020 
021 /**
022  * This rule finds code like this:
023  *
024  <pre>
025  * public void testFoo() {
026  *     try {
027  *         doSomething();
028  *         fail(&quot;should have thrown an exception&quot;);
029  *     } catch (Exception e) {
030  *     }
031  * }
032  </pre>
033  *
034  * In JUnit 4, use
035  *
036  <pre>
037  *  &#064;Test(expected = Exception.class)
038  </pre>
039  *
040  @author acaplan
041  *
042  */
043 public class JUnitUseExpectedRule extends AbstractJUnitRule {
044 
045     @Override
046     public Object visit(ASTClassOrInterfaceBodyDeclaration node, Object data) {
047         boolean inAnnotation = false;
048         for (int i = 0; i < node.jjtGetNumChildren(); i++) {
049             Node child = node.jjtGetChild(i);
050             if (child instanceof ASTAnnotation) {
051                 ASTName annotationName = child.getFirstDescendantOfType(ASTName.class);
052                 if ("Test".equals(annotationName.getImage())) {
053                     inAnnotation = true;
054                     continue;
055                 }
056             }
057             if (child instanceof ASTMethodDeclaration) {
058                 boolean isJUnitMethod = isJUnitMethod((ASTMethodDeclarationchild, data);
059                 if (inAnnotation || isJUnitMethod) {
060                     List<Node> found = new ArrayList<Node>();
061                     found.addAll((List<Node>visit((ASTMethodDeclarationchild, data));
062                     for (Node name : found) {
063                         addViolation(data, name);
064                     }
065                 }
066             }
067             inAnnotation = false;
068         }
069 
070         return super.visit(node, data);
071     }
072 
073     @Override
074     public Object visit(ASTMethodDeclaration node, Object data) {
075         List<ASTTryStatement> catches = node.findDescendantsOfType(ASTTryStatement.class);
076         List<Node> found = new ArrayList<Node>();
077         if (catches.isEmpty()) {
078             return found;
079         }
080         for (ASTTryStatement trySt : catches) {
081             ASTCatchStatement cStatement = getCatch(trySt);
082             if (cStatement != null) {
083                 ASTBlock block = (ASTBlockcStatement.jjtGetChild(1);
084                 if (block.jjtGetNumChildren() != 0) {
085                     continue;
086                 }
087                 List<ASTBlockStatement> blocks = trySt.jjtGetChild(0).findDescendantsOfType(ASTBlockStatement.class);
088                 if (blocks.isEmpty()) {
089                     continue;
090                 }
091                 ASTBlockStatement st = blocks.get(blocks.size() 1);
092                 ASTName name = st.getFirstDescendantOfType(ASTName.class);
093                 if (name != null && st.equals(name.getNthParent(5)) && "fail".equals(name.getImage())) {
094                     found.add(name);
095                     continue;
096                 }
097                 ASTThrowStatement th = st.getFirstDescendantOfType(ASTThrowStatement.class);
098                 if (th != null && st.equals(th.getNthParent(2))) {
099                     found.add(th);
100                     continue;
101                 }
102             }
103         }
104         return found;
105     }
106 
107     private ASTCatchStatement getCatch(Node n) {
108         for (int i = 0; i < n.jjtGetNumChildren(); i++) {
109             if (n.jjtGetChild(iinstanceof ASTCatchStatement) {
110                 return (ASTCatchStatementn.jjtGetChild(i);
111             }
112         }
113         return null;
114     }
115 }