1 # Copyright 2015 The Chromium Authors. All rights reserved.
2 # Use of this source code is governed by a BSD-style license that can be
3 # found in the LICENSE file.
9 from catapult_base
import refactor
12 def Run(sources
, target
, files_to_update
):
13 """Move modules and update imports.
16 sources: List of source module or package paths.
17 target: Destination module or package path.
18 files_to_update: Modules whose imports we should check for changes.
20 # TODO(dtu): Support moving classes and functions.
21 moves
= tuple(_Move(source
, target
) for source
in sources
)
23 # Update imports and references.
24 refactor
.Transform(functools
.partial(_Update
, moves
), files_to_update
)
28 os
.rename(move
.source_path
, move
.target_path
)
31 def _Update(moves
, module
):
32 for import_statement
in module
.FindAll(refactor
.Import
):
35 if move
.UpdateImportAndReferences(module
, import_statement
):
37 except NotImplementedError as e
:
38 print >> sys
.stderr
, 'Error updating %s: %s' % (module
.file_path
, e
)
42 def __init__(self
, source
, target
):
43 self
._source
_path
= os
.path
.realpath(source
)
44 self
._target
_path
= os
.path
.realpath(target
)
46 if os
.path
.isdir(self
._target
_path
):
47 self
._target
_path
= os
.path
.join(
48 self
._target
_path
, os
.path
.basename(self
._source
_path
))
51 def source_path(self
):
52 return self
._source
_path
55 def target_path(self
):
56 return self
._target
_path
59 def source_module_path(self
):
60 return _ModulePath(self
._source
_path
)
63 def target_module_path(self
):
64 return _ModulePath(self
._target
_path
)
66 def UpdateImportAndReferences(self
, module
, import_statement
):
67 """Update an import statement in a module and all its references..
70 module: The refactor.Module to update.
71 import_statement: The refactor.Import to update.
74 True if the import statement was updated, or False if the import statement
77 statement_path_parts
= import_statement
.path
.split('.')
78 source_path_parts
= self
.source_module_path
.split('.')
79 if source_path_parts
!= statement_path_parts
[:len(source_path_parts
)]:
82 # Update import statement.
83 old_name_parts
= import_statement
.name
.split('.')
84 new_name_parts
= ([self
.target_module_path
] +
85 statement_path_parts
[len(source_path_parts
):])
86 import_statement
.path
= '.'.join(new_name_parts
)
87 new_name
= import_statement
.name
90 for reference
in module
.FindAll(refactor
.Reference
):
91 reference_parts
= reference
.value
.split('.')
92 if old_name_parts
!= reference_parts
[:len(old_name_parts
)]:
95 new_reference_parts
= [new_name
] + reference_parts
[len(old_name_parts
):]
96 reference
.value
= '.'.join(new_reference_parts
)
101 def _BaseDir(module_path
):
102 if not os
.path
.isdir(module_path
):
103 module_path
= os
.path
.dirname(module_path
)
105 while '__init__.py' in os
.listdir(module_path
):
106 module_path
= os
.path
.dirname(module_path
)
111 def _ModulePath(module_path
):
112 if os
.path
.split(module_path
)[1] == '__init__.py':
113 module_path
= os
.path
.dirname(module_path
)
114 rel_path
= os
.path
.relpath(module_path
, _BaseDir(module_path
))
115 return os
.path
.splitext(rel_path
)[0].replace(os
.sep
, '.')