-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpseudo_labels.py
More file actions
37 lines (25 loc) · 1.08 KB
/
pseudo_labels.py
File metadata and controls
37 lines (25 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# -*- coding: utf-8 -*-
"""pseudo_labels.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1DBldKepiIfjJAyfavm71XrCJ8fUT8N2N
"""
!pip install transformers
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model = "facebook/bart-large-mnli")
labels = ["politics","sports","business","technology"]
import pandas as pd
data = pd.read_csv("/content/drive/MyDrive/WeSTClass/News/combined_file.csv", header = None,error_bad_lines=False)
data.columns = ["actual_labels","document"]
from tqdm import tqdm
for i in tqdm(range(2000)):
sequence= data.iloc[i,1]
data.loc[i,"pseudo_label_bart-l-m"]=classifier(sequence,labels)["labels"][0]
data.to_csv("/content/drive/MyDrive/WeSTClass/News/combined_file_latest4.csv")
# For Movies:
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model = "facebook/bart-large-mnli")
labels = ["good","bad"]
import pandas as pd
data = pd.read_csv("/", header = None,error_bad_lines=False)
data.columns = ["actual_labels","document"]