[docs]classFractalImageGenerator(DatasetGenerator):""" ImageGenerator generates 3-channel synthetic images with provided shape. Uses the algorithm from the article: https://arxiv.org/abs/2103.13023 """_MODEL_PROTO_FILENAME="colorization_deploy_v2.prototxt"_MODEL_WEIGHTS_FILENAME="colorization_release_v2.caffemodel"_HULL_PTS_FILE_NAME="pts_in_hull.npy"_COLORS_FILE="background_colors.txt"def__init__(self,output_dir:str,count:int,shape:Tuple[int,int],model_path:str=get_datumaro_cache_dir(),)->None:assert0<count,"Image count cannot be lesser than 1"self._count=countself._output_dir=output_dirself._model_dir=model_pathself._cpu_count=min(os.cpu_count(),self._count)assertlen(shape)==2self._height,self._width=shapeself._weights=self._create_weights(IFSFunction.NUM_PARAMS)self._threshold=0.2self._iterations=200000self._num_of_points=100000self._initialize_params()
[docs]defgenerate_dataset(self)->None:log.info("Generation of '%d' 3-channel images with height = '%d' and width = '%d'",self._count,self._height,self._width,)self._download_colorization_model(self._model_dir)mp_ctx=get_context("spawn")# On Mac 10.15 and Python 3.7 fork leads to hangswithmp_ctx.Pool(processes=self._cpu_count)aspool:try:params=pool.map(self._generate_category,[Random(i)foriinrange(self._categories)],# nosec B311)finally:pool.close()pool.join()instances_weights=np.repeat(self._weights,self._instances,axis=0)weight_per_img=np.tile(instances_weights,(self._categories,1))params=np.array(params,dtype=object)repeated_params=np.repeat(params,self._weights.shape[0]*self._instances,axis=0)repeated_params=repeated_params[:self._count]weight_per_img=weight_per_img[:self._count]assertweight_per_img.shape[0]==len(repeated_params)==self._countsplits=min(self._cpu_count,self._count)params_per_proc=np.array_split(repeated_params,splits)weights_per_proc=np.array_split(weight_per_img,splits)generation_params=[]offset=0forparam,winzip(params_per_proc,weights_per_proc):indices=list(range(offset,offset+len(param)))offset+=len(param)generation_params.append((param,w,indices))withmp_ctx.Pool(processes=self._cpu_count)aspool:try:pool.starmap(self._generate_image_batch,generation_params)finally:pool.close()pool.join()
@scopeddef_generate_image_batch(self,params:np.ndarray,weights:np.ndarray,indices:List[int])->None:scope_add(suppress_computation_warnings())proto=osp.join(self._model_dir,self._MODEL_PROTO_FILENAME)model=osp.join(self._model_dir,self._MODEL_WEIGHTS_FILENAME)npy=osp.join(self._model_dir,self._HULL_PTS_FILE_NAME)pts_in_hull=np.load(npy).transpose().reshape(2,313,1,1).astype(np.float32)withopen_text(__package__,self._COLORS_FILE)asf:background_colors=np.loadtxt(f)net=cv.dnn.readNetFromCaffe(proto,model)net.getLayer(net.getLayerId("class8_ab")).blobs=[pts_in_hull]net.getLayer(net.getLayerId("conv8_313_rh")).blobs=[np.full([1,313],2.606,np.float32)]fori,param,winzip(indices,params,weights):image=self._generate_image(Random(i),# nosec B311param,self._iterations,self._height,self._width,draw_point=False,weight=w,)color_image=colorize(image,net)aug_image=augment(Random(i),color_image,background_colors)# nosec B311save_image(osp.join(self._output_dir,"{:06d}.png".format(i)),aug_image,create_dir=True)def_generate_image(self,rng:Random,params:np.ndarray,iterations:int,height:int,width:int,draw_point:bool=True,weight:Optional[np.ndarray]=None,)->np.ndarray:ifs_function=IFSFunction(rng,prev_x=0.0,prev_y=0.0)forparaminparams:ifs_function.add_param(param[:ifs_function.NUM_PARAMS],param[ifs_function.NUM_PARAMS],weight)ifs_function.calculate(iterations)img=ifs_function.draw(height,width,draw_point)returnimg@scopeddef_generate_category(self,rng:Random,base_h:int=512,base_w:int=512)->np.ndarray:scope_add(suppress_computation_warnings())pixels=-1i=0whilepixels<self._thresholdandi<self._iterations:param_size=rng.randint(2,7)params=np.zeros((param_size,IFSFunction.NUM_PARAMS+1),dtype=np.float32)sum_proba=1e-5forp_idxinrange(param_size):a,b,c,d,e,f=[rng.uniform(-1.0,1.0)for_inrange(IFSFunction.NUM_PARAMS)]prob=abs(a*d-b*c)sum_proba+=probparams[p_idx]=a,b,c,d,e,f,probparams[:,IFSFunction.NUM_PARAMS]/=sum_probafractal_img=self._generate_image(rng,params,self._num_of_points,base_h,base_w)pixels=np.count_nonzero(fractal_img)/(base_h*base_w)i+=1returnparamsdef_initialize_params(self)->None:ifself._count<self._weights.shape[0]:self._weights=self._weights[:self._count,:]instances_categories=np.ceil(self._count/self._weights.shape[0])self._instances=np.ceil(np.sqrt(instances_categories)).astype(int)self._categories=np.ceil(instances_categories/self._instances).astype(int)@staticmethoddef_create_weights(num_params):# weights from https://openaccess.thecvf.com/content/ACCV2020/papers/Kataoka_Pre-training_without_Natural_Images_ACCV_2020_paper.pdfBASE_WEIGHTS=np.ones((num_params,))WEIGHT_INTERVAL=0.4INTERVAL_MULTIPLIERS=(-2,-1,1,2)weight_vectors=[BASE_WEIGHTS]forweight_indexinrange(num_params):formultiplierinINTERVAL_MULTIPLIERS:modified_weights=BASE_WEIGHTS.copy()modified_weights[weight_index]+=multiplier*WEIGHT_INTERVALweight_vectors.append(modified_weights)weights=np.array(weight_vectors)returnweights@classmethoddef_download_colorization_model(cls,save_dir:str)->None:prototxt_file_name=cls._MODEL_PROTO_FILENAMEcaffemodel_file_name=cls._MODEL_WEIGHTS_FILENAMEhull_file_name=cls._HULL_PTS_FILE_NAMEproto_path=osp.join(save_dir,prototxt_file_name)model_path=osp.join(save_dir,caffemodel_file_name)hull_path=osp.join(save_dir,hull_file_name)ifnot(osp.exists(proto_path)andosp.exists(model_path)andosp.exists(hull_path))andnotos.access(save_dir,os.W_OK):raiseValueError("Please provide a path to a colorization model directory or ""a path to a writable directory to download the model")forurl,filename,size,sha512_checksumin[(f"https://raw.githubusercontent.com/richzhang/colorization/a1642d6ac6fc80fe08885edba34c166da09465f6/colorization/models/{prototxt_file_name}",prototxt_file_name,9945,"e3dd9188771202bd296623510bcf527b41c130fc9bae584e61dcdf66917b8c4d147b7b838fec0685568f7f287235c34e8b8e9c0482b555774795be89f0442820",),(f"http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/{caffemodel_file_name}",caffemodel_file_name,128946764,"3d773dd83cfcf8e846e3a9722a4d302a3b7a0f95a0a7ae1a3d3ef5fe62eecd617f4f30eefb1d8d6123be4a8f29f7c6e64f07b36193f45710b549f3e4796570f1",),(f"https://raw.githubusercontent.com/richzhang/colorization/a1642d6ac6fc80fe08885edba34c166da09465f6/colorization/resources/{hull_file_name}",hull_file_name,5088,"bf59a8a4e74b18948e4aeaa430f71eb8603bd9dbbce207ea086dd0fb976a34672beaeea6f1233a21687da710e0f8d36e86133a8532265dfda52994a7d6f0dbf5",),]:save_path=osp.join(save_dir,filename)ifosp.exists(save_path):continuelog.info("Downloading the '%s' file to '%s'",filename,save_dir)try:cls._download_file(url,save_path,expected_size=size,expected_checksum=sha512_checksum)exceptExceptionase:raiseException(f"Failed to download the '{filename}' file: {str(e)}")frome@staticmethod@scopeddef_download_file(url:str,output_path:str,*,timeout:int=60,expected_size:int,expected_checksum:str)->None:BLOCK_SIZE=2**20assertnotosp.exists(output_path)tmp_path=output_path+".tmp"ifosp.exists(tmp_path):raiseException(f"Can't write temporary file '{tmp_path}' - file exists")response=requests.get(url,timeout=timeout,stream=True)on_exit_do(response.close)response.raise_for_status()checksum_counter=hashlib.sha512()actual_size=0withopen(tmp_path,"wb")asfd:on_error_do(os.unlink,tmp_path)forchunkinresponse.iter_content(chunk_size=BLOCK_SIZE):actual_size+=len(chunk)ifactual_size>expected_size:# There is also the context-length header, but it can be corrupted or invalid# for different reasonsraiseException(f"The downloaded file has unexpected size, expected {expected_size}.")checksum_counter.update(chunk)fd.write(chunk)actual_checksum=checksum_counter.hexdigest()ifactual_checksum.lower()!=expected_checksum.lower():raiseException("The downloaded file has unexpected checksum")os.rename(tmp_path,output_path)