diff --git a/CHANGELOG.md b/CHANGELOG.md index cd0c47c1a..28df1506a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - #719 Allows the in-memory db to be shared across threads (@tkrabel) - #720 create one sqlite3.Connection per thread using a thread local (@tkrabel) - #715 change AutoImport's `get_modules` to be case sensitive (@bagel897) +- #734 raise exception when extracting the start of a block without the end # Release 1.10.0 diff --git a/rope/refactor/extract.py b/rope/refactor/extract.py index c16c7b0c1..1b1659fad 100644 --- a/rope/refactor/extract.py +++ b/rope/refactor/extract.py @@ -444,8 +444,10 @@ def __call__(self, info): def base_conditions(self, info): if info.region[1] > info.scope_region[1]: raise RefactoringError("Bad region selected for extract method") + end_line = info.region_lines[1] end_scope = info.global_scope.get_inner_scope_for_line(end_line) + if end_scope != info.scope and end_scope.get_end() != end_line: raise RefactoringError("Bad region selected for extract method") try: @@ -497,6 +499,14 @@ def multi_line_conditions(self, info): raise RefactoringError( "Extracted piece should contain complete statements." ) + unbalanced_region_finder = _UnbalancedRegionFinder( + info.region_lines[0], info.region_lines[1] + ) + unbalanced_region_finder.visit(info.pymodule.ast_node) + if unbalanced_region_finder.error: + raise RefactoringError( + "Extracted piece cannot contain the start of a block without the end." + ) def _is_region_on_a_word(self, info): if ( @@ -1093,6 +1103,34 @@ def _ClassDef(self, node): pass +class _UnbalancedRegionFinder(_BaseErrorFinder): + """ + Flag an error if we are including the start of a block without the end. + We detect this by ensuring there is no AST node that starts inside the + selected range but ends outside of it. + """ + + def __init__(self, line_start: int, line_end: int): + self.error = False + self.line_start = line_start + self.line_end = line_end + + def generic_visit(self, node: ast.AST): + if not hasattr(node, "end_lineno"): + super().generic_visit(node) # Visit children + return + ends_before_range_starts = node.end_lineno < self.line_start + starts_after_range_ends = node.lineno > self.line_end + if ends_before_range_starts or starts_after_range_ends: + return # Don't visit children + starts_on_or_after_range_start = node.lineno >= self.line_start + ends_after_range_ends = node.end_lineno > self.line_end + if starts_on_or_after_range_start and ends_after_range_ends: + self.error = True + return # Don't visit children + super().generic_visit(node) # Visit children + + class _GlobalFinder(ast.RopeNodeVisitor): def __init__(self): self.globals_ = OrderedSet() diff --git a/ropetest/refactor/extracttest.py b/ropetest/refactor/extracttest.py index ed4054c06..e8c2fc206 100644 --- a/ropetest/refactor/extracttest.py +++ b/ropetest/refactor/extracttest.py @@ -1149,6 +1149,64 @@ def xxx_test_raising_exception_on_function_parens(self): end = code.rindex(")") + 1 with self.assertRaises(rope.base.exceptions.RefactoringError): self.do_extract_method(code, start, end, "new_func") + + def test_raising_exception_on_incomplete_block(self): + code = dedent("""\ + if True: + a = 1 + b = 2 + """) + start = code.index("if") + end = code.index("1") + 1 + with self.assertRaises(rope.base.exceptions.RefactoringError): + self.do_extract_method(code, start, end, "new_func") + + def test_raising_exception_on_incomplete_block_2(self): + code = dedent("""\ + if True: + a = 1 + # + b = 2 + """) + start = code.index("if") + end = code.index("1") + 1 + with self.assertRaises(rope.base.exceptions.RefactoringError): + self.do_extract_method(code, start, end, "new_func") + + def test_raising_exception_on_incomplete_block_3(self): + code = dedent("""\ + if True: + a = 1 + + b = 2 + """) + start = code.index("if") + end = code.index("1") + 1 + with self.assertRaises(rope.base.exceptions.RefactoringError): + self.do_extract_method(code, start, end, "new_func") + + def test_raising_exception_on_incomplete_block_4(self): + code = dedent("""\ + # + if True: + a = 1 + b = 2 + """) + start = code.index("#") + end = code.index("1") + 1 + with self.assertRaises(rope.base.exceptions.RefactoringError): + self.do_extract_method(code, start, end, "new_func") + + def test_raising_exception_on_incomplete_block_5(self): + code = dedent("""\ + if True: + if 0: + a = 1 + """) + start = code.index("if") + end = code.index("0:") + 2 + with self.assertRaises(rope.base.exceptions.RefactoringError): + self.do_extract_method(code, start, end, "new_func") def test_extract_method_and_extra_blank_lines(self): code = dedent("""\