Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 451d838

Browse files
dbogunowiczrahul-tulibogunowicz@arrival.com
authored
[Fix] Fully functional FSDP one-shot process (#2305)
* Update tests; diff updated on compressed tensors side * Style * Initial commit * fix the FSDP name stripping * cleanup after rebase * refactoring --------- Co-authored-by: Rahul Tuli <rahul@neuralmagic.com> Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
1 parent 56b7854 commit 451d838

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

src/sparseml/modifiers/quantization/gptq/pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sparseml.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper
2424
from sparseml.modifiers.utils.layer_compressor import LayerCompressor
2525
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward
26+
from sparseml.utils.fsdp.context import fix_fsdp_module_name
2627

2728

2829
__all__ = ["GPTQModifierPyTorch"]
@@ -116,6 +117,7 @@ def initialize_compression(
116117
self.layer_compressors_ = []
117118

118119
for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
120+
name = fix_fsdp_module_name(name)
119121
_LOGGER.info(f"Preparing {name} for compression")
120122
args = self._pruning_arguments()
121123
comp_cls = self._compression_class()

src/sparseml/pytorch/utils/sparsification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def __init__(
6969
self.state_dict = state_dict
7070

7171
if self.state_dict is not None:
72+
# when analyzing an FSDP model, the state_dict does not differentiate
73+
# between trainable and non-trainable parameters
74+
# (e.g. it can contain buffers) this means that the
75+
# self.trainable_parameters may be overestimated
7276
self.trainable_params = [param for _, param in state_dict.items()]
7377
else:
7478
self.trainable_params = list(

src/sparseml/utils/fsdp/context.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"fix_fsdp_module_name",
3131
]
3232

33-
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module."
33+
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
3434

3535

3636
def summon_full_params_context(model, offload_to_cpu: bool = False):
@@ -61,9 +61,13 @@ def main_process_first_context():
6161

6262
def fix_fsdp_module_name(name: str) -> str:
6363
"""
64-
Remove FSDP wrapper prefixes from a module name
64+
Remove FSDP wrapper prefixes from a module name.
65+
Accounts for scenario where FSDP_WRAPPER_NAME is
66+
at the end of the name, as well as in the middle.
6567
6668
:param name: name to strip
6769
:return: stripped name
6870
"""
69-
return name.replace(FSDP_WRAPPER_NAME, "")
71+
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
72+
"." + FSDP_WRAPPER_NAME, ""
73+
)

src/sparseml/utils/pytorch/module.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def get_layer(target: str, module: Module) -> Tuple[str, Module]:
188188

189189

190190
def set_layer(target: str, layer: Module, module: Module) -> Module:
191-
target = fix_fsdp_module_name(target)
192191
with summon_full_params_context(module):
193192
# importing here to avoid circular import
194193
from sparseml.utils.fsdp.helpers import maybe_get_wrapped

0 commit comments

Comments
 (0)