"Fossies" - the Fresh Open Source Software Archive

Member "transformers-4.21.1/examples/research_projects/movement-pruning/bertarize.py" (4 Aug 2022, 5157 Bytes) of package /linux/misc/transformers-4.21.1.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file. See also the last Fossies "Diffs" side-by-side code changes report for "bertarize.py": 4.19.4_vs_4.20.0.

    1 # Copyright 2020-present, the HuggingFace Inc. team.
    2 #
    3 # Licensed under the Apache License, Version 2.0 (the "License");
    4 # you may not use this file except in compliance with the License.
    5 # You may obtain a copy of the License at
    6 #
    7 #     http://www.apache.org/licenses/LICENSE-2.0
    8 #
    9 # Unless required by applicable law or agreed to in writing, software
   10 # distributed under the License is distributed on an "AS IS" BASIS,
   11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   12 # See the License for the specific language governing permissions and
   13 # limitations under the License.
   14 """
   15 Once a model has been fine-pruned, the weights that are masked during the forward pass can be pruned once for all.
   16 For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceClassification` is trained, it can be saved (and then loaded)
   17 as a standard :class:`~transformers.BertForSequenceClassification`.
   18 """
   19 
   20 import argparse
   21 import os
   22 import shutil
   23 
   24 import torch
   25 
   26 from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
   27 
   28 
   29 def main(args):
   30     pruning_method = args.pruning_method
   31     threshold = args.threshold
   32 
   33     model_name_or_path = args.model_name_or_path.rstrip("/")
   34     target_model_path = args.target_model_path
   35 
   36     print(f"Load fine-pruned model from {model_name_or_path}")
   37     model = torch.load(os.path.join(model_name_or_path, "pytorch_model.bin"))
   38     pruned_model = {}
   39 
   40     for name, tensor in model.items():
   41         if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
   42             pruned_model[name] = tensor
   43             print(f"Copied layer {name}")
   44         elif "classifier" in name or "qa_output" in name:
   45             pruned_model[name] = tensor
   46             print(f"Copied layer {name}")
   47         elif "bias" in name:
   48             pruned_model[name] = tensor
   49             print(f"Copied layer {name}")
   50         else:
   51             if pruning_method == "magnitude":
   52                 mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
   53                 pruned_model[name] = tensor * mask
   54                 print(f"Pruned layer {name}")
   55             elif pruning_method == "topK":
   56                 if "mask_scores" in name:
   57                     continue
   58                 prefix_ = name[:-6]
   59                 scores = model[f"{prefix_}mask_scores"]
   60                 mask = TopKBinarizer.apply(scores, threshold)
   61                 pruned_model[name] = tensor * mask
   62                 print(f"Pruned layer {name}")
   63             elif pruning_method == "sigmoied_threshold":
   64                 if "mask_scores" in name:
   65                     continue
   66                 prefix_ = name[:-6]
   67                 scores = model[f"{prefix_}mask_scores"]
   68                 mask = ThresholdBinarizer.apply(scores, threshold, True)
   69                 pruned_model[name] = tensor * mask
   70                 print(f"Pruned layer {name}")
   71             elif pruning_method == "l0":
   72                 if "mask_scores" in name:
   73                     continue
   74                 prefix_ = name[:-6]
   75                 scores = model[f"{prefix_}mask_scores"]
   76                 l, r = -0.1, 1.1
   77                 s = torch.sigmoid(scores)
   78                 s_bar = s * (r - l) + l
   79                 mask = s_bar.clamp(min=0.0, max=1.0)
   80                 pruned_model[name] = tensor * mask
   81                 print(f"Pruned layer {name}")
   82             else:
   83                 raise ValueError("Unknown pruning method")
   84 
   85     if target_model_path is None:
   86         target_model_path = os.path.join(
   87             os.path.dirname(model_name_or_path), f"bertarized_{os.path.basename(model_name_or_path)}"
   88         )
   89 
   90     if not os.path.isdir(target_model_path):
   91         shutil.copytree(model_name_or_path, target_model_path)
   92         print(f"\nCreated folder {target_model_path}")
   93 
   94     torch.save(pruned_model, os.path.join(target_model_path, "pytorch_model.bin"))
   95     print("\nPruned model saved! See you later!")
   96 
   97 
   98 if __name__ == "__main__":
   99     parser = argparse.ArgumentParser()
  100 
  101     parser.add_argument(
  102         "--pruning_method",
  103         choices=["l0", "magnitude", "topK", "sigmoied_threshold"],
  104         type=str,
  105         required=True,
  106         help=(
  107             "Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
  108             " sigmoied_threshold = Soft movement pruning)"
  109         ),
  110     )
  111     parser.add_argument(
  112         "--threshold",
  113         type=float,
  114         required=False,
  115         help=(
  116             "For `magnitude` and `topK`, it is the level of remaining weights (in %) in the fine-pruned model."
  117             "For `sigmoied_threshold`, it is the threshold \tau against which the (sigmoied) scores are compared."
  118             "Not needed for `l0`"
  119         ),
  120     )
  121     parser.add_argument(
  122         "--model_name_or_path",
  123         type=str,
  124         required=True,
  125         help="Folder containing the model that was previously fine-pruned",
  126     )
  127     parser.add_argument(
  128         "--target_model_path",
  129         default=None,
  130         type=str,
  131         required=False,
  132         help="Folder containing the model that was previously fine-pruned",
  133     )
  134 
  135     args = parser.parse_args()
  136 
  137     main(args)