%%capture
!pip install pytorch-lightning
!pip install torchmetrics
Attention based LSTM for Aspect level Sentiment Classification
The full notebook is available here.
Introduction
Aspect-level sentiment classification is a fine- grained task in sentiment analysis. Since it provides more complete and in-depth results, aspect-level sentiment analysis has received much attention these years. In this paper, we reveal that the sentiment polarity of a sentence is not only determined by the content but is also highly related to the concerned aspect. For instance, “The appetizers are ok, but the service is slow.”, for aspect taste, the polarity is positive while for service, the polarity is negative. Therefore, it is worthwhile to explore the connection between an aspect and the content of a sentence. To this end, we propose an Attention-based Long Short-Term Memory Network for aspect-level sentiment classification. The attention mechanism can concentrate on different parts of a sentence when different aspects are taken as input. We experiment on the SemEval 2014 dataset and results show that our model achieves state-of- the-art performance on aspect-level sentiment classification.
Install required packages
Import required packages
import os
import pickle
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.request import urlretrieve
import numpy as np
from tqdm import tqdm
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import torchtext
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn.utils.rnn import (pack_padded_sequence, pad_packed_sequence,
pad_sequence)from torch.utils.data import DataLoader, Dataset, random_split
from torchtext.data import get_tokenizer
from torchtext.vocab import Vectors, Vocab
import xml.etree.ElementTree as ET
# For repoducibility
=2401, workers=True) pl.utilities.seed.seed_everything(seed
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/__init__.py:44: LightningDeprecationWarning: `pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package (https://github.com/PyTorchLightning/metrics) since v1.3 and will be removed in v1.5
"`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"
Global seed set to 2401
2401
Define dataset, data module class, utils function
class TqdmUpTo(tqdm):
"""From https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py"""
def update_to(self, blocks=1, bsize=1, tsize=None):
"""
Parameters
----------
blocks: int, optional
Number of blocks transferred so far [default: 1].
bsize: int, optional
Size of each block (in tqdm units) [default: 1].
tsize: int, optional
Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
self.total = tsize
self.update(blocks * bsize - self.n)
class Tokenizer():
def __init__(self, tokenizer: Any):
self.counter = Counter(['<pad>', '<unk>'])
self.tokenizer = tokenizer
self.vocab = self.update_vocab()
def update_vocab(self):
= sorted(self.counter.items(), key=lambda x: x[1], reverse=True)
sorted_by_freq_tuples = OrderedDict(sorted_by_freq_tuples)
ordered_dict self.vocab = torchtext.vocab.vocab(ordered_dict, min_freq=1)
def fit_on_texts(self, texts: List[str]):
"""
Updates internal vocabulary based on a list of texts.
"""
# lower and tokenize texts to sequences
for text in texts:
self.counter.update(self.tokenizer(text))
self.update_vocab()
def texts_to_sequences(self, texts: List[str], reverse: bool=False, tensor: bool=True) -> List[int]:
= self.vocab.get_stoi()
word2idx = []
sequences for text in texts:
= [word2idx.get(word, word2idx['<unk>']) for word in self.tokenizer(text)]
seq if reverse:
= seq[::-1]
seq if tensor:
= torch.tensor(seq)
seq
sequences.append(seq)return sequences
def download_url(url, filename, directory='.'):
"""Download a file from url to filename, with a progress bar."""
if not os.path.exists(directory):
os.makedirs(directory)= os.path.join(directory, filename)
path
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t:
=t.update_to, data=None) # nosec
urlretrieve(url, path, reporthookreturn path
def _preprocess_data(data, tokenizer, mapping_polarity, mapping_aspect):
= data
sentences, text_aspects, text_sentiments
# Create sentence sequences and aspects sequences
= tokenizer.texts_to_sequences(sentences)
sequences
= []
sentiments for val in text_sentiments:
sentiments.append(mapping_polarity[val])= torch.tensor(sentiments)
sentiments
= []
aspects for val in text_aspects:
aspects.append(mapping_aspect[val])= torch.tensor(aspects)
aspects
# pad sequences
= torch.tensor([len(seq) for seq in sequences])
seq_lens = pad_sequence(sequences, batch_first=True)
sequences assert len(sequences) == len(sentiments)
assert len(sequences) == len(aspects)
= []
all_data for i in range(len(sentiments)):
= {
sample 'sequence': sequences[i],
'seq_len': seq_lens[i],
'aspect': aspects[i],
'sentiment': sentiments[i]
}
all_data.append(sample)return all_data
def load_data_from(path: Union[str, Path]):
= ET.parse(path)
tree = tree.getroot()
root = []
sentences = []
aspects = []
sentiments for sent in root:
= ''
text for i, child in enumerate(sent):
if i == 0:
= child.text
text elif i == 2:
for aspect in child:
# Get polarities
= aspect.attrib['polarity'].lower().strip()
polarity if polarity == "conflict":
continue
sentiments.append(polarity)
# Get aspects
= aspect.attrib['category'].lower().strip()
asp if asp == 'anecdotes/miscellaneous':
'miscellaneous')
aspects.append(else:
aspects.append(asp)
# Get sentences
sentences.append(text)
return sentences, aspects, sentiments
def build_vocab(tokenizer, data):
= data[0]
sentences
tokenizer.fit_on_texts(sentences)
def load_pretrained_word_embeddings(options: Dict[str, Any]):
return torchtext.vocab.GloVe(options['name'], options['dim'])
def create_embedding_matrix(word_embeddings: Vectors, vocab: Vocab, path: Union[str, Path]):
if os.path.exists(path):
print(f'loading embedding matrix from {path}')
= pickle.load(open(path, 'rb'))
embedding_matrix else:
= torch.zeros((len(vocab), word_embeddings.dim),
embedding_matrix =torch.float)
dtype
# words that are not availabel in the pretrained word embeddings will be zeros
for word, index in vocab.get_stoi().items():
= word_embeddings.get_vecs_by_tokens(word)
embedding_matrix[index]
# save embedding matrix
open(path, 'wb'))
pickle.dump(embedding_matrix, return embedding_matrix
class SemEvalDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
= 'https://raw.githubusercontent.com/zhangjf-nlp/ATAE-LSTM/master/data/restaurants-train.xml'
TRAIN_DS_URL = 'https://raw.githubusercontent.com/zhangjf-nlp/ATAE-LSTM/master/data/restaurants-trial.xml'
VALID_DS_URL = "https://raw.githubusercontent.com/AlexYangLi/ABSA_Keras/master/raw_data/semeval14_restaurant/Restaurants_Test_Gold.xml"
TEST_DS_URL
class SemEval2014(pl.LightningDataModule):
def __init__(self, tokenizer, opts):
super().__init__()
self.tokenizer = tokenizer
self.batch_size = opts['batch_size']
self.num_workers = opts['num_workers']
self.on_gpu = opts['on_gpu']
self.mapping = {"negative": 0, "neutral": 1, "positive": 2}
self.mapping_aspects = {'ambience': 0, 'food': 1, 'miscellaneous': 2,
'price': 3, 'service': 4}
self.inverse_mapping = {v: k for k, v in enumerate(self.mapping)}
def prepare_data(self, *args, **kwargs) -> None:
self.train_path = 'download/SemEval2014/train.xml'
self.valid_path = 'download/SemEval2014/valid.xml'
self.test_path = 'download/SemEval2014/test.xml'
if not os.path.exists(train_path):
print("Downloading train dataset")
self.train_path = download_url(TRAIN_DS_URL, 'train.xml', 'download/SemEval2014')
if not os.path.exists(valid_path):
print("Downloading valid dataset")
self.valid_path = download_url(VALID_DS_URL, 'valid.xml', 'download/SemEval2014')
if not os.path.exists(test_path):
print("Downloading test dataset")
self.test_path = download_url(TEST_DS_URL, 'test.xml', 'download/SemEval2014')
def setup(self, stage: str = None) -> None:
if stage == 'fit' or stage is None:
# Load data from files
= load_data_from(self.train_path)
train_data = load_data_from(self.valid_path)
valid_data self.train_data = _preprocess_data(train_data, self.tokenizer, self.mapping, self.mapping_aspects)
self.val_data = _preprocess_data(valid_data, self.tokenizer, self.mapping, self.mapping_aspects)
elif stage == 'test' or stage is None:
= load_data_from(self.test_path)
test_data self.test_data = _preprocess_data(test_data, self.tokenizer, self.mapping, self.mapping_aspects)
def train_dataloader(self):
# Create Dataset object
= SemEvalDataset(self.train_data)
train_ds # Create Dataloader
return DataLoader(
train_ds,=True,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory
)
def val_dataloader(self):
= SemEvalDataset(self.val_data)
val_ds return DataLoader(
val_ds,=False,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory
)
def test_dataloader(self):
= SemEvalDataset(self.test_data)
test_ds return DataLoader(
test_ds,=False,
shuffle=self.batch_size,
batch_size=self.num_workers,
num_workers=self.on_gpu,
pin_memory
)
def __repr__(self):
= f"SemEval2014 Dataset\nNum classes: {len(self.mapping)}\nMapping: {self.mapping}\n"
basic if self.train_data is None and self.val_data is None and self.test_data is None:
return basic
= next(iter(self.train_dataloader()))
batch = batch['sequence'], batch['seq_len'], batch['aspect'], batch['sentiment']
sequences, seq_lens, aspects, sentiments = (
data f"Train/val/test sizes: {len(self.train_data)}, {len(self.val_data)}, {len(self.test_data)}\n"
f"Batch sequences stats: {(sequences.shape, sequences.dtype)}\n"
f"Batch seq_lens stats: {(seq_lens.shape, seq_lens.dtype)}\n"
f"Batch aspects stats: {(aspects.shape, aspects.dtype)}\n"
f"Batch sentiments stats: {(sentiments.shape, sentiments.dtype)}\n"
)return basic + data
Implementation
AT-LSTM
LSTM with Aspect Embedding (AE-LSTM)
Aspect information is important when doing classificaiton on the sentence. We may get different polarities with different aspects. The author proposed to learn an embedding vector for each aspect.
Attention-based LSTM (AT-LSTM)
The standard LSTM cannot detect which is the important part for aspect-level sentiment classification. The author proposed to design an attention mechanism capturing the key part of sentence in response to a given aspect.
from IPython.display import Image
='/Users/minhdang/Desktop/AT-LSTM.png') Image(filename
class AT_LSTM(pl.LightningModule):
def __init__(self, embeddings, hidden_size, aspect_hidden_size, num_layers=1, num_classes=3, batch_first=True, lr=1e-3, dropout=0, l2reg=0.01):
super().__init__()
= embeddings.shape[1]
embedding_dim self.embedding = nn.Embedding.from_pretrained(embeddings) # load pre-trained word embeddings
self.aspect_embedding = nn.Embedding(5, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=batch_first, dropout=dropout)
self.linear_h = nn.Linear(hidden_size, hidden_size, bias=False)
self.linear_v = nn.Linear(aspect_hidden_size, aspect_hidden_size, bias=False)
self.linear_p = nn.Linear(hidden_size, hidden_size, bias=False)
self.linear_x = nn.Linear(hidden_size, hidden_size, bias=False)
self.linear = nn.Linear(hidden_size + aspect_hidden_size, 1)
self.linear_s = nn.Linear(hidden_size, num_classes)
self.batch_first = batch_first
self.lr = lr
self.l2reg = l2reg
# Define metrics
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.val_f1 = torchmetrics.F1(num_classes=3, average='macro')
self.test_acc = torchmetrics.Accuracy()
self.test_f1 = torchmetrics.F1(num_classes=3, average='macro')
def configure_optimizers(self):
= torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.l2reg)
optim return optim
def forward(self, input):
= input['sequence'], input['seq_len'], input['aspect']
sequences, seq_lens, aspect_seqs
# Covert sequence to embeddings
= self.embedding(sequences)
embeds # Get the max sequence length
= torch.max(seq_lens)
max_seq_len # Convert aspect to embeddings
= self.aspect_embedding(aspect_seqs)
aspect_embeds
= pack_padded_sequence(embeds, seq_lens.cpu(), batch_first=self.batch_first, enforce_sorted=False)
packed_embeds = self.lstm(packed_embeds)
H, (h, c) = pad_packed_sequence(H, batch_first=True)
padded_H, lens
= self.linear_h(padded_H)
Wh_H = self.linear_v(aspect_embeds)
Wv_va = Wv_va.unsqueeze(1).repeat(1, max_seq_len, 1)
Wv_va = torch.tanh(torch.cat([Wh_H, Wv_va], dim=-1))
M
# Calculate attention score
= self.linear(M).squeeze()
score = torch.arange(max_seq_len, device=self.device)[None,:] < seq_lens[:, None]
att_mask # Create mask to zero out attention scores for padding tokens
~att_mask] = float('-inf')
score[
= F.softmax(score, dim=-1).unsqueeze(2)
alpha = torch.matmul(padded_H.transpose(-2,-1), alpha).squeeze()
r = torch.tanh(self.linear_p(r) + self.linear_x(h[-1]))
final_h = self.linear_s(final_h)
out return out
def training_step(self, batch, batch_idx):
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.cross_entropy(logits, sentiments)
loss = F.softmax(logits, dim=-1)
scores self.train_acc(scores, sentiments)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.cross_entropy(logits, sentiments)
loss = F.softmax(logits, dim=-1)
scores self.val_acc(scores, sentiments)
self.val_f1(scores, sentiments)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.softmax(logits, dim=-1)
scores self.test_acc(scores, sentiments)
self.test_f1(scores, sentiments)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, logger=True)
self.log('test_f1', self.test_f1, on_step=False, on_epoch=True, logger=True)
ATAE-LSTM
To take advanatage of aspect information, we append the input aspect embedding into each word input vector. By doing that, the inter-dependence between words and the input aspect can be modeled.
from IPython.display import Image
='/Users/minhdang/Desktop/ATAE-LSTM.png') Image(filename
class ATAE_LSTM(pl.LightningModule):
def __init__(self, embeddings, hidden_size, aspect_hidden_size, num_layers=1, num_classes=3, batch_first=True, lr=1e-3, dropout=0, l2reg=0.01):
super().__init__()
= embeddings.shape[1]
embedding_dim self.embedding = nn.Embedding.from_pretrained(embeddings) # load pre-trained word embeddings
self.aspect_embedding = nn.Embedding(5, embedding_dim)
self.lstm = nn.LSTM(2*embedding_dim, hidden_size, num_layers, batch_first=batch_first, dropout=dropout)
self.linear_h = nn.Linear(hidden_size, hidden_size, bias=False)
self.linear_v = nn.Linear(aspect_hidden_size, aspect_hidden_size, bias=False)
self.linear_p = nn.Linear(hidden_size, hidden_size, bias=False)
self.linear_x = nn.Linear(hidden_size, hidden_size, bias=False)
self.linear = nn.Linear(hidden_size + aspect_hidden_size, 1)
self.linear_s = nn.Linear(hidden_size, num_classes)
self.batch_first = batch_first
self.lr = lr
self.l2reg = l2reg
# Define metrics
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.val_f1 = torchmetrics.F1(num_classes=3, average='macro')
self.test_acc = torchmetrics.Accuracy()
self.test_f1 = torchmetrics.F1(num_classes=3, average='macro')
def configure_optimizers(self):
= torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.l2reg)
optim return optim
def forward(self, input):
= input['sequence'], input['seq_len'], input['aspect']
sequences, seq_lens, aspect_seqs
# Covert sequence to embeddings
= self.embedding(sequences)
embeds # Get the max sequence length
= torch.max(seq_lens)
max_seq_len # Convert aspect to embeddings
= self.aspect_embedding(aspect_seqs)
aspect_embeds # Repeat the aspect vector across the dimension 1
= aspect_embeds.unsqueeze(1).repeat(1,embeds.shape[1],1)
aspect_embeds_repeat # Append the aspect vector to the input word vector
= torch.cat([embeds, aspect_embeds_repeat], dim=-1)
embeds
= pack_padded_sequence(embeds, seq_lens.cpu(), batch_first=self.batch_first, enforce_sorted=False)
packed_embeds = self.lstm(packed_embeds)
H, (h, c) = pad_packed_sequence(H, batch_first=True)
padded_H, lens
= self.linear_h(padded_H)
Wh_H = self.linear_v(aspect_embeds)
Wv_va = Wv_va.unsqueeze(1).repeat(1, max_seq_len, 1)
Wv_va = torch.tanh(torch.cat([Wh_H, Wv_va], dim=-1))
M
# Calculate attention score
= self.linear(M).squeeze()
score # Create mask to zero out attention scores for padding tokens
= torch.arange(max_seq_len, device=self.device)[None,:] < seq_lens[:, None]
att_mask ~att_mask] = float('-inf')
score[
= F.softmax(score, dim=-1).unsqueeze(2)
alpha = torch.matmul(padded_H.transpose(-2,-1), alpha).squeeze()
r = torch.tanh(self.linear_p(r) + self.linear_x(h[-1]))
final_h = self.linear_s(final_h)
out return out
def training_step(self, batch, batch_idx):
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.cross_entropy(logits, sentiments)
loss = F.softmax(logits, dim=-1)
scores self.train_acc(scores, sentiments)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.cross_entropy(logits, sentiments)
loss = F.softmax(logits, dim=-1)
scores self.val_acc(scores, sentiments)
self.val_f1(scores, sentiments)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
def test_step(self, batch, batch_idx): # pylint: disable=unused-argument
= batch['sentiment']
sentiments = self.forward(batch)
logits = F.softmax(logits, dim=-1)
scores self.test_acc(scores, sentiments)
self.test_f1(scores, sentiments)
self.log('test_acc', self.test_acc, on_step=False, on_epoch=True, logger=True)
self.log('test_f1', self.test_f1, on_step=False, on_epoch=True, logger=True)
Training
= load_pretrained_word_embeddings({"name": "840B", "dim": 300}) word_embeddings
= download_url(TRAIN_DS_URL, 'train.xml', 'download/SemEval2014')
train_path = download_url(VALID_DS_URL, 'valid.xml', 'download/SemEval2014')
valid_path = download_url(TEST_DS_URL, 'test.xml', 'download/SemEval2014')
test_path
= load_data_from(train_path)
train_data = load_data_from(valid_path)
valid_data = load_data_from(test_path)
test_data
= train_data[0] + valid_data[0] + test_data[0]
all_sentences = Tokenizer(get_tokenizer("basic_english"))
tokenizer
build_vocab(tokenizer, [all_sentences])
= {
options "on_gpu": True,
"batch_size": 64,
"num_workers": 2
}= SemEval2014(tokenizer, options)
datamodule
= create_embedding_matrix(word_embeddings, tokenizer.vocab, "embedding_matrix.dat") embedding_matrix
1.18MB [00:00, 9.47MB/s]
40.0kB [00:00, 1.78MB/s]
352kB [00:00, 12.3MB/s]
loading embedding matrix from embedding_matrix.dat
AT-LSTM
= ModelCheckpoint(
checkpoint_callback ='val_acc', # save the model with the best validation accuracy
monitor='checkpoints',
dirpath='max',
mode
)
= pl_loggers.TensorBoardLogger('logs/') # create logger for tensorboard
tb_logger
# Set hyper-parameters
= 1e-3
lr = 300
hidden_size = 300
aspect_embedding_dim = 30
num_epochs = 0.0
l2reg
= pl.Trainer(gpus=1, max_epochs=num_epochs, logger=tb_logger, callbacks=[checkpoint_callback], deterministic=True)
trainer # trainer = pl.Trainer(fast_dev_run=True) #Debug
# trainer = pl.Trainer(overfit_batches=0.025, max_epochs=num_epochs) #Debug
= AT_LSTM(embedding_matrix, hidden_size, aspect_embedding_dim, lr=lr, l2reg=l2reg)
model trainer.fit(model, datamodule)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
------------------------------------------------
0 | embedding | Embedding | 1.3 M
1 | aspect_embedding | Embedding | 1.5 K
2 | lstm | LSTM | 722 K
3 | linear_h | Linear | 90.0 K
4 | linear_v | Linear | 90.0 K
5 | linear_p | Linear | 90.0 K
6 | linear_x | Linear | 90.0 K
7 | linear | Linear | 601
8 | linear_s | Linear | 903
9 | train_acc | Accuracy | 0
10 | val_acc | Accuracy | 0
11 | val_f1 | F1 | 0
12 | test_acc | Accuracy | 0
13 | test_f1 | F1 | 0
------------------------------------------------
1.1 M Trainable params
1.3 M Non-trainable params
2.4 M Total params
9.708 Total estimated model params size (MB)
Global seed set to 2401
=checkpoint_callback.best_model_path) trainer.test(ckpt_path
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.8150510191917419, 'test_f1': 0.6747192740440369}
--------------------------------------------------------------------------------
[{'test_acc': 0.8150510191917419, 'test_f1': 0.6747192740440369}]
ATAE-LSTM
= ModelCheckpoint(
checkpoint_callback ='val_acc', # save the model with the best validation accuracy
monitor='checkpoints',
dirpath='max',
mode
)
= pl_loggers.TensorBoardLogger('logs/') # create logger for tensorboard
tb_logger
# Set hyper-parameters
= 1e-3
lr = 300
hidden_size = 300
aspect_embedding_dim = 30
num_epochs = 0.0
l2reg
= pl.Trainer(gpus=1, max_epochs=num_epochs, logger=tb_logger, callbacks=[checkpoint_callback], deterministic=True)
trainer # trainer = pl.Trainer(fast_dev_run=True) #Debug
# trainer = pl.Trainer(overfit_batches=0.025, max_epochs=num_epochs) #Debug
= ATAE_LSTM(embedding_matrix, hidden_size, aspect_embedding_dim, lr=lr, l2reg=l2reg)
model trainer.fit(model, datamodule)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
------------------------------------------------
0 | embedding | Embedding | 1.3 M
1 | aspect_embedding | Embedding | 1.5 K
2 | lstm | LSTM | 1.1 M
3 | linear_h | Linear | 90.0 K
4 | linear_v | Linear | 90.0 K
5 | linear_p | Linear | 90.0 K
6 | linear_x | Linear | 90.0 K
7 | linear | Linear | 601
8 | linear_s | Linear | 903
9 | train_acc | Accuracy | 0
10 | val_acc | Accuracy | 0
11 | val_f1 | F1 | 0
12 | test_acc | Accuracy | 0
13 | test_f1 | F1 | 0
------------------------------------------------
1.4 M Trainable params
1.3 M Non-trainable params
2.8 M Total params
11.148 Total estimated model params size (MB)
Global seed set to 2401
=checkpoint_callback.best_model_path) trainer.test(ckpt_path
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.8227040767669678, 'test_f1': 0.6682634353637695}
--------------------------------------------------------------------------------
[{'test_acc': 0.8227040767669678, 'test_f1': 0.6682634353637695}]
Discussion
Since I could not find out the exact dataset used in the paper, it is impossible to compare with the paper’s result. Instead, I will compare my implementation with other implementations from Github.
Our result
Model | Accuracy | F1 macro |
---|---|---|
AT-LSTM | 0.81 | 0.674 |
ATAE-LSTM | 0.82 | 0.668 |
There is a unbalance between positive, negative and positive classes in both train, valid and test set. That’s the reason why we have the F1 macro lower than the accuracy.