Auto Text Classification
In [ ]:
Copied!
import os
import sys
os.chdir("../../")
import os
import sys
os.chdir("../../")
In [ ]:
Copied!
import warnings
warnings.filterwarnings("ignore")
import warnings
warnings.filterwarnings("ignore")
In [ ]:
Copied!
!pip install git+https://github.com/gradsflow/gradsflow@main -q
!pip install "lightning-flash[text]" -q
!pip install git+https://github.com/gradsflow/gradsflow@main -q
!pip install "lightning-flash[text]" -q
In [ ]:
Copied!
from flash.text import TextClassificationData
from flash.core.data.utils import download_data
from gradsflow import AutoTextClassifier
from flash.text import TextClassificationData
from flash.core.data.utils import download_data
from gradsflow import AutoTextClassifier
In [ ]:
Copied!
# Download dataset for the experiment
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
# Download dataset for the experiment
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
./data/imdb.zip: 0%| | 0/15575 [00:00<?, ?KB/s]
In [ ]:
Copied!
datamodule = TextClassificationData.from_csv(
"review",
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
batch_size=64,
)
datamodule = TextClassificationData.from_csv(
"review",
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
batch_size=64,
)
Using custom data configuration default-27edd8668677dc0d Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-27edd8668677dc0d/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)
0%| | 0/1 [00:00<?, ?it/s]
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-27edd8668677dc0d/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-250d4864068a61ff.arrow Using custom data configuration default-a15e2e740b1162cd Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-a15e2e740b1162cd/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)
0%| | 0/1 [00:00<?, ?it/s]
Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-a15e2e740b1162cd/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519/cache-53740c8fce8e4ca8.arrow
In [ ]:
Copied!
suggested_conf = dict(
optimizer=["adam", "sgd"],
lr=(5e-4, 1e-3),
)
model = AutoTextClassifier(
datamodule,
suggested_backbones=["prajjwal1/bert-tiny"],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
n_trials=3,
prune=True,
)
print("AutoTextClassifier initialised!")
model.hp_tune(gpu=1/3)
suggested_conf = dict(
optimizer=["adam", "sgd"],
lr=(5e-4, 1e-3),
)
model = AutoTextClassifier(
datamodule,
suggested_backbones=["prajjwal1/bert-tiny"],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
n_trials=3,
prune=True,
)
print("AutoTextClassifier initialised!")
model.hp_tune(gpu=1/3)
In [ ]:
Copied!
model.analysis.dataframe()
model.analysis.dataframe()
Out[ ]:
val_accuracy | train_accuracy | time_this_iter_s | should_checkpoint | done | timesteps_total | episodes_total | training_iteration | trial_id | experiment_id | ... | hostname | node_ip | time_since_restore | timesteps_since_restore | iterations_since_restore | warmup_time | config/backbone | config/lr | config/optimizer | logdir | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.8208 | 0.875000 | 58.556785 | True | False | NaN | NaN | 1 | 29bf7_00000 | 16c9bff4926a49168f54720c0443e38b | ... | 4af1932bdf0c | 172.28.0.2 | 58.556785 | 0 | 1 | 0.005947 | prajjwal1/bert-tiny | 0.000523 | adam | /root/ray_results/optimization_objective_2022-... |
1 | 0.5960 | 0.578125 | 59.146531 | True | False | NaN | NaN | 1 | 29bf7_00001 | fc357681df964fc9ab18fc84a6667e7c | ... | 4af1932bdf0c | 172.28.0.2 | 59.146531 | 0 | 1 | 0.004946 | prajjwal1/bert-tiny | 0.000919 | sgd | /root/ray_results/optimization_objective_2022-... |
2 | 0.8236 | 0.843750 | 38.021982 | True | False | NaN | NaN | 1 | 29bf7_00002 | f8ea8f938eba436aa5d060f77efcd871 | ... | 4af1932bdf0c | 172.28.0.2 | 38.021982 | 0 | 1 | 0.004159 | prajjwal1/bert-tiny | 0.000639 | adam | /root/ray_results/optimization_objective_2022-... |
3 rows ร 24 columns
In [ ]:
Copied!
from flash import Trainer
trainer = Trainer(accelerator="auto", devices=1)
from flash import Trainer
trainer = Trainer(accelerator="auto", devices=1)
GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs
In [ ]:
Copied!
trainer.validate(model.model, datamodule=datamodule)
trainer.validate(model.model, datamodule=datamodule)
Missing logger folder: /lightning_logs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Validation: 0it [00:00, ?it/s]
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโณโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ Validate metric โ DataLoader 0 โ โกโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฉ โ val_accuracy โ 0.7681461572647095 โ โ val_cross_entropy โ 0.3843795359134674 โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
Out[ ]:
[{'val_accuracy': 0.7681461572647095, 'val_cross_entropy': 0.3843795359134674}]
Last update:
May 18, 2022