from fastai.gen_doc.nbdoc import *
from fastai.vision import *
from fastai import *
import PIL
The fastai library is built such that the pictures loaded are wrapped in an Image
. This Image
contains the array of pixels associated to the picture, but also has a lot of built-in functions that will help the fastai library to process transformations applied to the corresponding image. There are also sub-classes for special types of image-like objects:
ImageSegment
for segmentation masksImageBBox
for bounding boxesSee the following sections for documentation of all the details of these classes. But first, let's have a quick look at the main functionality you'll need to know about.
Opening an image and converting to an Image
object is easily done by using the open_image
function:
img = open_image('imgs/cat_example.jpg')
img
To look at the picture that this Image
contains, you can also use its show
method. It will show a resized version and has more options to customize the display.
img.show()
This show
method can take a few arguments (see the documentation of show_image
for details) but the two we will use the most in this documentation are:
ax
which is the matplolib.pyplot axes on which we want to show the imagetitle
which is an optional title we can give to the image._,axs = plt.subplots(1,4,figsize=(12,4))
for i,ax in enumerate(axs): img.show(ax=ax, title=f'Copy {i+1}')
If you're interested in the tensor of pixels, it's stored in the fastai.vision.data
attribute of an Image
.
img.data.shape
torch.Size([3, 500, 394])
Image
is the class that wraps every picture in the fastai library. It is subclassed to create ImageSegment
and ImageBBox
when dealing with segmentation and object detection tasks.
show_doc(Image, title_level=3)
Most of the functions of the Image
class deal with the internal pipeline of transforms, so they are only shown at the end of this page. The easiest way to create one is through the function open_image
.
show_doc(open_image)
img = open_image('imgs/cat_example.jpg')
img
In a Jupyter Notebook, the representation of an Image
is its underlying picture (shown to its full size). On top of containing the tensor of pixels of the Image (and automatically doing the conversion after decoding the image), this class contains various methods for the implementation of transforms. The Image.show
method also allows to pass more arguments:
show_doc(Image.show, arg_comments ={
'ax': 'matplotlib.pyplot axes on which show the image',
'figsize': 'Size of the figure',
'title': 'Title to display on top of the graph',
'hide_axis': 'If True, the axis of the graph are hidden',
'cmap': 'Color map to use',
'y': 'Potential target to be superposed on the same graph (mask, bounding box, points)'
}, full_name='Image.show')
Image.show
[source]
Image.show
(ax
:Axes
=None
,figsize
:tuple
=(3, 3)
,title
:Optional
[str
]=None
,hide_axis
:bool
=True
,cmap
:str
='viridis'
,y
:Any
=None
,kwargs
)
This allows us to completely customize the display of an Image
. We'll see examples of the y
functionality below with segmentation and bounding boxes tasks, for now here is an example using the other features.
img.show(figsize=(2, 1), title='Little kitten')
img.show(figsize=(10,5), title='Big kitten')
An Image
object also has a few attributes that can be useful:
Image.data
gives you the underlying tensor of pixelImage.shape
gives you the size of that tensor (channels x height x width)Image.size
gives you the size of the image (height x width)img.data, img.shape, img.size
(tensor([[[0.1294, 0.0863, 0.0392, ..., 0.4706, 0.4941, 0.4863], [0.0745, 0.0471, 0.0392, ..., 0.4706, 0.4863, 0.4863], [0.0706, 0.0510, 0.0627, ..., 0.4784, 0.4784, 0.4784], ..., [0.3059, 0.3647, 0.3686, ..., 0.5412, 0.5725, 0.5725], [0.3294, 0.4000, 0.4039, ..., 0.5882, 0.5765, 0.5765], [0.3843, 0.4627, 0.4667, ..., 0.6471, 0.5725, 0.5725]], [[0.0235, 0.0000, 0.0000, ..., 0.3490, 0.3686, 0.3725], [0.0000, 0.0000, 0.0000, ..., 0.3569, 0.3725, 0.3725], [0.0000, 0.0000, 0.0157, ..., 0.3647, 0.3686, 0.3686], ..., [0.3882, 0.4588, 0.4627, ..., 0.6471, 0.6784, 0.6784], [0.4118, 0.4941, 0.4980, ..., 0.6941, 0.6824, 0.6824], [0.4667, 0.5569, 0.5608, ..., 0.7529, 0.6784, 0.6784]], [[0.0980, 0.0863, 0.1059, ..., 0.1765, 0.2078, 0.2078], [0.0706, 0.0745, 0.1137, ..., 0.1922, 0.2078, 0.2157], [0.1020, 0.1176, 0.1647, ..., 0.2078, 0.2118, 0.2157], ..., [0.4941, 0.5608, 0.5647, ..., 0.7294, 0.7608, 0.7529], [0.5176, 0.5961, 0.6000, ..., 0.7765, 0.7647, 0.7569], [0.5725, 0.6588, 0.6627, ..., 0.8353, 0.7608, 0.7529]]]), torch.Size([3, 500, 394]), torch.Size([500, 394]))
For a segmentation task, the target is usually a mask. The fastai library represents it as an ImageSegment
object.
show_doc(ImageSegment, title_level=3)
To easily open a mask, the function open_mask
plays the same role as open_image
:
show_doc(open_mask)
open_mask
[source]
open_mask
(fn
:PathOrStr
,div
=False
,convert_mode
='L'
) →ImageSegment
Return ImageSegment
object create from mask in file fn
. If div
, divides pixel values by 255.
open_mask('imgs/mask_example.png')
From time to time, you may encouter mask data as run lengh encoding string instead of picture.
df = pd.read_csv('imgs/mask_rle_sample.csv')
encoded_str = df.iloc[1]['rle_mask'];
df[:2]
img | rle_mask | |
---|---|---|
0 | 00087a6bd4dc_01.jpg | 879386 40 881253 141 883140 205 885009 17 8850... |
1 | 00087a6bd4dc_02.jpg | 873779 4 875695 7 877612 9 879528 12 881267 15... |
You can also read a mask in run length encoding, with an extra argument shape for image size
mask = open_mask_rle(df.iloc[0]['rle_mask'], shape=(1918, 1280)).resize((1,128,128))
mask
show_doc(open_mask_rle)
open_mask_rle
[source]
open_mask_rle
(mask_rle
:str
,shape
:Tuple
[int
,int
]) →ImageSegment
Return ImageSegment
object create from run-length encoded string in mask_lre
with size in shape
.
The open_mask_rle
simply make use of the helper function rle_decode
rle_decode(encoded_str, (1912, 1280)).shape
(1912, 1280)
show_doc(rle_decode)
rle_decode
[source]
rle_decode
(mask_rle
:str
,shape
:Tuple
[int
,int
]) →ndarray
Return an image array from run-length encoded string
You can also convert ImageSegment
to run length encoding.
type(mask)
fastai.vision.image.ImageSegment
rle_encode(mask.data)
'5943 21 6070 25 6197 26 6324 28 6452 29 6579 30 6707 31 6835 31 6962 32 7090 33 7217 34 7345 35 7473 35 7595 2 7600 36 7722 5 7728 37 7766 4 7850 43 7894 5 7978 43 8022 5 8106 49 8238 44 8366 40 8494 41 8621 42 8748 44 8875 46 9003 47 9130 48 9258 49 9386 49 9513 50 9641 51 9769 51 9897 51 10024 52 10152 53 10280 53 10408 53 10536 53 10664 53 10792 53 10920 53 11048 53 11176 53 11304 53 11432 53 11560 53 11688 53 11816 53 11944 53 12072 53 12200 53 12328 53 12456 53 12584 53 12712 53 12840 53 12968 53 13097 51 13225 51 13353 51 13481 51 13610 49 13742 44 13880 30'
show_doc(rle_encode)
rle_encode
[source]
rle_encode
(img
:ndarray
) →str
Return run-length encoding string from an image array
An ImageSegment
object has the same properties as an Image
. The only difference is that when applying the transformations to an ImageSegment
, it will ignore the functions that deal with lighting and keep values of 0 and 1. As explained earlier, it's easy to show the segmentation mask over the associated Image
by using the y
argument of show_image
.
img = open_image('imgs/car_example.jpg')
mask = open_mask('imgs/mask_example.png')
_,axs = plt.subplots(1,3, figsize=(8,4))
img.show(ax=axs[0], title='no mask')
img.show(ax=axs[1], y=mask, title='masked')
mask.show(ax=axs[2], title='mask only', alpha=1.)
When the targets are a bunch of points, the following class will help.
show_doc(ImagePoints, doc_string=False, title_level=3)
Create an ImagePoints
object from a flow
of coordinates. Coordinates need to be scaled to the range (-1,1) which will be done in the intialization if scale
is left as True
. Convention is to have point coordinates in the form [y,x]
unless y_first
is set to False
.
img = open_image('imgs/face_example.jpg')
pnts = torch.load('points.pth')
pnts = ImagePoints(FlowField(img.size, pnts))
img.show(y=pnts)
Note that the raw points are gathered in a FlowField
object, which is a class that wraps together a bunch of coordinates with the corresponding image size. In fastai, we expect points to have the y coordinate first by default. The underlying data of pnts
is the flow of points scaled from -1 to 1 (again with the y coordinate first):
pnts.data[:10]
tensor([[-0.1875, -0.6000], [-0.0500, -0.5875], [ 0.0750, -0.5750], [ 0.2125, -0.5750], [ 0.3375, -0.5375], [ 0.4500, -0.4875], [ 0.5250, -0.3750], [ 0.5750, -0.2375], [ 0.5875, -0.1000], [ 0.5750, 0.0375]])
For an objection detection task, the target is a bounding box containg the picture.
show_doc(ImageBBox, doc_string=False, title_level=3)
class
ImageBBox
[source]
ImageBBox
(flow
:FlowField
,scale
:bool
=True
,y_first
:bool
=True
,labels
:Collection
=None
,classes
:dict
=None
,pad_idx
:int
=0
) ::ImagePoints
Create an ImageBBox
object from a flow
of coordinates. Those coordinates are expected to be in a FlowField
with an underlying flow of size 4N, if we have N bboxes, describing for each box the top left, top right, bottom left, bottom right corners. Coordinates need to be scaled to the range (-1,1) which will be done in the intialization if scale
is left as True
. Convention is to have point coordinates in the form [y,x]
unless y_first
is set to False
. labels
is an optional collection of labels, which should be the same size as flow
. pad_idx
is used if the set of transform somehow leaves the image without any bounding boxes.
To create an ImageBBox
, you can use the create
helper function that takes a list of bounding boxes, the height of the input image, and the width of the input image. Each bounding box is represented by a list of four numbers: the coordinates of the corners of the box with the following convention: top, left, bottom, right.
show_doc(ImageBBox.create, arg_comments={
'bboxes': 'list of bboxes (each of those being four integers with the top, left, bottom, right convention)',
'h': 'height of the input image',
'w': 'width of the input image',
'labels': 'labels of the images',
'pad_idx': 'padding index that will be used to group the ImageBBox in a batch'
})
create
[source]
create
(h
:int
,w
:int
,bboxes
:Collection
[Collection
[int
]],labels
:Collection
=None
,classes
:dict
=None
,pad_idx
:int
=0
) →ImageBBox
Create an ImageBBox object from bboxes
.
We need to pass the dimensions of the input image so that ImageBBox
can internally create the FlowField
. Again, the Image.show
method will display the bouding box on the same image if it's passed as a y
argument.
img = open_image('imgs/car_bbox.jpg')
bbox = ImageBBox.create(*img.size, [[96, 155, 270, 351]])
img.show(y=bbox)
To help with the conversion of images or to show them, we use these helper functions:
show_doc(show_image)
show_doc(pil2tensor)
pil2tensor
[source]
pil2tensor
(image
:ndarray
,dtype
:dtype
) →Tensor
Convert PIL style image
array to torch style image tensor.
pil2tensor(PIL.Image.open('imgs/cat_example.jpg').convert("RGB"), np.float32).div_(255).size()
torch.Size([3, 500, 394])
pil2tensor(PIL.Image.open('imgs/cat_example.jpg').convert("I"), np.float32).div_(255).size()
torch.Size([1, 500, 394])
pil2tensor(PIL.Image.open('imgs/mask_example.png').convert("L"), np.float32).div_(255).size()
torch.Size([1, 128, 128])
pil2tensor(np.random.rand(224,224,3).astype(np.float32), np.float32).size()
torch.Size([3, 224, 224])
pil2tensor(PIL.Image.open('imgs/cat_example.jpg'), np.float32).div_(255).size()
torch.Size([3, 500, 394])
pil2tensor(PIL.Image.open('imgs/mask_example.png'), np.float32).div_(255).size()
torch.Size([1, 128, 128])
show_doc(image2np)
image2np
[source]
image2np
(image
:Tensor
) →ndarray
Convert from torch style image
to numpy/matplotlib style.
show_doc(scale_flow)
scale_flow
[source]
scale_flow
(flow
,to_unit
=True
)
Scale the coords in flow
to -1/1 or the image size depending on to_unit
.
show_doc(bb2hw)
bb2hw
[source]
bb2hw
(a
:Collection
[int
]) →ndarray
Convert bounding box points from (width,height,center) to (height,width,top,left).
All the transforms available for data augmentation in computer vision are defined in the vision.transform module. When we want to apply them to an Image
, we use this method:
show_doc(Image.apply_tfms, arg_comments={
'tfms': '`Transform` or list of `Transform`',
'do_resolve': 'if False, the values of random parameters are kept from the last draw',
'xtra': 'extra arguments to pass to the transforms',
'size': 'desired target size',
'mult': 'makes sure the final size is a multiple of mult',
'resize_method': 'how to get to the final size (crop, pad, squish)',
'padding_mode': "how to pad the image ('zeros', 'border', 'reflection')"
})
apply_tfms
[source]
apply_tfms
(tfms
:Union
[Callable
,Collection
[Callable
]],do_resolve
:bool
=True
,xtra
:Optional
[Dict
[Callable
,dict
]]=None
,size
:Union
[int
,TensorImageSize
,NoneType
]=None
,mult
:int
=32
,resize_method
:ResizeMethod
=<ResizeMethod.CROP: 1>
,padding_mode
:str
='reflection'
,kwargs
:Any
) →Tensor
Apply all tfms
- do_resolve
: bind random args - size
, mult
used to crop/pad.
Transform
or list of Transform
Before showing examples, let's take a few moments to comment those arguments a bit more:
do_resolve
decides if we resolve the random arguments by drawing new numbers or not. The intended use is to have the tfms
applied to the input x
with do_resolve
=True, then, if the target y
needs to be applied data augmentation (if it's a segmentation mask or bounding box), apply the tfms
to y
with do_resolve
=False.mult
default value is very important to make sure your image can pass through most recent CNNs: they divide the size of the input image by 2 multiple times so both dimensions of your picture should be mutliples of at least 32. Only change the value of this parameter if you know it will be accepted by your model.Here are a few helper functions to help us load the examples we saw before.
def get_class_ex(): return open_image('imgs/cat_example.jpg')
def get_seg_ex(): return open_image('imgs/car_example.jpg'), open_mask('imgs/mask_example.png')
def get_pnt_ex():
img = open_image('imgs/face_example.jpg')
pnts = torch.load('points.pth')
return img, ImagePoints(FlowField(img.size, pnts))
def get_bb_ex():
img = open_image('imgs/car_bbox.jpg')
return img, ImageBBox.create(*img.size, [[96, 155, 270, 351]])
Now let's grab our usual bunch of transforms and see what they do.
tfms = get_transforms()
_, axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flatten():
img = get_class_ex().apply_tfms(tfms[0], get_class_ex(), size=224)
img.show(ax=ax)
Now let's check what it gives for a segmentation task. Note that, as instructed by the documentation of apply_tfms
, we first apply the transforms to the input, and then apply them to the target while adding do_resolve
=False.
tfms = get_transforms()
_, axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flatten():
img,mask = get_seg_ex()
img.apply_tfms(tfms[0], size=224)
mask.apply_tfms(tfms[0], do_resolve=False, size=224)
img.show(ax=ax, y=mask)
Internally, each transforms saves the values it randomly picked into a dictionary called resolved, which it can reuse for the target.
tfms[0][4]
RandTransform(tfm=TfmAffine (zoom), kwargs={'row_pct': (0, 1), 'col_pct': (0, 1), 'scale': (1.0, 1.1)}, p=0.75, resolved={'row_pct': 0.6607397753488364, 'col_pct': 0.17601725428112858, 'scale': 1.0604327245601588}, do_run=True, is_random=True)
For points, ImagePoints
will apply the transforms to the coordinates.
tfms = get_transforms()
_, axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flatten():
img,pnts = get_pnt_ex()
img.apply_tfms(tfms[0], size=224)
pnts.apply_tfms(tfms[0], do_resolve=False, size=224)
img.show(ax=ax, y=pnts)
Now for the bounding box, the ImageBBox
will automatically update the coordinates of the two opposite corners in its data attribute.
tfms = get_transforms()
_, axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flatten():
img,bbox = get_bb_ex()
img.apply_tfms(tfms[0], size=224)
bbox.apply_tfms(tfms[0], do_resolve=False, size=224)
img.show(ax=ax, y=bbox)
As explained in the transform module, to indicate to a Transform
how to randomize an argument, we use a type annotation by a random function. Here is the list of the available functions.
show_doc(rand_bool)
rand_bool
[source]
rand_bool
(p
:float
,size
:Optional
[List
[int
]]=None
) →BoolOrTensor
Draw 1 or shape=size
random booleans (True occuring probability p
).
rand_bool(0.5, 8)
tensor([1, 0, 1, 0, 1, 0, 0, 0], dtype=torch.uint8)
show_doc(uniform)
uniform
[source]
uniform
(low
:Number
,high
:Number
=None
,size
:Optional
[List
[int
]]=None
) →FloatOrTensor
Draw 1 or shape=size
random floats from uniform dist: min=low
, max=high
.
uniform(0,1,(8,))
tensor([0.7970, 0.1545, 0.7247, 0.5526, 0.7966, 0.4069, 0.7329, 0.4941])
show_doc(uniform_int)
uniform_int
[source]
uniform_int
(low
:int
,high
:int
,size
:Optional
[List
[int
]]=None
) →IntOrTensor
Generate int or tensor size
of ints between low
and high
(included).
uniform_int(0,2,(8,))
tensor([0, 1, 1, 0, 2, 2, 2, 2])
show_doc(log_uniform, doc_string=False)
log_uniform
[source]
log_uniform
(low
,high
,size
:Optional
[List
[int
]]=None
) →FloatOrTensor
Picks a random number (or tensor size
) between log(low
) and log(high
), then returns its exponential (so that it's between low
and high
in the end).
log_uniform(0.5,2,(8,))
tensor([1.3962, 0.5576, 1.2281, 1.6216, 1.5609, 1.5811, 1.3120, 0.7624])
Typically, a data augmentation operation will randomly modify an image input. This operation can apply to pixels (when we modify the contrast or brightness for instance) or to coordinates (when we do a rotation, a zoom or a resize). The operations that apply to pixels can easily be coded in numpy/pytorch, directly on an array/tensor but the ones that modify the coordinates are a bit more tricky.
They usually come in three steps: first we create a grid of coordinates for our picture: this is an array of size h * w * 2
(h
for height, w
for width in the rest of this post) that contains in position i,j two floats representing the position of the pixel (i,j) in the picture. They could simply be the integers i and j, but since most transformations are centered with the center of the picture as origin, they’re usually rescaled to go from -1 to 1, (-1,-1) being the top left corner of the picture, (1,1) the bottom right corner (and (0,0) the center), and this can be seen as a regular grid of size h * w. Here is a what our grid would look like for a 5px by 5px image.
Then, we apply the transformation to modify this grid of coordinates. For instance, if we want to apply an affine transformation (like a rotation) we will transform each of those vectors x
of size 2 by A @ x + b
at every position in the grid. This will give us the new coordinates, as seen here in the case of our previous grid.
There are two problems that arise after the transformation: the first one is that the pixel values won’t fall exactly on the grid, and the other is that we can get values that get out of the grid (one of the coordinates is greater than 1 or lower than -1).
To solve the first problem, we use an interpolation. If we forget the rescale for a minute and go back to coordinates being integers, the result of our transformation gives us float coordinates, and we need to decide, for each (i,j), which pixel value in the original picture we need to take. The most basic interpolation called nearest neighbor would just round the floats and take the nearest integers. If we think in terms of the grid of coordinates (going from -1 to 1), the result of our transformation gives a point that isn’t in the grid, and we replace it by its nearest neighbor in the grid.
To be smarter, we can perform a bilinear interpolation. This takes an average of the values of the pixels corresponding to the four points in the grid surrounding the result of our transformation, with weights depending on how close we are to each of those points. This comes at a computational cost though, so this is where we have to be careful.
As for the values that go out of the picture, we treat them by padding it either:
Usually, data augmentation libraries have separated the different operations. So for a resize, we’ll go through the three steps above, then if we do a random rotation, we’ll go again to do those steps, then for a zoom etc... The fastai library works differently in the sense that it will do all the transformations on the coordinates at the same time, so that we only do those three steps once, especially the last one (the interpolation) is the most heavy in computation.
The first thing is that we can regroup all affine transforms in just one pass (because an affine transform composed by an affine transform is another affine transform). This is already done in some other libraries but we pushed it one step further. We integrated the resize, the crop and any non-affine transformation of the coordinates in the same process. Let’s dig in!
new_h, new_w
(and not h, w
). This takes care of the resize operation.Note that the transforms operating on pixels are applied in two phases:
This is why all transforms have an attribute (such as TfmAffine
, TfmCoord
, TfmCrop
or TfmPixel
) so that the fastai library can regroup them and apply them all together at the right step. In terms of implementation:
_affine_grid
is reponsible for creating the grid of coordinates_affine_mult
is in charge of doing the affine multiplication on that grid_grid_sample
is the function that is responsible for the interpolation stepTODO: add a comparison of speeds.
Adding a new transformation doesn't impact performance much (since the costly steps are done only once). In contrast with other libraries with classic data augmentation implementations, augmentation usually result in a longer training time.
In terms of final result, doing only one interpolation also gives a better result. If we stack several transforms and do an interpolation on each one, we approximate the true value of our coordinates in some way. This tends to blur the image a bit, which often negatively affects performance. By regrouping all the transformations together and only doing this step at the end, the image is often less blurry and the model often performs better.
See how the same rotation then zoom done separately (so there are two interpolations):
is blurrier than regrouping the transforms and doing just one interpolation:
show_doc(ResizeMethod, doc_string=False)
Enum
= [CROP, PAD, SQUISH, NO]
Resize methods to transform an image to a given size:
The basic class that defines transformation in the fastai library is Transform
.
show_doc(Transform, title_level=3,
alt_doc_string="Create a `Transform` for `func` and assign it a priority `order`.")
show_doc(RandTransform, title_level=3, doc_string=False)
Create a Transform
from func that can be randomized. Each argument of func
in kwargs is analyzed and if it has a type annotaiton that is a random function, this function will be called to pick a value for it. This value will be stored in the resolved
dictionary. Following the same idea, p
is the probability for func to be called and do_run
will be set to True if it was the cause, False otherwise. Lastly, setting is_random
to False allows to send specific values for each parameter.
show_doc(RandTransform.resolve)
show_doc(TfmAffine, title_level=3, doc_string=False)
Decorate func
to make it an affine transform; func
should return the 3 by 3 matrix representing the transform. The default order
is 5 for such transforms.
show_doc(TfmCoord, title_level=3, doc_string=False)
Decorate func
to make it a coord transform; func
should take two mandatory arguments: c
(the flow of coordinate) and img_size
(the size of the corresponding image) and return the modified flow of coordinates. The default order
is 4 for such transforms.
show_doc(TfmLighting, title_level=3, doc_string=False)
Decorate func
to make it a lighting transform; func
takes the logits of the pixel tensor and changes them. The default order
is 8 for such transforms.
show_doc(TfmPixel, title_level=3, doc_string=False)
Decorate func
to make it a pixel transform; func
takes the pixel tensor and modifies it. The default order
is 10 for such transforms.
show_doc(TfmCrop, title_level=3, doc_string=False)
Decorate func
to make it a crop transform; This is a special case of TfmPixel
with order
set to 99.
To help with the conversion to logits for the TfmLighting
, we use these helper functions:
show_doc(logit)
Take the element-wise logit of x
. Logit is the invert function of the sigmoid, defined by log(x/(1-x)).
show_doc(logit_)
In-place version of logit
.
All the Image
classes have the same internal functions that deal with data augmentation.
show_doc(Image.affine, doc_string=False)
affine
[source]
affine
(func
:AffineFunc
,args
,kwargs
) →Image
Apply the affine transform given by func
to the object.
show_doc(Image.clone)
show_doc(Image.coord, doc_string=False)
Apply the coord transform given by func
to the object.
show_doc(Image.lighting, doc_string=False)
lighting
[source]
lighting
(func
:LightingFunc
,args
:Any
,kwargs
:Any
)
Apply the lighting transform given by func
to the object.
show_doc(Image.pixel, doc_string=False)
pixel
[source]
pixel
(func
:LightingFunc
,args
,kwargs
) →Image
Apply the pixel transform given by func to the object.
show_doc(Image.refresh)
show_doc(Image.resize)
resize
[source]
resize
(size
:Union
[int
,TensorImageSize
]) →Image
Resize the image to size
, size can be a single int.
show_doc(Image.save)
show_doc(Image.show, full_name='show')
show
[source]
show
(ax
:Axes
=None
,figsize
:tuple
=(3, 3)
,title
:Optional
[str
]=None
,hide_axis
:bool
=True
,cmap
:str
='viridis'
,y
:Any
=None
,kwargs
)
Show the image on ax
with figsize
, optional title
, hide_axis
will hide the axis, cmap
is used and if y
is passed, it is showed on the same ax
.
show_doc(Image.show_batch)
Show the images in idxs
from ds
on a few rows
with figsize
.
show_doc(Image.show_results)
show_results
[source]
show_results
(xys
,preds
,figsize
:Tuple
[int
,int
]=None
)
Show the results in xys
from preds
with figsize
.
show_doc(FlowField, title_level=3)
class
FlowField
[source]
FlowField
(size
:Tuple
[int
,int
],flow
:Tensor
)
Wrap together some coords flow
with a size
.
show_doc(Image.crop_pad)
[source]
(
x
,args
,kwargs
)
show_doc(Image.contrast)
[source]
(
x
,args
,kwargs
)
show_doc(Image.brightness)
[source]
(
x
,args
,kwargs
)
show_doc(Image.flip_lr)
[source]
(
x
,args
,kwargs
)
show_doc(Image.pad)
[source]
(
x
,args
,kwargs
)
show_doc(Image.pixel)
pixel
[source]
pixel
(func
:LightingFunc
,args
,kwargs
) →Image
Equivalent to image.px = func(image.px)
.
show_doc(Image.zoom)
[source]
(
x
,args
,kwargs
)
show_doc(Image.dihedral)
[source]
(
x
,args
,kwargs
)
show_doc(ImageSegment.refresh)
show_doc(Image.jitter)
[source]
(
x
,args
,kwargs
)
show_doc(Image.squish)
[source]
(
x
,args
,kwargs
)
show_doc(Image.skew)
[source]
(
x
,args
,kwargs
)
show_doc(Image.perspective_warp)
[source]
(
x
,args
,kwargs
)
show_doc(Image.zoom_squish)
[source]
(
x
,args
,kwargs
)
show_doc(Image.crop)
[source]
(
x
,args
,kwargs
)
show_doc(Image.tilt)
[source]
(
x
,args
,kwargs
)
show_doc(Image.rotate)
[source]
(
x
,args
,kwargs
)
show_doc(ImageSegment.lighting)
lighting
[source]
lighting
(func
:LightingFunc
,args
:Any
,kwargs
:Any
) →Image
Equivalent to image = sigmoid(func(logit(image)))
.
show_doc(Image.symmetric_warp)
[source]
(
x
,args
,kwargs
)
show_doc(Image.dihedral_affine)
[source]
(
x
,args
,kwargs
)
show_doc(ImagePoints.pixel)
pixel
[source]
pixel
(func
:LightingFunc
,args
,kwargs
) →ImagePoints
Equivalent to self = func_flow(self)
.
show_doc(ImageBBox.clone)
show_doc(ImagePoints.refresh)
show_doc(ImagePoints.coord)
show_doc(Image.set_sample)
set_sample
[source]
set_sample
(kwargs
) →ImageBase
Set parameters that control how we grid_sample
the image after transforms are applied.
show_doc(ImageSegment.show)
show
[source]
show
(ax
:Axes
=None
,figsize
:tuple
=(3, 3)
,title
:Optional
[str
]=None
,hide_axis
:bool
=True
,cmap
:str
='tab20'
,alpha
:float
=0.5
,kwargs
)
show_doc(ImagePoints.show)
show
[source]
show
(ax
:Axes
=None
,figsize
:tuple
=(3, 3)
,title
:Optional
[str
]=None
,hide_axis
:bool
=True
,kwargs
)
show_doc(ImagePoints.clone)
show_doc(ImagePoints.lighting)
lighting
[source]
lighting
(func
:LightingFunc
,args
:Any
,kwargs
:Any
) →ImagePoints
Equivalent to image = sigmoid(func(logit(image)))
.
show_doc(Transform.calc)
show_doc(Image.flip_affine)
[source]
(
x
,args
,kwargs
)
show_doc(ImageBBox.show)
show_doc(ImagePoints.resize)
resize
[source]
resize
(size
:Union
[int
,TensorImageSize
]) →ImagePoints
Resize the image to size
, size can be a single int.
show_doc(ImagePoints.reconstruct_output)
reconstruct_output
[source]
reconstruct_output
(out
,x
)
show_doc(ImageSegment.reconstruct_output)
reconstruct_output
[source]
reconstruct_output
(out
,x
)