大学のスパコンはジョブ管理システム PBSによってジョブスケジュールを行なっているのですが,
最近変わったのか,Pythonで並列処理をしようとするとなぜかエラーが出るようになりました。
エラー
=>> PBS: job killed: ncpus 20.8 exceeded limit 16
コード
import os
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
name = MPI.Get_processor_name()
import csv
source=[]
target=[]
with open("test.csv","r",encoding="utf-8") as f:
cdata=csv.reader(f,delimiter=",")
for i,items in enumerate(cdata):
if i==0:
continue
source.append(items[0])
target.append(items[1])
data_lst=[]
for src,tgt in zip(source,target):
data_lst.append((src,tgt))
# http://muscle199x.blog.fc2.com/blog-entry-70.html?sp
def split_list(lst, n):
list_size = len(lst)
a = list_size // n
b = list_size % n
return [lst[i * a + (i if i < b else b):(i + 1) * a + (i + 1 if i < b else b)] for i in range(n)]
from transformers import MT5ForConditionalGeneration, MT5Model, MT5EncoderModel, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("google/mt5-base")
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
if rank == 0:
data = split_list(data_lst, size)
else:
data = None
def calc(my_data):
res=[]
for src,tgt in my_data:
inputs = tokenizer.encode("summarize: " + src, return_tensors='pt', max_length=512, truncation=True)
summary_ids = model.generate(inputs, max_length=50, min_length=10, length_penalty=5., num_beams=2)
summary = tokenizer.decode(summary_ids[0])
res.append((src,tgt,summary))
return res
my_data = comm.scatter(data, root=0)
for r in range(comm.size):
if rank == r:
print("scatter.[%d] %d" % (rank, len(my_data)))
comm.Barrier()
my_res = calc(my_data)
def flatten(l):
try:
return ([item for sublist in l for item in sublist])
except:
import traceback
traceback.print_exc()
print(l)
# gather
res = MPI.COMM_WORLD.gather(my_res, root=0)
if rank == 0:
res=flatten(res)
for src,tgt,summary in res:
print(f"source:{src}")
print(f"target:{tgt}")
print(f"summary:{summary}")
print("-----")
これで実行スクリプトはこんな感じ
#!/usr/bin/bash #PBS -N sample #PBS -j oe #PBS -l select=1:ncpus=64:mpiprocs=64 #PBS -q SINGLE NPROCS=`cat $PBS_NODEFILE|wc -l` cd ${PBS_O_WORKDIR} mpirun -np ${NPROCS} python mt5.py
解決方法
ここにありました。
簡単に解決するためには,Pythonコードの一番上に
import os
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
これを入れるだけ。Numpyを呼ぶ前に呼ぶ必要があるので,シェルスクリプトで呼んでもいいですし,Pythonの一番上で環境変数に設定してやってもいいです。