"Fossies" - the Fresh Open Source Software Archive

Member "transformers-4.21.1/examples/research_projects/jax-projects/big_bird/bigbird_flax.py" (4 Aug 2022, 11714 Bytes) of package /linux/misc/transformers-4.21.1.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file.

    1 import json
    2 import os
    3 from dataclasses import dataclass
    4 from functools import partial
    5 from typing import Callable
    6 
    7 from tqdm.auto import tqdm
    8 
    9 import flax.linen as nn
   10 import jax
   11 import jax.numpy as jnp
   12 import joblib
   13 import optax
   14 import wandb
   15 from flax import jax_utils, struct, traverse_util
   16 from flax.serialization import from_bytes, to_bytes
   17 from flax.training import train_state
   18 from flax.training.common_utils import shard
   19 from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
   20 from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
   21 
   22 
   23 class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule):
   24     """
   25     BigBirdForQuestionAnswering with CLS Head over the top for predicting category
   26 
   27     This way we can load its weights with FlaxBigBirdForQuestionAnswering
   28     """
   29 
   30     config: BigBirdConfig
   31     dtype: jnp.dtype = jnp.float32
   32     add_pooling_layer: bool = True
   33 
   34     def setup(self):
   35         super().setup()
   36         self.cls = nn.Dense(5, dtype=self.dtype)
   37 
   38     def __call__(self, *args, **kwargs):
   39         outputs = super().__call__(*args, **kwargs)
   40         cls_out = self.cls(outputs[2])
   41         return outputs[:2] + (cls_out,)
   42 
   43 
   44 class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering):
   45     module_class = FlaxBigBirdForNaturalQuestionsModule
   46 
   47 
   48 def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels):
   49     def cross_entropy(logits, labels, reduction=None):
   50         """
   51         Args:
   52             logits: bsz, seqlen, vocab_size
   53             labels: bsz, seqlen
   54         """
   55         vocab_size = logits.shape[-1]
   56         labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4")
   57         logits = jax.nn.log_softmax(logits, axis=-1)
   58         loss = -jnp.sum(labels * logits, axis=-1)
   59         if reduction is not None:
   60             loss = reduction(loss)
   61         return loss
   62 
   63     cross_entropy = partial(cross_entropy, reduction=jnp.mean)
   64     start_loss = cross_entropy(start_logits, start_labels)
   65     end_loss = cross_entropy(end_logits, end_labels)
   66     pooled_loss = cross_entropy(pooled_logits, pooler_labels)
   67     return (start_loss + end_loss + pooled_loss) / 3
   68 
   69 
   70 @dataclass
   71 class Args:
   72     model_id: str = "google/bigbird-roberta-base"
   73     logging_steps: int = 3000
   74     save_steps: int = 10500
   75 
   76     block_size: int = 128
   77     num_random_blocks: int = 3
   78 
   79     batch_size_per_device: int = 1
   80     max_epochs: int = 5
   81 
   82     # tx_args
   83     lr: float = 3e-5
   84     init_lr: float = 0.0
   85     warmup_steps: int = 20000
   86     weight_decay: float = 0.0095
   87 
   88     save_dir: str = "bigbird-roberta-natural-questions"
   89     base_dir: str = "training-expt"
   90     tr_data_path: str = "data/nq-training.jsonl"
   91     val_data_path: str = "data/nq-validation.jsonl"
   92 
   93     def __post_init__(self):
   94         os.makedirs(self.base_dir, exist_ok=True)
   95         self.save_dir = os.path.join(self.base_dir, self.save_dir)
   96         self.batch_size = self.batch_size_per_device * jax.device_count()
   97 
   98 
   99 @dataclass
  100 class DataCollator:
  101 
  102     pad_id: int
  103     max_length: int = 4096  # no dynamic padding on TPUs
  104 
  105     def __call__(self, batch):
  106         batch = self.collate_fn(batch)
  107         batch = jax.tree_map(shard, batch)
  108         return batch
  109 
  110     def collate_fn(self, features):
  111         input_ids, attention_mask = self.fetch_inputs(features["input_ids"])
  112         batch = {
  113             "input_ids": jnp.array(input_ids, dtype=jnp.int32),
  114             "attention_mask": jnp.array(attention_mask, dtype=jnp.int32),
  115             "start_labels": jnp.array(features["start_token"], dtype=jnp.int32),
  116             "end_labels": jnp.array(features["end_token"], dtype=jnp.int32),
  117             "pooled_labels": jnp.array(features["category"], dtype=jnp.int32),
  118         }
  119         return batch
  120 
  121     def fetch_inputs(self, input_ids: list):
  122         inputs = [self._fetch_inputs(ids) for ids in input_ids]
  123         return zip(*inputs)
  124 
  125     def _fetch_inputs(self, input_ids: list):
  126         attention_mask = [1 for _ in range(len(input_ids))]
  127         while len(input_ids) < self.max_length:
  128             input_ids.append(self.pad_id)
  129             attention_mask.append(0)
  130         return input_ids, attention_mask
  131 
  132 
  133 def get_batched_dataset(dataset, batch_size, seed=None):
  134     if seed is not None:
  135         dataset = dataset.shuffle(seed=seed)
  136     for i in range(len(dataset) // batch_size):
  137         batch = dataset[i * batch_size : (i + 1) * batch_size]
  138         yield dict(batch)
  139 
  140 
  141 @partial(jax.pmap, axis_name="batch")
  142 def train_step(state, drp_rng, **model_inputs):
  143     def loss_fn(params):
  144         start_labels = model_inputs.pop("start_labels")
  145         end_labels = model_inputs.pop("end_labels")
  146         pooled_labels = model_inputs.pop("pooled_labels")
  147 
  148         outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True)
  149         start_logits, end_logits, pooled_logits = outputs
  150 
  151         return state.loss_fn(
  152             start_logits,
  153             start_labels,
  154             end_logits,
  155             end_labels,
  156             pooled_logits,
  157             pooled_labels,
  158         )
  159 
  160     drp_rng, new_drp_rng = jax.random.split(drp_rng)
  161     grad_fn = jax.value_and_grad(loss_fn)
  162     loss, grads = grad_fn(state.params)
  163     metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
  164     grads = jax.lax.pmean(grads, "batch")
  165 
  166     state = state.apply_gradients(grads=grads)
  167     return state, metrics, new_drp_rng
  168 
  169 
  170 @partial(jax.pmap, axis_name="batch")
  171 def val_step(state, **model_inputs):
  172     start_labels = model_inputs.pop("start_labels")
  173     end_labels = model_inputs.pop("end_labels")
  174     pooled_labels = model_inputs.pop("pooled_labels")
  175 
  176     outputs = state.apply_fn(**model_inputs, params=state.params, train=False)
  177     start_logits, end_logits, pooled_logits = outputs
  178 
  179     loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels)
  180     metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
  181     return metrics
  182 
  183 
  184 class TrainState(train_state.TrainState):
  185     loss_fn: Callable = struct.field(pytree_node=False)
  186 
  187 
  188 @dataclass
  189 class Trainer:
  190     args: Args
  191     data_collator: Callable
  192     train_step_fn: Callable
  193     val_step_fn: Callable
  194     model_save_fn: Callable
  195     logger: wandb
  196     scheduler_fn: Callable = None
  197 
  198     def create_state(self, model, tx, num_train_steps, ckpt_dir=None):
  199         params = model.params
  200         state = TrainState.create(
  201             apply_fn=model.__call__,
  202             params=params,
  203             tx=tx,
  204             loss_fn=calculate_loss_for_nq,
  205         )
  206         if ckpt_dir is not None:
  207             params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state)
  208             tx_args = {
  209                 "lr": args.lr,
  210                 "init_lr": args.init_lr,
  211                 "warmup_steps": args.warmup_steps,
  212                 "num_train_steps": num_train_steps,
  213                 "weight_decay": args.weight_decay,
  214             }
  215             tx, lr = build_tx(**tx_args)
  216             state = train_state.TrainState(
  217                 step=step,
  218                 apply_fn=model.__call__,
  219                 params=params,
  220                 tx=tx,
  221                 opt_state=opt_state,
  222             )
  223             self.args = args
  224             self.data_collator = data_collator
  225             self.scheduler_fn = lr
  226             model.params = params
  227         state = jax_utils.replicate(state)
  228         return state
  229 
  230     def train(self, state, tr_dataset, val_dataset):
  231         args = self.args
  232         total = len(tr_dataset) // args.batch_size
  233 
  234         rng = jax.random.PRNGKey(0)
  235         drp_rng = jax.random.split(rng, jax.device_count())
  236         for epoch in range(args.max_epochs):
  237             running_loss = jnp.array(0, dtype=jnp.float32)
  238             tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
  239             i = 0
  240             for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"):
  241                 batch = self.data_collator(batch)
  242                 state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch)
  243                 running_loss += jax_utils.unreplicate(metrics["loss"])
  244                 i += 1
  245                 if i % args.logging_steps == 0:
  246                     state_step = jax_utils.unreplicate(state.step)
  247                     tr_loss = running_loss.item() / i
  248                     lr = self.scheduler_fn(state_step - 1)
  249 
  250                     eval_loss = self.evaluate(state, val_dataset)
  251                     logging_dict = dict(
  252                         step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()
  253                     )
  254                     tqdm.write(str(logging_dict))
  255                     self.logger.log(logging_dict, commit=True)
  256 
  257                 if i % args.save_steps == 0:
  258                     self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state)
  259 
  260     def evaluate(self, state, dataset):
  261         dataloader = get_batched_dataset(dataset, self.args.batch_size)
  262         total = len(dataset) // self.args.batch_size
  263         running_loss = jnp.array(0, dtype=jnp.float32)
  264         i = 0
  265         for batch in tqdm(dataloader, total=total, desc="Evaluating ... "):
  266             batch = self.data_collator(batch)
  267             metrics = self.val_step_fn(state, **batch)
  268             running_loss += jax_utils.unreplicate(metrics["loss"])
  269             i += 1
  270         return running_loss / i
  271 
  272     def save_checkpoint(self, save_dir, state):
  273         state = jax_utils.unreplicate(state)
  274         print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
  275         self.model_save_fn(save_dir, params=state.params)
  276         with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
  277             f.write(to_bytes(state.opt_state))
  278         joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
  279         joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib"))
  280         with open(os.path.join(save_dir, "training_state.json"), "w") as f:
  281             json.dump({"step": state.step.item()}, f)
  282         print("DONE")
  283 
  284 
  285 def restore_checkpoint(save_dir, state):
  286     print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
  287     with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
  288         params = from_bytes(state.params, f.read())
  289 
  290     with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
  291         opt_state = from_bytes(state.opt_state, f.read())
  292 
  293     args = joblib.load(os.path.join(save_dir, "args.joblib"))
  294     data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
  295 
  296     with open(os.path.join(save_dir, "training_state.json"), "r") as f:
  297         training_state = json.load(f)
  298     step = training_state["step"]
  299 
  300     print("DONE")
  301     return params, opt_state, step, args, data_collator
  302 
  303 
  304 def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
  305     decay_steps = num_train_steps - warmup_steps
  306     warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
  307     decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
  308     lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
  309     return lr
  310 
  311 
  312 def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
  313     def weight_decay_mask(params):
  314         params = traverse_util.flatten_dict(params)
  315         mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
  316         return traverse_util.unflatten_dict(mask)
  317 
  318     lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)
  319 
  320     tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
  321     return tx, lr