from IPython.display import Image
='architecture.png') Image(filename
Interactive Attention Networks for Aspect-Level Sentiment Classification
The full notebook is available here.
Introduction
Previous stuides have realized the importance of targets in aspect-based sentiment analysis. However, they all ignore the separate modeling for targets. In other word, they don’t have a separate model to learn the target text. In this paper, the author proposed an architecture which has 2 sub-networks used to model the both the contexts and the targets. They argued that targets and contexts can be modelled separately, but learned from their interaction.
When “short” is collocated with “battery life”, the sentiment tends to be negative. But when “short” is used with “spoon” in the context “Short fat noodle spoon, relatively deep some curva”, the sentiment can be neu- tral. Then, the next problem is how to simultaneously model targets and contexts precisely. First, target and context can determine representations of each other. For example, when we see the target “picture quality”, context word “clear-cut” is naturally associated with the target. And it is vice versa - “picture quality” is first connected with “clear-cut”
Also, contexts and targets both includes many words. Different words may have different contributions to the final representation. Therefore, in this paper, the author created 2 attention mechanisms to capture the important information for both contexts and targets.
The model is called interactive attention network (IAN). It is based on LSTM and attention mechanism.
The text input will be firstly converted to embeddings. Then they are feeded into LSTMs. After that, the authors averaged the hidden states of the context LSTM to get the inital representation of context (pool vector in the figure). They do the same for the target LSTM. Then they used the target pool vector in the context attention computation and vice versa. They argued that with this design, the target and context can influence the generation of their representations interactively. Lastly, they concatenate the target and context vector and feed it into the softmax layer to do the classification.
Explain using Query, Key, Value: Regarding the target LSTM, in the figure 1 the pool vector is computed by averaging hidden states of the LSTM. That vector will play as the query vector. Each hidden state vectors represents both the key and value vectors. Then just calculate the attention score as the step below: - Step 1: Calculate the similarity score between the query vector and all the key vectors. - Step 2: Normalize the score using softmax. - Step 3: Calculate the final representation vector by weighted average the value vectors using the normalized scores.
Install and Import required packages
%%capture
!pip install pytorch-lightning
!pip install torchmetrics
import os
import pickle
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.request import urlretrieve
import numpy as np
from tqdm import tqdm
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import torchtext
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import (pack_padded_sequence, pad_packed_sequence,
pad_sequence)from torch.utils.data import DataLoader, Dataset, random_split
from torchtext.data import get_tokenizer
from torchtext.vocab import Vectors, Vocab
# For repoducibility
=2401, workers=True) pl.utilities.seed.seed_everything(seed
Global seed set to 2401
2401
Define dataset, dataloader class and utility functions
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n)
class Tokenizer():
def __init__(self, tokenizer: Any, is_lower=True):
self.counter = Counter(['<pad>', '<unk>'])
self.tokenizer = tokenizer
self.vocab = self.update_vocab()
self.is_lower = is_lower
def update_vocab(self):
= sorted(self.counter.items(), key=lambda x: x[1], reverse=True)
sorted_by_freq_tuples = OrderedDict(sorted_by_freq_tuples)
ordered_dict self.vocab = torchtext.vocab.vocab(ordered_dict, min_freq=1)
def fit_on_texts(self, texts: List[str]):
"""
Updates internal vocabulary based on a list of texts.
"""
# lower and tokenize texts to sequences
for text in texts:
self.counter.update(self.tokenizer(text))
self.update_vocab()
def texts_to_sequences(self, texts: List[str], reverse: bool=False, tensor: bool=True) -> List[List[int]]:
= self.vocab.get_stoi()
word2idx = []
sequences for text in texts:
if self.is_lower:
= text.lower()
text = [word2idx.get(word, word2idx['<unk>']) for word in self.tokenizer(text)]
seq if reverse:
= seq[::-1]
seq if tensor:
= torch.tensor(seq)
seq
sequences.append(seq)return sequences
def text_to_sequence(self, text: str, reverse: bool=False, tensor: bool=True) -> List[int]:
if self.is_lower:
= text.lower()
text = self.vocab.get_stoi()
word2idx = [word2idx.get(word, word2idx['<unk>']) for word in self.tokenizer(text)]
seq if reverse:
= seq[::-1]
seq if tensor:
= torch.tensor(seq)
seq return seq
def download_url(url, filename, directory='.'):
"""Download a file from url to filename, with a progress bar."""
if not os.path.exists(directory):
os.makedirs(directory)= os.path.join(directory, filename)
path
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
=t.update_to, data=None) # nosec
urlretrieve(url, path, reporthookreturn path
def load_data_from(path: Union[str, Path]):
= []
sentences = []
targets = []
sentiments with open(path, 'r') as f:
= f.readlines()
lines for index in range(0, len(lines), 3):
= lines[index].lower().strip()
text = lines[index+1].lower().strip()
target = text.replace('$t$', target)
text
sentences.append(text)
targets.append(target)int(lines[index+2].strip()))
sentiments.append(return sentences, targets, sentiments
def _preprocess_data(data, tokenizer):
= data
sentences, targets, sentiments
# Create sentence sequences and aspects sequences
= tokenizer.texts_to_sequences(sentences)
sequences = tokenizer.texts_to_sequences(targets)
target_seqs = torch.tensor(sentiments) + 1
sentiments
# pad sequences
= torch.tensor([len(seq) for seq in sequences])
seq_lens = torch.tensor([len(target_seq) for target_seq in target_seqs])
target_lens
= pad_sequence(sequences, batch_first=True)
sequences = pad_sequence(target_seqs, batch_first=True)
target_seqs
assert len(sequences) == len(sentiments)
assert len(sequences) == len(target_seqs)
= []
all_data for i in range(len(sentiments)):
= {
sample 'context_seq': sequences[i],
'context_len': seq_lens[i],
'target_seq': target_seqs[i],
'target_len': target_lens[i],
'sentiment': sentiments[i]
}
all_data.append(sample)return all_data
def build_vocab(tokenizer, data):
= data[0]
sentences
tokenizer.fit_on_texts(sentences)
def load_pretrained_word_embeddings(options: Dict[str, Any]):
return torchtext.vocab.GloVe(options['name'], options['dim'])
def create_embedding_matrix(word_embeddings: Vectors, vocab: Vocab, path: Union[str, Path]):
if os.path.exists(path):
print(f'loading embedding matrix from {path}')
= pickle.load(open(path, 'rb'))
embedding_matrix else:
= torch.zeros((len(vocab), word_embeddings.dim),
embedding_matrix =torch.float)
dtype
# words that are not availabel in the pretrained word embeddings will be zeros
for word, index in vocab.get_stoi().items():
= word_embeddings.get_vecs_by_tokens(word)
embedding_matrix[index]
# save embedding matrix
open(path, 'wb'))
pickle.dump(embedding_matrix, return embedding_matrix
class SemEvalDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Restaurant
= 'https://raw.githubusercontent.com/songyouwei/ABSA-PyTorch/master/datasets/semeval14/Restaurants_Train.xml.seg'
RES_TRAIN_DS_URL = 'https://raw.githubusercontent.com/songyouwei/ABSA-PyTorch/master/datasets/semeval14/Restaurants_Test_Gold.xml.seg'
RES_TEST_DS_URL
# Laptop
= 'https://raw.githubusercontent.com/songyouwei/ABSA-PyTorch/master/datasets/semeval14/Laptops_Train.xml.seg'
LAP_TRAIN_DS_URL = 'https://raw.githubusercontent.com/songyouwei/ABSA-PyTorch/master/datasets/semeval14/Laptops_Test_Gold.xml.seg'
LAP_TEST_DS_URL
class SemEval2014(pl.LightningDataModule):
def __init__(self, tokenizer, opts):
super().__init__()
self.tokenizer = tokenizer
self.batch_size = opts['batch_size']
self.num_workers = opts['num_workers']
self.on_gpu = opts['on_gpu']
self.mapping = {"negative": 0, "neutral": 1, "positive": 2}
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
def prepare_data(self) -> None:
self.train_path = 'download/SemEval2014/train.xml'
self.test_path = 'download/SemEval2014/test.xml'
if not os.path.exists(train_path):
print("Downloading train dataset")
self.train_path = download_url(RES_TRAIN_DS_URL, 'train.xml', 'download/SemEval2014')
if not os.path.exists(test_path):
print("Downloading test dataset")
self.test_path = download_url(RES_TEST_DS_URL, 'test.xml', 'download/SemEval2014')
def setup(self, stage: str = None) -> None:
if stage == 'fit' or stage is None:
# Load data from files
= load_data_from(self.train_path)
train_data = load_data_from(self.test_path)
valid_data self.train_data = _preprocess_data(train_data, self.tokenizer)
self.val_data = _preprocess_data(valid_data, self.tokenizer)
elif stage == 'test' or stage is None:
= load_data_from(self.test_path)
test_data self.test_data = _preprocess_data(test_data, self.tokenizer)
def train_dataloader(self):
# Create Dataset object
= SemEvalDataset(self.train_data)
train_ds # Create Dataloader
return DataLoader(
train_ds,=True,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory
)
def val_dataloader(self):
= SemEvalDataset(self.val_data)
val_ds return DataLoader(
val_ds,=False,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory
)
def test_dataloader(self):
= SemEvalDataset(self.test_data)
test_ds return DataLoader(
test_ds,=False,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory
)
def __repr__(self):
= f"SemEval2014 Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\n"
basic if self.train_data is None and self.val_data is None and self.test_data is None:
return basic
= next(iter(self.train_dataloader()))
batch = ['context_seq', 'context_len', 'target_seq', 'target_len', 'sentiment']
cols = [batch[col] for col in cols]
context_seqs, context_lens, target_seqs, target_lens, sentiments = (
data f"Train/val/test sizes: {len(self.train_data)}, {len(self.val_data)}, {len(self.test_data)}\n"
f"Batch context_seqs stats: {(context_seqs.shape, context_seqs.dtype)}\n"
f"Batch context_lens stats: {(context_lens.shape, context_lens.dtype)}\n"
f"Batch target_seqs stats: {(target_seqs.shape, target_seqs.dtype)}\n"
f"Batch target_lens stats: {(target_lens.shape, target_lens.dtype)}\n"
f"Batch sentiments stats: {(sentiments.shape, sentiments.dtype)}\n"
)return basic + data
Implementation
class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self, query, key, value, max_seq_len, seq_lens, device):
# Calculate similarity between query and key vectors
= torch.tanh(torch.matmul(
score self.linear(key), query.transpose(-2,-1))).squeeze(-1) # BxL
# Mask out padding score
= torch.arange(max_seq_len, device=device)[None,:] < seq_lens[:, None]
att_mask = score.masked_fill(att_mask == False, float('-inf'))
score = F.softmax(score, dim=-1).unsqueeze(2) #BxLx1
softmax_score = torch.matmul(value.transpose(-2,-1), softmax_score).squeeze() #BxH
out return out
class IAN(pl.LightningModule):
def __init__(self, embedding_matrix, hidden_size, num_layers=1,
=3, batch_first=True, lr=1e-3, dropout=0, l2reg=0.0):
num_classessuper().__init__()
= embedding_matrix.shape[1]
embedding_dim self.batch_first = batch_first
self.lr = lr
self.l2reg = l2reg
# Define architecture components
self.embedding = nn.Embedding.from_pretrained(embedding_matrix)
self.target_lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=batch_first, dropout=dropout)
self.context_lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=batch_first, dropout=dropout)
self.context_attn = Attention(hidden_size)
self.target_attn = Attention(hidden_size)
self.linear = nn.Linear(hidden_size*2, num_classes)
# Define metrics
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
# Initialize layer parameters
# for layer in [self.context_lstm, self.target_lstm,
# self.context_attention, self.target_attention, self.linear]:
# nn.init.uniform_(layer.weight, a=-0.1, b=0.1)
def configure_optimizers(self):
= torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.l2reg)
optim return optim
def forward(self, input):
= ["context_seq", "context_len", "target_seq", "target_len"]
cols = [input[col] for col in cols]
padded_context_seqs, context_lens, padded_target_seqs, target_lens = self.embedding(padded_context_seqs)
padded_context_embeddings = self.embedding(padded_target_seqs)
padded_target_embeddings
= pack_padded_sequence(padded_context_embeddings, context_lens.cpu(),
context_seqs_pack =self.batch_first, enforce_sorted=False)
batch_first= pack_padded_sequence(padded_target_embeddings, target_lens.cpu(),
target_seqs_pack =self.batch_first, enforce_sorted=False)
batch_first
= self.context_lstm(context_seqs_pack)
H_context, _ = self.target_lstm(target_seqs_pack)
H_target, _
# Unpack to get the full hidden states
= pad_packed_sequence(H_context, batch_first=self.batch_first) # BxLxH
padded_H_context, _ = pad_packed_sequence(H_target, batch_first=self.batch_first) # BxLxH
padded_H_target, _
# Compute the initial representation for target and context
= torch.mean(padded_H_context, dim=1, keepdim=True) #Bx1xH
c_avg = torch.mean(padded_H_target, dim=1, keepdim=True) #Bx1xH
t_avg
= torch.max(context_lens)
c_max_seq_len = self.context_attn(t_avg, padded_H_context, padded_H_context,
final_c self.device)
c_max_seq_len, context_lens,
= torch.max(target_lens)
t_max_seq_len = self.target_attn(c_avg, padded_H_target, padded_H_target,
inal_t self.device)
t_max_seq_len, target_lens,
= torch.cat([final_t, final_c], dim=-1) # Bx2H
final_vector = self.linear(final_vector)
out = torch.tanh(out)
logits return logits
def training_step(self, batch, batch_idx):
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.cross_entropy(logits, sentiments)
loss = F.softmax(logits, dim=-1)
scores self.train_acc(scores, sentiments)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_acc', self.train_acc, on_step=False, on_epoch=True,
=True, logger=True)
prog_barreturn loss
def validation_step(self, batch, batch_idx):
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.cross_entropy(logits, sentiments)
loss = F.softmax(logits, dim=-1)
scores self.val_acc(scores, sentiments)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
def test_step(self, batch, batch_idx):
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.softmax(logits, dim=-1)
scores self.test_acc(scores, sentiments)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, logger=True)
#Training
= _preprocess_data(train_data, tokenizer) processed_train_data
= download_url(RES_TRAIN_DS_URL, 'train.xml', 'download/SemEval2014')
train_path = download_url(RES_TEST_DS_URL, 'test.xml', 'download/SemEval2014')
test_path
= load_data_from(train_path)
train_data = load_data_from(test_path)
test_data
= train_data[0] + test_data[0]
all_sentences = Tokenizer(get_tokenizer("basic_english"))
tokenizer build_vocab(tokenizer, [all_sentences])
384kB [00:00, 1.71MB/s]
120kB [00:00, 626kB/s]
= load_pretrained_word_embeddings({"name": "42B", "dim": 300}) word_embeddings
.vector_cache/glove.42B.300d.zip: 1.88GB [05:53, 5.31MB/s]
100%|█████████▉| 1916797/1917494 [04:10<00:00, 8178.84it/s]
= download_url(RES_TRAIN_DS_URL, 'train.xml', 'download/SemEval2014')
train_path = download_url(RES_TEST_DS_URL, 'test.xml', 'download/SemEval2014')
test_path
= load_data_from(train_path)
train_data = load_data_from(test_path)
test_data
= train_data[0] + test_data[0]
all_sentences = Tokenizer(get_tokenizer("basic_english"))
tokenizer
build_vocab(tokenizer, [all_sentences])
= {
options "on_gpu": True,
"batch_size": 16,
"num_workers": 2
}= SemEval2014(tokenizer, options)
datamodule = create_embedding_matrix(word_embeddings, tokenizer.vocab, "embedding_matrix.dat") embedding_matrix
384kB [00:00, 7.09MB/s]
120kB [00:00, 2.81MB/s]
loading embedding matrix from embedding_matrix.dat
True)
torch.autograd.set_detect_anomaly(= ModelCheckpoint(
checkpoint_callback ='val_acc', # save the model with the best validation accuracy
monitor='checkpoints',
dirpath='max',
mode
)
= pl_loggers.TensorBoardLogger('logs/') # create logger for tensorboard
tb_logger
# Set hyper-parameters
= 1e-3
lr = 300
hidden_size = 300
aspect_embedding_dim = 30
num_epochs = 1e-5
l2reg = 0.0
dropout
= pl.Trainer(gpus=1, max_epochs=num_epochs, logger=tb_logger, callbacks=[checkpoint_callback], deterministic=True)
trainer # trainer = pl.Trainer(fast_dev_run=True, gpus=1) #Debug
# trainer = pl.Trainer(overfit_batches=0.025, max_epochs=num_epochs, gpus=1) #Debug
= IAN(embedding_matrix=embedding_matrix, hidden_size=hidden_size, lr=lr, l2reg=l2reg, dropout=dropout)
model trainer.fit(model, datamodule)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------
0 | embedding | Embedding | 1.4 M
1 | target_lstm | LSTM | 722 K
2 | context_lstm | LSTM | 722 K
3 | context_attn | Attention | 90.3 K
4 | target_attn | Attention | 90.3 K
5 | linear | Linear | 1.8 K
6 | train_acc | Accuracy | 0
7 | val_acc | Accuracy | 0
8 | test_acc | Accuracy | 0
-------------------------------------------
1.6 M Trainable params
1.4 M Non-trainable params
3.0 M Total params
11.921 Total estimated model params size (MB)
Global seed set to 2401
=checkpoint_callback.best_model_path, test_dataloaders=datamodule.test_dataloader()) trainer.test(ckpt_path
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.7866071462631226}
--------------------------------------------------------------------------------
[{'test_acc': 0.7866071462631226}]
Discussion
Our result:
Dataset | Restaurants |
---|---|
IAN | 0.786 |
Paper result:
Dataset | Restaurants | Laptops |
---|---|---|
No-target | 0.772 | 0.708 |
No-interaction | 0.769 | 0.706 |
Target2Content | 0.775 | 0.712 |
IAN | 0.786 | 0.721 |
From the two table above, we can see that our training get the same accuracy value with the paper.
Analysis:
The author analyzed the IAN model effectiveness by comparing it with 3 other types of models. All the models are based on LSTM and attention mechanism. The No-target model does not model the representation of the target. The second model, No-interaction, used 2 LSTM networks to model the representations of target and context via their own local atttentions, but without interaction. Next, the Target2Content model also employs 2 LSTM networks to learn target and context representation, but only uses the pool target vector for the context attention computation. The difference between this model and the IAN is the IAN also use the pool context vector in target attetion computation.
The results verify that target should be separately modeled and target representations can make contribution to judging the sentiment polarity of a target.
The improvements on Restaurant category is much less than those on Laptop category. The author explained that by pointing out that the Restaurant dataset has 9% 1-word target more than the Laptop one. In other words, the Laptop dataset has more multi-words targets. In IAN, the targets are modeled by LSTM networks and interactive attentions. LSTM networks and interactive attention are more effective on modelling long targets than short ones.
You can read more about the case study in which the author analyzes the attention score when doing inference here.
Lesson
- Pass device in forward function instead of __init__
- When masking, create a mask matrix and times with the matrix we want to mask. By doing that, we can avoid modifying the tensor in-place error. We can also use the function mask_fill with ‘underscore’.
- When the training model is slow, check the number of model parameters!
- When using functions that requires the dim, we should set it explicitly to avoid bugs in our code. For example, in this implementation, using squeeze() function after calculating the similariry score between query and key vectors has error when the sequence length of key is 1.
Suggestion for Readers:
- You can try a larger word embeddings to see whether we can improve the metrics.
- Training on the Laptop data
- Have fun :)