"Fossies" - the Fresh Open Source Software Archive 
Member "transformers-4.21.1/examples/research_projects/jax-projects/big_bird/bigbird_flax.py" (4 Aug 2022, 11714 Bytes) of package /linux/misc/transformers-4.21.1.tar.gz:
As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style:
standard) with prefixed line numbers.
Alternatively you can here
view or
download the uninterpreted source code file.
1 import json
2 import os
3 from dataclasses import dataclass
4 from functools import partial
5 from typing import Callable
6
7 from tqdm.auto import tqdm
8
9 import flax.linen as nn
10 import jax
11 import jax.numpy as jnp
12 import joblib
13 import optax
14 import wandb
15 from flax import jax_utils, struct, traverse_util
16 from flax.serialization import from_bytes, to_bytes
17 from flax.training import train_state
18 from flax.training.common_utils import shard
19 from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
20 from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
21
22
23 class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule):
24 """
25 BigBirdForQuestionAnswering with CLS Head over the top for predicting category
26
27 This way we can load its weights with FlaxBigBirdForQuestionAnswering
28 """
29
30 config: BigBirdConfig
31 dtype: jnp.dtype = jnp.float32
32 add_pooling_layer: bool = True
33
34 def setup(self):
35 super().setup()
36 self.cls = nn.Dense(5, dtype=self.dtype)
37
38 def __call__(self, *args, **kwargs):
39 outputs = super().__call__(*args, **kwargs)
40 cls_out = self.cls(outputs[2])
41 return outputs[:2] + (cls_out,)
42
43
44 class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering):
45 module_class = FlaxBigBirdForNaturalQuestionsModule
46
47
48 def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels):
49 def cross_entropy(logits, labels, reduction=None):
50 """
51 Args:
52 logits: bsz, seqlen, vocab_size
53 labels: bsz, seqlen
54 """
55 vocab_size = logits.shape[-1]
56 labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4")
57 logits = jax.nn.log_softmax(logits, axis=-1)
58 loss = -jnp.sum(labels * logits, axis=-1)
59 if reduction is not None:
60 loss = reduction(loss)
61 return loss
62
63 cross_entropy = partial(cross_entropy, reduction=jnp.mean)
64 start_loss = cross_entropy(start_logits, start_labels)
65 end_loss = cross_entropy(end_logits, end_labels)
66 pooled_loss = cross_entropy(pooled_logits, pooler_labels)
67 return (start_loss + end_loss + pooled_loss) / 3
68
69
70 @dataclass
71 class Args:
72 model_id: str = "google/bigbird-roberta-base"
73 logging_steps: int = 3000
74 save_steps: int = 10500
75
76 block_size: int = 128
77 num_random_blocks: int = 3
78
79 batch_size_per_device: int = 1
80 max_epochs: int = 5
81
82 # tx_args
83 lr: float = 3e-5
84 init_lr: float = 0.0
85 warmup_steps: int = 20000
86 weight_decay: float = 0.0095
87
88 save_dir: str = "bigbird-roberta-natural-questions"
89 base_dir: str = "training-expt"
90 tr_data_path: str = "data/nq-training.jsonl"
91 val_data_path: str = "data/nq-validation.jsonl"
92
93 def __post_init__(self):
94 os.makedirs(self.base_dir, exist_ok=True)
95 self.save_dir = os.path.join(self.base_dir, self.save_dir)
96 self.batch_size = self.batch_size_per_device * jax.device_count()
97
98
99 @dataclass
100 class DataCollator:
101
102 pad_id: int
103 max_length: int = 4096 # no dynamic padding on TPUs
104
105 def __call__(self, batch):
106 batch = self.collate_fn(batch)
107 batch = jax.tree_map(shard, batch)
108 return batch
109
110 def collate_fn(self, features):
111 input_ids, attention_mask = self.fetch_inputs(features["input_ids"])
112 batch = {
113 "input_ids": jnp.array(input_ids, dtype=jnp.int32),
114 "attention_mask": jnp.array(attention_mask, dtype=jnp.int32),
115 "start_labels": jnp.array(features["start_token"], dtype=jnp.int32),
116 "end_labels": jnp.array(features["end_token"], dtype=jnp.int32),
117 "pooled_labels": jnp.array(features["category"], dtype=jnp.int32),
118 }
119 return batch
120
121 def fetch_inputs(self, input_ids: list):
122 inputs = [self._fetch_inputs(ids) for ids in input_ids]
123 return zip(*inputs)
124
125 def _fetch_inputs(self, input_ids: list):
126 attention_mask = [1 for _ in range(len(input_ids))]
127 while len(input_ids) < self.max_length:
128 input_ids.append(self.pad_id)
129 attention_mask.append(0)
130 return input_ids, attention_mask
131
132
133 def get_batched_dataset(dataset, batch_size, seed=None):
134 if seed is not None:
135 dataset = dataset.shuffle(seed=seed)
136 for i in range(len(dataset) // batch_size):
137 batch = dataset[i * batch_size : (i + 1) * batch_size]
138 yield dict(batch)
139
140
141 @partial(jax.pmap, axis_name="batch")
142 def train_step(state, drp_rng, **model_inputs):
143 def loss_fn(params):
144 start_labels = model_inputs.pop("start_labels")
145 end_labels = model_inputs.pop("end_labels")
146 pooled_labels = model_inputs.pop("pooled_labels")
147
148 outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True)
149 start_logits, end_logits, pooled_logits = outputs
150
151 return state.loss_fn(
152 start_logits,
153 start_labels,
154 end_logits,
155 end_labels,
156 pooled_logits,
157 pooled_labels,
158 )
159
160 drp_rng, new_drp_rng = jax.random.split(drp_rng)
161 grad_fn = jax.value_and_grad(loss_fn)
162 loss, grads = grad_fn(state.params)
163 metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
164 grads = jax.lax.pmean(grads, "batch")
165
166 state = state.apply_gradients(grads=grads)
167 return state, metrics, new_drp_rng
168
169
170 @partial(jax.pmap, axis_name="batch")
171 def val_step(state, **model_inputs):
172 start_labels = model_inputs.pop("start_labels")
173 end_labels = model_inputs.pop("end_labels")
174 pooled_labels = model_inputs.pop("pooled_labels")
175
176 outputs = state.apply_fn(**model_inputs, params=state.params, train=False)
177 start_logits, end_logits, pooled_logits = outputs
178
179 loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels)
180 metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
181 return metrics
182
183
184 class TrainState(train_state.TrainState):
185 loss_fn: Callable = struct.field(pytree_node=False)
186
187
188 @dataclass
189 class Trainer:
190 args: Args
191 data_collator: Callable
192 train_step_fn: Callable
193 val_step_fn: Callable
194 model_save_fn: Callable
195 logger: wandb
196 scheduler_fn: Callable = None
197
198 def create_state(self, model, tx, num_train_steps, ckpt_dir=None):
199 params = model.params
200 state = TrainState.create(
201 apply_fn=model.__call__,
202 params=params,
203 tx=tx,
204 loss_fn=calculate_loss_for_nq,
205 )
206 if ckpt_dir is not None:
207 params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state)
208 tx_args = {
209 "lr": args.lr,
210 "init_lr": args.init_lr,
211 "warmup_steps": args.warmup_steps,
212 "num_train_steps": num_train_steps,
213 "weight_decay": args.weight_decay,
214 }
215 tx, lr = build_tx(**tx_args)
216 state = train_state.TrainState(
217 step=step,
218 apply_fn=model.__call__,
219 params=params,
220 tx=tx,
221 opt_state=opt_state,
222 )
223 self.args = args
224 self.data_collator = data_collator
225 self.scheduler_fn = lr
226 model.params = params
227 state = jax_utils.replicate(state)
228 return state
229
230 def train(self, state, tr_dataset, val_dataset):
231 args = self.args
232 total = len(tr_dataset) // args.batch_size
233
234 rng = jax.random.PRNGKey(0)
235 drp_rng = jax.random.split(rng, jax.device_count())
236 for epoch in range(args.max_epochs):
237 running_loss = jnp.array(0, dtype=jnp.float32)
238 tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
239 i = 0
240 for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"):
241 batch = self.data_collator(batch)
242 state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch)
243 running_loss += jax_utils.unreplicate(metrics["loss"])
244 i += 1
245 if i % args.logging_steps == 0:
246 state_step = jax_utils.unreplicate(state.step)
247 tr_loss = running_loss.item() / i
248 lr = self.scheduler_fn(state_step - 1)
249
250 eval_loss = self.evaluate(state, val_dataset)
251 logging_dict = dict(
252 step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()
253 )
254 tqdm.write(str(logging_dict))
255 self.logger.log(logging_dict, commit=True)
256
257 if i % args.save_steps == 0:
258 self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state)
259
260 def evaluate(self, state, dataset):
261 dataloader = get_batched_dataset(dataset, self.args.batch_size)
262 total = len(dataset) // self.args.batch_size
263 running_loss = jnp.array(0, dtype=jnp.float32)
264 i = 0
265 for batch in tqdm(dataloader, total=total, desc="Evaluating ... "):
266 batch = self.data_collator(batch)
267 metrics = self.val_step_fn(state, **batch)
268 running_loss += jax_utils.unreplicate(metrics["loss"])
269 i += 1
270 return running_loss / i
271
272 def save_checkpoint(self, save_dir, state):
273 state = jax_utils.unreplicate(state)
274 print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
275 self.model_save_fn(save_dir, params=state.params)
276 with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
277 f.write(to_bytes(state.opt_state))
278 joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
279 joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib"))
280 with open(os.path.join(save_dir, "training_state.json"), "w") as f:
281 json.dump({"step": state.step.item()}, f)
282 print("DONE")
283
284
285 def restore_checkpoint(save_dir, state):
286 print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
287 with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
288 params = from_bytes(state.params, f.read())
289
290 with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
291 opt_state = from_bytes(state.opt_state, f.read())
292
293 args = joblib.load(os.path.join(save_dir, "args.joblib"))
294 data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
295
296 with open(os.path.join(save_dir, "training_state.json"), "r") as f:
297 training_state = json.load(f)
298 step = training_state["step"]
299
300 print("DONE")
301 return params, opt_state, step, args, data_collator
302
303
304 def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
305 decay_steps = num_train_steps - warmup_steps
306 warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
307 decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
308 lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
309 return lr
310
311
312 def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
313 def weight_decay_mask(params):
314 params = traverse_util.flatten_dict(params)
315 mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
316 return traverse_util.unflatten_dict(mask)
317
318 lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)
319
320 tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
321 return tx, lr