Interactive Attention Networks for Aspect-Level Sentiment Classification

Tutorials / Implementations
NLP
Interactive Attention Networks for Aspect-Level Sentiment Classification
Published

July 4, 2021

The full notebook is available here.

Introduction

Previous stuides have realized the importance of targets in aspect-based sentiment analysis. However, they all ignore the separate modeling for targets. In other word, they don’t have a separate model to learn the target text. In this paper, the author proposed an architecture which has 2 sub-networks used to model the both the contexts and the targets. They argued that targets and contexts can be modelled separately, but learned from their interaction.

When “short” is collocated with “battery life”, the sentiment tends to be negative. But when “short” is used with “spoon” in the context “Short fat noodle spoon, relatively deep some curva”, the sentiment can be neu- tral. Then, the next problem is how to simultaneously model targets and contexts precisely. First, target and context can determine representations of each other. For example, when we see the target “picture quality”, context word “clear-cut” is naturally associated with the target. And it is vice versa - “picture quality” is first connected with “clear-cut”

Also, contexts and targets both includes many words. Different words may have different contributions to the final representation. Therefore, in this paper, the author created 2 attention mechanisms to capture the important information for both contexts and targets.

from IPython.display import Image
Image(filename='architecture.png')

The model is called interactive attention network (IAN). It is based on LSTM and attention mechanism.

The text input will be firstly converted to embeddings. Then they are feeded into LSTMs. After that, the authors averaged the hidden states of the context LSTM to get the inital representation of context (pool vector in the figure). They do the same for the target LSTM. Then they used the target pool vector in the context attention computation and vice versa. They argued that with this design, the target and context can influence the generation of their representations interactively. Lastly, they concatenate the target and context vector and feed it into the softmax layer to do the classification.

Explain using Query, Key, Value: Regarding the target LSTM, in the figure 1 the pool vector is computed by averaging hidden states of the LSTM. That vector will play as the query vector. Each hidden state vectors represents both the key and value vectors. Then just calculate the attention score as the step below: - Step 1: Calculate the similarity score between the query vector and all the key vectors. - Step 2: Normalize the score using softmax. - Step 3: Calculate the final representation vector by weighted average the value vectors using the normalized scores.

Install and Import required packages

%%capture
!pip install pytorch-lightning
!pip install torchmetrics
import os
import pickle
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.request import urlretrieve

import numpy as np
from tqdm import tqdm

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import torchtext
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import (pack_padded_sequence, pad_packed_sequence,
                                pad_sequence)
from torch.utils.data import DataLoader, Dataset, random_split
from torchtext.data import get_tokenizer
from torchtext.vocab import Vectors, Vocab

# For repoducibility
pl.utilities.seed.seed_everything(seed=2401, workers=True)
Global seed set to 2401
2401

Define dataset, dataloader class and utility functions

class TqdmUpTo(tqdm):
    """From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""

    def update_to(self, blocks=1, bsize=1, tsize=None):
        """
        Parameters
        ----------
        blocks: int, optional
            Number of blocks transferred so far [default: 1].
        bsize: int, optional
            Size of each block (in tqdm units) [default: 1].
        tsize: int, optional
            Total size (in tqdm units). If [default: None] remains unchanged.
        """
        if tsize is not None:
            self.total = tsize  
        self.update(blocks * bsize - self.n)  

class Tokenizer():
    def __init__(self, tokenizer: Any, is_lower=True):
        self.counter = Counter(['<pad>', '<unk>'])
        self.tokenizer = tokenizer
        self.vocab = self.update_vocab()
        self.is_lower = is_lower

    def update_vocab(self):
        sorted_by_freq_tuples = sorted(self.counter.items(), key=lambda x: x[1], reverse=True)
        ordered_dict = OrderedDict(sorted_by_freq_tuples)
        self.vocab = torchtext.vocab.vocab(ordered_dict, min_freq=1)

    def fit_on_texts(self, texts: List[str]):
        """
        Updates internal vocabulary based on a list of texts.
        """
        # lower and tokenize texts to sequences
        for text in texts:
            self.counter.update(self.tokenizer(text))
        self.update_vocab()

    def texts_to_sequences(self, texts: List[str], reverse: bool=False, tensor: bool=True) -> List[List[int]]:
        word2idx = self.vocab.get_stoi()
        sequences = []
        for text in texts:
            if self.is_lower:
                text = text.lower()
            seq = [word2idx.get(word, word2idx['<unk>']) for word in self.tokenizer(text)]
            if reverse:
                seq = seq[::-1] 
            if tensor:
                seq = torch.tensor(seq)
            sequences.append(seq)
        return sequences
    
    def text_to_sequence(self, text: str, reverse: bool=False, tensor: bool=True) -> List[int]:
        if self.is_lower:
            text = text.lower()
        word2idx = self.vocab.get_stoi()
        seq = [word2idx.get(word, word2idx['<unk>']) for word in self.tokenizer(text)]
        if reverse:
            seq = seq[::-1]
        if tensor:
            seq = torch.tensor(seq)
        return seq  

def download_url(url, filename, directory='.'):
    """Download a file from url to filename, with a progress bar."""
    if not os.path.exists(directory):
        os.makedirs(directory)
    path = os.path.join(directory, filename)

    with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
        urlretrieve(url, path, reporthook=t.update_to, data=None)  # nosec
    return  path

def load_data_from(path: Union[str, Path]):
    sentences = []
    targets = []
    sentiments = []
    with open(path, 'r') as f:
        lines = f.readlines()
        for index in range(0, len(lines), 3):
            text = lines[index].lower().strip()
            target = lines[index+1].lower().strip()
            text = text.replace('$t$', target)
            sentences.append(text)
            targets.append(target)
            sentiments.append(int(lines[index+2].strip()))
    return sentences, targets, sentiments

def _preprocess_data(data, tokenizer):
    sentences, targets, sentiments = data
    
    # Create sentence sequences and aspects sequences
    sequences = tokenizer.texts_to_sequences(sentences)
    target_seqs = tokenizer.texts_to_sequences(targets)
    sentiments = torch.tensor(sentiments) + 1

    # pad sequences
    seq_lens = torch.tensor([len(seq) for seq in sequences])
    target_lens = torch.tensor([len(target_seq) for target_seq in target_seqs])

    sequences = pad_sequence(sequences, batch_first=True)
    target_seqs = pad_sequence(target_seqs, batch_first=True)

    assert len(sequences) == len(sentiments)
    assert len(sequences) == len(target_seqs)

    all_data = []
    for i in range(len(sentiments)):
        sample = {
            'context_seq': sequences[i],
            'context_len': seq_lens[i],
            'target_seq': target_seqs[i],
            'target_len': target_lens[i],
            'sentiment': sentiments[i]
        }
        all_data.append(sample)
    return all_data

def build_vocab(tokenizer, data):
    sentences = data[0]
    tokenizer.fit_on_texts(sentences)

def load_pretrained_word_embeddings(options: Dict[str, Any]):
    return torchtext.vocab.GloVe(options['name'], options['dim'])

def create_embedding_matrix(word_embeddings: Vectors, vocab: Vocab, path: Union[str, Path]):
    if os.path.exists(path):
        print(f'loading embedding matrix from {path}')
        embedding_matrix = pickle.load(open(path, 'rb'))
    else:
        embedding_matrix = torch.zeros((len(vocab), word_embeddings.dim), 
                                       dtype=torch.float)

        # words that are not availabel in the pretrained word embeddings will be zeros
        for word, index in vocab.get_stoi().items():
            embedding_matrix[index] = word_embeddings.get_vecs_by_tokens(word)

        # save embedding matrix
        pickle.dump(embedding_matrix, open(path, 'wb'))
    return embedding_matrix
class SemEvalDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
# Restaurant
RES_TRAIN_DS_URL = 'https://raw.githubusercontent.com/songyouwei/ABSA-PyTorch/master/datasets/semeval14/Restaurants_Train.xml.seg'
RES_TEST_DS_URL = 'https://raw.githubusercontent.com/songyouwei/ABSA-PyTorch/master/datasets/semeval14/Restaurants_Test_Gold.xml.seg'

# Laptop
LAP_TRAIN_DS_URL = 'https://raw.githubusercontent.com/songyouwei/ABSA-PyTorch/master/datasets/semeval14/Laptops_Train.xml.seg'
LAP_TEST_DS_URL = 'https://raw.githubusercontent.com/songyouwei/ABSA-PyTorch/master/datasets/semeval14/Laptops_Test_Gold.xml.seg'

class SemEval2014(pl.LightningDataModule):
    def __init__(self, tokenizer, opts):
        super().__init__()
        self.tokenizer = tokenizer
        self.batch_size = opts['batch_size']
        self.num_workers = opts['num_workers']
        self.on_gpu = opts['on_gpu']

        self.mapping = {"negative": 0, "neutral": 1, "positive": 2} 
        self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)} 

    def prepare_data(self) -> None:
        self.train_path = 'download/SemEval2014/train.xml'
        self.test_path = 'download/SemEval2014/test.xml'
        
        if not os.path.exists(train_path):
            print("Downloading train dataset")
            self.train_path = download_url(RES_TRAIN_DS_URL, 'train.xml', 'download/SemEval2014')

        if not os.path.exists(test_path):
            print("Downloading test dataset")
            self.test_path = download_url(RES_TEST_DS_URL, 'test.xml', 'download/SemEval2014')
        
    def setup(self, stage: str = None) -> None:
        if stage == 'fit' or stage is None:
            # Load data from files
            train_data = load_data_from(self.train_path)
            valid_data = load_data_from(self.test_path)
            self.train_data = _preprocess_data(train_data, self.tokenizer)
            self.val_data = _preprocess_data(valid_data, self.tokenizer)

        elif stage == 'test' or stage is None:
            test_data = load_data_from(self.test_path)
            self.test_data = _preprocess_data(test_data, self.tokenizer)
            
    def train_dataloader(self):
        # Create Dataset object
        train_ds = SemEvalDataset(self.train_data)
        # Create Dataloader
        return DataLoader(
            train_ds,
            shuffle=True,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.on_gpu,
        ) 

    def val_dataloader(self):
        val_ds = SemEvalDataset(self.val_data)
        return DataLoader(
            val_ds,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.on_gpu,
        ) 

    def test_dataloader(self):
        test_ds = SemEvalDataset(self.test_data)
        return DataLoader(
            test_ds,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.on_gpu,
        ) 

    def __repr__(self):
        basic = f"SemEval2014 Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\n"
        if self.train_data is None and self.val_data is None and self.test_data is None:
            return basic
        batch = next(iter(self.train_dataloader()))
        cols = ['context_seq', 'context_len', 'target_seq', 'target_len', 'sentiment']
        context_seqs, context_lens, target_seqs, target_lens, sentiments = [batch[col] for col in cols]
        data = (
            f"Train/val/test sizes: {len(self.train_data)}, {len(self.val_data)}, {len(self.test_data)}\n"
            f"Batch context_seqs stats: {(context_seqs.shape, context_seqs.dtype)}\n"
            f"Batch context_lens stats: {(context_lens.shape, context_lens.dtype)}\n"
            f"Batch target_seqs stats: {(target_seqs.shape, target_seqs.dtype)}\n"
            f"Batch target_lens stats: {(target_lens.shape, target_lens.dtype)}\n"
            f"Batch sentiments stats: {(sentiments.shape, sentiments.dtype)}\n"
        )
        return basic + data

Implementation

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, query, key, value, max_seq_len, seq_lens, device):
        # Calculate similarity between query and key vectors
        score = torch.tanh(torch.matmul(
            self.linear(key), query.transpose(-2,-1))).squeeze(-1) # BxL

        # Mask out padding score
        att_mask = torch.arange(max_seq_len, device=device)[None,:] < seq_lens[:, None]
        score = score.masked_fill(att_mask == False, float('-inf'))
        softmax_score = F.softmax(score, dim=-1).unsqueeze(2) #BxLx1
        out = torch.matmul(value.transpose(-2,-1), softmax_score).squeeze() #BxH
        return out
class IAN(pl.LightningModule):
    def __init__(self, embedding_matrix, hidden_size, num_layers=1, 
                 num_classes=3, batch_first=True, lr=1e-3, dropout=0, l2reg=0.0):
        super().__init__()
        embedding_dim = embedding_matrix.shape[1]
        self.batch_first = batch_first
        self.lr = lr
        self.l2reg = l2reg
        # Define architecture components
        self.embedding = nn.Embedding.from_pretrained(embedding_matrix)
        self.target_lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=batch_first, dropout=dropout)
        self.context_lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=batch_first, dropout=dropout)
        self.context_attn = Attention(hidden_size)
        self.target_attn = Attention(hidden_size)
        self.linear = nn.Linear(hidden_size*2, num_classes)

        # Define metrics
        self.train_acc = torchmetrics.Accuracy() 
        self.val_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()

        # Initialize layer parameters
        # for layer in [self.context_lstm, self.target_lstm, 
        #               self.context_attention, self.target_attention, self.linear]:
        #     nn.init.uniform_(layer.weight, a=-0.1, b=0.1)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.l2reg)
        return optim

    def forward(self, input):
        cols = ["context_seq", "context_len", "target_seq", "target_len"]
        padded_context_seqs, context_lens, padded_target_seqs, target_lens = [input[col] for col in cols]
        padded_context_embeddings = self.embedding(padded_context_seqs)
        padded_target_embeddings = self.embedding(padded_target_seqs)
        
        context_seqs_pack = pack_padded_sequence(padded_context_embeddings, context_lens.cpu(), 
                                                 batch_first=self.batch_first, enforce_sorted=False)
        target_seqs_pack = pack_padded_sequence(padded_target_embeddings, target_lens.cpu(),
                                                 batch_first=self.batch_first, enforce_sorted=False)

        H_context, _ = self.context_lstm(context_seqs_pack)
        H_target, _ = self.target_lstm(target_seqs_pack)

        # Unpack to get the full hidden states
        padded_H_context, _ = pad_packed_sequence(H_context, batch_first=self.batch_first) # BxLxH
        padded_H_target, _ = pad_packed_sequence(H_target, batch_first=self.batch_first) # BxLxH

        # Compute the initial representation for target and context
        c_avg = torch.mean(padded_H_context, dim=1, keepdim=True) #Bx1xH
        t_avg = torch.mean(padded_H_target, dim=1, keepdim=True) #Bx1xH

        c_max_seq_len = torch.max(context_lens)
        final_c = self.context_attn(t_avg, padded_H_context, padded_H_context, 
                                         c_max_seq_len, context_lens, self.device)
        
        t_max_seq_len = torch.max(target_lens)
        inal_t = self.target_attn(c_avg, padded_H_target, padded_H_target, 
                                        t_max_seq_len, target_lens, self.device)
        
        final_vector = torch.cat([final_t, final_c], dim=-1) # Bx2H
        out = self.linear(final_vector)
        logits = torch.tanh(out)
        return logits

    def training_step(self, batch, batch_idx):
        sentiments = batch['sentiment']
        logits = self.forward(batch)
        loss = F.cross_entropy(logits, sentiments)
        scores = F.softmax(logits, dim=-1)
        self.train_acc(scores, sentiments)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, 
                 prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        sentiments = batch['sentiment']
        logits = self.forward(batch)
        loss = F.cross_entropy(logits, sentiments)
        scores = F.softmax(logits, dim=-1)
        self.val_acc(scores, sentiments)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        sentiments = batch['sentiment']
        logits = self.forward(batch)
        scores = F.softmax(logits, dim=-1)
        self.test_acc(scores, sentiments)
        self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, logger=True)
        

#Training

processed_train_data = _preprocess_data(train_data, tokenizer)
train_path = download_url(RES_TRAIN_DS_URL, 'train.xml', 'download/SemEval2014')
test_path = download_url(RES_TEST_DS_URL, 'test.xml', 'download/SemEval2014')

train_data = load_data_from(train_path)
test_data = load_data_from(test_path)

all_sentences = train_data[0] + test_data[0] 
tokenizer = Tokenizer(get_tokenizer("basic_english"))
build_vocab(tokenizer, [all_sentences])
384kB [00:00, 1.71MB/s]                            
120kB [00:00, 626kB/s]                             
word_embeddings = load_pretrained_word_embeddings({"name": "42B", "dim": 300})
.vector_cache/glove.42B.300d.zip: 1.88GB [05:53, 5.31MB/s]                            
100%|█████████▉| 1916797/1917494 [04:10<00:00, 8178.84it/s]
train_path = download_url(RES_TRAIN_DS_URL, 'train.xml', 'download/SemEval2014')
test_path = download_url(RES_TEST_DS_URL, 'test.xml', 'download/SemEval2014')

train_data = load_data_from(train_path)
test_data = load_data_from(test_path)

all_sentences = train_data[0] + test_data[0] 
tokenizer = Tokenizer(get_tokenizer("basic_english"))
build_vocab(tokenizer, [all_sentences])

options = {
    "on_gpu": True,
    "batch_size": 16,
    "num_workers": 2
}
datamodule = SemEval2014(tokenizer, options)
embedding_matrix = create_embedding_matrix(word_embeddings, tokenizer.vocab, "embedding_matrix.dat")
384kB [00:00, 7.09MB/s]
120kB [00:00, 2.81MB/s]
loading embedding matrix from embedding_matrix.dat
torch.autograd.set_detect_anomaly(True)
checkpoint_callback = ModelCheckpoint(
    monitor='val_acc', # save the model with the best validation accuracy
    dirpath='checkpoints',
    mode='max',
)

tb_logger = pl_loggers.TensorBoardLogger('logs/') # create logger for tensorboard

# Set hyper-parameters
lr = 1e-3 
hidden_size = 300
aspect_embedding_dim = 300
num_epochs = 30
l2reg = 1e-5
dropout = 0.0

trainer = pl.Trainer(gpus=1, max_epochs=num_epochs, logger=tb_logger, callbacks=[checkpoint_callback], deterministic=True)
# trainer = pl.Trainer(fast_dev_run=True, gpus=1) #Debug 
# trainer = pl.Trainer(overfit_batches=0.025, max_epochs=num_epochs, gpus=1) #Debug
model = IAN(embedding_matrix=embedding_matrix, hidden_size=hidden_size, lr=lr, l2reg=l2reg, dropout=dropout)
trainer.fit(model, datamodule)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | embedding    | Embedding | 1.4 M 
1 | target_lstm  | LSTM      | 722 K 
2 | context_lstm | LSTM      | 722 K 
3 | context_attn | Attention | 90.3 K
4 | target_attn  | Attention | 90.3 K
5 | linear       | Linear    | 1.8 K 
6 | train_acc    | Accuracy  | 0     
7 | val_acc      | Accuracy  | 0     
8 | test_acc     | Accuracy  | 0     
-------------------------------------------
1.6 M     Trainable params
1.4 M     Non-trainable params
3.0 M     Total params
11.921    Total estimated model params size (MB)
Global seed set to 2401
trainer.test(ckpt_path=checkpoint_callback.best_model_path, test_dataloaders=datamodule.test_dataloader())
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.7866071462631226}
--------------------------------------------------------------------------------
[{'test_acc': 0.7866071462631226}]

Discussion

Our result:

Dataset Restaurants
IAN 0.786

Paper result:

Dataset Restaurants Laptops
No-target 0.772 0.708
No-interaction 0.769 0.706
Target2Content 0.775 0.712
IAN 0.786 0.721

From the two table above, we can see that our training get the same accuracy value with the paper.

Analysis:

The author analyzed the IAN model effectiveness by comparing it with 3 other types of models. All the models are based on LSTM and attention mechanism. The No-target model does not model the representation of the target. The second model, No-interaction, used 2 LSTM networks to model the representations of target and context via their own local atttentions, but without interaction. Next, the Target2Content model also employs 2 LSTM networks to learn target and context representation, but only uses the pool target vector for the context attention computation. The difference between this model and the IAN is the IAN also use the pool context vector in target attetion computation.

The results verify that target should be separately modeled and target representations can make contribution to judging the sentiment polarity of a target.

The improvements on Restaurant category is much less than those on Laptop category. The author explained that by pointing out that the Restaurant dataset has 9% 1-word target more than the Laptop one. In other words, the Laptop dataset has more multi-words targets. In IAN, the targets are modeled by LSTM networks and interactive attentions. LSTM networks and interactive attention are more effective on modelling long targets than short ones.

You can read more about the case study in which the author analyzes the attention score when doing inference here.

Lesson

  1. Pass device in forward function instead of __init__
  2. When masking, create a mask matrix and times with the matrix we want to mask. By doing that, we can avoid modifying the tensor in-place error. We can also use the function mask_fill with ‘underscore’.
  3. When the training model is slow, check the number of model parameters!
  4. When using functions that requires the dim, we should set it explicitly to avoid bugs in our code. For example, in this implementation, using squeeze() function after calculating the similariry score between query and key vectors has error when the sequence length of key is 1.

Suggestion for Readers:

  1. You can try a larger word embeddings to see whether we can improve the metrics.
  2. Training on the Laptop data
  3. Have fun :)