当前位置: 移动技术网 > IT编程>脚本编程>Python > HuggingFace的DistilBERT学习笔记-MyToolkit

HuggingFace的DistilBERT学习笔记-MyToolkit

2020年07月30日  | 移动技术网IT编程  | 我要评论
HuggingFace的DistilBERT学习笔记-顺序版学习DistillBERT My Toolkitpython 代码规范os模块argparse 模块logging 模块shutil 模块深浅拷贝 地址引用还是值引用 函数 循环读写 json pickle numpy torch等todo torch的DistributedParall多机多卡训练todo from torch.utils.data import BatchSampler, DataLoader, RandomSampler, D

DistillBERT My Toolkit

学习Facebook的DistillBERT中所使用的工具包

python 代码规范

函数名、文件名、变量名:big_apple
类名:BigApple
导入argument后要进行sanity_checks(args),包括 todo

os模块

# os.path.dirname/abspath/isdir/isfile/exits/join/curdir/
# os.getcwd() get current work directory
# os.mkdir("dir")
import os
shell_path = os.getcwd()
file_path_name = os.path.abspath(__file__)
file_path = os.path.dirname(os.path.abspath(__file__))
print(shell_path) #/home/zhangmengyu/distill
print(file_path_name) #/home/zhangmengyu/distill/test.py
print(file_path) #/home/zhangmengyu/distill
print(os.path.isfile(file_path_name)) #True
print(os.path.isdir(file_path)) #True

# create logs file folder
logs_dir = os.path.join(os.getcwd(), "logs") # '/home/zhangmengyu/distill/logs'
logs_dir = os.path.join(os.path.curdir, "logs") # './logs'
if os.path.exists(logs_dir) and os.path.isdir(logs_dir):
    pass
else:
    os.mkdir(logs_dir)

argparse 模块

主要用到了 type default required choices=[1,2] action=“store_true” help等参数
但是要注意,如果用了action=“store_true”,那么这个flag出现就是true,不出现就是false。

import argparse
parser = argparse.ArgumentParser(description="study argparse module")
parser.add_argument("--bool_arg", action="store_true", help="bool arg: use --bool_arg")
parser.add_argument("--str_arg", type=str, default="default arg", required=True, choices=["test1", "test2"], help="str arg: use --str_arg str")
parser.add_argument("--int_arg", type=int, default="default arg", required=True, choices=[1,2], help="str arg: use --int_arg int")
args = parser.parse_args()
print(args) # Namespace(bool_arg=True, int_arg=1, str_arg='test1')
# shell$ python test.py - -bool_arg - -str_arg test1 - -int_arg 1

logging 模块

用的比较多的是Logger Formatter Handler

import logging
import os
import logging.handlers

LEVELS = {'NOSET': logging.NOTSET,
          'DEBUG': logging.DEBUG,
          'INFO': logging.INFO,
          'WARNING': logging.WARNING,
          'ERROR': logging.ERROR,
          'CRITICAL': logging.CRITICAL}

## choice 1 可以同时支持输出到console 输出到文件 输出到回滚文件##
# create logs file folder
logs_dir = os.path.join(os.getcwd(), "logs") # '/home/zhangmengyu/distill/logs'
logs_dir = os.path.join(os.path.curdir, "logs") # './logs'
if os.path.exists(logs_dir) and os.path.isdir(logs_dir):
    pass
else:
    os.mkdir(logs_dir)

# init logger
logger = logging.getLogger(__name__)
formatter = logging.Formatter("%(asctime)s - %(levelname)8s - %(name)10s - PID: %(process)d -  %(message)s")
logger.setLevel(logging.INFO)

# define a rotating file handler
rotatingFileHandler = logging.handlers.RotatingFileHandler(
    filename="test_rotating.txt",
    maxBytes=1024 * 1024 * 50,
    backupCount=5,
) # 新的run的log 会 append 到旧的里面
rotatingFileHandler.setFormatter(formatter)
logger.addHandler(rotatingFileHandler)

# define a handler whitch writes messages to sys
console = logging.StreamHandler()
console.setFormatter(formatter)
logger.addHandler(console)

# define a file handler
FileHandler = logging.FileHandler(
    filename="test_file.txt",
    mode="w",
)
FileHandler.setFormatter(formatter)
logger.addHandler(FileHandler)

## choice 2 ##
# 或者可以使用 basic config
# 要么输出到console 要么输出到文件 根据filename参数是否为空来决定 默认filemode="a"
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d -  %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

## use logger example ##
logger.info(type(a))
logger.info(f"Saaaapecial tokens {a}")
logger.error("error XXX")

shutil 模块

shutil.rmtree(path) 递归的删除文件夹

深浅拷贝 地址引用还是值引用 函数 循环

todo这个内容理解的还不深刻,每次用到都会忘记。
核心是?引用的都是地址,不管是for aa in a还是函数,只不过当对地址里面的值进行操作的时候。如果遇到【数字 string tuple】不可变类型,则不会更改,如果遇到【list dict】可变类型则会修改。
有个例外,不管什么时候,不管是在for循环里还是在函数体里还是在main中,对变量进行重新赋值都会重新开辟空间,不会对原来的值进行修改。

a=[1,2,[3,4]]
b=a
b[0]=0
b[2][0]=0
a[1]=0
print(id(b))
print(id(a))
print(a) #[0, 0, [0, 4]]
print(b) #[0, 0, [0, 4]]

aa=0
a = [1, 2, 3]
for aa in a:
    aa = 0
print(a) # [1,2,3] 

for i in range(len(a)):
    a[i] = 0
print(a) # [0, 0, 0]

读写 json pickle numpy torch等

json和pickle读写方式类似
numpy和torch读写方式类似
正常存储时把data存到f中,这就是正常的顺序
但是numpy 由于有一个np.savez(f, key_a=data_a, key_b=data_b)所以会导致numpy的顺序和其他三个不同

json
dumps是将dict转化成json字符串格式,loads是将json字符串转化成dict格式。
dump和load也是类似的功能,只是与文件操作结合起来了。
只能序列化python的object

with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
    json.dump(vars(args), f, indent=4)
with open('data.json','r') as f:
    data = json.load(f)

pickle
pickle序列化后的数据,可读性差,人一般无法识别。
可以序列化函数和类,但是要让pickle找到类或函数的定义

with open("test.txt","wb") as f:
    pickle.dump(a, f) #重点在于rb和wb 二进制形式dump和load
with open("test.txt","rb") as f:
    d = pickle.load(f)

numpy 要注意的是后缀npy npz要写对 不然会自动加上去

    import numpy as np
    a = np.random.random_sample([10,5])
    np.save("./test.npy", a)
    a_load = np.load("test.npy")
    b = np.random.randint([2,1])
    np.savez("./test.npz", aa=a, bb=b)
    data=np.load("./test.npz")
    print(data["aa"])
    print(data["bb"])

torch

torch.save(data, path)
data = torch.load(path)

todo torch的DistributedParall多机多卡训练

from torch.utils.data.distributed import DistributedSampler
init_gpu_params(args)
set_seed(args)

todo from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Dataset

看了一部分的Dataset和DataLoader,学到的知识暂时解了迷惑,但是还有很多地方没有学会,总要一点一点学习的
DataLoader是一个封装的类,可以看作一个黑盒子,这个黑盒子和DataLoadery一起配合使用,用来返回一个batch的tensor格式的数据。要注意什么时候是一个batch的tensor数据(即inputs.tensor,labels.tensor),什么时候是一个数据对(input, label),什么时候是一个list列表类型的数据([input_1, label_1], [input_2, label_2]…)

对DataSet不管__init__的数据格式是怎么存储的,只需要实现__len__(self)和__getitem__(self, index)(输出index对应的一个数据对)和collate_fn(也就是例子代码的batch_sequence,实际上是用于DataLoader的collate_fn)

DataLoader的数据处理流大概是这样的,首先根据Dataset的__len__(self)和__getitem__(self, index)(输出index对应的一个数据对)和batch_size和sampler,输出一个list列表类型的数据([input_1, label_1], [input_2, label_2]…),然后将这个list传递给collate_fn,collate_fn返回一个batch的tensor数据(即inputs.tensor,labels.tensor)

from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler

train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)

if params.n_gpu <= 1:
    sampler = RandomSampler(dataset)
else:
    sampler = DistributedSampler(dataset)

my_dataloader = DataLoader(
	dataset=train_lm_seq_dataset, 
	batch_sampler=sampler,  
	collate_fn=dataset.batch_sequences
	)

class LmSeqsDataset(Dataset):
    """Custom Dataset wrapping language modeling sequences.

    Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.

    Input:
    ------
        params: `NameSpace` parameters
        data: `List[np.array[int]]
    """

    def __init__(self, params, data):
        self.params = params
        self.token_ids = np.array(data)
        self.lengths = np.array([len(t) for t in data])

    def __getitem__(self, index):
        return (self.token_ids[index], self.lengths[index])

    def __len__(self):
        return len(self.lengths)

    def batch_sequences(self, batch):
        """
        Do the padding and transform into torch.tensor.
        """
        token_ids = [t[0] for t in batch]
        lengths = [t[1] for t in batch]
        assert len(token_ids) == len(lengths)

        # Max for paddings
        max_seq_len_ = max(lengths)

        # Pad token ids
        if self.params.mlm:
            pad_idx = self.params.special_tok_ids["pad_token"]
        else:
            pad_idx = self.params.special_tok_ids["unk_token"]
        tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
        assert len(tk_) == len(token_ids)
        assert all(len(t) == max_seq_len_ for t in tk_)

        tk_t = torch.tensor(tk_)  # (bs, max_seq_len_)
        lg_t = torch.tensor(lengths)  # (bs)
        return tk_t, lg_t

todo tqdm

这是一个看起来最高大上,用起来也最高大上的,学起来最简单的一个工具包呀
和for aa in a:是相同的,不同的是可以定义desc 可以定义disable=True则不显示进度条

from tqdm import tqdm
char_test = [11,22,33,44,55,66,77,88,99]
import time
tchar = tqdm(char_test,desc="desc", disable=False)
for t in tchar: # 或者for t in trange(100):
    time.sleep(1.0)
    print(f"processing {t}")
    # desc:  22%|████████▋         | 2/9 [00:02<00:07,  1.00s/it]processing 33

todo optimizer

本文地址:https://blog.csdn.net/weixin_43526074/article/details/107636816

如您对本文有疑问或者有任何想说的,请点击进行留言回复,万千网友为您解惑!

相关文章:

验证码:
移动技术网