%%capture
!pip install pytorch-lightning
!pip install torchmetrics
!pip install transformers
!pip install datasets
PhoBERT Vietnamese Sentiment Analysis on UIT-VSFC dataset with transformers and Pytorch Lightning
The full notebook is available here.
Introduction
PhoBERT: Pre-trained language models for Vietnamese
PhoBERT models are the SOTA language models for Vietnamese. There are two versions of PhoBERT, which are PhoBERT base and PhoBERT large. Their pretraining approach is based on RoBERTa which optimizes the BERT pre-training procedure for more robust performance. PhoBERT has achieved SOTA in many downstream task such as POS, Dependency parsing, NER and NLI. You can read more about the PhoBERT here.
UIT-VSFC: Vietnamese Students’ Feedback Corpus
Vietnamese Students’ Feedback Corpus (UIT-VSFC) is the resource consists of over 16,000 sentences which are human-annotated with two different tasks: sentiment-based and topic-based classifications.
In this project, we will apply PhoBERT to do the sentiment classification task on UIT-VSFC dataset. We will use pytorch-lightning and transformers for this project.
Install required packages
Import required packages
import os
import zipfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.request import urlretrieve
import pandas as pd
from tqdm import tqdm
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
from datasets import load_dataset
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
DataCollatorWithPadding)
# For repoducibility
=2401, workers=True) pl.utilities.seed.seed_everything(seed
Global seed set to 2401
2401
Define dataset, dataloader class and utility functions
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize # pylint: disable=attribute-defined-outside-init
self.update(blocks * bsize - self.n) # will also set self.n = b * bsize
def download_url(url, filename, directory='.'):
"""Download a file from url to filename, with a progress bar."""
if not os.path.exists(directory):
os.makedirs(directory)= os.path.join(directory, filename)
path
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
=t.update_to, data=None) # nosec
urlretrieve(url, path, reporthookreturn path
def _load_data_from(data_dir: Union[str, Path]):
= ['sentiments.txt', 'sents.txt', 'topics.txt']
fnames = []
sentiments = []
sents = []
topics for name in fnames:
with open(f"{data_dir}/{name}", 'r') as f:
if name == "sentiments.txt":
= [int(line.strip()) for line in f.readlines()]
sentiments elif name == "sents.txt":
= [line.strip() for line in f.readlines()]
sents else:
= [int(line.strip()) for line in f.readlines()]
topics return sents, sentiments, topics
def _save_to_csv(file_path: Union[str, Path], data):
= data
sents, sentiments, topics = pd.DataFrame({
df "sents": sents,
"labels": sentiments,
"topics": topics
})=False)
df.to_csv(file_path, indexreturn file_path
Define the UIT_VSFC datamodule class. You can read more here.
= "https://drive.google.com/uc?export=download&id=1zg7cbRF2nFuJ2Q-AB63xlKuwEX3dTBsx"
DS_URL
class UIT_VSFC(pl.LightningDataModule):
"""
The Twitter dataset is ndwritten character digits derived from the NIST Special Database 19
"""
def __init__(self, tokenizer, opts: Dict[str, Any]):
super().__init__()
self.tokenizer = tokenizer
self.batch_size = opts['batch_size']
self.num_workers = opts['num_workers']
self.on_gpu = opts['on_gpu']
self.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
self.dataset = None
self.mapping = {"negative": 0, "neutral": 1, "positive": 2}
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
def prepare_data(self, *args, **kwargs) -> None:
= 'download/UIT_VSFC'
data_dir = 'download/UIT_VSFC.zip'
data_path if not os.path.exists(data_path):
# Download the data
= download_url(DS_URL, "UIT_VSFC.zip", "download")
data_path if not os.path.exists(data_dir):
# Unzip file
with zipfile.ZipFile(data_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)# Load and save data to csv
for path in ["train", "dev", "test"]:
= _load_data_from(f"download/UIT_VSFC/{path}")
data if path == "train":
self.train_path = _save_to_csv(f'{path}.csv', data)
elif path == "dev":
self.dev_path = _save_to_csv(f'{path}.csv', data)
else:
self.test_path = _save_to_csv(f'{path}.csv', data)
def setup(self, stage: str = None) -> None:
def encode(sample):
return self.tokenizer(sample['sents'], truncation=True)
= load_dataset('csv', data_files={'train': self.train_path, 'dev': self.dev_path,
raw_datasets 'test': self.test_path})
self.dataset = raw_datasets.map(encode, batched=True)
self.dataset = self.dataset.remove_columns(
'sents', 'topics']
[
)self.dataset.set_format("torch") # Set the format of the datasets so they return PyTorch tensors instead of lists.
def train_dataloader(self):
return DataLoader(
self.dataset['train'],
=True,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory=self.data_collator
collate_fn
)
def val_dataloader(self):
return DataLoader(
self.dataset['dev'],
=False,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory=self.data_collator
collate_fn
)
def test_dataloader(self):
return DataLoader(
self.dataset['test'],
=False,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory=self.data_collator
collate_fn
)
def __repr__(self):
= f"Twitter Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\n"
basic if self.dataset is None:
return basic
= next(iter(self.train_dataloader()))
batch = (
data f"Train/val/test sizes: {len(self.dataset['train'])}, {len(self.dataset['dev'])}, {len(self.dataset['test'])}\n"
f"Input_ids shape: {batch['input_ids'].shape}"
)return basic + data
Implementation
We will use the transformers model and wrapping it with the pytorch-lightning model class. This will help our code more clean and debug. You can read more about the pytorch-lightning model class here
class PhoBERT(pl.LightningModule):
def __init__(self, lr, weight_decay):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained("vinai/phobert-base", num_labels=3)
self.lr = lr
self.weight_decay = weight_decay
# Define metrics
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.val_f1 = torchmetrics.F1(num_classes=3)
self.test_acc = torchmetrics.Accuracy()
self.test_f1 = torchmetrics.F1(num_classes=3)
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
def training_step(self, batch, batch_idx):
= self.model(**batch)
outputs = outputs.loss, outputs.logits
loss, logits = batch['labels']
sentiments = F.softmax(logits, dim=-1)
scores self.train_acc(scores, sentiments)
self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
= self.model(**batch)
outputs = outputs.loss, outputs.logits
loss, logits = batch['labels']
sentiments = F.softmax(logits, dim=-1)
scores self.val_acc(scores, sentiments)
self.val_f1(scores, sentiments)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
def test_step(self, batch, batch_idx):
= self.model(**batch)
outputs = outputs.logits
logits = batch['labels']
sentiments = F.softmax(logits, dim=-1)
scores self.test_acc(scores, sentiments)
self.test_f1(scores, sentiments)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, logger=True)
self.log('test_f1', self.test_f1, on_step=False, on_epoch=True, logger=True)
Training
= AutoTokenizer.from_pretrained("vinai/phobert-base")
tokenizer = {
options "on_gpu": True,
"batch_size": 32,
"num_workers": 4
}= UIT_VSFC(tokenizer, options)
datamodule
= pl_loggers.TensorBoardLogger('logs/') # create logger for tensorboard
tb_logger
# hyper-parameters
= 2e-5
lr = 10
max_epochs = 0.01
weight_decay
= PhoBERT(lr, weight_decay)
model
= ModelCheckpoint(
checkpoint_callback ='val_f1', # save the model with the best validation accuracy
monitor='checkpoints',
dirpath='max',
mode
)= pl.Trainer(gpus=1, max_epochs=max_epochs, logger=tb_logger, callbacks=[checkpoint_callback], deterministic=True)
trainer # trainer = pl.Trainer(fast_dev_run=True) #Debug
# trainer = pl.Trainer(overfit_batches=0.1, max_epochs=max_epochs) #Debug
trainer.fit(model, datamodule)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of the model checkpoint at vinai/phobert-base were not used when initializing RobertaForSequenceClassification: ['lm_head.decoder.bias', 'lm_head.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at vinai/phobert-base and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using custom data configuration default-d20422fbdfea28fe
Downloading and preparing dataset csv/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-d20422fbdfea28fe/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0...
Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-d20422fbdfea28fe/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0. Subsequent calls will reuse this data.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
---------------------------------------------------------------
0 | model | RobertaForSequenceClassification | 135 M
1 | train_acc | Accuracy | 0
2 | val_acc | Accuracy | 0
3 | val_f1 | F1 | 0
4 | test_acc | Accuracy | 0
5 | test_f1 | F1 | 0
---------------------------------------------------------------
135 M Trainable params
0 Non-trainable params
135 M Total params
540.002 Total estimated model params size (MB)
Global seed set to 2401
=checkpoint_callback.best_model_path, test_dataloaders=datamodule.test_dataloader()) trainer.test(ckpt_path
Using custom data configuration default-d20422fbdfea28fe
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-d20422fbdfea28fe/0.0.0/2dc6629a9ff6b5697d82c25b73731dd440507a69cbce8b425db50b751e8fcfd0)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.9314592480659485, 'test_f1': 0.9314592480659485}
--------------------------------------------------------------------------------
[{'test_acc': 0.9314592480659485, 'test_f1': 0.9314592480659485}]
Discussion
Results on test dataset:
Method | Accuracy | Macro-F1 |
---|---|---|
phoBERT | 0.931 | 0.931 |
MaxEnt (paper) | 87.9 | 87.9 |
We haven’t tune the model but still get better result than the one in the UIT-VSFC paper.
To tune the model, there are somethings need to set up: + Wandb/Tensorboard: these tools will help us to visualize the loss per epoch of the model and other relevant information regarding metrics. Using those information we can come up with some ideas to tune the model.
- Wandb sweep: this tool allows us to define the range of hyperparameters we want to tune.
Lessons
- When using transformers’ models, we should create out Dataset using the datasets library since it is helps to make the preprocessing step easier and cleaner.
- Be careful with the metrics, for example, F1 micro and macro.
- F1 and Accuracy are equal for cases in which every instance must be classified into one (and only one) class